Pytorch ddp切换forward函数 验证ddp是否生效

DDP及其在pytorch中应用

ddp默认调用forward函数,有些模型无法使用forward函数,可以对模型包装一下。

python 复制代码
class modelWraper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, *args, **kwargs):
        return self.model.rlhf(*args, **kwargs)

有时ddp跑起来,不确定是否生效,loss backward后不同rank进程的梯度应该一样的,可以通过print 梯度确认。

python 复制代码
loss.backward()
grad_flag = raw_model.lm_head.weight.grad[0,:3]
print(f"grad {ddp_rank} {grad_flag}")


grad 1 tensor([2.9296e-04, 6.2223e-05, 1.0089e-03], device='cuda:1')
grad 0 tensor([2.9296e-04, 6.2223e-05, 1.0089e-03], device='cuda:0')

pytorch分布式系列2------DistributedDataParallel是如何做同步的?

相关推荐
爱笑的眼睛113 分钟前
从零构建与深度优化:PyTorch训练循环的工程化实践
java·人工智能·python·ai
古城小栈4 分钟前
Spring Boot 4.0 虚拟线程启用配置与性能测试全解析
spring boot·后端·python
c#上位机4 分钟前
halcon刚性变换(平移+旋转)——vector_angle_to_rigid
人工智能·计算机视觉·c#·上位机·halcon·机器视觉
liliangcsdn4 分钟前
如何使用pytorch模拟Pearson loss训练模型
人工智能·pytorch·python
做cv的小昊8 分钟前
VLM相关论文阅读:【LoRA】Low-rank Adaptation of Large Language Models
论文阅读·人工智能·深度学习·计算机视觉·语言模型·自然语言处理·transformer
VertGrow AI销冠9 分钟前
AI获客软件VertGrow AI销冠的自动化功能测评
人工智能
TextIn智能文档云平台9 分钟前
抽取出的JSON结构混乱,如何设计后处理规则来标准化输出?
人工智能·json
百罹鸟10 分钟前
在langchain Next 项目中使用 shadcn/ui 的记录
前端·css·人工智能
MediaTea12 分钟前
Python 的设计哲学P08:可读性与人类语言
开发语言·python