大语言模型占显存的计算和优化

可以优化的地方:

per_device_train_batch_size(相当于batch size,越小显存占的越小)

gradient_accumulation_steps(per_device_train_batch_size*gradient_accumulation_steps=计算梯度的数据数)

gradient_checkpointing(前项激活值里面有很多是不需要存的,可以在反向传播再次计算的)

optim(可以改为adafactor)

冻结参数(只训练下游任务的参数)

将max_length减小

参考代码:

复制代码
train_args = TrainingArguments(output_dir="./checkpoints",      # 输出文件夹
                               per_device_train_batch_size=1,   # 训练时的batch_size
                               gradient_accumulation_steps=32,  # *** 梯度累加 ***
                               gradient_checkpointing=True,     # *** 梯度检查点 *** 前项激活值里面有很多是不需要存的,可以在反向传播再次计算的
                               optim="adafactor",               # *** adafactor优化器 *** 
                               per_device_eval_batch_size=1,    # 验证时的batch_size
                               num_train_epochs=1,              # 训练轮数
                               logging_steps=10,                # log 打印的频率
                               evaluation_strategy="epoch",     # 评估策略
                               save_strategy="epoch",           # 保存策略
                               save_total_limit=3,              # 最大保存数
                               learning_rate=2e-5,              # 学习率
                               weight_decay=0.01,               # weight_decay
                               metric_for_best_model="f1",      # 设定评估指标
                               load_best_model_at_end=True)     # 训练完成后加载最优模型

for name, param in model.bert.named_parameters():
    param.requires_grad = False

tokenized_examples = tokenizer(examples["review"], max_length=32, truncation=True, padding="max_length")
相关推荐
Yolanda94几秒前
【人工智能】《从零搭建AI问答助手项目(四):API调用实战》
人工智能
AI服务老曹2 分钟前
从GB28181接入到边缘NPU算力调度:深度解析支持异构计算的工业级AI视频管理平台架构
人工智能·架构·音视频
workflower4 分钟前
机器人应用-高空立面清洁
人工智能·深度学习·设计模式·机器人·软件工程·软件构建
唐兴通个人5 分钟前
唐兴通受邀华润医药高管培训:AI时代OTC与处方药营销逻辑全面重构数字化转型与创新思维
大数据·人工智能
SariHcr1238 分钟前
Openarm机器人双臂模型仿真从零部署
c++·人工智能·python·机器人·bash·openarm
格林威8 分钟前
面阵相机 vs 线阵相机:堡盟与Basler选型差异全解析 +C++ 实战演示
开发语言·c++·人工智能·数码相机·计算机视觉·视觉检测·工业相机
eastyuxiao13 分钟前
MMM 工具一键去水印+检测 批处理脚本(Windows/Mac 双版本)
人工智能·windows·macos·ai音乐去水印
ACCELERATOR_LLC20 分钟前
【DataWhale组队学习】DIY-LLM Task4 GPU和GPU相关的优化
人工智能·深度学习·大模型·transformer·gpu
Chase_______23 分钟前
【2026】NotebookLM 快速指南:从入门到精通的AI知识管理实战
人工智能·notebooklm