如何使用 AutoModel 保存与加载自定义模型

介绍

在原始BERT模型的基础上额外添加一些层,并把新架构的模型保存到本地。

然后使用 AutoModel 加载模型,这样更方便一点。我们不需要在本地一直保存这个模型的自定义的网络结构的 python文件。

custom.py

要保证 model 的python文件足够的干净,只有模型架构的代码。不然容易报错,导致模型权重保存到本地失败。

python 复制代码
from torch import nn
from transformers import BertModel, BertPreTrainedModel


class CustomBERTModel(BertPreTrainedModel):
    def __init__(self, config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.bert = BertModel(config)
        self.linear = nn.Linear(config.hidden_size, config.hidden_size)
        self.post_init()

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        logits = outputs[0]
        new_output = self.linear(logits)
        return new_output

# # 注册自定义模型到 AutoModel
# MODEL_MAPPING.register(AutoConfig, CustomBERTModel)

save_model.py

修改 linear 的权重,再把模型参数保存到本地。

python 复制代码
import torch
from torch import nn
from transformers import AutoTokenizer

from custom_BERT_model import CustomBERTModel

CustomBERTModel.register_for_auto_class("AutoModel")
model_name = "google-bert/bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = CustomBERTModel.from_pretrained(model_name)
config = model.config
model.linear.weight = nn.Parameter(torch.zeros(config.hidden_size, config.hidden_size))
model.linear.bias = nn.Parameter(torch.ones(config.hidden_size))
model.save_pretrained("custom_bert_model")
tokenizer.save_pretrained("custom_bert_model")

保存的文件如下:

复制代码
.
├── config.json
├── custom_BERT_model.py
├── model.safetensors
├── special_tokens_map.json
├── tokenizer_config.json
├── tokenizer.json
└── vocab.txt

上述代码注意
CustomBERTModel.register_for_auto_class("AutoModel")。只有写了这行代码,保存的文件里面才会有custom_BERT_model.py

模型加载

【注意】:
AutoModel.from_pretrained(save_directory, trust_remote_code=True),trust_remote_code 设置为True才能加载自定义模型结构的模型,否则加载的就是原始的BERT的结构。

python 复制代码
from transformers import AutoModel, AutoTokenizer
save_directory = "custom_bert_model"
tokenizer = AutoTokenizer.from_pretrained(save_directory)
loaded_model = AutoModel.from_pretrained(save_directory, trust_remote_code=True)

text = "Hello World!"
tokenized_text = tokenizer(text, return_tensors="pt")
print(tokenized_text)
print(loaded_model(**tokenized_text))

输出:

复制代码
tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]], grad_fn=<ViewBackward0>)

无论输入的文本是什么,输出都是全1的矩阵,这是因为在前面我们修改linear层的参数。

相关推荐
Alter12302 分钟前
从“力大砖飞”到“拟态共生”,新华三定义AI基础设施的系统级进化
大数据·运维·人工智能
哔哩哔哩技术14 分钟前
bili-fe-workflow —商业化智能开发工作流实践
人工智能
王木风16 分钟前
终端里的编程副驾:DeepSeek-TUI-项目深度拆解,实测与原理分析
linux·运维·人工智能·rust·node.js
IT_陈寒16 分钟前
为什么你应该学习JavaScript?
前端·人工智能·后端
Java技术小馆25 分钟前
我用 30 分钟构建了 100% 数据主权的私有化健康库
人工智能
tq108641 分钟前
认知连续性与组织墙的崩塌:AI原生时代的架构重构
人工智能·架构
Phodal43 分钟前
AI 解决繁杂任务:从 /goal 到长时间异步 Agent 运行
人工智能
tedcloud1231 小时前
ppt-master部署教程:快速搭建智能演示文稿系统
服务器·人工智能·系统架构·游戏引擎·powerpoint
闵孚龙1 小时前
Claude Code 工具提示词全拆解:AI Agent、Prompt Engineering、工具调用、上下文工程、自动化编程的底层逻辑
人工智能·自动化·prompt
白鲸开源1 小时前
杀疯了!SeaTunnel AI CLI 解锁数据集成新玩法
大数据·人工智能·github