Pytorch:torch.utils.tensorboard.SummaryWriter

SummaryWriter 通常是指在TensorBoard日志记录库中的一个类,这是TensorFlow的一个组件,可以用于记录和展示神经网络训练过程中的各种参数。PyTorch也提供了与TensorBoard兼容的工具torch.utils.tensorboard.SummaryWriter

SummaryWriter 主要用于在训练神经网络时捕获和存储指标,比如损失、精度、模型内部的权重和梯度等。随后,这些数据可以被TensorBoard读取并显示为图形,帮助开发者可视化训练过程,从而更好地理解模型的训练和性能。

以下是一个使用PyTorch中SummaryWriter的基本示例:

python 复制代码
from torch.utils.tensorboard import SummaryWriter

# 实例化SummaryWriter
writer = SummaryWriter('runs/experiment_1')

for epoch in range(num_epochs):
    # ...
    # 在这里实现模型的训练逻辑
    # ...

    # 记录损失
    writer.add_scalar('Loss/train', loss_value, epoch)

    # 记录模型权重直方图
    for tag, value in model.named_parameters():
        tag = tag.replace('.', '/')
        writer.add_histogram(tag, value.data.cpu().numpy(), epoch)

    # 记录图像
    if epoch % 10 == 0:
        # 假设images是一个批次的图像张量
        writer.add_images('Train/images', images, epoch)

# 训练结束后关闭SummaryWriter
writer.close()

使用这样的代码,你可以在TensorBoard中查看损失和权重直方图的变化,以及每十个epoch保存的图像。

以下是它常用的几个功能:

标量记录(Scalars):

python 复制代码
   # 记录损失
   for epoch in range(num_epochs):
       writer.add_scalar('training_loss', loss_value, epoch)

直方图记录(Histograms):

python 复制代码
# 记录权重分布
   for name, param in model.named_parameters():
       writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch)

图结构(Graphs):

python 复制代码
# 记录模型结构
   inputs = torch.rand(1, 3, 224, 224)  # 随机生成输入张量
   writer.add_graph(model, inputs)

图像记录(Images):

python 复制代码
   # 记录一批图像
   images, labels = next(iter(dataloader))
   writer.add_images('four_fashion_mnist_images', images, epoch)

音频记录(Audio):

python 复制代码
 # 记录一段音频
   audio = torch.randn(1, 44100)  # 随机生成音频数据
   writer.add_audio('my_audio', audio, epoch, sample_rate=44100)

文本记录(Text):

python 复制代码
  # 记录文本信息
   writer.add_text('my_text', 'This is an example of adding text to TensorBoard', epoch)

嵌入记录(Embeddings):

python 复制代码
   # 记录嵌入
   features = torch.randn(100, 256)  # 假设有100个256维的特征向量
   writer.add_embedding(features, metadata=None, label_img=None, global_step=epoch)

超参数调优(HParams):

python 复制代码
# 记录超参数与指标
   hparams = {'lr': 0.1, 'bsize': 1}
   metrics = {'hparam/accuracy': 0.99}

   writer.add_hparams(hparam_dict=hparams, metric_dict=metrics)

PR曲线(Precision-Recall Curves):

python 复制代码
   # 记录PR曲线
   from sklearn.metrics import precision_recall_curve
   y_true = np.array([0, 1, 1, 0, 1])  # 实际标签
   y_scores = np.array([0.1, 0.4, 0.35, 0.8, 0.7])  # 预测得分
   precision, recall, _ = precision_recall_curve(y_true, y_scores)
   writer.add_pr_curve('pr_curve', y_true, y_scores, epoch)

使用上述代码的方式,利用SummaryWriter 的各种方法,可以帮助你记录和可视化模型训练时产生的各种数据。这些数据记录之后可以通过TensorBoard进行查看。

记得在代码中导入必要的模块,并适当调整参数和数据要与你的实际模型训练情况相符。

相关推荐
小雷FansUnion44 分钟前
深入理解MCP架构:智能服务编排、上下文管理与动态路由实战
人工智能·架构·大模型·mcp
资讯分享周1 小时前
扣子空间PPT生产力升级:AI智能生成与多模态创作新时代
人工智能·powerpoint
思则变2 小时前
[Pytest] [Part 2]增加 log功能
开发语言·python·pytest
叶子爱分享2 小时前
计算机视觉与图像处理的关系
图像处理·人工智能·计算机视觉
鱼摆摆拜拜2 小时前
第 3 章:神经网络如何学习
人工智能·神经网络·学习
一只鹿鹿鹿2 小时前
信息化项目验收,软件工程评审和检查表单
大数据·人工智能·后端·智慧城市·软件工程
张较瘦_2 小时前
[论文阅读] 人工智能 | 深度学习系统崩溃恢复新方案:DaiFu框架的原位修复技术
论文阅读·人工智能·深度学习
cver1232 小时前
野生动物检测数据集介绍-5,138张图片 野生动物保护监测 智能狩猎相机系统 生态研究与调查
人工智能·pytorch·深度学习·目标检测·计算机视觉·目标跟踪
漫谈网络2 小时前
WebSocket 在前后端的完整使用流程
javascript·python·websocket
学技术的大胜嗷2 小时前
离线迁移 Conda 环境到 Windows 服务器:用 conda-pack 摆脱硬路径限制
人工智能·深度学习·yolo·目标检测·机器学习