解决pytorch训练的过程中内存一直增加的问题

来自:解决pytorch训练的过程中内存一直增加的问题 - 知乎

pytorch训练中内存一直增加的原因(部分)

  • 代码中存在累加loss,但每步的loss没加item()

    import torch
    import torch.nn as nn
    from collections import defaultdict

    if torch.cuda.is_available():
    device = 'cuda'
    else:
    device = 'cpu'

    model = nn.Linear(100, 400).to(device)
    criterion = nn.L1Loss(reduction='mean').to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    train_loss = defaultdict(float)
    eval_loss = defaultdict(float)

    for i in range(10000):
    model.train()
    x = torch.rand(50, 100, device=device)
    y_pred = model(x) # 50 * 400
    y_tgt = torch.rand(50, 400, device=device)

    复制代码
      loss = criterion(y_pred, y_tgt)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      
      # 会导致内存一直增加,需改为train_loss['loss'] += loss.item()
      train_loss['loss'] += loss
    
      if i % 100 == 0:
          train_loss = defaultdict(float)
          model.eval()
          x = torch.rand(50, 100, device=device)
          y_pred = model(x) # 50 * 400
    
          y_tgt = torch.rand(50, 400, device=device)
          loss = criterion(y_pred, y_tgt)
    
          # 会导致内存一直增加,需改为eval_loss['loss'] += loss.item()
          eval_loss['loss'] += loss

以上代码会导致内存占用越来越大,解决的方法是:train_l oss['loss'] += loss.item() 以及 eval_loss['loss'] += loss.item()。值得注意的是,要复现内存越来越大的问题,模型中需要切换model.train() 和 model.eval(),train_loss以及eval_loss的作用是保存模型的平均误差(这里是累积误差),保存到tensorboard中。

相关推荐
没有梦想的咸鱼185-1037-1663几秒前
AI大模型支持下的顶刊绘图|散点图、气泡图、柱状图、热力图、柱状图、热力图、箱线图、热力图、云雨图、韦恩图、瀑布图、神经网络图、时间序列或分布展示
人工智能·神经网络·arcgis·信息可视化·数据分析·r语言·ai写作
魔乐社区几秒前
从0到1:魔乐社区贡献者丁一超的大模型量化实战指南
人工智能·大模型·量化
weixin_408099671 分钟前
【保姆级教程】易语言调用 OCR 文字识别 API(从0到1完整实战 + 示例源码)
图像处理·人工智能·后端·ocr·api·文字识别·易语言
七七powerful1 分钟前
AI实战--MiroFish:群体智能引擎,预测万物
人工智能·microfish
斯~内克3 分钟前
AI 推理提示工程技术
人工智能
鱼干~4 分钟前
【全栈知识点】全栈开发知识点
前端·人工智能·c#
白小筠4 分钟前
自然语言处理之迁移学习
人工智能·自然语言处理·迁移学习
语戚7 分钟前
Stable Diffusion 核心模块深度拆解:CLIP、U-Net 与 VAE 原理全解析
人工智能·ai·stable diffusion·aigc·模型
枫叶林FYL7 分钟前
【自然语言处理 NLP】8.3 长文本推理评估与针在大海堆任务
人工智能·算法
TDengine (老段)7 分钟前
TDengine IDMP 事件 —— 事件模板
大数据·数据库·人工智能·时序数据库·tdengine·涛思数据