Lora训练的参数和性能

复制代码
主要为了测试模型增加Lora模块后,参数量和训练速度的变化情况。
结论:正常情况下,增加Lora模块是会增加参数量的,因此前向传播和反向传播的时间也会增加。
但是,在大语言模型训练的情况下,因为基础模型本身参数量非常大,Lora模块增加的参数量相对非常小。并且,基础模型不参与梯度更新,可以做模型量化,实际上是能减少模型训练时间和显存使用量的。
以下是实验脚本和运行结果:
复制代码
#部分参考https://zhuanlan.zhihu.com/p/666000885
import time
import torch
from torch import nn
from peft import LoraConfig, get_peft_model, PeftModel
from torchsummary import summary


x_train = torch.randn((1000, 10))
y_train = torch.randn((1000, 1))

net = nn.Sequential(
    nn.Linear(10,20),
    nn.Sigmoid(),
    nn.Linear(20,30),
    nn.Sigmoid(),
    nn.Linear(30,1)
)
summary(net, (1,10))

config = LoraConfig(target_modules=["0"], r=2)
model = get_peft_model(net, config)
criterion = torch.nn.MSELoss(reduction='mean')            # 定义损失函数,采用均方误差
optimizer = torch.optim.Adam(model.parameters(), lr=0.3)  # 定义优化器,采用Adam
summary(model, (1,10))


# base 前向计算时间
start = time.time()
for i in range(100000):
    y_pre = net(x_train)            # 前向传播
print("base 前向计算时间: ", time.time() - start)

# lora 前向计算时间
start = time.time()
for i in range(100000):
    y_pre = model(x_train)            # 前向传播
print("lora 前向计算时间", time.time() - start)

# base 反向传播时间
start = time.time()
for i in range(1000):
    y_pre = net(x_train)            # 前向传播
    loss = criterion(y_pre, y_train)  # 计算损失
    optimizer.zero_grad()             # 梯度清零
    loss.backward()                   # 反向传播
    optimizer.step()                  # 使用优化器更新梯度
print("base loss after training: ", loss.item())
print("base 反向计算时间", time.time() - start)

# lora 反向传播时间
start = time.time()
for i in range(1000):
    y_pre = model(x_train)            # 前向传播
    loss = criterion(y_pre, y_train)  # 计算损失
    optimizer.zero_grad()             # 梯度清零
    loss.backward()                   # 反向传播
    optimizer.step()                  # 使用优化器更新梯度
print("lora loss after training: ", loss.item())
print("lora 反向计算时间", time.time() - start)

运行代码输出:

复制代码
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                [-1, 1, 20]             220
           Sigmoid-2                [-1, 1, 20]               0
            Linear-3                [-1, 1, 30]             630
           Sigmoid-4                [-1, 1, 30]               0
            Linear-5                 [-1, 1, 1]              31
================================================================
Total params: 881
Trainable params: 881
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                [-1, 1, 20]             220
          Identity-2                [-1, 1, 10]               0
            Linear-3                 [-1, 1, 2]              20
            Linear-4                [-1, 1, 20]              40
            Linear-5                [-1, 1, 20]             220
           Sigmoid-6                [-1, 1, 20]               0
            Linear-7                [-1, 1, 30]             630
           Sigmoid-8                [-1, 1, 30]               0
            Linear-9                 [-1, 1, 1]              31
================================================================
Total params: 1,161
Trainable params: 60
Non-trainable params: 1,101
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.01
----------------------------------------------------------------
base loss after training:  1.0724023580551147
base 反向计算时间 2.9570980072021484
lora loss after training:  1.0643658638000488
lora 反向计算时间 3.053032159805298
相关推荐
Learn Beyond Limits4 分钟前
Transfer Learning|迁移学习
人工智能·python·深度学习·神经网络·机器学习·ai·吴恩达
love530love2 小时前
【保姆级教程】阿里 Wan2.1-T2V-14B 模型本地部署全流程:从环境配置到视频生成(附避坑指南)
人工智能·windows·python·开源·大模型·github·音视频
He1955012 小时前
Go初级之十:错误处理与程序健壮性
开发语言·python·golang
森之鸟2 小时前
寻找AI——初识3D建模AI
ai·aigc
和鲸社区3 小时前
《斯坦福CS336》作业1开源,从0手搓大模型|代码复现+免环境配置
人工智能·python·深度学习·计算机视觉·语言模型·自然语言处理·nlp
豌豆花下猫3 小时前
Python 潮流周刊#118:Python 异步为何不够流行?(摘要)
后端·python·ai
THMAIL4 小时前
深度学习从入门到精通 - LSTM与GRU深度剖析:破解长序列记忆遗忘困境
人工智能·python·深度学习·算法·机器学习·逻辑回归·lstm
wheeldown4 小时前
【数学建模】数据预处理入门:从理论到动手操作
python·数学建模·matlab·python3.11
YF云飞4 小时前
数据仓库进化:Agent驱动数智化新范式
数据仓库·人工智能·ai
多打代码4 小时前
2025.09.05 用队列实现栈 & 有效的括号 & 删除字符串中的所有相邻重复项
python·算法