LlamaFactory——如何使用魔改后的模型

需求来源:有时我们可能想在llamafactory框架支持的模型上进行一些改动,例如修改forward()方法等,修改方法我们可以通过继承Transformers库中相应的class并重写相应的方法即可,那我们如何使用自己的模型呢?

首先,我们需要定位模型初始化的相关代码,具体路径为:LLaMA-Factory-main/src/llamafactory/model/loader.py

python 复制代码
# 大致在169行的位置
model = load_class.from_pretrained(**init_kwargs)

上述代码实现了模型的初始化,其中load_class是OrderDict的一个子类,功能主要是根据config的类型找到对应模型class,例如Qwen2VLConfig(源码:transformers/models/qwen2_vl/configuration_qwen2_vl.py)对应Qwen2VLForConditionalGeneration(源码:transformers/models/qwen2_vl/modeling_qwen2_vl.py),本质上类似于字典,那我们只需要把相应的值替换为我们自己的模型即可,具体代码如下:

python 复制代码
load_class.register(type(config), YourCustomModelClass, exist_ok=True)
model = load_class.from_pretrained(**init_kwargs)

使用load_class的register()方法,把模型class替换为自己的模型即可,一定注意参数exist_ok要设置为True,才能覆写已有Config类对应的模型,不然会报错。

相关推荐
阡之尘埃几秒前
Python使用MD5码加密手机号等敏感信息
python·数据挖掘·数据分析·哈希算法·md5·加密算法
Ashlee_code6 分钟前
TRS收益互换平台开发实践:从需求分析到系统实现
java·数据结构·c++·python·架构·php·需求分析
编程自留地38 分钟前
第12次04 :首页展示用户名
数据库·python·django·商城
kooboo china.43 分钟前
Tailwind CSS 实战,基于 Kooboo 构建 AI 对话框页面(四):语音识别输入功能
前端·css·人工智能·ui·html·交互·语音识别
Deng9452013141 小时前
员工管理系统 (Python实现)
开发语言·python
禺垣1 小时前
循环神经网络(RNN)模型
人工智能·深度学习·机器学习·循环神经网络·序列数据
pen-ai1 小时前
【深度学习-pytorch篇】1. Pytorch矩阵操作与DataSet创建
pytorch·深度学习·矩阵
时空无限1 小时前
nvidia could not select device driver ““ with capabilities: [[gpu]]
人工智能
whaosoft-1431 小时前
51c视觉~3D~合集3
人工智能