pytorch训练权重转化为tensorflow模型的教训

模型构建时候有时候在工程量比较大的时候,不可避免使用迭代算法,迭代算法本身会让错误的追踪更加困难,因此掌握基本的框架之间的差异非常重要。以下均是在模型转换过程中出现的错误。

shuffle operation(shuffle 操作)

这个操作原本是用来将各个通道之间的信息进行打乱后,此时面临重要的问题就是,如果将通道打乱,在pytorch里面与tensorflow中间,两种通道排序是不一样的,是采用不同的通道数据排列进行的。

python 复制代码
import tensorflow as tf

def channel_shuffle(x, groups):
    _, h, w, c = x.shape
    # c 通道进行划分
    x = tf.reshape(x, [-1, h, w, groups, c // groups])
    # 通道为基本单位的情况下,多group均采样重组
    x = tf.transpose(x, [0, 1, 2, 4, 3])  # 调整通道维度顺序

    # 混洗采样重组后再reshape变成之前的通道
    x = tf.reshape(x, [-1, h, w, c])
    return x

# 示例张量
x = tf.random.normal((2, 3, 3, 8))
print("Original tensor:\n", x.numpy())

# 进行通道混洗
shuffled_x = channel_shuffle(x, groups=2)
print("Shuffled tensor:\n", shuffled_x.numpy())

Pytorch下的GSBottleneck采用的Sequential具有差异

python 复制代码
class GSBottleneck(nn.Module):
    # GS Bottleneck https://github.com/AlanLi1997/slim-neck-by-gsconv
    def __init__(self, c1, c2, k=3, s=1):
        super().__init__()
        c_ = c2 // 2
        # for lighting
        self.conv_lighting = nn.Sequential(
            GSConv(c1, c_, 1, 1),
            GSConv(c_, c2, 1, 1, act=False))
        # for receptive field
        self.conv = nn.Sequential(
            GSConv(c1, c_, 3, 1),
            GSConv(c_, c2, 3, 1, act=False))
        self.shortcut = nn.Identity()

    def forward(self, x):
        return self.conv_lighting(x)

我遇到的坑为:

python 复制代码
class TFGSBottleneck(tf.keras.layers.Layer):
    # GS Bottleneck https://github.com/AlanLi1997/slim-neck-by-gsconv
    def __init__(self, c1, c2, k=3, s=1,w=None):
        super().__init__()
        c_ = c2 // 2
        # example
        # self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
        # self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)

        self.conv_lighting = tf.keras.Sequential(
            TFGSConv(c1, c_, 1, 1,w=w.conv_lighting[0]),
            TFGSConv(c_, c2, 1, 1, act=False, w=w.conv_lighting[1])
        )

        # for receptive field
        self.conv = tf.keras.Sequential(
            TFGSConv(c1, c_, 3, 1,w=w.conv[0]),
            TFGSConv(c_, c2, 3, 1, act=False,w=w.conv[1])
        )

        self.shortcut = tf.keras.layers.Lambda(lambda x: x)

    def call(self, x):
        print("TFGSBottleneck input: ",x.shape)
        print("TFGSBottleneck output: ", self.conv_lighting(x).shape)
        return self.conv_lighting(x)

有以下错误

python 复制代码
Traceback (most recent call last):
  File "D:\TEST\yolov5\models\tf.py", line 1078, in <module>
    main(opt)
  File "D:\TEST\yolov5\models\tf.py", line 1073, in main
    run(**vars(opt))
  File "D:\TEST\yolov5\models\tf.py", line 1044, in run
    _ = tf_model.predict(im)  # inference
        ^^^^^^^^^^^^^^^^^^^^
  File "D:\TEST\yolov5\models\tf.py", line 922, in predict
    x = m(x)  # run
        ^^^^
  File "C:\Users\Zhuliang\.conda\envs\exportyolo2\Lib\site-packages\keras\utils\traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "D:\TEST\yolov5\models\tf.py", line 383, in call
    m_x1 = self.m(x1)
           ^^^^^^^^^^
  File "D:\TEST\yolov5\models\tf.py", line 361, in call
    print("TFGSBottleneck output: ", self.conv_lighting(x).shape)
                                     ^^^^^^^^^^^^^^^^^^^^^
ValueError: Exception encountered when calling layer 'tfgs_bottleneck' (type TFGSBottleneck).

name for name_scope must be a string.

Call arguments received by layer 'tfgs_bottleneck' (type TFGSBottleneck):
  • x=tf.Tensor(shape=(1, 136, 136, 8), dtype=float32)

解决方案为

错误地将两个层作为位置参数传递给tf.keras.Sequential构造函数,导致第二个参数被误解为name参数,而name必须是一个字符串,但用户传递了一个层实例。这导致在调用层时,TensorFlow无法正确创建name_scope,从而引发错误。解决方法是把这两个层放在一个列表中,作为Sequential构造函数的第一个参数。

要解决此错误,需要将传递给tf.keras.Sequential的层放在列表中,确保它们被正确解析为层序列而不是其他参数。以下是修改后的代码:

tf.keras.Sequential与tensorflow中的pytorch

python 复制代码
class TFGSBottleneck(tf.keras.layers.Layer):
    # GS Bottleneck https://github.com/AlanLi1997/slim-neck-by-gsconv
    def __init__(self, c1, c2, k=3, s=1, w=None):
        super().__init__()
        c_ = c2 // 2

        # 使用列表包裹层以正确传递
        self.conv_lighting = tf.keras.Sequential([
            TFGSConv(c1, c_, 1, 1, w=w.conv_lighting[0]),
            TFGSConv(c_, c2, 1, 1, act=False, w=w.conv_lighting[1])
        ])

        self.conv = tf.keras.Sequential([
            TFGSConv(c1, c_, 3, 1, w=w.conv[0]),
            TFGSConv(c_, c2, 3, 1, act=False, w=w.conv[1])
        ])

        self.shortcut = tf.keras.layers.Lambda(lambda x: x)

    def call(self, x):
        return self.conv_lighting(x)
相关推荐
陈橘又青9 小时前
100% AI 写的开源项目三周多已获得 800 star 了
人工智能·后端·ai·restful·数据
松岛雾奈.2309 小时前
深度学习--TensorFlow框架使用
深度学习·tensorflow·neo4j
中杯可乐多加冰9 小时前
逻辑控制案例详解|基于smardaten实现OA一体化办公系统逻辑交互
人工智能·深度学习·低代码·oa办公·无代码·一体化平台·逻辑控制
IT_陈寒10 小时前
Redis实战:5个高频应用场景下的性能优化技巧,让你的QPS提升50%
前端·人工智能·后端
龙智DevSecOps解决方案10 小时前
Perforce《2025游戏技术现状报告》Part 1:游戏引擎技术的广泛影响以及生成式AI的成熟之路
人工智能·unity·游戏引擎·游戏开发·perforce
大佬,救命!!!10 小时前
更换适配python版本直接进行机器学习深度学习等相关环境配置(非仿真环境)
人工智能·python·深度学习·机器学习·学习笔记·详细配置
星空的资源小屋10 小时前
VNote:程序员必备Markdown笔记神器
javascript·人工智能·笔记·django
梵得儿SHI10 小时前
(第七篇)Spring AI 基础入门总结:四层技术栈全景图 + 三大坑根治方案 + RAG 进阶预告
java·人工智能·spring·springai的四大核心能力·向量维度·prompt模板化·向量存储检索
亚马逊云开发者10 小时前
Amazon Bedrock助力飞书深诺电商广告分类
人工智能
2301_8234380210 小时前
解析论文《复杂海上救援环境中无人机群的双阶段协作路径规划与任务分配》
人工智能·算法·无人机