解决pytorch训练的过程中内存一直增加的问题

来自:解决pytorch训练的过程中内存一直增加的问题 - 知乎

pytorch训练中内存一直增加的原因(部分)

  • 代码中存在累加loss,但每步的loss没加item()

    import torch
    import torch.nn as nn
    from collections import defaultdict

    if torch.cuda.is_available():
    device = 'cuda'
    else:
    device = 'cpu'

    model = nn.Linear(100, 400).to(device)
    criterion = nn.L1Loss(reduction='mean').to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    train_loss = defaultdict(float)
    eval_loss = defaultdict(float)

    for i in range(10000):
    model.train()
    x = torch.rand(50, 100, device=device)
    y_pred = model(x) # 50 * 400
    y_tgt = torch.rand(50, 400, device=device)

    复制代码
      loss = criterion(y_pred, y_tgt)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      
      # 会导致内存一直增加,需改为train_loss['loss'] += loss.item()
      train_loss['loss'] += loss
    
      if i % 100 == 0:
          train_loss = defaultdict(float)
          model.eval()
          x = torch.rand(50, 100, device=device)
          y_pred = model(x) # 50 * 400
    
          y_tgt = torch.rand(50, 400, device=device)
          loss = criterion(y_pred, y_tgt)
    
          # 会导致内存一直增加,需改为eval_loss['loss'] += loss.item()
          eval_loss['loss'] += loss

以上代码会导致内存占用越来越大,解决的方法是:train_l oss['loss'] += loss.item() 以及 eval_loss['loss'] += loss.item()。值得注意的是,要复现内存越来越大的问题,模型中需要切换model.train() 和 model.eval(),train_loss以及eval_loss的作用是保存模型的平均误差(这里是累积误差),保存到tensorboard中。

相关推荐
Ares-Wang17 分钟前
图像》》仿射变换和透视变换放 、图像分割、目标检测
人工智能·计算机视觉
艾醒(AiXing-w)18 分钟前
从Prompt到Harness:AI Agent三代工程化全解析
人工智能
空中湖21 分钟前
AI 指数级进化 · 一场跨越千年的智能之旅
人工智能
大空大地202626 分钟前
# C#基础语法总结
人工智能·计算机视觉
Volunteer Technology30 分钟前
SpringAI Chat Client (四)
人工智能·spring
城事漫游Molly34 分钟前
案例研究:如何明智地选择案例、精巧地界定边界、深刻地进行分析?
大数据·人工智能·ai写作·论文笔记
易观Analysys1 小时前
范式革命已至:OpenClaw引爆中国AI“行动时代”——《重构与崛起—OpenClaw时代的中国Agent产业生态报告》解读一
人工智能·重构
CoCo的编程之路1 小时前
2026 前端效能飞跃:深度解析智能助手的页面构建最大化方案
前端·人工智能·ai编程·智能编程助手·文心快码baiducomate
豹哥学前端1 小时前
agent智能体经典范式构建
人工智能·后端
纤纡.1 小时前
从零搭建 AI 智能 PDF 问答工具:Streamlit+LangChain + 千问大模型实战
人工智能·阿里云·语言模型·langchain