【Java深度学习】PyTorch On Java 系列课程 第八章 17 :模型评估【AI Infra 3.0】[PyTorch Java 硕士研一课程]

训练/评估期间的指标记录

尽管 TensorBoard 提供丰富的可视化功能,但有效监控的根本在于训练和评估代码中的系统性记录。仅仅运行循环是不够的;你需要记录主要的性能数据,以了解训练是如何进展的,并在问题出现时进行诊断。这里说明如何在 PyTorch 训练和评估例程中直接实现基础的指标记录。

为什么要记录指标?

记录指标有几个重要目的:

  1. 性能追踪: 观察损失和准确度(或其他相关指标)在不同训练轮次中的趋势,以判断模型是否有效学习。
  2. 调试: 及早发现潜在问题。损失是否在下降?是否停滞不前?验证性能是否在改善或变差?记录指标中的异常情况常常指向一些根本问题,比如之前讨论过的学习率问题或过拟合。
  3. 比较: 记录的指标允许在不同模型架构、超参数或训练运行之间进行客观比较。
  4. 可视化的根本: 像 TensorBoard 这样的工具依赖这些记录的值来生成图表和仪表板。

要记录的指标

具体指标取决于你的任务,但常用指标包括:

  • 损失: 这是你的优化器试图最小化的值。你几乎总是应该记录损失。追踪训练损失(在训练循环中基于训练数据计算)和验证损失(在评估循环中基于单独的验证数据集计算)很重要。比较这两者对于识别过拟合非常重要。
  • 准确度: 对于分类任务,准确度(正确分类样本的比例)是一个标准且易于理解的指标。
  • 其他任务特定指标: 根据问题类型,你可能记录精确度、召回率、F1分数(用于分类),平均绝对误差(MAE)或均方根误差(RMSE)(用于回归),交并比(IoU)(用于分割)等。

在训练循环中实现记录

在训练期间,你通常希望追踪每个训练轮次中的平均损失和准确度。为每个批次记录指标可能产生噪音,并且对整体趋势的指示性较差,尽管有时它对调试不稳定性有用。

下面是修改标准训练轮次函数以包含记录的方法:

scala 复制代码
import torch

// 假设模型、train_dataloader、loss_fn、optimizer 已定义

def train_one_epoch(model, train_dataloader, loss_fn, optimizer, device):
    model.train() // 设置模型为训练模式
    var running_loss = 0.0
    var correct_predictions = 0
    var total_samples = 0

    for batch_idx, (inputs, labels) <- train_dataloader:
        var inputs, labels = inputs.to(device), labels.to(device)

        // 1. 清零梯度
        optimizer.zero_grad()

        // 2. 前向传播
        var outputs = model(inputs)

        // 3. 计算损失
        var loss = loss_fn(outputs, labels)

        // 4. 反向传播
        loss.backward()

        // 5. 优化器步骤
        optimizer.step()

        // --- 记录步骤 ---
        // 累加损失(使用 .item() 获取 Python 数字)
        running_loss += loss.item() * inputs.size(0) // 按批次大小加权

        // 累加准确度(分类示例)
        val _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()
        // --- 记录步骤结束 ---

    // 计算当前轮次的平均损失和准确度
    val epoch_loss = running_loss / total_samples
    val epoch_acc = correct_predictions / total_samples

    println(f"Training Epoch: Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

    // 返回指标以供后续记录或分析
    return epoch_loss, epoch_acc
java 复制代码
package vals;



import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.pytorch.global.torch;

/**
 * 单Epoch训练函数Java实现(适配JavaRandomDataLoader):
 * 1. 模型训练模式切换、批次遍历与设备迁移
 * 2. 梯度清零→前向传播→损失计算→反向传播→优化器步进
 * 3. 按批次大小加权累加损失、统计分类准确率
 * 4. 计算并返回Epoch平均损失和准确率
 */
public class TrainOneEpoch {

    /**
     * 单Epoch训练核心函数(等效Scala的train_one_epoch)
     * @param model 训练的模型
     * @param trainDataloader 训练数据加载器(JavaRandomDataLoader)
     * @param lossFn 损失函数(分类任务用CrossEntropyLoss)
     * @param optimizer 优化器(如Adam/SGD)
     * @param device 训练设备(CPU/GPU)
     * @return 包含Epoch平均损失和准确率的数组 [epochLoss, epochAcc]
     */
    public static double[] trainOneEpoch(
            Module model,
            JavaRandomDataLoader trainDataloader,
            CrossEntropyLossImpl lossFn,
            Optimizer optimizer,
            Device device
    ) {
        // 1. 设置模型为训练模式(启用Dropout/BatchNorm训练行为)
        model.train(true);

        // 2. 初始化统计变量
        double runningLoss = 0.0;       // 加权累加损失(按批次大小)
        long correctPredictions = 0;    // 正确预测数
        long totalSamples = 0;          // 总样本数

        // 3. 遍历训练数据加载器(迭代器方式,适配JavaRandomDataLoader)
        var trainIter = trainDataloader.begin();
        var trainEnd = trainDataloader.end();
        while (!trainIter.equals(trainEnd)) {
            // 3.1 获取批次数据并解析输入/标签
            ExampleVector batch = trainIter.access();
            Tensor inputs = stackData(batch).to(device, torch.ScalarType.Float);  // 输入移至目标设备
            Tensor labels = stackTarget(batch).to(device, torch.ScalarType.Long); // 标签移至目标设备

            // 3.2 梯度清零(必须在每个批次前执行)
            optimizer.zero_grad();

            // 3.3 前向传播
            Tensor outputs = model.asSequential().forward(inputs);

            // 3.4 计算损失
            Tensor loss = lossFn.forward(outputs, labels);

            // 3.5 反向传播(计算梯度)
            loss.backward();

            // 3.6 优化器步进(更新模型参数)
            optimizer.step();

            // ======================== 指标统计 ========================
            // 3.7 累加加权损失:loss.item() * 批次大小(等效Scala loss.item() * inputs.size(0))
            runningLoss += loss.item().toDouble() * inputs.size(0);

            // 3.8 计算分类准确率
            // torch.max(outputs.data, 1) → 获取预测类别索引
            Tensor predicted = torch.max(outputs.data(), 1).get1().data();
            // (predicted == labels).sum() → 统计正确预测数
            Tensor correctMask = torch.eq(predicted, labels);
            correctPredictions += correctMask.sum().item().toLong();
            // 累加总样本数
            totalSamples += labels.size(0);

            // ======================== 资源释放 ========================
            inputs.close();
            labels.close();
            outputs.close();
            loss.close();
            predicted.close();
            correctMask.close();
            batch.close();

            // 3.9 迭代器移动到下一个批次
            trainIter.increment();
        }

        // 4. 计算Epoch级指标
        double epochLoss = runningLoss / totalSamples;  // 平均损失
        double epochAcc = (double) correctPredictions / totalSamples; // 准确率(0-1之间)

        // 5. 打印Epoch训练结果
        System.out.printf("Training Epoch: Loss: %.4f, Accuracy: %.4f%n", epochLoss, epochAcc);

        // 6. 返回Epoch指标(损失+准确率)
        return new double[]{epochLoss, epochAcc};
    }

    // ======================== 辅助方法(复用你参考代码中的实现) ========================
    /**
     * 将ExampleVector中的输入数据堆叠为批次张量
     */
    public static Tensor stackData(ExampleVector batch) {
        TensorVector tensorList = new TensorVector();
        for (long i = 0; i < batch.size(); i++) {
            tensorList.push_back(batch.get(i).data());
        }
        Tensor stacked = torch.stack(tensorList);
        tensorList.close();
        return stacked;
    }

    /**
     * 将ExampleVector中的标签数据堆叠为批次张量
     */
    public static Tensor stackTarget(ExampleVector batch) {
        TensorVector tensorList = new TensorVector();
        for (long i = 0; i < batch.size(); i++) {
            tensorList.push_back(batch.get(i).target());
        }
        Tensor stacked = torch.stack(tensorList).flatten(); // 展平为一维标签
        tensorList.close();
        return stacked;
    }

    // ======================== 测试示例(可直接运行) ========================
    public static void main(String[] args) {
        // 1. 初始化设备(优先GPU)
        Device device = torch.cuda_is_available()
                ? new Device(torch.DeviceType.CUDA)
                : new Device(torch.DeviceType.CPU);

        // 2. 初始化模型(线性分类模型,10维输入→5类输出)
        LinearImpl model = new LinearImpl(10, 5);
        model.to(device, false); // 模型移至目标设备

        // 3. 初始化训练数据加载器(1000样本,批次32)
        JavaRandomDataLoader trainLoader = createDummyDataLoader(1000, 32);

        // 4. 初始化损失函数(交叉熵损失,适配分类任务)
        CrossEntropyLossImpl lossFn = new CrossEntropyLossImpl();

        // 5. 初始化优化器(Adam,学习率0.001)
        AdamOptions options = new AdamOptions(0.001f);
        Optimizer optimizer = new Adam(model.parameters(), options);

        // 6. 执行单Epoch训练
        double[] epochMetrics = trainOneEpoch(model, trainLoader, lossFn, optimizer, device);

        // 7. 打印返回结果
        System.out.printf("返回的Epoch损失: %.4f, 返回的Epoch准确率: %.4f%n",
                epochMetrics[0], epochMetrics[1]);

        // 8. 资源释放
        model.close();
        lossFn.close();
        optimizer.close();
        trainLoader.close();
        device.close();
    }

    // ======================== 模拟JavaRandomDataLoader创建 ========================
    private static JavaRandomDataLoader createDummyDataLoader(long totalSamples, long batchSize) {
        // 模拟数据集:输入10维,标签0-4分类
        JavaDataset dummyDataset = new JavaDataset() {
            @Override
            public SizeTOptional size() {
                return new SizeTOptional(totalSamples);
            }

            @Override
            public Example get(long index) {
                // 生成单样本:input(10维) + label(1维)
                Tensor input = torch.randn(10).to(torch.ScalarType.Float);
                Tensor label = torch.randint(0, 5, new long[]{1}).to(torch.ScalarType.Long);
                return new Example(input, label);
            }
        };

        // 随机采样器(批次大小)
        RandomSampler sampler = new RandomSampler(batchSize);

        // DataLoader配置
        DataLoaderOptions options = new DataLoaderOptions();
        options.batch_size().put(batchSize);
        options.enforce_ordering().put(false);
        options.workers().put(0); // 单线程

        return new JavaRandomDataLoader(dummyDataset, sampler, options);
    }
}

记录实现中的要点:

  • 在轮次开始时初始化累加器(running_losscorrect_predictionstotal_samples)。
  • 在批次循环内部,计算损失和预测后,更新累加器。
    • 使用 loss.item() 获取当前批次损失张量的标量值,以防止计算图被保留。如果你想在最后平均之前得到总损失,则乘以 inputs.size(0)(批次大小);否则,你可以平均批次损失,但如果最后一个批次较小,按批次大小加权会更准确。
    • 计算批次的准确度(或其他指标),并添加到累加总数中。在这里也使用 .item()
  • 循环结束后,通过将累加总数除以样本数量,计算整个轮次的平均损失和准确度。
  • 打印或存储这些轮次级别的指标。

在评估循环中实现记录

在评估循环中进行记录是类似的,但存在重要区别:

  • 它在 torch.no_grad() 下运行,以禁用梯度计算。
  • 模型应处于评估模式(model.eval()),以禁用 dropout 并使用训练期间学习到的批归一化统计量。
  • 没有反向传播或优化器步骤。
scala 复制代码
import torch.*

// 假设模型、val_dataloader、loss_fn 已定义

def evaluate_model(model, val_dataloader, loss_fn, device):
    model.eval() # 设置模型为评估模式
    var running_loss = 0.0
    var correct_predictions = 0
    var total_samples = 0

    with torch.no_grad(): // 禁用梯度计算
        for inputs, labels <- val_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)

            // 前向传播
            var outputs = model(inputs)

            // 计算损失
            var loss = loss_fn(outputs, labels)

            // --- 记录步骤 ---
            running_loss += loss.item() * inputs.size(0)

            val _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
            // --- 记录步骤结束 ---

    // 计算当前轮次的平均损失和准确度
    val epoch_loss = running_loss / total_samples
    val epoch_acc = correct_predictions / total_samples

    println(f"Validation: Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")
    return epoch_loss, epoch_acc
java 复制代码
package vals;

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.pytorch.global.torch;

/**
 * 模型验证/评估函数Java实现(适配JavaRandomDataLoader):
 * 1. 模型评估模式切换、无梯度计算上下文
 * 2. 批次遍历与设备迁移、仅前向传播(无反向传播/优化)
 * 3. 按批次大小加权累加损失、统计分类准确率
 * 4. 计算并返回验证Epoch平均损失和准确率
 */
public class EvaluateModel {

    /**
     * 模型评估核心函数(等效Scala的evaluate_model)
     * @param model 待评估的模型
     * @param valDataloader 验证数据加载器(JavaRandomDataLoader)
     * @param lossFn 损失函数(分类任务用CrossEntropyLoss)
     * @param device 评估设备(CPU/GPU)
     * @return 包含验证平均损失和准确率的数组 [epochLoss, epochAcc]
     */
    public static double[] evaluateModel(
            Module model,
            JavaRandomDataLoader valDataloader,
            CrossEntropyLossImpl lossFn,
            Device device
    ) {
        // 1. 设置模型为评估模式(禁用Dropout/BatchNorm训练行为)
        model.eval();

        // 2. 初始化统计变量
        double runningLoss = 0.0;       // 加权累加损失(按批次大小)
        long correctPredictions = 0;    // 正确预测数
        long totalSamples = 0;          // 总样本数

        // 3. 禁用梯度计算(等效Scala with torch.no_grad())
        NoGradGuard noGradGuard = new NoGradGuard();
        try {
            // 4. 遍历验证数据加载器
            var valIter = valDataloader.begin();
            var valEnd = valDataloader.end();
            while (!valIter.equals(valEnd)) {
                // 4.1 获取批次数据并解析输入/标签
                ExampleVector batch = valIter.access();
                Tensor inputs = TrainOneEpoch.stackData(batch).to(device, torch.ScalarType.Float);  // 输入移至目标设备
                Tensor labels = TrainOneEpoch.stackTarget(batch).to(device, torch.ScalarType.Long); // 标签移至目标设备

                // 4.2 仅前向传播(无梯度计算)
                Tensor outputs = model.asSequential().forward(inputs);

                // 4.3 计算损失(仅用于统计,无反向传播)
                Tensor loss = lossFn.forward(outputs, labels);

                // ======================== 指标统计 ========================
                // 4.4 累加加权损失:loss.item() * 批次大小
                runningLoss += loss.item().toDouble() * inputs.size(0);

                // 4.5 计算分类准确率
                // torch.max(outputs.data, 1) → 获取预测类别索引
                Tensor predicted = torch.max(outputs.data(), 1).get1().data();
                // (predicted == labels).sum() → 统计正确预测数
                Tensor correctMask = torch.eq(predicted, labels);
                correctPredictions += correctMask.sum().item().toLong();
                // 累加总样本数
                totalSamples += labels.size(0);

                // ======================== 资源释放 ========================
                inputs.close();
                labels.close();
                outputs.close();
                loss.close();
                predicted.close();
                correctMask.close();
                batch.close();

                // 4.6 迭代器移动到下一个批次
                valIter.increment();
            }
        } finally {
            // 确保无梯度上下文释放(避免资源泄漏)
            noGradGuard.close();
        }

        // 5. 计算验证Epoch级指标
        double epochLoss = runningLoss / totalSamples;  // 平均验证损失
        double epochAcc = (double) correctPredictions / totalSamples; // 验证准确率(0-1之间)

        // 6. 打印验证结果
        System.out.printf("Validation: Loss: %.4f, Accuracy: %.4f%n", epochLoss, epochAcc);

        // 7. 返回验证指标(损失+准确率)
        return new double[]{epochLoss, epochAcc};
    }

    // ======================== 测试示例(可直接运行,依赖TrainOneEpoch) ========================
    public static void main(String[] args) {
        // 1. 初始化设备(优先GPU)
        Device device = torch.cuda_is_available()
                ? new Device(torch.DeviceType.CUDA)
                : new Device(torch.DeviceType.CPU);

        // 2. 初始化模型(线性分类模型,10维输入→5类输出)
        LinearImpl model = new LinearImpl(10, 5);
        model.to(device, false); // 模型移至目标设备

        // 3. 初始化验证数据加载器(200样本,批次32)
        JavaRandomDataLoader valLoader = createDummyDataLoader(200, 32);

        // 4. 初始化损失函数(交叉熵损失,适配分类任务)
        CrossEntropyLossImpl lossFn = new CrossEntropyLossImpl();

        // 5. 执行模型评估
        double[] valMetrics = evaluateModel(model, valLoader, lossFn, device);

        // 6. 打印返回结果
        System.out.printf("返回的验证损失: %.4f, 返回的验证准确率: %.4f%n",
                valMetrics[0], valMetrics[1]);

        // 7. 资源释放
        model.close();
        lossFn.close();
        valLoader.close();
        device.close();
    }

    // ======================== 辅助方法:创建模拟JavaRandomDataLoader ========================
    private static JavaRandomDataLoader createDummyDataLoader(long totalSamples, long batchSize) {
        // 模拟数据集:输入10维,标签0-4分类
        JavaDataset dummyDataset = new JavaDataset() {
            @Override
            public SizeTOptional size() {
                return new SizeTOptional(totalSamples);
            }

            @Override
            public Example get(long index) {
                // 生成单样本:input(10维) + label(1维)
                Tensor input = torch.randn(10).to(torch.ScalarType.Float);
                Tensor label = torch.randint(0, 5, new long[]{1}).to(torch.ScalarType.Long);
                return new Example(input, label);
            }
        };

        // 随机采样器(批次大小)
        RandomSampler sampler = new RandomSampler(batchSize);

        // DataLoader配置
        DataLoaderOptions options = new DataLoaderOptions();
        options.batch_size().put(batchSize);
        options.enforce_ordering().put(false);
        options.workers().put(0); // 单线程

        return new JavaRandomDataLoader(dummyDataset, sampler, options);
    }
}

存储和使用记录的指标

将指标打印到控制台对于即时反馈很有用。为了更系统的分析或可视化,你会希望存储它们。简单的 Python 列表或字典效果很好:

scala 复制代码
// --- 在你的主训练脚本中 ---
val num_epochs = 10
val device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
// ... 初始化模型、数据、损失函数、优化器 ...

// 存储指标
var train_losses = List.empty[Double]
var train_accuracies = List.empty[Double]
var val_losses = List.empty[Double]
var val_accuracies = List.empty[Double]

for epoch <- Range(num_epochs):
    println(f"--- Epoch {epoch+1}/{num_epochs} ---")
    val train_loss, train_acc = train_one_epoch(model, train_dataloader, loss_fn, optimizer, device)
    val val_loss, val_acc = evaluate_model(model, val_dataloader, loss_fn, device)

    // 存储指标
    train_losses = train_losses :+ train_loss
    train_accuracies = train_accuracies :+ train_acc
    val_losses = val_losses :+ val_loss
    val_accuracies = val_accuracies :+ val_acc

    // 可选:根据验证性能在此处保存模型检查点
    // 例如:
    // if val_acc > best_val_acc:
    //     best_val_acc = val_acc
    //     torch.save(model.state_dict(), 'best_model.pth')

    // 可选:将指标记录到 TensorBoard(使用上述存储的值)
    // writer.add_scalar('Loss/train', train_loss, epoch)
    // writer.add_scalar('Loss/validation', val_loss, epoch)
    // ... 等等。

println("Training finished.")
// 现在你可以分析这些列表:train_losses、val_losses 等。
// 例如,将它们保存到文件或绘制图表。
java 复制代码
package vals;


import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import org.tensorboard.writer.SummaryWriter;

import java.util.ArrayList;
import java.util.List;

/**
 * 多Epoch主训练脚本Java实现:
 * 1. 多Epoch循环执行训练+验证流程
 * 2. 存储训练/验证的损失和准确率指标
 * 3. 最佳模型保存(基于验证准确率)
 * 4. TensorBoard日志记录训练/验证指标
 * 5. 完整的资源管理和工程化设计
 */
public class MainTrainingScript {

    public static void main(String[] args) {
        // ======================== 1. 初始化训练配置 ========================
        int numEpochs = 10; // 训练轮数
        // 选择设备(优先GPU)
        Device device = torch.cuda_is_available()
                ? new Device(torch.DeviceType.CUDA)
                : new Device(torch.DeviceType.CPU);
        System.out.println("Training on device: " + (device.type() == torch.DeviceType.CUDA ? "cuda:0" : "cpu"));

        // ======================== 2. 初始化核心组件(替换为实际业务逻辑) ========================
        // 2.1 初始化模型(线性分类模型:10维输入→5类输出)
        LinearImpl model = new LinearImpl(10, 5);
        model.to(device, false); // 模型移至目标设备

        // 2.2 初始化数据加载器
        JavaRandomDataLoader trainDataloader = TrainOneEpoch.createDummyDataLoader(1000, 32); // 训练集:1000样本
        JavaRandomDataLoader valDataloader = EvaluateModel.createDummyDataLoader(200, 32);    // 验证集:200样本

        // 2.3 初始化损失函数(交叉熵损失)
        CrossEntropyLossImpl lossFn = new CrossEntropyLossImpl();

        AdamOptions options = new AdamOptions(0.001f);
        // 2.4 初始化优化器(Adam,学习率0.001)
        Optimizer optimizer = new Adam(model.parameters(), options);

        // 2.5 初始化TensorBoard写入器(日志保存路径)
        SummaryWriter writer = new SummaryWriter("./tb_logs/main_training");

        // ======================== 3. 初始化指标存储容器(等效Scala的List) ========================
        List<Double> trainLosses = new ArrayList<>();    // 训练损失列表
        List<Double> trainAccuracies = new ArrayList<>();// 训练准确率列表
        List<Double> valLosses = new ArrayList<>();      // 验证损失列表
        List<Double> valAccuracies = new ArrayList<>();  // 验证准确率列表

        // ======================== 4. 最佳模型保存相关变量 ========================
        double bestValAcc = 0.0; // 最佳验证准确率
        String bestModelPath = "best_model.pth"; // 最佳模型保存路径

        // ======================== 5. 多Epoch训练循环 ========================
        for (int epoch = 0; epoch < numEpochs; epoch++) {
            System.out.printf("--- Epoch %d/%d ---%n", epoch + 1, numEpochs);

            // 5.1 执行单Epoch训练,获取训练指标
            double[] trainMetrics = TrainOneEpoch.trainOneEpoch(model, trainDataloader, lossFn, optimizer, device);
            double trainLoss = trainMetrics[0];
            double trainAcc = trainMetrics[1];

            // 5.2 执行模型验证,获取验证指标
            double[] valMetrics = EvaluateModel.evaluateModel(model, valDataloader, lossFn, device);
            double valLoss = valMetrics[0];
            double valAcc = valMetrics[1];

            // 5.3 存储指标(等效Scala的 :+ 追加操作)
            trainLosses.add(trainLoss);
            trainAccuracies.add(trainAcc);
            valLosses.add(valLoss);
            valAccuracies.add(valAcc);

            // 5.4 保存最佳模型(基于验证准确率)
            if (valAcc > bestValAcc) {
                bestValAcc = valAcc;
                // 保存模型权重(state_dict)
//                torch.save(model.state_dict(), bestModelPath);
                System.out.printf("Best model saved! New best val accuracy: %.4f%n", bestValAcc);
            }

            // 5.5 记录TensorBoard日志(按Epoch索引)
            writer.addScalar("Loss/train", trainLoss, epoch);
            writer.addScalar("Accuracy/train", trainAcc, epoch);
            writer.addScalar("Loss/validation", valLoss, epoch);
            writer.addScalar("Accuracy/validation", valAcc, epoch);
        }

        // ======================== 6. 训练完成后处理 ========================
        System.out.println("Training finished.");

        // 6.1 打印指标列表(示例:打印前3个Epoch的指标)
        System.out.println("\n=== 训练指标汇总(前3个Epoch) ===");
        for (int i = 0; i < Math.min(3, numEpochs); i++) {
            System.out.printf("Epoch %d: Train Loss=%.4f, Train Acc=%.4f, Val Loss=%.4f, Val Acc=%.4f%n",
                    i + 1,
                    trainLosses.get(i),
                    trainAccuracies.get(i),
                    valLosses.get(i),
                    valAccuracies.get(i)
            );
        }

        // 6.2 后续操作建议:
        // - 将指标保存到CSV文件:可使用Apache Commons CSV等库
        // - 绘制训练曲线:可使用JFreeChart等库
        // - 加载最佳模型:model.load_state_dict(torch.load(bestModelPath));

        // ======================== 7. 资源释放 ========================
//        writer.close(); // 关闭TensorBoard写入器
        model.close();
        lossFn.close();
        optimizer.close();
        trainDataloader.close();
        valDataloader.close();
        device.close();
    }
}

这种结构使你能够收集整个训练过程中的性能数据。这些存储的列表(train_lossesval_losses等)正是你可以输入到 Matplotlib 等绘图库或传递给 TensorBoard SummaryWriter(如前一节所述)以创建如下可视化图表的内容。

训练和验证损失曲线在10个轮次上绘制。观察这些趋势有助于诊断过拟合(验证损失增加而训练损失减少)或欠拟合(两种损失都保持高位)。

通过在训练和评估期间持续记录指标,你对模型的行为有很好的了解,从而能够就超参数调整、模型调整和调试策略做出明智的决定。

在 PyTorch 中使用 Python 调试器 (pdb)

收藏

尽管像 TensorBoard 这样的工具可以帮助监视训练趋势,并且可视化网络图可以提供架构理解,但有时你需要在程序执行到某个特定点时,仔细检查其确切状态。形状不匹配可能在模型的 forward 传递中发生,梯度可能意外地变成 NaN,或者张量值可能无故发散。对于这些情况,使用调试器逐行执行代码通常是找出根本原因最直接的方法。

Python 的内置调试器 pdb 是一个强大的文本工具,它能与 PyTorch 代码一同使用。它允许你暂停执行、检查变量(包括张量)、逐行执行代码,并在问题发生时准确了解程序的流程。

启用调试器

使用 pdb 启动调试会话最常见的方法是,在你希望执行暂停的位置,直接将以下两行代码插入到你的 Python 脚本中:

scala 复制代码
import pdb
pdb.set_trace()

当 Python 解释器遇到 pdb.set_trace() 时,它会停止执行,并将你带入终端中的 pdb 交互式控制台。(Pdb) 提示符表示你当前处于调试器中。

你应该把 pdb.set_trace() 放在哪里?

  • 在可能出错的代码前: 如果某行代码(例如,矩阵乘法、损失计算)引发错误,请将 pdb.set_trace() 放在紧邻该行之前。这能让你检查该操作的输入。
  • 在模型的 forward 方法中: 为了了解数据在经过层时如何变化,请将跟踪点放在 forward 方法内。你可以逐层执行应用,并检查张量的形状和值。
  • 在训练循环中: 为了查看特定迭代或 epoch 的状态,请将其放在循环内。你可以将其包装在条件语句中(例如,if batch_idx == problematic_index: import pdb; pdb.set_trace())。
  • loss.backward() 之后: 为了在梯度计算完毕但优化器步骤执行之前检查梯度,请将跟踪点放在 backward() 调用之后。

此外,你也可以从命令行在 pdb 的控制下启动整个脚本,这会在脚本的第一行启动调试器:

bash 复制代码
python -m pdb your_pytorch_script.py

这对于诊断脚本执行早期发生的问题很有用,例如导入错误或设置问题。

PDB 基本命令

一旦你在 (Pdb) 提示符下,就可以使用各种命令来控制执行和检查状态。下面是一些最常用的命令:

  • n (next):执行当前行,并在当前函数的下一行停止。如果当前行是一个函数调用,n 会执行整个函数并在它返回停止。
  • s (step):与 n 类似,但如果当前行是一个函数调用,s进入函数,并在其第一行停止。
  • c (continue):恢复正常执行,直到遇到下一个断点(或 pdb.set_trace() 调用),或者直到脚本结束或出错。
  • l (list):显示当前执行行周围的源代码。使用 l . 再次列出以当前行为中心的源代码。
  • p <expression> (print):在当前上下文中评估 <expression> 并打印其值。这可以说是调试 PyTorch 最重要的命令。你可以检查张量、变量、模型参数等。
    • p my_tensor.shape
    • p my_tensor.dtype
    • p my_tensor.device
    • p my_tensor (打印张量本身;可能很大)
    • p model.layer1.weight.grad (在 backward() 之后)
    • p loss.item()
  • a (args):打印当前函数的参数列表。
  • r (return):继续执行直到当前函数返回。
  • b <line_number> (breakpoint):在当前文件的特定 <line_number> 处设置断点。当执行到达该行时会暂停。你也可以在其他文件 (b path/to/file.py:<line_number>) 或方法上 (b self.my_method) 指定断点。
  • clclear:清除所有断点。cl <bp_number> 清除特定断点。
  • q (quit):退出调试器并立即终止脚本。
  • h (help):显示可用命令列表。h <command> 提供特定命令的帮助。

使用 PDB 调试 PyTorch 代码:示例

让我们看看 pdb 如何帮助处理常见的 PyTorch 调试情况。

情况 1:调试模型中的形状不匹配问题

假设你有一个简单模型,并且在前向传播时遇到了形状不匹配错误。

scala 复制代码
import torch
import torch.nn as nn
import pdb

class SimpleNet extends nn.Module:
    def __init__(self):
        super().__init__()
        val layer1 = nn.Linear(10, 20)
        val activation = nn.ReLU()
        // 潜在错误:输入尺寸与 layer1 的输出不匹配
        val layer2 = nn.Linear(20, 5) // 错误可能在此处 (20 != 25)

    def forward(x: torch.Tensor):
        println(f"Initial shape: {x.shape}")
        x = layer1(x)
        println(f"After layer1: {x.shape}")
        x = activation(x)
        println(f"After activation: {x.shape}")
        // 在 layer2 之前进行调试
        pdb.set_trace()
        // 这行代码很可能会导致运行时错误
        x = layer2(x)
        println(f"After layer2: {x.shape}")
        return x

// 示例用法
val net = SimpleNet()
// 创建一个虚拟输入张量
val input_tensor = torch.randn(32, 10) // 批次大小 32,特征数 10
val output = net(input_tensor)
java 复制代码
package vals;


import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.pytorch.global.torch;

/**
 * 含潜在维度错误的SimpleNet模型Java实现:
 * 1. 自定义Module,模拟Scala代码中的层定义(含潜在维度错误)
 * 2. 前向传播中打印各阶段张量形状,定位维度不匹配问题
 * 3. 模拟pdb断点调试(Java调试器适配)
 * 4. 复现维度不匹配的运行时错误场景
 */
public class SimpleNetWithDebug {

    /**
     * 自定义SimpleNet模型(含潜在维度错误:与Scala代码完全一致)
     */
    public static class SimpleNet extends Module {
        // 定义网络层(对应Scala的layer1/activation/layer2)
        private final LinearImpl layer1;
        private final ReLUImpl activation;
        private final LinearImpl layer2; // 潜在错误:若此处写成LinearImpl(25,5)则维度不匹配

        /**
         * 构造函数(等效Scala __init__)
         */
        public SimpleNet() {
            super("SimpleNet");
            // 初始化层:与Scala代码一致(layer2故意保留"潜在错误"的注释)
            layer1 = new LinearImpl(10, 20); // in_features=10, out_features=20
            activation = new ReLUImpl();    // ReLU激活函数
            // 潜在错误:若将20改为25,会导致layer1输出(20)与layer2输入(25)不匹配
            layer2 = new LinearImpl(20, 5); // 注释标注:错误可能在此处 (20 != 25)
        }

        /**
         * 前向传播方法(含维度打印+调试断点)
         */
//        @Override
        public Tensor forward(Tensor x) {
            // 打印初始张量形状(等效Scala println(f"Initial shape: {x.shape}"))
            System.out.println("Initial shape: " + getTensorShapeStr(x));

            // layer1前向传播
            x = layer1.forward(x);
            System.out.println("After layer1: " + getTensorShapeStr(x));

            // ReLU激活
            x = activation.forward(x);
            System.out.println("After activation: " + getTensorShapeStr(x));

            // ======================== 模拟pdb.set_trace() 调试断点 ========================
            // 方案1:手动打印调试信息(基础调试)
            System.out.println("\n=== 调试断点:layer2执行前 ===");
            System.out.println("当前张量形状: " + getTensorShapeStr(x));
            System.out.println("layer2输入维度要求: " + layer2.options().in_features() + ", 输出维度: " + layer2.options().out_features());

            // 方案2:Java调试器断点(等效pdb.set_trace())
            // 在此处添加断点(IDEA/Eclipse中点击行号旁的红点),运行时会暂停执行
            // 可在调试模式下查看x的形状、layer2的参数等信息
            // 【关键】若要复现错误,将layer2的构造参数改为25→触发维度不匹配

            // layer2前向传播(可能触发运行时错误)
            x = layer2.forward(x);
            System.out.println("After layer2: " + getTensorShapeStr(x));

            return x;
        }

        /**
         * 辅助方法:获取张量形状的可读字符串(模拟Scala的x.shape输出)
         * 示例:输入(32,10) → 输出 "[32, 10]"
         */
        private String getTensorShapeStr(Tensor tensor) {
            long[] shape = tensor.sizes().vec().get(); // 获取张量维度数组
            StringBuilder sb = new StringBuilder("[");
            for (int i = 0; i < shape.length; i++) {
                sb.append(shape[i]);
                if (i < shape.length - 1) {
                    sb.append(", ");
                }
            }
            sb.append("]");
            return sb.toString();
        }

        /**
         * 资源释放:重写close方法释放层资源
         */
        @Override
        public void close() {
            layer1.close();
            activation.close();
            layer2.close();
//            super.close();
        }
    }

    // ======================== 主函数:示例用法(复现错误场景) ========================
    public static void main(String[] args) {
        try {
            // 1. 实例化模型(含潜在维度错误)
            SimpleNet net = new SimpleNet();
            System.out.println("SimpleNet模型初始化完成\n");

            // 2. 创建虚拟输入张量(批次大小32,特征数10)
            Tensor inputTensor = torch.randn(new long[]{32, 10});
            System.out.println("虚拟输入张量创建完成:" + net.getTensorShapeStr(inputTensor) + "\n");

            // 3. 前向传播(触发维度检查/错误)
            Tensor output = net.forward(inputTensor);
            System.out.println("\n模型输出张量形状:" + net.getTensorShapeStr(output));

            // 4. 资源释放
            inputTensor.close();
            output.close();
            net.close();

        } catch (Exception e) {
            // 捕获维度不匹配的运行时错误(若layer2参数错误)
            System.err.println("\n=== 运行时错误:维度不匹配 ===");
            e.printStackTrace();
        }
    }
}

当你运行这段代码时,它会打印形状,然后在 pdb.set_trace() 处停止。

text 复制代码
Initial shape: torch.Size([32, 10])
After layer1: torch.Size([32, 20])
After activation: torch.Size([32, 20])
-> x = self.layer2(x)
(Pdb)

(Pdb) 提示符下,你可以检查:

  • p x.shape:这将输出 torch.Size([32, 20])
  • p self.layer2:这将显示定义 Linear(in_features=25, out_features=5, bias=True)
  • p self.layer2.in_features:这将输出 25

通过比较输入形状 [32, 20](特别是特征维度 20)与 layer2.in_features25),不匹配之处变得显而易见。然后你可以退出 (q) 并修正 nn.Linear 的定义。

情况 2:检查梯度

假设你的损失没有下降,并且你怀疑出现了梯度消失或梯度爆炸。你可以在反向传播后检查它们。

scala 复制代码
// 假设 model、data (inputs, targets)、loss_fn、optimizer 都已定义

// 前向传播
val outputs = model(inputs)
val loss = loss_fn(outputs, targets)

// 反向传播
optimizer.zero_grad()
loss.backward()

# 在此处插入调试器以检查梯度
import pdb
pdb.set_trace()

// 优化器步骤(将在调试后发生)
optimizer.step()

当执行暂停时,你可以检查特定参数的梯度:

  • p model.some_layer.weight.grad:检查 some_layer 权重部分的梯度张量。查看是否存在 NaN 值、非常大的值(爆炸)或非常小的值(消失)。
  • p model.some_layer.weight.grad.abs().mean():计算梯度的平均绝对值,以了解其大小。
  • p loss.item():提醒自己当前的损失值。

有效使用 PDB 的建议

  • 有针对性:pdb.set_trace() 放置在离你怀疑问题所在位置尽可能近的地方。
  • 使用 Print (p): 广泛运用 p 命令来检查张量形状、数据类型、设备位置和实际值。
  • 谨慎执行: 使用 n(next)逐行移动。只有当你需要检查你编写的函数调用时才使用 s(step)(检查 PyTorch 的内部函数可能会很冗长)。
  • 移除跟踪: 在最终确定代码之前,请记得移除或注释掉 import pdb; pdb.set_trace() 调用,尤其是在提交到版本控制或部署之前。
  • 考虑长时间运行的替代方案: pdb 会停止执行。对于仅在训练数小时后才出现的问题进行调试,交互式调试可能不切实际。在这种情况下,定期记录详细信息(张量形状、损失值、梯度范数)可能更适合,或许可以结合条件断点或断言检查。

有效使用 pdb 需要一些练习,但它是理解 PyTorch 代码详细的、逐步执行过程以及解决许多从堆栈跟踪或高级指标中不明显问题的必不可少的工具。

实践:调试与可视化

收藏

动手练习旨在巩固应用常见调试难点和可视化工具的技能。这些练习侧重于识别错误、检查模型行为,并使用 TensorBoard 和标准调试方法监控训练进度。涵盖的场景包括形状不匹配、设备放置错误和设置可视化。

练习 1:修正形状不匹配

形状不匹配是构建或修改神经网络时常见的错误。请看下面这个简单模型,它旨在处理 28x28 灰度图像(如 MNIST),并将其展平为 784 元素的向量:

scala 复制代码
import torch
import torch.nn as nn

class SimpleMLP extends nn.Module:
    def __init__(self):
        super().__init__()
        val layer1 = nn.Linear(784, 128) // 输入 784,输出 128
        val activation = nn.ReLU()
        val layer2 = nn.Linear(128, 64)  // 输入 128,输出 64
        // layer3 的输入大小不正确 - 应该是 64
        val layer3 = nn.Linear(64, 10)  // 输入 64,输出 10

    def forward(x: torch.Tensor):
        x = layer1(x)
        x = activation(x)
        x = layer2(x)
        x = activation(x)
        // 这行代码将导致错误
        x = layer3(x)
        return x

// 创建一个模拟输入批次(批次大小 4,特征 784)
val dummy_input = torch.randn(4, 784)
val model = SimpleMLP()

// 尝试前向传播
try:
    val output = model(dummy_input)
    println("模型运行成功!")
catch RuntimeError as e:
    println(s"捕获到错误:$e")
java 复制代码
package vals;

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.pytorch.global.torch;

/**
 * 含维度错误的SimpleMLP模型Java实现:
 * 1. 定义三层MLP(故意保留维度错误注释,与Scala代码对齐)
 * 2. 前向传播逻辑完全复刻
 * 3. 模拟输入批次创建 + 异常捕获(等效Scala的try-catch)
 * 4. 复现维度不匹配的RuntimeError并打印错误信息
 */
public class SimpleMLPWithErrorHandling {

    /**
     * 自定义SimpleMLP模型(三层全连接网络,含维度错误注释)
     */
    public static class SimpleMLP extends Module {
        // 定义网络层(对应Scala的layer1/layer2/layer3/activation)
        private final LinearImpl layer1;
        private final ReLUImpl activation;
        private final LinearImpl layer2;
        private final LinearImpl layer3; // 注释标注维度错误(示例中已修正,可手动改错复现)

        /**
         * 构造函数(等效Scala __init__)
         */
        public SimpleMLP() {
            super("SimpleMLP");
            // 初始化层:与Scala代码完全对齐(保留维度错误注释)
            layer1 = new LinearImpl(784, 128); // 输入784,输出128
            activation = new ReLUImpl();       // ReLU激活函数
            layer2 = new LinearImpl(128, 64);  // 输入128,输出64
            // 注释:layer3 的输入大小不正确 - 应该是 64(示例中已正确设置,改为其他值可触发错误)
            layer3 = new LinearImpl(64, 10);   // 输入64,输出10(改64为其他值如128可复现错误)
        }

        /**
         * 前向传播方法(等效Scala forward)
         */

        public Tensor forward(Tensor x) {
            // 逐层前向传播(与Scala逻辑完全一致)
            x = layer1.forward(x);
            x = activation.forward(x);
            x = layer2.forward(x);
            x = activation.forward(x);
            // 这行代码将导致错误(若layer3输入维度设置错误)
            x = layer3.forward(x);
            return x;
        }

        /**
         * 资源释放:重写close方法释放层资源
         */
        @Override
        public void close() {
            layer1.close();
            activation.close();
            layer2.close();
            layer3.close();
//            super.close();
        }
    }

    // ======================== 主函数:模拟输入 + 前向传播 + 异常捕获 ========================
    public static void main(String[] args) {
        // 1. 创建模拟输入批次(批次大小4,特征数784)
        // 等效Scala:val dummy_input = torch.randn(4, 784)
        Tensor dummyInput = torch.randn(new long[]{4, 784});
        System.out.println("模拟输入张量形状: " + getTensorShapeStr(dummyInput));

        // 2. 实例化模型
        SimpleMLP model = new SimpleMLP();

        // 3. 尝试前向传播(等效Scala的try-catch)
        try {
            Tensor output = model.forward(dummyInput);
            System.out.println("模型运行成功!");
            System.out.println("模型输出张量形状: " + getTensorShapeStr(output));

            // 释放输出张量资源
            output.close();
        } catch (Exception e) {
            // 捕获RuntimeError(维度不匹配等运行时错误)
            System.out.println("捕获到错误:" + e.getMessage());
            // 可选:打印完整堆栈信息,便于定位错误
            // e.printStackTrace();
        } finally {
            // 确保资源释放(无论是否异常)
            dummyInput.close();
            model.close();
        }
    }

    /**
     * 辅助方法:获取张量形状的可读字符串(模拟Scala的x.shape)
     */
    private static String getTensorShapeStr(Tensor tensor) {
        long[] shape = tensor.sizes().vec().get();
        StringBuilder sb = new StringBuilder("[");
        for (int i = 0; i < shape.length; i++) {
            sb.append(shape[i]);
            if (i < shape.length - 1) sb.append(", ");
        }
        sb.append("]");
        return sb.toString();
    }
}
  1. 运行代码: 执行上面的代码片段。你会遇到一个 RuntimeError。仔细查看错误信息。它通常指向尺寸不匹配,常常会提到特定层的预期输入尺寸和实际输入尺寸(在本例中是 mat1 and mat2 shapes cannot be multiplied)。

  2. 诊断: 错误发生在输入 x 到达 self.layer3 时。前一个层 self.layer2 输出一个形状为 (batch_size, 64) 的张量。然而,self.layer3 被定义为 nn.Linear(100, 10),它预期输入有 100 个特征。这种不匹配导致了错误。

    • 你可以在报错行之前插入打印语句来确认形状:

      scala 复制代码
      // 在 forward 方法中,在 layer3(x) 之前
      println("layer3 之前的形状:" + x.shape)
      val x = layer3(x)
  3. 修正代码: 修改 __init__ 方法中 self.layer3 的定义,以接受正确数量的输入特征(即 64,self.layer2 的输出大小)。

    scala 复制代码
    // 修正后的层定义
    val layer3 = nn.Linear(64, 10) // 输入 64,输出 10
  4. 验证: 使用修正后的层定义重新运行脚本。前向传播现在应该能顺利完成,没有错误。

练习 2:修正设备放置

使用 GPU 时,很重要的一点是模型和数据都在同一个设备上。我们来模拟一个错误场景,其中模型被移到 GPU,但输入张量仍留在 CPU 上。

scala 复制代码
import torch
import torch.nn as nn

// 假设可用的 CUDA GPU
if torch.cuda.is_available():
    val device = torch.device("cuda")
    println("正在使用 GPU:" + torch.cuda.get_device_name(0))
else:
    val device = torch.device("cpu")
    println("正在使用 CPU")

class SimpleNet extends nn.Module:
    def __init__(self):
        super().__init__()
        val linear = nn.Linear(10, 5)

    def forward(x: torch.Tensor):
        return linear(x)

// 创建模型并将其移动到目标设备(例如,GPU)
val model = SimpleNet().to(device)
println(s"模型参数位于:${next(model.parameters()).device}")

// 创建输入数据 - 有意留在 CPU 上
val input_data = torch.randn(8, 10)
println(s"输入数据位于:${input_data.device}")

// 尝试前向传播 - 如果设备是 'cuda',这可能会导致错误
try:
    val output = model(input_data)
    println("前向传播成功!")
catch RuntimeError as e:
    println(s"\n捕获到错误:$e")
    println("\n提示:检查模型和输入数据是否在同一个设备上。")
java 复制代码
package vals;

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.pytorch.global.torch;
import org.bytedeco.pytorch.global.torch_cuda;

/**
 * 模型与输入设备不匹配错误演示Java实现:
 * 1. 检测CUDA可用性,选择GPU/CPU设备
 * 2. 定义SimpleNet模型并移至目标设备
 * 3. 输入数据故意留在CPU,触发设备不匹配错误
 * 4. 捕获RuntimeError并给出明确的错误提示
 */
public class DeviceMismatchErrorDemo {

    /**
     * 自定义SimpleNet模型(单层线性层)
     */
    public static class SimpleNet extends Module {
        private final LinearImpl linear;

        public SimpleNet() {
            super("SimpleNet");
            // 初始化线性层:输入10维,输出5维
            linear = new LinearImpl(10, 5);
        }

//        @Override
        public Tensor forward(Tensor x) {
            // 前向传播:仅执行线性层计算
            return linear.forward(x);
        }

        /**
         * 资源释放:释放线性层资源
         */
        @Override
        public void close() {
            linear.close();
//            super.close();
        }
    }

    // ======================== 主函数:设备检测 + 模型/数据设备不匹配演示 ========================
    public static void main(String[] args) {
        // 1. 检测CUDA可用性,选择目标设备(等效Scala的if torch.cuda.is_available())
        Device device;
        if (torch.cuda_is_available()) {
            device = new Device(torch.DeviceType.CUDA); // 指定GPU 0
            // 获取GPU名称(等效Scala torch.cuda.get_device_name(0))
         
//            String gpuName = torch.cuda_get_device_name(0).getString();
//            System.out.println("正在使用 GPU: " + gpuName);
        } else {
            device = new Device(torch.DeviceType.CPU);
            System.out.println("正在使用 CPU");
        }

        // 2. 创建模型并移至目标设备(等效Scala model = SimpleNet().to(device))
        SimpleNet model = new SimpleNet();
        model.to(device, false); // 移至目标设备,false表示非原地修改

        // 打印模型参数所在设备(等效Scala next(model.parameters()).device)
        TensorVector params = model.parameters();
        Tensor firstParam = params.get(0);
        Device paramDevice = firstParam.device();
        String paramDeviceStr = getDeviceString(paramDevice);
        System.out.println("模型参数位于:" + paramDeviceStr);

        // 3. 创建输入数据 - 故意留在CPU上(核心:不调用to(device))
        Tensor inputData = torch.randn(new long[]{8, 10}); // 批次8,特征10
        String inputDeviceStr = getDeviceString(inputData.device());
        System.out.println("输入数据位于:" + inputDeviceStr);

        // 4. 尝试前向传播 - 设备不匹配时触发错误
        try {
            Tensor output = model.forward(inputData);
            System.out.println("前向传播成功!");
            // 释放输出张量资源
            output.close();
        } catch (Exception e) {
            // 捕获设备不匹配的RuntimeError
            System.out.println("\n捕获到错误:" + e.getMessage());
            System.out.println("\n提示:检查模型和输入数据是否在同一个设备上。");
            // 可选:打印完整堆栈信息,便于定位
            // e.printStackTrace();
        } finally {
            // 确保所有资源释放(无论是否异常)
            inputData.close();
            model.close();
            device.close();
//            params.close();
//            firstParam.close();
        }
    }

    /**
     * 辅助方法:将Device对象转为可读字符串(如 "cuda:0" / "cpu")
     */
    private static String getDeviceString(Device device) {
        if (device.type() == torch.DeviceType.CUDA) {
            return "cuda:" + device.index();
        } else {
            return "cpu";
        }
    }
}
  1. 运行代码: 如果你有支持 CUDA 的 GPU,运行这段代码会产生一个 RuntimeError。错误信息很可能会显示类似 Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! 的内容。

  2. 诊断: 打印语句确认模型位于 cuda 设备上(如果可用),而 input_data 位于 cpu 设备上。PyTorch 操作通常要求操作数位于同一个设备上。

  3. 修正代码: 在将 input_data 传递给模型之前,将其移动到模型所在的设备上。

    scala 复制代码
    // 将输入数据移动到正确设备
    val input_data = input_data.to(device)
    println(s"输入数据已移至:${input_data.device}")
    
    // 现在,再次尝试前向传播
    val output = model(input_data)
    println("数据移动后前向传播成功!")
  4. 验证: 重新运行修正后的脚本。前向传播应该能顺利执行,没有设备不匹配错误。记住这个原理也适用于训练循环中;从 DataLoader 获取的每个批次都需要被移动到相应的设备上。

练习 3:使用 TensorBoard 可视化训练

TensorBoard 为训练过程提供了宝贵的参考信息。让我们将其集成到一个简化训练循环中。我们将模拟训练数据并追踪一个模拟损失值。

scala 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.tensorboard.SummaryWriter
import time

// 1. 设置 TensorBoard 写入器
// 日志文件将保存在 'runs/simple_experiment' 目录中
val writer = new SummaryWriter("runs/simple_experiment")

// 2. 定义一个简单模型、损失函数和优化器
val model = nn.Linear(10, 2) // 简单的线性模型
val criterion = nn.MSELoss()
val optimizer = optim.SGD(model.parameters(), lr=0.01)

// 模拟一个简单数据集
val inputs = torch.randn(100, 10) // 100 个样本,10 个特征
val targets = torch.randn(100, 2) // 100 个样本,2 个输出值

// 3. 简单训练循环
println("开始模拟训练...")
val num_epochs = 50
for (epoch <- 0 until num_epochs):
    optimizer.zero_grad()    // 梯度清零
    val outputs = model(inputs)  // 前向传播
    val loss = criterion(outputs, targets) // 计算损失

    // 模拟损失变化(在实际训练中替换为真实损失)
    // 为演示目的,让损失随周期递减
    val simulated_loss = loss + torch.randn(1) * 0.1 + (num_epochs - epoch) / num_epochs

    simulated_loss.backward() // 反向传播(使用模拟损失进行演示)
    optimizer.step()       // 更新权重

    // 4. 将指标记录到 TensorBoard
    if (epoch + 1) % 5 == 0: // 每 5 个周期记录一次
        // 记录标量"损失"值
        writer.add_scalar('Training/Loss', simulated_loss.item(), epoch)

        // 记录模型权重分布(以线性层为例)
        writer.add_histogram('Model/Weights', model.weight, epoch)
        writer.add_histogram('Model/Bias', model.bias, epoch)

        println(f'周期 [{epoch+1}/{num_epochs}],模拟损失:{simulated_loss.item():.4f}')

    time.sleep(0.1) # 模拟训练时间

// 5. 添加模型图(可选)
// 确保输入形状与模型预期一致
// writer.add_graph(model, inputs[0].unsqueeze(0)) // 提供一个样本输入批次

// 6. 关闭写入器
writer.close()
println("模拟训练完成。TensorBoard 日志已保存到 'runs/simple_experiment'。")
println("在你的终端中运行 'tensorboard --logdir=runs' 来查看。")
java 复制代码
package vals;

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import org.tensorboard.writer.SummaryWriter;

/**
 * 带TensorBoard日志记录的模拟训练循环Java实现:
 * 1. 初始化TensorBoard写入器,指定日志保存路径
 * 2. 定义线性模型、MSE损失函数、SGD优化器
 * 3. 模拟数据集创建 + 多Epoch训练循环
 * 4. 记录训练损失(标量)、模型权重/偏置分布(直方图)
 * 5. 资源释放与日志提示
 */
public class TensorBoardTrainingDemo {

    public static void main(String[] args) throws InterruptedException {
        // ======================== 1. 初始化TensorBoard写入器 ========================
        // 等效Scala:val writer = new SummaryWriter("runs/simple_experiment")
        SummaryWriter writer = new SummaryWriter("runs/simple_experiment");
        System.out.println("TensorBoard写入器初始化完成,日志将保存至: runs/simple_experiment");

        // ======================== 2. 定义模型、损失函数、优化器 ========================
        // 线性模型:输入10维,输出2维(等效Scala nn.Linear(10,2))
        LinearImpl model = new LinearImpl(10, 2);
        // MSE损失函数(等效Scala nn.MSELoss())
        MSELossImpl criterion = new MSELossImpl();
        // SGD优化器:学习率0.01(等效Scala optim.SGD(model.parameters(), lr=0.01))
        SGDOptions options = new SGDOptions( 0.01f);
        Optimizer optimizer = new SGD(model.parameters(), options);

        // ======================== 3. 创建模拟数据集 ========================
        // 输入:100样本 × 10特征(等效Scala torch.randn(100,10))
        Tensor inputs = torch.randn(new long[]{100, 10});
        // 目标:100样本 × 2输出(等效Scala torch.randn(100,2))
        Tensor targets = torch.randn(new long[]{100, 2});

        // ======================== 4. 模拟训练循环 ========================
        System.out.println("开始模拟训练...");
        int numEpochs = 50; // 训练轮数
        for (int epoch = 0; epoch < numEpochs; epoch++) {
            // 4.1 梯度清零(等效Scala optimizer.zero_grad())
            optimizer.zero_grad();

            // 4.2 前向传播(等效Scala val outputs = model(inputs))
            Tensor outputs = model.forward(inputs);

            // 4.3 计算原始损失(等效Scala val loss = criterion(outputs, targets))
            Tensor loss = criterion.forward(outputs, targets);

            // 4.4 模拟损失变化(随Epoch递减,与Scala逻辑一致)
            // 公式:simulated_loss = loss + randn(1)*0.1 + (num_epochs - epoch)/num_epochs
            Tensor randnTensor = torch.randn(new long[]{1}).mul(new Scalar(0.1)); // 随机扰动
            double epochFactor = (double) (numEpochs - epoch) / numEpochs; // 递减因子
            Tensor simulatedLoss = loss.add(randnTensor).add(new Scalar(epochFactor));

            // 4.5 反向传播 + 优化器步进(等效Scala backward() + step())
            simulatedLoss.backward();
            optimizer.step();

            // ======================== 5. TensorBoard日志记录 ========================
            // 每5个Epoch记录一次(等效Scala if (epoch + 1) % 5 == 0)
            if ((epoch + 1) % 5 == 0) {
                // 5.1 记录标量:训练损失(等效Scala writer.add_scalar)
                double simulatedLossVal = simulatedLoss.item().toDouble();
                writer.addScalar("Training/Loss", simulatedLossVal, epoch);

                // 5.2 记录直方图:模型权重分布(等效Scala writer.add_histogram)
                writer.addHistogram("Model/Weights", model.weight().data(), epoch);
                // 5.3 记录直方图:模型偏置分布
                writer.addHistogram("Model/Bias", model.bias().data(), epoch);

                // 打印训练信息(等效Scala println)
                System.out.printf("周期 [%d/%d],模拟损失:%.4f%n", epoch + 1, numEpochs, simulatedLossVal);
            }

            // 模拟训练耗时(等效Scala time.sleep(0.1))
            Thread.sleep(100); // 单位:毫秒

            // 释放当前Epoch临时张量资源
            outputs.close();
            loss.close();
            randnTensor.close();
            simulatedLoss.close();
        }

        // ======================== 6. 可选:添加模型图(需适配TensorBoard Java SDK) ========================
        // 注:TensorBoard Java SDK的add_graph功能有限,此处给出等效实现思路
        // 1. 生成单样本输入(等效Scala inputs[0].unsqueeze(0))
        Tensor sampleInput = inputs.narrow(0, 0, 1); // 取第一个样本,形状[1,10]
        // 2. (可选)若使用PyTorch原生SummaryWriter,可调用:
        // torch.utils.tensorboard.SummaryWriter.add_graph(model, sampleInput);
        sampleInput.close();

        // ======================== 7. 资源释放 ========================
        // 关闭TensorBoard写入器(等效Scala writer.close())
//        writer.close();
        // 释放核心资源
        model.close();
        criterion.close();
        optimizer.close();
        inputs.close();
        targets.close();

        // ======================== 8. 训练完成提示 ========================
        System.out.println("模拟训练完成。TensorBoard 日志已保存到 'runs/simple_experiment'。");
        System.out.println("在你的终端中运行 'tensorboard --logdir=runs' 来查看。");
    }
}
  1. 运行代码: 执行 Python 脚本。它会打印周期进度并提及保存日志。

  2. 启动 TensorBoard: 打开你的终端或命令提示符,导航到 包含 runs 文件夹的目录(而不是 runs 文件夹内部),并运行命令:

    bash 复制代码
    tensorboard --logdir=runs
  3. 在浏览器中查看: TensorBoard 会输出一个 URL(通常是 http://localhost:6006/)。在你的网页浏览器中打开这个 URL。

  4. 浏览: 浏览 TensorBoard 界面。你应该能找到 simple_experiment 运行记录。在"标量"(Scalars)标签页下,你会看到"训练/损失"(Training/Loss)图表,它显示了我们模拟损失的递减趋势。在"直方图"(Histograms)或"分布"(Distributions)标签页下,你可以观察模型权重和偏差的分布如何随周期变化(或者在这个简单模拟中变化不大)。如果你取消注释了 add_graph 那行代码,你还会发现模型架构的可视化图表在"图"(Graphs)标签页下。

下面这个图表显示了一个在 TensorBoard 中查看时,损失如何随周期递减的例子。

一个折线图,描绘了 50 个周期内模拟训练损失的递减情况,每 5 个周期记录一次。

练习 4:使用 Scala 调试器 (sdb)

有时,打印语句不足以解决问题,你需要交互式地检查程序状态。Scala3 调试器(sdb)是一个强大的工具。让我们回顾练习 1 中的形状不匹配场景,并使用 sdb

通过在顶部添加 import sdb 并在导致错误的行之前添加 sdb.set_trace(),修改练习 1 中 原始的 失败代码:

scala 复制代码
import torch
import torch.nn as nn
import sdb // 导入调试器

class SimpleMLP extends nn.Module:
    def __init__(self):
        super().__init__()
        val layer1 = nn.Linear(784, 128)
        self.activation = nn.ReLU()
        val layer2 = nn.Linear(128, 64)
        // layer3 的输入大小不正确
        val layer3 = nn.Linear(100, 10) // 这里有错误!

    def forward(x: torch.Tensor):
        x = layer1(x)
        x = activation(x)
        x = layer2(x)
        x = activation(x)
        print("即将进入 pdb...")
        pdb.set_trace() // 在此设置断点
        // 执行将在此暂停
        print("layer3 之前的形状:", x.shape) // 我们可以在 pdb 中检查 x
        x = layer3(x) // 这行代码将导致错误
        return x

val dummy_input = torch.randn(4, 784)
val model = SimpleMLP()
val output = model(dummy_input) # 运行现在将在 forward() 内部暂停
java 复制代码
package vals;

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.pytorch.global.torch;

/**
 * 含维度错误+调试断点的SimpleMLP模型Java实现:
 * 1. 复刻三层MLP结构(layer3故意设置维度错误:100输入→与layer2输出64不匹配)
 * 2. 前向传播中添加调试断点(IDE适配)+ 维度打印
 * 3. 复现维度不匹配的运行时错误,模拟pdb调试场景
 */
public class SimpleMLPWithDebug {

    /**
     * 自定义SimpleMLP模型(含layer3维度错误)
     */
    public static class SimpleMLP extends Module {
        // 定义网络层(对应Scala的layer1/activation/layer2/layer3)
        private final LinearImpl layer1;
        private final ReLUImpl activation;
        private final LinearImpl layer2;
        private final LinearImpl layer3; // 核心错误:输入维度100≠layer2输出64

        /**
         * 构造函数(等效Scala __init__)
         */
        public SimpleMLP() {
            super("SimpleMLP");
            // 初始化层:与Scala代码完全对齐(保留维度错误)
            layer1 = new LinearImpl(784, 128);  // 输入784,输出128
            activation = new ReLUImpl();        // ReLU激活函数
            layer2 = new LinearImpl(128, 64);   // 输入128,输出64
            // 错误点:layer3输入维度应为64,故意设为100(与Scala一致)
            layer3 = new LinearImpl(100, 10);   // 输入100,输出10
        }

        /**
         * 前向传播方法(含调试断点+维度打印)
         */
//        @Override
        public Tensor forward(Tensor x) {
            // 逐层前向传播
            x = layer1.forward(x);
            x = activation.forward(x);
            x = layer2.forward(x);
            x = activation.forward(x);

            // ======================== 等效Scala的pdb.set_trace() ========================
            System.out.println("即将进入 调试断点...");
            // 【关键】在此处添加IDE断点(IDEA/Eclipse中点击行号旁红点)
            // 调试模式运行时会暂停,可查看x的形状、layer3的参数等信息
            // 1. 打印layer3前的张量形状(等效Scala print("layer3 之前的形状:", x.shape))
            System.out.println("layer3 之前的形状:" + getTensorShapeStr(x));
            // 2. 打印layer3的输入维度要求(辅助调试)
            System.out.println("layer3 要求的输入维度:" + layer3.options().in_features());
            System.out.println("当前张量的最后一维:" + x.size(1));
            // ==========================================================================

            // 执行此行会触发维度不匹配错误(核心错误场景)
            x = layer3.forward(x);
            return x;
        }

        /**
         * 辅助方法:获取张量形状的可读字符串(模拟Scala的x.shape)
         */
        private String getTensorShapeStr(Tensor tensor) {
            long[] shape = tensor.sizes().vec().get();
            StringBuilder sb = new StringBuilder("[");
            for (int i = 0; i < shape.length; i++) {
                sb.append(shape[i]);
                if (i < shape.length - 1) sb.append(", ");
            }
            sb.append("]");
            return sb.toString();
        }

        /**
         * 资源释放:释放所有层资源
         */
        @Override
        public void close() {
            layer1.close();
            activation.close();
            layer2.close();
            layer3.close();
//            super.close();
        }
    }

    // ======================== 主函数:模型实例化 + 前向传播(触发调试/错误) ========================
    public static void main(String[] args) {
        try {
            // 1. 创建模拟输入张量(批次大小4,特征数784)
            // 等效Scala:val dummy_input = torch.randn(4, 784)
            Tensor dummyInput = torch.randn(new long[]{4, 784});
            System.out.println("模拟输入张量形状: " + getTensorShapeStr(dummyInput));

            // 2. 实例化模型(含layer3维度错误)
            SimpleMLP model = new SimpleMLP();

            // 3. 前向传播(会在调试断点处暂停/触发错误)
            // 等效Scala:val output = model(dummy_input)
            Tensor output = model.forward(dummyInput);
            System.out.println("模型输出张量形状: " + getTensorShapeStr(output));

            // 释放资源
            dummyInput.close();
            output.close();
            model.close();

        } catch (Exception e) {
            // 捕获维度不匹配的运行时错误
            System.err.println("\n=== 捕获到运行时错误 ===");
            System.err.println("错误信息:" + e.getMessage());
            System.err.println("\n错误原因:layer3输入维度要求100,但实际输入维度为64");
        }
    }

    /**
     * 工具方法:获取张量形状字符串(供main函数使用)
     */
    private static String getTensorShapeStr(Tensor tensor) {
        long[] shape = tensor.sizes().vec().get();
        StringBuilder sb = new StringBuilder("[");
        for (int i = 0; i < shape.length; i++) {
            sb.append(shape[i]);
            if (i < shape.length - 1) sb.append(", ");
        }
        sb.append("]");
        return sb.toString();
    }
}
  1. 运行修改后的代码: 执行脚本。当程序执行到 pdb.set_trace() 时,它会暂停,你会在终端中看到 (Pdb) 提示符。
  2. 与 pdb 交互:
    • 输入 p x.shape(打印 x.shape)然后按回车键。你会看到 torch.Size([4, 64])
    • 输入 p self.layer3 然后按回车键。你会看到定义 Linear(in_features=100, out_features=10, bias=True)
    • 比较输入形状(64 个特征)与层预期输入(100 个特征),可以清楚地看到不匹配。
    • 输入 n(下一步)然后按回车键。这将尝试执行下一行代码(x = self.layer3(x)),这会导致 RuntimeError,并可能退出调试器或在其中显示错误堆栈。
    • 或者,输入 c(继续)让程序继续运行直到下一个断点或出现错误。
    • 输入 q(退出)立即退出调试器并终止脚本。
  3. 修正并删除: 一旦你理解了问题,你可以输入 q 退出,然后像练习 1 中那样修正代码,并删除 import pdbpdb.set_trace() 行。

这个练习为你在 PyTorch 项目中处理调试和监控任务打下了基础。记住,使用打印语句进行快速检查,pdb 进行交互式检查,以及 TensorBoard 用于可视化训练过程和模型结构。这些工具对于构建、训练和改进高效的深度学习模型来说非常重要。

相关推荐
剑穗挂着新流苏3121 小时前
115_PyTorch 实战:从零搭建 CIFAR-10 完整训练与测试流水线
人工智能·pytorch·深度学习·神经网络
weisian1511 小时前
Java并发编程--19-ThreadPoolExecutor七参数详解:拒绝Executors,手动掌控线程池
java·线程池·threadpool·七大参数
csdn5659738501 小时前
Java打包时,本地仓库有jar 包,Maven打包却还去远程拉取
java·maven·jar
Demon_Hao2 小时前
JAVA通过Redis实现Key分区分片聚合点赞、收藏等计数同步数据库,并且通过布隆过滤器防重复点赞
java·数据库·redis
链上杯子2 小时前
《2026 LangChain零基础入门:用AI应用框架快速搭建智能助手》第8课(完结篇):小项目实战 + 部署 —— 构建网页版个人知识库 AI 助手
人工智能·langchain
华科易迅2 小时前
Spring装配对象方法-注解
java·后端·spring
东方不败之鸭梨的测试笔记2 小时前
AI生成测试用例方案
人工智能·测试用例
庄周的大鱼3 小时前
分析@TransactionalEventListener注解失效
java·spring·springboot·事务监听器·spring 事件机制·事务注解失效解决
笨手笨脚の3 小时前
AI 基础概念
人工智能·大模型·prompt·agent·tool