模型初始化:加载分词器和模型

在模型训练中,一般会定义一个函数 init_model,用于初始化模型和分词器。它加载了一个预训练的分词器,初始化了一个自定义的 MiniMindLM 模型,并将其移动到指定的设备上(CPU 或 GPU)。最后,它还会记录模型的总参数量。

先看代码:

python 复制代码
def init_model(lm_config):
    tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') # 使用 AutoTokenizer 对象上的原生方法 from_pertrained 加载分词器,这一步开发人员不需要做额外的事情
    model = MiniMindLM(lm_config).to(args.device) # 这一步是获取我们的模型,并设置训练模型的设备信息
    
    Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
    
    return  model, tokenizer

逐行详细解释

python 复制代码
def init_model(lm_config):
  • 定义函数 :这行代码定义了一个名为 init_model 的函数,它接受一个参数 lm_config,这个参数应该是包含模型配置信息的对象。
python 复制代码
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') 
  • 加载分词器 :这行代码使用 AutoTokenizer 类的 from_pretrained 方法加载了一个预训练的分词器。AutoTokenizer 是 Hugging Face Transformers 库中的一个通用分词器类,它可以根据提供的模型名称或路径自动选择并加载相应的分词器实现。
  • ./model/minimind_tokenizer 是一个路径,表示分词器保存的本地目录。AutoTokenizer.from_pretrained 会在这个目录中查找分词器的相关文件,并加载它们以创建一个分词器实例。
python 复制代码
model = MiniMindLM(lm_config).to(args.device)
  • 初始化模型 :这行代码创建了一个 MiniMindLM 模型实例。MiniMindLM 是一个自定义的模型类,lm_config 是传递给模型构造函数的配置参数,用于定义模型的结构和超参数。
  • 移动模型到设备.to(args.device) 将模型移动到指定的设备上,args.device 通常是一个字符串,表示使用哪个设备进行训练(如 'cuda' 表示 GPU,'cpu' 表示 CPU)。这一步是必要的,因为它确保了模型的参数和计算都在正确的设备上进行。
python 复制代码
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
  • 记录模型参数量 :这行代码使用 Logger 对象记录模型的总参数量。sum(p.numel() for p in model.parameters() if p.requires_grad) 计算了模型中所有可训练参数的总数。p.numel() 返回张量 p 中的元素数量,model.parameters() 返回模型中所有的参数张量。if p.requires_grad 确保只计算需要梯度的参数(即可训练的参数)。
  • / 1e6 将参数数量转换为以百万为单位,:.3f 格式化输出,保留三位小数。这个信息对于了解模型的复杂度和资源占用非常有用。
python 复制代码
return model, tokenizer
  • 返回模型和分词器:这行代码返回初始化好的模型和分词器,以便在其他地方使用。

总结

初始化模型的过程可以概括为以下几个步骤:

  1. 加载预训练分词器:从指定的本地路径加载一个预训练的分词器,用于处理文本数据。

  2. 初始化模型 :根据提供的配置创建一个自定义的 MiniMindLM 模型实例,并将其移动到指定的计算设备上(CPU 或 GPU)。

  3. 记录模型参数量:计算并记录模型中可训练参数的总数,以帮助开发人员了解模型的规模和复杂度。

  4. 返回模型和分词器:将初始化好的模型和分词器返回给调用者,以便在后续的训练或推理过程中使用。

相关推荐
Dream it possible!22 分钟前
LeetCode 面试经典 150_二叉树_二叉树展开为链表(74_114_C++_中等)
c++·leetcode·链表·面试·二叉树
牛客企业服务1 小时前
2025年AI面试防作弊指南:技术笔试如何识别异常行为
人工智能·面试·职场和发展
TT哇2 小时前
【面经 每日一题】面试题16.25.LRU缓存(medium)
java·算法·缓存·面试
9号达人5 小时前
接口设计中的扩展与组合:一次Code Review引发的思考
java·后端·面试
xhxxx6 小时前
《大厂面试:从手写 Ajax 到封装 getJSON,再到理解 Promise 与 sleep》
ajax·面试
yoke菜籽8 小时前
面试150——二叉树
面试·职场和发展
程序员小寒8 小时前
前端高频面试题之Vuex篇
前端·javascript·面试
程序员爱钓鱼10 小时前
Python 编程实战 · 实用工具与库 — Django 项目结构简介
后端·python·面试
许强0xq19 小时前
Q3: create 和 create2 有什么区别?
面试·web3·区块链·智能合约·solidity·dapp·evm
han_19 小时前
前端高频面试题之Vuex篇
前端·vue.js·面试