计算深度学习的参数

构建模型和学习率衰减

python 复制代码
model = TextCNN().to(device)
criterion = nn.CrossEntropyLoss().to(device)  #
# optimizer = optim.AdamW(model.parameters(), lr=5e-4)  # weight_decay=1e-4 weight_decay 就是 L2 正则化系数  , betas=(0.9, 0.888)
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)  # weight_decay=1e-4 weight_decay 就是 L2 正则化系数  , betas=(0.9, 0.888)

# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=10, verbose=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, min_lr=1e-5,patience=20, verbose=True)

计算相应指标并画图

python 复制代码
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import matthews_corrcoef, f1_score, precision_score, recall_score

best_val_accuracy = 0  # 设置初始最佳验证准确率为0

# 用于存储每个 epoch 的训练和验证结果
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(300):
    print('Epoch {}/{}'.format(epoch, 300))

    # 训练过程
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    all_train_preds = []
    all_train_targets = []

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        pred = model(x)
        loss = criterion(pred, y)
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
        optimizer.zero_grad()  # 清除梯度

        train_loss += loss.item()
        _, train_predicted = torch.max(pred, 1)
        train_total += y.size(0)
        train_correct += (train_predicted == y).sum().item()

        all_train_preds.extend(train_predicted.cpu().numpy())
        all_train_targets.extend(y.cpu().numpy())

    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = 100 * train_correct / train_total
    train_mcc = matthews_corrcoef(all_train_targets, all_train_preds)
    train_f1 = f1_score(all_train_targets, all_train_preds, average='weighted')
    train_precision = precision_score(all_train_targets, all_train_preds, average='weighted')
    train_recall = recall_score(all_train_targets, all_train_preds, average='weighted')

    print(f'Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%,Train MCC: {train_mcc:.4f}, Train F1: {train_f1:.4f},Train Precision: {train_precision:.4f}, Train Recall: {train_recall:.4f}')
    # print(f'Train MCC: {train_mcc:.4f}, Train F1: {train_f1:.4f}')
    # print(f'Train Precision: {train_precision:.4f}, Train Recall: {train_recall:.4f}')

    # 保存训练集上的损失和准确率
    train_losses.append(avg_train_loss)
    train_accuracies.append(train_accuracy)
    # current_lr = scheduler.optimizer.param_groups[0]['lr']
    # print(f'Current Learning Rate: {current_lr}')
    # scheduler.step(avg_val_loss)

    # 验证过程
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    all_val_preds = []
    all_val_targets = []

    with torch.no_grad():
        for inputs, target in val_loader:
            inputs, target = inputs.to(device), target.to(device)
            output = model(inputs)
            loss = criterion(output, target)

            val_loss += loss.item()
            _, val_predicted = torch.max(output, 1)
            val_total += target.size(0)
            val_correct += (val_predicted == target).sum().item()

            all_val_preds.extend(val_predicted.cpu().numpy())
            all_val_targets.extend(target.cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = 100 * val_correct / val_total
    val_mcc = matthews_corrcoef(all_val_targets, all_val_preds)
    val_f1 = f1_score(all_val_targets, all_val_preds, average='weighted')
    val_precision = precision_score(all_val_targets, all_val_preds, average='weighted')
    val_recall = recall_score(all_val_targets, all_val_preds, average='weighted')
    ################################
    current_lr = scheduler.optimizer.param_groups[0]['lr']
    print(f'Current Learning Rate: {current_lr}')
    scheduler.step(avg_val_loss)

    print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}% ,Validation MCC: {val_mcc:.4f}, Validation F1: {val_f1:.4f} ,Validation Precision: {val_precision:.4f}, Validation Recall: {val_recall:.4f}')
    # print(f'Validation MCC: {val_mcc:.4f}, Validation F1: {val_f1:.4f}')
    # print(f'Validation Precision: {val_precision:.4f}, Validation Recall: {val_recall:.4f}')

    # 保存验证集上的损失和准确率
    val_losses.append(avg_val_loss)
    val_accuracies.append(val_accuracy)

    # 如果需要保存验证集上表现最好的模型,可以添加如下代码
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_model_{}.pth'.format(epoch))
        print('Best model saved best_model_{}.pth'.format(epoch))

# 训练结束后绘制损失和准确率曲线
epochs = range(1, len(train_losses) + 1)

plt.figure(figsize=(12, 5))

# 绘制损失曲线
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, 'b', label='Train Loss')
plt.plot(epochs, val_losses, 'r', label='Validation Loss')
plt.title('Train and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# 绘制准确率曲线
plt.subplot(1, 2, 2)
plt.plot(epochs, train_accuracies, 'b', label='Train Accuracy')
plt.plot(epochs, val_accuracies, 'r', label='Validation Accuracy')
plt.title('Train and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.tight_layout()
plt.show()

#
# from google.colab import drive
# drive.mount('/content/drive')
#
# # 保存模型到 Google Drive 中
# model_save_path = '/content/drive/MyDrive/best_model_{}.pth'.format(epoch)
# torch.save(BiGRU.state_dict(), model_save_path)
相关推荐
ZOMI酱30 分钟前
【AI系统】GPU 架构回顾(从2018年-2024年)
人工智能·架构
土豆炒马铃薯。1 小时前
【深度学习】Pytorch 1.x 安装命令
linux·人工智能·pytorch·深度学习·ubuntu·centos
阿_旭1 小时前
【超全】目标检测模型分类对比与综述:单阶段、双阶段、有无锚点、DETR、旋转框
人工智能·深度学习·目标检测·分类
研一计算机小白一枚1 小时前
Which Tasks Should Be Learned Together in Multi-task Learning? 译文
人工智能·python·学习·机器学习
xianghan收藏册1 小时前
基于lora的llama2二次预训练
人工智能·深度学习·机器学习·chatgpt·transformer
2zcode1 小时前
基于YOLOv8深度学习的智慧农业果园果树柑橘类果实目标检测系统(PyQt5界面+数据集+训练代码)
深度学习·yolo·目标检测
Eric.Lee20211 小时前
数据集-目标检测系列- 蘑菇 检测数据集 mushroom >> DataBall
人工智能·python·yolo·目标检测·计算机视觉·蘑菇检测
像污秽一样1 小时前
根据气候变化自动制定鲜花存储策略(BabyAGI)
人工智能·chatgpt·langchain
Struart_R2 小时前
Epipolar-Free 3D Gaussian Splatting for Generalizable Novel View Synthesis 论文解读
人工智能·深度学习·计算机视觉·3d·transformer·三维重建·新视角生成
不去幼儿园2 小时前
【RL Base】多级反馈队列(MFQ)算法
人工智能·python·算法·机器学习·强化学习