解决pytorch训练的过程中内存一直增加的问题

来自:解决pytorch训练的过程中内存一直增加的问题 - 知乎

pytorch训练中内存一直增加的原因(部分)

  • 代码中存在累加loss,但每步的loss没加item()

    import torch
    import torch.nn as nn
    from collections import defaultdict

    if torch.cuda.is_available():
    device = 'cuda'
    else:
    device = 'cpu'

    model = nn.Linear(100, 400).to(device)
    criterion = nn.L1Loss(reduction='mean').to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    train_loss = defaultdict(float)
    eval_loss = defaultdict(float)

    for i in range(10000):
    model.train()
    x = torch.rand(50, 100, device=device)
    y_pred = model(x) # 50 * 400
    y_tgt = torch.rand(50, 400, device=device)

    复制代码
      loss = criterion(y_pred, y_tgt)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      
      # 会导致内存一直增加,需改为train_loss['loss'] += loss.item()
      train_loss['loss'] += loss
    
      if i % 100 == 0:
          train_loss = defaultdict(float)
          model.eval()
          x = torch.rand(50, 100, device=device)
          y_pred = model(x) # 50 * 400
    
          y_tgt = torch.rand(50, 400, device=device)
          loss = criterion(y_pred, y_tgt)
    
          # 会导致内存一直增加,需改为eval_loss['loss'] += loss.item()
          eval_loss['loss'] += loss

以上代码会导致内存占用越来越大,解决的方法是:train_l oss['loss'] += loss.item() 以及 eval_loss['loss'] += loss.item()。值得注意的是,要复现内存越来越大的问题,模型中需要切换model.train() 和 model.eval(),train_loss以及eval_loss的作用是保存模型的平均误差(这里是累积误差),保存到tensorboard中。

相关推荐
TDengine (老段)3 分钟前
人力减 60%:时序数据库 TDengine 助力桂冠电力实现 AI 智能巡检
java·大数据·数据库·人工智能·时序数据库·tdengine·涛思数据
CNRio4 分钟前
智驭天象:人工智能重塑气象科技新纪元
人工智能·科技
szxinmai主板定制专家5 分钟前
JETSON orin+FPGA+GMSL+AI协作机器人视觉感知
网络·arm开发·人工智能·嵌入式硬件·fpga开发·机器人
羊羊小栈11 分钟前
基于「YOLO姿态识别 + AI大模型分析」的智能健身辅助系统(vue+flask+AI算法)
vue.js·人工智能·yolo·毕业设计·创业创新·大作业
秋邱12 分钟前
AR 技术创新与商业化新方向:AI+AR 融合,抢占 2025 高潜力赛道
前端·人工智能·后端·python·html·restful
咚咚王14 分钟前
人工智能之数据分析 Pandas:第八章 数据可视化
人工智能·数据分析·pandas
前端九哥18 分钟前
如何让AI设计出Apple风格的顶级UI?
前端·人工智能
星诺算法备案20 分钟前
AI小程序合规指南:从上线要求到标识的“双保险”
人工智能·算法·推荐算法·备案
ar012320 分钟前
AR技术如何助力工业制造验收智能化
人工智能·ar