【pytorch可视化工具】

TensorboardX

在PyTorch中,模型训练的可视化通常通过TensorBoard或Visdom等工具实现。以下是如何使用TensorBoard进行模型训练可视化的步骤:

使用TensorboardX与PyTorch配合

确保已经安装了tensorboardtensorboardX库。

bash 复制代码
pip install tensorboard
pip install tensorboardX

在训练过程中记录损失、准确率等指标:

python 复制代码
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn

# 假设你已经有了模型、优化器、损失函数以及数据加载器
model = ...  # 你的模型实例
criterion = nn.CrossEntropyLoss()  # 或者其他适合的损失函数
optimizer = torch.optim.Adam(model.parameters())
dataloader = DataLoader(...)  # 你的数据加载器

# 创建一个SummaryWriter对象来写入日志文件
writer = SummaryWriter()

num_epochs = 100
for epoch in range(num_epochs):
    for inputs, targets in dataloader:
        optimizer.zero_grad()
    
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        # 记录每批次的损失到TensorBoard
        writer.add_scalar('Training Loss', loss.item(), epoch * len(dataloader) + len(inputs))

    # 在每个epoch结束时记录其他评估指标(例如验证集上的精度)
    with torch.no_grad():
        val_loss = validate_your_model(model, validation_loader)
        writer.add_scalar('Validation Loss', val_loss, epoch)

# 在所有训练完成后关闭writer
writer.close()

# 然后运行tensorboard服务并打开可视化界面
%tensorboard --logdir=runs  # Jupyter notebook内
# 或在终端执行
tensorboard --logdir=runs

计算模型参数量和浮点数:

python 复制代码
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(model)
print(f"Total trainable parameters: {total_params}")

评价指标通常根据任务类型有所不同,例如分类任务中的精度、召回率、F1分数等,回归任务中的均方误差(MSE)、平均绝对误差(MAE)等。可以将这些指标也记录到TensorBoard中,就像记录损失那样。

例如,对于分类任务,假设有预测输出和真实标签:

python 复制代码
from sklearn.metrics import accuracy_score

# 验证集上的预测
predictions = torch.argmax(model(val_inputs), dim=1)
true_labels = val_targets

accuracy = accuracy_score(true_labels.cpu().numpy(), predictions.cpu().numpy())
writer.add_scalar('Validation Accuracy', accuracy, epoch

Visdom

这是一个Web-based实时数据可视化工具,可以与PyTorch一起使用来监控训练过程。下面是使用Visdom的基本代码示例:

python 复制代码
import visdom
vis = visdom.Visdom()

# 记录损失值
vis.line(Y=[loss], X=[epoch], win='Loss', update='append')

# 显示图像等其他类型的数据也类似,需要根据Visdom API操作
  1. 对于模型参数量的计算,可以通过torch.nn.Module的子类实例直接统计:
python 复制代码
import torch
from your_model_module import YourModelClass

model = YourModelClass()
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")
  1. 浮点数计算通常指的是模型占用的内存大小,可以通过下面的方式来估算(单位为MB):
python 复制代码
param_size = sum(torch.prod(torch.tensor(p.size())) * p.element_size() for p in model.parameters())
print(f"Estimated memory usage: {param_size / (1024 ** 2):.2f} MB")
相关推荐
ALex_zry几秒前
让 Python 脚本在后台持续运行:架构级解决方案与工业级实践指南
开发语言·python·架构
永洪科技2 分钟前
AI领域再突破,永洪科技荣获“2025人工智能+创新案例”奖
大数据·人工智能·科技·数据分析·数据可视化
拓端研究室TRL2 分钟前
Python对Airbnb北京与上海链家租房数据用逻辑回归、决策树、岭回归、Lasso、随机森林、XGBoost、神经网络、聚类
python·决策树·随机森林·回归·逻辑回归
潇湘夜雨6972 分钟前
第十四届蓝桥杯大赛软件赛国赛Python大学B组题解
python·蓝桥杯
that's boy2 分钟前
Google 发布 Sec-Gemini v1:用 AI 重塑网络安全防御格局?
人工智能·安全·web安全·chatgpt·midjourney·ai编程·ai写作
Sui_Network3 分钟前
Crossmint 与 Walrus 合作,将协议集成至其跨链铸造 API 中
人工智能·物联网·游戏·区块链·智能合约
liruiqiang053 分钟前
循环神经网络 - 长短期记忆网络
人工智能·rnn·深度学习·神经网络·机器学习·ai·lstm
小杨4046 分钟前
python入门系列十六(网络编程)
人工智能·python·网络协议
Elastic 中国社区官方博客7 分钟前
Elasticsearch 向量数据库,原生支持 Google Cloud Vertex AI 平台
大数据·数据库·人工智能·elasticsearch·搜索引擎·语言模型·自然语言处理
北极星6号7 分钟前
python manimgl数学动画演示_微积分_线性代数原理_ubuntu安装问题[已解决]
python·ubuntu·opengl·数学动画·manimgl