模型初始化:加载分词器和模型

在模型训练中,一般会定义一个函数 init_model,用于初始化模型和分词器。它加载了一个预训练的分词器,初始化了一个自定义的 MiniMindLM 模型,并将其移动到指定的设备上(CPU 或 GPU)。最后,它还会记录模型的总参数量。

先看代码:

python 复制代码
def init_model(lm_config):
    tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') # 使用 AutoTokenizer 对象上的原生方法 from_pertrained 加载分词器,这一步开发人员不需要做额外的事情
    model = MiniMindLM(lm_config).to(args.device) # 这一步是获取我们的模型,并设置训练模型的设备信息
    
    Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
    
    return  model, tokenizer

逐行详细解释

python 复制代码
def init_model(lm_config):
  • 定义函数 :这行代码定义了一个名为 init_model 的函数,它接受一个参数 lm_config,这个参数应该是包含模型配置信息的对象。
python 复制代码
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') 
  • 加载分词器 :这行代码使用 AutoTokenizer 类的 from_pretrained 方法加载了一个预训练的分词器。AutoTokenizer 是 Hugging Face Transformers 库中的一个通用分词器类,它可以根据提供的模型名称或路径自动选择并加载相应的分词器实现。
  • ./model/minimind_tokenizer 是一个路径,表示分词器保存的本地目录。AutoTokenizer.from_pretrained 会在这个目录中查找分词器的相关文件,并加载它们以创建一个分词器实例。
python 复制代码
model = MiniMindLM(lm_config).to(args.device)
  • 初始化模型 :这行代码创建了一个 MiniMindLM 模型实例。MiniMindLM 是一个自定义的模型类,lm_config 是传递给模型构造函数的配置参数,用于定义模型的结构和超参数。
  • 移动模型到设备.to(args.device) 将模型移动到指定的设备上,args.device 通常是一个字符串,表示使用哪个设备进行训练(如 'cuda' 表示 GPU,'cpu' 表示 CPU)。这一步是必要的,因为它确保了模型的参数和计算都在正确的设备上进行。
python 复制代码
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
  • 记录模型参数量 :这行代码使用 Logger 对象记录模型的总参数量。sum(p.numel() for p in model.parameters() if p.requires_grad) 计算了模型中所有可训练参数的总数。p.numel() 返回张量 p 中的元素数量,model.parameters() 返回模型中所有的参数张量。if p.requires_grad 确保只计算需要梯度的参数(即可训练的参数)。
  • / 1e6 将参数数量转换为以百万为单位,:.3f 格式化输出,保留三位小数。这个信息对于了解模型的复杂度和资源占用非常有用。
python 复制代码
return model, tokenizer
  • 返回模型和分词器:这行代码返回初始化好的模型和分词器,以便在其他地方使用。

总结

初始化模型的过程可以概括为以下几个步骤:

  1. 加载预训练分词器:从指定的本地路径加载一个预训练的分词器,用于处理文本数据。

  2. 初始化模型 :根据提供的配置创建一个自定义的 MiniMindLM 模型实例,并将其移动到指定的计算设备上(CPU 或 GPU)。

  3. 记录模型参数量:计算并记录模型中可训练参数的总数,以帮助开发人员了解模型的规模和复杂度。

  4. 返回模型和分词器:将初始化好的模型和分词器返回给调用者,以便在后续的训练或推理过程中使用。

相关推荐
程序员二叉3 小时前
【Java】集合面试全套精讲|HashMap/ArrayList高频考点完整版
java·面试·哈希算法
不懂数据的小白5 小时前
面试题一:【三】AB实验入门(验证)
面试
我叫黑大帅5 小时前
通过php 中的Route:: 的写法了解什么是静态类调用
后端·面试·php
Aphasia3116 小时前
从输入URL到页面展示全流程
前端·面试
2601_961845426 小时前
高考真题试卷电子版|2025高考全科试卷分类下载
考研·面试·蓝桥杯·远程工作·程序员创富·高考
我叫黑大帅6 小时前
前端如何竖屏固定视口背景
前端·javascript·面试
折哥的程序人生 · 物流技术专研6 小时前
《Java 100 天进阶之路》第95篇:消息队列基础(RocketMQ/Kafka)(2026版)
java·面试·kafka·rocketmq·java-rocketmq·求职招聘
不会敲代码17 小时前
我花了三天时间,终于把 Cookie、XSS、CSRF 和浏览器存储给整明白了
javascript·面试
swipe7 小时前
Mem0 x Agent 实战系列:分层记忆 + 三路召回,搭建真正可用的长期记忆层
前端·javascript·面试
Lee川7 小时前
Event Loop 面试通关:从原理到口述再到实战
前端·面试