Lora训练的参数和性能

复制代码
主要为了测试模型增加Lora模块后,参数量和训练速度的变化情况。
结论:正常情况下,增加Lora模块是会增加参数量的,因此前向传播和反向传播的时间也会增加。
但是,在大语言模型训练的情况下,因为基础模型本身参数量非常大,Lora模块增加的参数量相对非常小。并且,基础模型不参与梯度更新,可以做模型量化,实际上是能减少模型训练时间和显存使用量的。
以下是实验脚本和运行结果:
复制代码
#部分参考https://zhuanlan.zhihu.com/p/666000885
import time
import torch
from torch import nn
from peft import LoraConfig, get_peft_model, PeftModel
from torchsummary import summary


x_train = torch.randn((1000, 10))
y_train = torch.randn((1000, 1))

net = nn.Sequential(
    nn.Linear(10,20),
    nn.Sigmoid(),
    nn.Linear(20,30),
    nn.Sigmoid(),
    nn.Linear(30,1)
)
summary(net, (1,10))

config = LoraConfig(target_modules=["0"], r=2)
model = get_peft_model(net, config)
criterion = torch.nn.MSELoss(reduction='mean')            # 定义损失函数,采用均方误差
optimizer = torch.optim.Adam(model.parameters(), lr=0.3)  # 定义优化器,采用Adam
summary(model, (1,10))


# base 前向计算时间
start = time.time()
for i in range(100000):
    y_pre = net(x_train)            # 前向传播
print("base 前向计算时间: ", time.time() - start)

# lora 前向计算时间
start = time.time()
for i in range(100000):
    y_pre = model(x_train)            # 前向传播
print("lora 前向计算时间", time.time() - start)

# base 反向传播时间
start = time.time()
for i in range(1000):
    y_pre = net(x_train)            # 前向传播
    loss = criterion(y_pre, y_train)  # 计算损失
    optimizer.zero_grad()             # 梯度清零
    loss.backward()                   # 反向传播
    optimizer.step()                  # 使用优化器更新梯度
print("base loss after training: ", loss.item())
print("base 反向计算时间", time.time() - start)

# lora 反向传播时间
start = time.time()
for i in range(1000):
    y_pre = model(x_train)            # 前向传播
    loss = criterion(y_pre, y_train)  # 计算损失
    optimizer.zero_grad()             # 梯度清零
    loss.backward()                   # 反向传播
    optimizer.step()                  # 使用优化器更新梯度
print("lora loss after training: ", loss.item())
print("lora 反向计算时间", time.time() - start)

运行代码输出:

复制代码
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                [-1, 1, 20]             220
           Sigmoid-2                [-1, 1, 20]               0
            Linear-3                [-1, 1, 30]             630
           Sigmoid-4                [-1, 1, 30]               0
            Linear-5                 [-1, 1, 1]              31
================================================================
Total params: 881
Trainable params: 881
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                [-1, 1, 20]             220
          Identity-2                [-1, 1, 10]               0
            Linear-3                 [-1, 1, 2]              20
            Linear-4                [-1, 1, 20]              40
            Linear-5                [-1, 1, 20]             220
           Sigmoid-6                [-1, 1, 20]               0
            Linear-7                [-1, 1, 30]             630
           Sigmoid-8                [-1, 1, 30]               0
            Linear-9                 [-1, 1, 1]              31
================================================================
Total params: 1,161
Trainable params: 60
Non-trainable params: 1,101
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.01
----------------------------------------------------------------
base loss after training:  1.0724023580551147
base 反向计算时间 2.9570980072021484
lora loss after training:  1.0643658638000488
lora 反向计算时间 3.053032159805298
相关推荐
小oo呆40 分钟前
【学习心得】Jupyter 如何在conda的base环境中其他虚拟环境内核
python·jupyter·conda
小白学大数据2 小时前
Scrapy框架下地图爬虫的进度监控与优化策略
开发语言·爬虫·python·scrapy·数据分析
浊酒南街2 小时前
TensorFlow之微分求导
人工智能·python·tensorflow
立秋67892 小时前
用Python绘制梦幻星空
开发语言·python·pygame
alpszero2 小时前
YOLO11解决方案之对象裁剪探索
人工智能·python·计算机视觉·yolo11
白云千载尽2 小时前
相机、雷达标定工具,以及雷达自动标定的思路
python·自动驾驶·ros
咕噜咕噜啦啦3 小时前
python爬虫实战训练
爬虫·python
盛夏绽放3 小时前
Python字符串常用内置函数详解
服务器·开发语言·python
我想睡觉2613 小时前
Python训练营打卡DAY27
开发语言·python·机器学习
蹦蹦跳跳真可爱5893 小时前
Python----神经网络(基于DNN的风电功率预测)
人工智能·pytorch·python·深度学习·神经网络·dnn