在Pytorch中使用Tensorboard可视化训练过程

【在Pytorch中使用Tensorboard可视化训练过程】 https://www.bilibili.com/video/BV1Qf4y1C7kz/?share_source=copy_web\&vd_source=f00bfb41b3b450c3767070ed82f30ac8

主要功能:

1.保存网络结构图

2.保存训练集的损失Loss,验证集的正确性Accuracy以及学习率变化等

3.保存训练的权重

4.保存预测图片的相关信息

使用方法:

①summarywriter来自于torch.utils.tensorboard模块中导入

python 复制代码
from torch.utils.tensorboard import SummaryWriter

②首先需要实例化summarywriter对象,需要定义一个将tensorboard文件保存路径

在实例化后会自动创建文件

python 复制代码
# 实例化SummaryWriter对象
tb_writer = SummaryWriter(log_dir="runs/flower_experiment")

③想要看到模型结构图需要在实例化模型后,创建init_img,使图形进行正向传播;通过模型的正向传播得到结构图

python 复制代码
# 实例化模型
model = resnet34(num_classes=args.num_classes).to(device)

# 将模型写入tensorboard
init_img = torch.zeros((1, 3, 224, 224), device=device)
tb_writer.add_graph(model, init_img)

通过实例化模型add_graph函数将模型和初始图片传入

④在每个训练的epoch之后,在验证完模型后,会保存当前轮数的训练集平均损失Loss和验证集的Accuracy以及learning rate

python 复制代码
# add loss, acc and lr into tensorboard
print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
tags = ["train_loss", "accuracy", "learning_rate"]
tb_writer.add_scalar(tags[0], mean_loss, epoch)
tb_writer.add_scalar(tags[1], acc, epoch)
tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)

⑤添加预测图片,使用add_figure将结果保存为图片存储

python 复制代码
if fig is not None:
    tb_writer.add_figure("predictions vs. actuals",
        figure=fig,
        global_step=epoch)

⑥添加直方图,使用add_histogram

python 复制代码
tb_writer.add_histogram(tag="conv1",
        values=model.conv1.weight,
        global_step=epoch)

效果:

展示的网络架构图,按层显示

values可以传入很多格式,包括torch.tensor,numpy.array,string,blockname

打开方式:

结果保存于事先设定的路径

在终端进入路径,或在文件夹按住shift和鼠标右键打开终端

在终端输入命令,需要加一个后面的参数打开指定数目的图片,否则会显示默认值

python 复制代码
tensorboard.exe --logdir=./ --samples_per_plugin=images=50

显示信息:

scalars中的显示信息

images中的显示信息

可以看到随着训练预测的结果越来越精准

graphs中保存每一个网络层结构中的信息

histogram中保存的为直方图

横坐标数值,纵坐标对应出现的次数,在中间分布最密集,随着不断迭代次数会变

点击左侧overlay切换

在distributions中展示权重变换

相关推荐
冬天给予的预感1 小时前
DAY 54 Inception网络及其思考
网络·python·深度学习
说私域1 小时前
互联网生态下赢家群体的崛起与“开源AI智能名片链动2+1模式S2B2C商城小程序“的赋能效应
人工智能·小程序·开源
钢铁男儿1 小时前
PyQt5高级界而控件(容器:装载更多的控件QDockWidget)
数据库·python·qt
董厂长4 小时前
langchain :记忆组件混淆概念澄清 & 创建Conversational ReAct后显示指定 记忆组件
人工智能·深度学习·langchain·llm
亿牛云爬虫专家5 小时前
Kubernetes下的分布式采集系统设计与实战:趋势监测失效引发的架构进化
分布式·python·架构·kubernetes·爬虫代理·监测·采集
G皮T8 小时前
【人工智能】ChatGPT、DeepSeek-R1、DeepSeek-V3 辨析
人工智能·chatgpt·llm·大语言模型·deepseek·deepseek-v3·deepseek-r1
九年义务漏网鲨鱼8 小时前
【大模型学习 | MINIGPT-4原理】
人工智能·深度学习·学习·语言模型·多模态
元宇宙时间8 小时前
Playfun即将开启大型Web3线上活动,打造沉浸式GameFi体验生态
人工智能·去中心化·区块链
开发者工具分享8 小时前
文本音频违规识别工具排行榜(12选)
人工智能·音视频
产品经理独孤虾8 小时前
人工智能大模型如何助力电商产品经理打造高效的商品工业属性画像
人工智能·机器学习·ai·大模型·产品经理·商品画像·商品工业属性