【pytorch深度学习 应用篇02】训练过程可视化

参考

安装相关库

bash 复制代码
pip install tensorboardX
pip install tensorboard

训练代码

python 复制代码
from tensorboardX import SummaryWriter
logger = SummaryWriter(log_dir="data/log")
 
# 获取优化器和损失函数
optimizer = torch.optim.Adam(MyConvNet.parameters(), lr=3e-4)
loss_func = nn.CrossEntropyLoss()
log_step_interval = 100# 记录的步数间隔
 
for epoch in range(5):
     print("epoch:", epoch)
     # 每一轮都遍历一遍数据加载器
     for step, (x, y) in enumerate(train_loader):
         # 前向计算->计算损失函数->(从损失函数)反向传播->更新网络
         predict = MyConvNet(x)
         loss = loss_func(predict, y)
         optimizer.zero_grad()   # 清空梯度(可以不写)
         loss.backward()     # 反向传播计算梯度
         optimizer.step()    # 更新网络
         global_iter_num = epoch * len(train_loader) + step + 1# 计算当前是从训练开始时的第几步(全局迭代次数)
         if global_iter_num % log_step_interval == 0:
             # 控制台输出一下
             print("global_step:{}, loss:{:.2}".format(global_iter_num, loss.item()))
             # 添加的第一条日志:损失函数-全局迭代次数
             logger.add_scalar("train loss", loss.item() ,global_step=global_iter_num)
             # 在测试集上预测并计算正确率
             test_predict = MyConvNet(test_data_x)
             _, predict_idx = torch.max(test_predict, 1)     # 计算softmax后的最大值的索引,即预测结果
             acc = accuracy_score(test_data_y, predict_idx)
             # 添加第二条日志:正确率-全局迭代次数
             logger.add_scalar("test accuary", acc.item(), global_step=global_iter_num)
             # 添加第三条日志:这个batch下的128张图像
             img = vutils.make_grid(x, nrow=12)
             logger.add_image("train image sample", img, global_step=global_iter_num)
             # 添加第三条日志:网络中的参数分布直方图
             for name, param in MyConvNet.named_parameters():
                 logger.add_histogram(name, param.data.numpy(), global_step=global_iter_num)

绑定端口

如果直接在服务器查看

bash 复制代码
$ tensorboard --logdir=xxx --port=6006

如果想在客户端查看

bash 复制代码
# ssh -L 服务器端口:127.0.0.1:客户端端口 服务器中你的用户名name@服务器的ip  
# 有的服务器做了端口映射 所以-p后面添加你服务器的连接端口号,默认是22端口 
$ ssh -L 16006:127.0.0.1:6006  name@ip -p 22
$ export LC_ALL=C
$ tensorboard --logdir=path --port=6006 # path是你服务器上保存的tensorboard文件

查看

在本地浏览器中访问http://localhost:6006/

相关推荐
袋鼠云数栈UED团队16 小时前
基于 OpenSpec 实现规范驱动开发
前端·人工智能
Raink老师16 小时前
【AI面试临阵磨枪】什么是 Tokenization?子词分词(Subword)的优缺点?
人工智能·ai 面试
迷你可可小生16 小时前
面经(三)
人工智能·rnn·lstm
云烟成雨TD16 小时前
Spring AI Alibaba 1.x 系列【28】Nacos Skill 管理中心功能说明
java·人工智能·spring
AI医影跨模态组学16 小时前
Cancer Letters(IF=10.1)中科院自动化研究所田捷等团队:整合纵向MRI与活检全切片图像用于乳腺癌新辅助治疗反应的早期预测及个体化管理
人工智能·深度学习·论文·医学·医学影像
oioihoii16 小时前
Graphify 简明指南
人工智能
王飞飞不会飞17 小时前
Mac 安装Hermes Agent 过程记录
运维·深度学习·机器学习
数字供应链安全产品选型17 小时前
AI全生命周期安全:从开发到下线,悬镜安全灵境AIDR如何覆盖智能体每一个环节?
人工智能
2501_9333295517 小时前
企业舆情处置实战:Infoseek数字公关AI中台技术架构与功能解析
大数据·人工智能·架构·数据库开发
带娃的IT创业者17 小时前
深度解析 Claude Design:如何利用 Anthropic 最新设计范式构建 AI 原生应用
人工智能·python·llm·claude·应用开发·anthropic·ai原生应用