计算深度学习的参数

构建模型和学习率衰减

python 复制代码
model = TextCNN().to(device)
criterion = nn.CrossEntropyLoss().to(device)  #
# optimizer = optim.AdamW(model.parameters(), lr=5e-4)  # weight_decay=1e-4 weight_decay 就是 L2 正则化系数  , betas=(0.9, 0.888)
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)  # weight_decay=1e-4 weight_decay 就是 L2 正则化系数  , betas=(0.9, 0.888)

# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=10, verbose=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, min_lr=1e-5,patience=20, verbose=True)

计算相应指标并画图

python 复制代码
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import matthews_corrcoef, f1_score, precision_score, recall_score

best_val_accuracy = 0  # 设置初始最佳验证准确率为0

# 用于存储每个 epoch 的训练和验证结果
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(300):
    print('Epoch {}/{}'.format(epoch, 300))

    # 训练过程
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    all_train_preds = []
    all_train_targets = []

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        pred = model(x)
        loss = criterion(pred, y)
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
        optimizer.zero_grad()  # 清除梯度

        train_loss += loss.item()
        _, train_predicted = torch.max(pred, 1)
        train_total += y.size(0)
        train_correct += (train_predicted == y).sum().item()

        all_train_preds.extend(train_predicted.cpu().numpy())
        all_train_targets.extend(y.cpu().numpy())

    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = 100 * train_correct / train_total
    train_mcc = matthews_corrcoef(all_train_targets, all_train_preds)
    train_f1 = f1_score(all_train_targets, all_train_preds, average='weighted')
    train_precision = precision_score(all_train_targets, all_train_preds, average='weighted')
    train_recall = recall_score(all_train_targets, all_train_preds, average='weighted')

    print(f'Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%,Train MCC: {train_mcc:.4f}, Train F1: {train_f1:.4f},Train Precision: {train_precision:.4f}, Train Recall: {train_recall:.4f}')
    # print(f'Train MCC: {train_mcc:.4f}, Train F1: {train_f1:.4f}')
    # print(f'Train Precision: {train_precision:.4f}, Train Recall: {train_recall:.4f}')

    # 保存训练集上的损失和准确率
    train_losses.append(avg_train_loss)
    train_accuracies.append(train_accuracy)
    # current_lr = scheduler.optimizer.param_groups[0]['lr']
    # print(f'Current Learning Rate: {current_lr}')
    # scheduler.step(avg_val_loss)

    # 验证过程
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    all_val_preds = []
    all_val_targets = []

    with torch.no_grad():
        for inputs, target in val_loader:
            inputs, target = inputs.to(device), target.to(device)
            output = model(inputs)
            loss = criterion(output, target)

            val_loss += loss.item()
            _, val_predicted = torch.max(output, 1)
            val_total += target.size(0)
            val_correct += (val_predicted == target).sum().item()

            all_val_preds.extend(val_predicted.cpu().numpy())
            all_val_targets.extend(target.cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = 100 * val_correct / val_total
    val_mcc = matthews_corrcoef(all_val_targets, all_val_preds)
    val_f1 = f1_score(all_val_targets, all_val_preds, average='weighted')
    val_precision = precision_score(all_val_targets, all_val_preds, average='weighted')
    val_recall = recall_score(all_val_targets, all_val_preds, average='weighted')
    ################################
    current_lr = scheduler.optimizer.param_groups[0]['lr']
    print(f'Current Learning Rate: {current_lr}')
    scheduler.step(avg_val_loss)

    print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}% ,Validation MCC: {val_mcc:.4f}, Validation F1: {val_f1:.4f} ,Validation Precision: {val_precision:.4f}, Validation Recall: {val_recall:.4f}')
    # print(f'Validation MCC: {val_mcc:.4f}, Validation F1: {val_f1:.4f}')
    # print(f'Validation Precision: {val_precision:.4f}, Validation Recall: {val_recall:.4f}')

    # 保存验证集上的损失和准确率
    val_losses.append(avg_val_loss)
    val_accuracies.append(val_accuracy)

    # 如果需要保存验证集上表现最好的模型,可以添加如下代码
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_model_{}.pth'.format(epoch))
        print('Best model saved best_model_{}.pth'.format(epoch))

# 训练结束后绘制损失和准确率曲线
epochs = range(1, len(train_losses) + 1)

plt.figure(figsize=(12, 5))

# 绘制损失曲线
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, 'b', label='Train Loss')
plt.plot(epochs, val_losses, 'r', label='Validation Loss')
plt.title('Train and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# 绘制准确率曲线
plt.subplot(1, 2, 2)
plt.plot(epochs, train_accuracies, 'b', label='Train Accuracy')
plt.plot(epochs, val_accuracies, 'r', label='Validation Accuracy')
plt.title('Train and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.tight_layout()
plt.show()

#
# from google.colab import drive
# drive.mount('/content/drive')
#
# # 保存模型到 Google Drive 中
# model_save_path = '/content/drive/MyDrive/best_model_{}.pth'.format(epoch)
# torch.save(BiGRU.state_dict(), model_save_path)
相关推荐
guanshiyishi2 小时前
ABeam 德硕 | 中国汽车市场(2)——新能源车的崛起与中国汽车市场机遇与挑战
人工智能
极客天成ScaleFlash2 小时前
极客天成NVFile:无缓存直击存储性能天花板,重新定义AI时代并行存储新范式
人工智能·缓存
Uzuki2 小时前
AI可解释性 II | Saliency Maps-based 归因方法(Attribution)论文导读(持续更新)
深度学习·机器学习·可解释性
澳鹏Appen3 小时前
AI安全:构建负责任且可靠的系统
人工智能·安全
蹦蹦跳跳真可爱5894 小时前
Python----机器学习(KNN:使用数学方法实现KNN)
人工智能·python·机器学习
视界宝藏库4 小时前
多元 AI 配音软件,打造独特音频体验
人工智能
xinxiyinhe4 小时前
GitHub上英语学习工具的精选分类汇总
人工智能·deepseek·学习英语精选
ZStack开发者社区5 小时前
全球化2.0 | ZStack举办香港Partner Day,推动AIOS智塔+DeepSeek海外实践
人工智能·云计算
Spcarrydoinb6 小时前
基于yolo11的BGA图像目标检测
人工智能·目标检测·计算机视觉
非ban必选6 小时前
spring-ai-alibaba第四章阿里dashscope集成百度翻译tool
java·人工智能·spring