visionTransformer window平台下报错

  • 错误:
clike 复制代码
KeyError: 'Transformer/encoderblock_0/MlpBlock_3/Dense_0kernel is not a file in the archive'
  • 解决方法:

修改这个函数即可,主要原因是Linux系统与window系统路径分隔符不一样导致

clike 复制代码
    def load_from(self, weights, n_block):
        ROOT = f"Transformer/encoderblock_{n_block}"
        with torch.no_grad():
            # query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            # key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            # value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            # out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            query_weight = np2th(weights[(ROOT + '/' + ATTENTION_Q + "/kernel")]).view(self.hidden_size,self.hidden_size).t()
            key_weight = np2th(weights[(ROOT + '/' + ATTENTION_K + "/kernel")]).view(self.hidden_size,self.hidden_size).t()
            value_weight = np2th(weights[(ROOT + '/' + ATTENTION_V + "/kernel")]).view(self.hidden_size,self.hidden_size).t()
            out_weight = np2th(weights[(ROOT + '/' + ATTENTION_OUT + "/kernel")]).view(self.hidden_size,self.hidden_size).t()

            # query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
            # key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
            # value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
            # out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
            query_bias = np2th(weights[(ROOT + '/' + ATTENTION_Q + "/bias")]).view(-1)
            key_bias = np2th(weights[(ROOT + '/' + ATTENTION_K + "/bias")]).view(-1)
            value_bias = np2th(weights[(ROOT + '/' + ATTENTION_V + "/bias")]).view(-1)
            out_bias = np2th(weights[(ROOT + '/' + ATTENTION_OUT + "/bias")]).view(-1)

            self.attn.query.weight.copy_(query_weight)
            self.attn.key.weight.copy_(key_weight)
            self.attn.value.weight.copy_(value_weight)
            self.attn.out.weight.copy_(out_weight)
            self.attn.query.bias.copy_(query_bias)
            self.attn.key.bias.copy_(key_bias)
            self.attn.value.bias.copy_(value_bias)
            self.attn.out.bias.copy_(out_bias)

            mlp_weight_0 = np2th(weights[(ROOT + '/' + FC_0 + "/kernel")]).t()
            mlp_weight_1 = np2th(weights[(ROOT + '/' + FC_1 + "/kernel")]).t()
            mlp_bias_0 = np2th(weights[(ROOT + '/' + FC_0 +"/bias")]).t()
            mlp_bias_1 = np2th(weights[(ROOT + '/' + FC_1 + "/bias")]).t()

            self.ffn.fc1.weight.copy_(mlp_weight_0)
            self.ffn.fc2.weight.copy_(mlp_weight_1)
            self.ffn.fc1.bias.copy_(mlp_bias_0)
            self.ffn.fc2.bias.copy_(mlp_bias_1)

            self.attention_norm.weight.copy_(np2th(weights[(ROOT + '/' + ATTENTION_NORM + "/scale")]))
            self.attention_norm.bias.copy_(np2th(weights[(ROOT + '/' + ATTENTION_NORM +  "/bias")]))
            self.ffn_norm.weight.copy_(np2th(weights[(ROOT + '/' + MLP_NORM + "/scale")]))
            self.ffn_norm.bias.copy_(np2th(weights[(ROOT + '/' +  MLP_NORM + "/bias")]))
相关推荐
xingshanchang2 分钟前
PyTorch 不支持旧GPU的异常状态与解决方案:CUDNN_STATUS_NOT_SUPPORTED_ARCH_MISMATCH
人工智能·pytorch·python
reddingtons1 小时前
Adobe Firefly AI驱动设计:实用技巧与创新思维路径
大数据·人工智能·adobe·illustrator·photoshop·premiere·indesign
CertiK1 小时前
IBW 2025: CertiK首席商务官出席,探讨AI与Web3融合带来的安全挑战
人工智能·安全·web3
Deepoch2 小时前
Deepoc 大模型在无人机行业应用效果的方法
人工智能·科技·ai·语言模型·无人机
Deepoch2 小时前
Deepoc 大模型:无人机行业的智能变革引擎
人工智能·科技·算法·ai·动态规划·无人机
kngines2 小时前
【字节跳动】数据挖掘面试题0003:有一个文件,每一行是一个数字,如何用 MapReduce 进行排序和求每个用户每个页面停留时间
人工智能·数据挖掘·mapreduce·面试题
Binary_ey2 小时前
AR衍射光波导设计遇瓶颈,OAS 光学软件来破局
人工智能·软件需求·光学软件·光波导
昵称是6硬币3 小时前
YOLOv11: AN OVERVIEW OF THE KEY ARCHITECTURAL ENHANCEMENTS目标检测论文精读(逐段解析)
图像处理·人工智能·深度学习·yolo·目标检测·计算机视觉
平和男人杨争争3 小时前
机器学习2——贝叶斯理论下
人工智能·机器学习
静心问道3 小时前
XLSR-Wav2Vec2:用于语音识别的无监督跨语言表示学习
人工智能·学习·语音识别