程序员转行学习 AI 大模型: 踩坑记录:服务器内存不够,程序被killed

本文是程序员转行学习AI大模型的踩坑记录分享。

当前阶段:还在学习知识点,由点及面,从 0 到 1 搭建 AI 大模型知识体系中。

系列更新,关注我,后续会持续记录分享转行经历~

踩坑问题

我是在阿里云上购买了一个 99 元入门级云服务器 ECS,本来就考虑到内存比较小,所以我选的是小模型,"distilbert/distilgpt2"。但是即使是这个小模型,程序执行的时候,还是 killed,应该是环境内存不够,这台服务器内存是 2g 的。

为了后续根据环境,选择合适的模型,查看了如何计算训练需要的内存大小。

如何根据模型计算所需内存

  1. 模型参数量的内存计算

FP32(32 位浮点数)

plain 复制代码
模型内存 = 参数量 * 4字节

FP16(16 位浮点数)

plain 复制代码
模型内存 = 参数量 * 2字节

INT8(8 位整数)

plain 复制代码
模型内存 = 参数量 * 1字节

可以在代码中,使用 weight.dtype 查看使用的数据类型。INT8 需要 GPU 环境。

不同数据类型对比:

关键要点:

plain 复制代码
1. FP32 是默认 :不指定时自动使用
2. FP16 节省 25% :权重和梯度减半
3. INT8 需要GPU :CPU上训练不稳定
  1. 训练时的内存占用组件

训练时的内存由以下部分组成

复制代码
1. 模型权重(Model Weights)
plain 复制代码
模型权重内存 = 参数量 * 数据类型大小

如 82M 参数,FP32,内存 = 82,000,000 * 4 字节 = 328MB

复制代码
2. 梯度(Gradients)
plain 复制代码
梯度内存 = 参数量 * 数据类型大小

说明:

  • 每个参数都需要存储梯度
  • 梯度的大小与参数相同
    3. 优化器状态

Adam 优化器(最常用):

plain 复制代码
一阶动量(m)= 参数量 × 4 字节
二阶动量(v)= 参数量 × 4 字节
优化器总内存 = 2 × 参数量 × 4 字节 = 参数量 × 8 字节

SGD 优化器:

plain 复制代码
优化器内存 = 0(不需要额外存储)
复制代码
4. 激活值
plain 复制代码
激活值内存 = batch_size × seq_len × hidden_dim × 层数 × 数据类型大小

简化估算:

plain 复制代码
激活值内存 ≈ 模型权重内存 × (0.5 - 2.0)

影响因素:

  • batch_size:越大,激活值越大
  • seq_len:序列越长,激活值越大
  • 模型深度:层数越深,激活值越大
  • hidden_dim:隐藏层维度越大,激活值越大
    5. PyTorch 开销(PyTorch Overhead)

包括:

  • CUDA 上下文(即使不用 GPU):100-200MB
  • DataLoader 缓存:100-200MB
  • 其他临时变量:100-300MB
plain 复制代码
PyTorch 开销 ≈ 300-700 MB
  1. 总内存计算公式

训练时总内存:

plain 复制代码
总内存 = 模型权重 + 梯度 + 优化器状态 + 激活值 + PyTorch 开销

详细公式:

plain 复制代码
总内存 = 参数量 × 4 字节(FP32 权重)
       + 参数量 × 4 字节(FP32 梯度)
       + 参数量 × 8 字节(Adam 优化器)
       + 激活值内存(100-500 MB)
       + PyTorch 开销(300-700 MB)

简化公式(估算):

plain 复制代码
总内存 ≈ 参数量 × 16 字节 + 500-1200 MB

服务器内存查看

检查服务器内存:

plain 复制代码
free -h
相关推荐
qq_334563551 天前
html标签怎么表示用户输入_kbd标签键盘快捷键标注【介绍】.txt
jvm·数据库·python
小陈工1 天前
数据库Operator开发实战:以PostgreSQL为例
开发语言·数据库·人工智能·python·安全·postgresql·开源
weixin_586061461 天前
SQL报表星型模型优化_事实表索引设计
jvm·数据库·python
Zn_lunar1 天前
autodl tizi+codex cli
运维·服务器·网络
耿雨飞1 天前
Python 后端开发技术博客专栏 | 第 07 篇 元类与类的创建过程 -- Python 最深层的魔法
开发语言·python
热爱生活的五柒1 天前
度量学习-Radar Signal Deinterleaving Using Transformer Encoder and HDBSCAN 论文解析
深度学习·学习·transformer
慕涯AI1 天前
Agent 30 课程开发指南 - 第21课
人工智能·python
@insist1231 天前
网络工程师-实战配置篇(一):深入 BGP 与 VRRP,构建高可靠网络
服务器·网络·php·网络工程师·软件水平考试
源码之家1 天前
计算机毕业设计:Python城市天气数据挖掘与预测系统 Flask框架 随机森林 K-Means 可视化 数据分析 大数据 机器学习 深度学习(建议收藏)✅
人工智能·爬虫·python·深度学习·机器学习·数据挖掘·课程设计
Dxy12393102161 天前
Python在图片上画多边形:从简单轮廓到复杂区域标注
开发语言·python