Day 49 预训练模型

  1. 第一步:预训练(基础教育)

    • 目标:让模型学通用知识,不针对任何具体任务。
    • 做法 :用超大范围的通用数据喂给模型 ------ 比如互联网上的万亿级文本(新闻、小说、论文)、亿级的图片(风景、动物、人脸)。
    • 学习内容
      • 文本类模型(比如 GPT、BERT):学词语的搭配、语法、逻辑、常识(比如 "吃饭" 要搭配 "筷子","下雨" 要搭配 "雨伞");
      • 图像类模型(比如 ResNet、ViT):学图像的特征(比如猫有尖耳朵、狗有长尾巴,圆形的轮子、方形的盒子)。
    • 特点 :这个过程耗时长、费算力(需要用很多 GPU/TPU 跑几天甚至几周),但只需要做一次 ------ 训练好的模型可以共享给全世界用。
  2. 第二步:微调(专业深造)

    • 目标 :让预训练好的模型适应具体任务

    • 做法 :用小范围的任务专属数据"微调" 模型。比如你想让模型做 "医学论文摘要生成",就喂给它几百篇医学论文和对应的摘要;想让它做 "猫咪图片识别",就喂给它几千张猫咪的标注图片。

    • 特点 :这个过程快、省钱 ,普通的电脑或者单块 GPU 就能搞定 ------ 因为模型已经有了基础,只需要 "校准" 一下方向就行。

      复制代码
      import torch
      import torch.nn as nn
      import torch.optim as optim
      from torchvision import datasets, transforms
      from torch.utils.data import DataLoader
      import matplotlib.pyplot as plt
       
      # 设置中文字体支持
      plt.rcParams["font.family"] = ["SimHei"]
      plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
       
      # 检查GPU是否可用
      device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
      print(f"使用设备: {device}")
       
      # 1. 数据预处理(训练集增强,测试集标准化)
      train_transform = transforms.Compose([
          transforms.RandomCrop(32, padding=4),
          transforms.RandomHorizontalFlip(),
          transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
          transforms.RandomRotation(15),
          transforms.ToTensor(),
          transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
      ])
       
      test_transform = transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
      ])
       
      # 2. 加载CIFAR-10数据集
      train_dataset = datasets.CIFAR10(
          root='./data',
          train=True,
          download=True,
          transform=train_transform
      )
       
      test_dataset = datasets.CIFAR10(
          root='./data',
          train=False,
          transform=test_transform
      )
       
      # 3. 创建数据加载器(可调整batch_size)
      batch_size = 64
      train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
      test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
       
      # 4. 训练函数(支持学习率调度器)
      def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs):
          model.train()  # 设置为训练模式
          train_loss_history = []
          test_loss_history = []
          train_acc_history = []
          test_acc_history = []
          all_iter_losses = []
          iter_indices = []
       
          for epoch in range(epochs):
              running_loss = 0.0
              correct_train = 0
              total_train = 0
              
              for batch_idx, (data, target) in enumerate(train_loader):
                  data, target = data.to(device), target.to(device)
                  optimizer.zero_grad()
                  output = model(data)
                  loss = criterion(output, target)
                  loss.backward()
                  optimizer.step()
                  
                  # 记录Iteration损失
                  iter_loss = loss.item()
                  all_iter_losses.append(iter_loss)
                  iter_indices.append(epoch * len(train_loader) + batch_idx + 1)
                  
                  # 统计训练指标
                  running_loss += iter_loss
                  _, predicted = output.max(1)
                  total_train += target.size(0)
                  correct_train += predicted.eq(target).sum().item()
                  
                  # 每100批次打印进度
                  if (batch_idx + 1) % 100 == 0:
                      print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} "
                            f"| 单Batch损失: {iter_loss:.4f}")
              
              # 计算 epoch 级指标
              epoch_train_loss = running_loss / len(train_loader)
              epoch_train_acc = 100. * correct_train / total_train
              
              # 测试阶段
              model.eval()
              correct_test = 0
              total_test = 0
              test_loss = 0.0
              with torch.no_grad():
                  for data, target in test_loader:
                      data, target = data.to(device), target.to(device)
                      output = model(data)
                      test_loss += criterion(output, target).item()
                      _, predicted = output.max(1)
                      total_test += target.size(0)
                      correct_test += predicted.eq(target).sum().item()
              
              epoch_test_loss = test_loss / len(test_loader)
              epoch_test_acc = 100. * correct_test / total_test
              
              # 记录历史数据
              train_loss_history.append(epoch_train_loss)
              test_loss_history.append(epoch_test_loss)
              train_acc_history.append(epoch_train_acc)
              test_acc_history.append(epoch_test_acc)
              
              # 更新学习率调度器
              if scheduler is not None:
                  scheduler.step(epoch_test_loss)
              
              # 打印 epoch 结果
              print(f"Epoch {epoch+1} 完成 | 训练损失: {epoch_train_loss:.4f} "
                    f"| 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%")
          
          # 绘制损失和准确率曲线
          plot_iter_losses(all_iter_losses, iter_indices)
          plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)
          
          return epoch_test_acc  # 返回最终测试准确率
       
      # 5. 绘制Iteration损失曲线
      def plot_iter_losses(losses, indices):
          plt.figure(figsize=(10, 4))
          plt.plot(indices, losses, 'b-', alpha=0.7)
          plt.xlabel('Iteration(Batch序号)')
          plt.ylabel('损失值')
          plt.title('训练过程中的Iteration损失变化')
          plt.grid(True)
          plt.show()
       
      # 6. 绘制Epoch级指标曲线
      def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):
          epochs = range(1, len(train_acc) + 1)
          
          plt.figure(figsize=(12, 5))
          
          # 准确率曲线
          plt.subplot(1, 2, 1)
          plt.plot(epochs, train_acc, 'b-', label='训练准确率')
          plt.plot(epochs, test_acc, 'r-', label='测试准确率')
          plt.xlabel('Epoch')
          plt.ylabel('准确率 (%)')
          plt.title('准确率随Epoch变化')
          plt.legend()
          plt.grid(True)
          
          # 损失曲线
          plt.subplot(1, 2, 2)
          plt.plot(epochs, train_loss, 'b-', label='训练损失')
          plt.plot(epochs, test_loss, 'r-', label='测试损失')
          plt.xlabel('Epoch')
          plt.ylabel('损失值')
          plt.title('损失值随Epoch变化')
          plt.legend()
          plt.grid(True)
          plt.tight_layout()
          plt.show()
      
      # 导入ResNet模型
      from torchvision.models import resnet18
       
      # 定义ResNet18模型(支持预训练权重加载)
      def create_resnet18(pretrained=True, num_classes=10):
          # 加载预训练模型(ImageNet权重)
          model = resnet18(pretrained=pretrained)
          
          # 修改最后一层全连接层,适配CIFAR-10的10分类任务
          in_features = model.fc.in_features
          model.fc = nn.Linear(in_features, num_classes)
          
          # 将模型转移到指定设备(CPU/GPU)
          model = model.to(device)
          return model
      
      # 创建ResNet18模型(加载ImageNet预训练权重,不进行微调)
      model = create_resnet18(pretrained=True, num_classes=10)
      model.eval()  # 设置为推理模式
       
      # 测试单张图片(示例)
      from torchvision import utils
       
      # 从测试数据集中获取一张图片
      dataiter = iter(test_loader)
      images, labels = dataiter.next()
      images = images[:1].to(device)  # 取第1张图片
       
      # 前向传播
      with torch.no_grad():
          outputs = model(images)
          _, predicted = torch.max(outputs.data, 1)
       
      # 显示图片和预测结果
      plt.imshow(utils.make_grid(images.cpu(), normalize=True).permute(1, 2, 0))
      plt.title(f"预测类别: {predicted.item()}")
      plt.axis('off')
      plt.show()

      @浙大疏锦行

相关推荐
zuozewei2 小时前
7D-AI系列:Transformer 与深度学习核心概念
人工智能·深度学习·transformer
victory04312 小时前
大模型长上下文长度使用窗口注意力表现有下降吗
深度学习
乐迪信息2 小时前
乐迪信息:异物入侵识别算法上线,AI摄像机保障智慧煤矿生产稳定
大数据·运维·人工智能·物联网·安全
CareyWYR2 小时前
每周AI论文速递(251222-251226)
人工智能
玄同7652 小时前
Python 真零基础入门:从 “什么是编程” 到 LLM Prompt 模板生成
人工智能·python·语言模型·自然语言处理·llm·nlp·prompt
虹科网络安全2 小时前
艾体宝洞察 | 生成式AI上线倒计时:Redis如何把“延迟”与“幻觉”挡在生产线之外?
数据库·人工智能·redis
Java后端的Ai之路2 小时前
【神经网络基础】-深度学习框架学习指南
人工智能·深度学习·神经网络·机器学习
熬夜敲代码的小N2 小时前
从SEO到GEO:AI时代内容优化的范式革命
大数据·人工智能·计算机网络
FakeOccupational2 小时前
【经济学】 基本面数据(Fundamental Data)之 美国劳动力报告&非农就业NFP + ADP + 美国劳动力参与率LFPR
开发语言·人工智能·python