-
第一步:预训练(基础教育)
- 目标:让模型学通用知识,不针对任何具体任务。
- 做法 :用超大范围的通用数据喂给模型 ------ 比如互联网上的万亿级文本(新闻、小说、论文)、亿级的图片(风景、动物、人脸)。
- 学习内容 :
- 文本类模型(比如 GPT、BERT):学词语的搭配、语法、逻辑、常识(比如 "吃饭" 要搭配 "筷子","下雨" 要搭配 "雨伞");
- 图像类模型(比如 ResNet、ViT):学图像的特征(比如猫有尖耳朵、狗有长尾巴,圆形的轮子、方形的盒子)。
- 特点 :这个过程耗时长、费算力(需要用很多 GPU/TPU 跑几天甚至几周),但只需要做一次 ------ 训练好的模型可以共享给全世界用。
-
第二步:微调(专业深造)
-
目标 :让预训练好的模型适应具体任务。
-
做法 :用小范围的任务专属数据"微调" 模型。比如你想让模型做 "医学论文摘要生成",就喂给它几百篇医学论文和对应的摘要;想让它做 "猫咪图片识别",就喂给它几千张猫咪的标注图片。
-
特点 :这个过程快、省钱 ,普通的电脑或者单块 GPU 就能搞定 ------ 因为模型已经有了基础,只需要 "校准" 一下方向就行。
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt # 设置中文字体支持 plt.rcParams["font.family"] = ["SimHei"] plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 # 检查GPU是否可用 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 1. 数据预处理(训练集增强,测试集标准化) train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.RandomRotation(15), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) # 2. 加载CIFAR-10数据集 train_dataset = datasets.CIFAR10( root='./data', train=True, download=True, transform=train_transform ) test_dataset = datasets.CIFAR10( root='./data', train=False, transform=test_transform ) # 3. 创建数据加载器(可调整batch_size) batch_size = 64 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 4. 训练函数(支持学习率调度器) def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs): model.train() # 设置为训练模式 train_loss_history = [] test_loss_history = [] train_acc_history = [] test_acc_history = [] all_iter_losses = [] iter_indices = [] for epoch in range(epochs): running_loss = 0.0 correct_train = 0 total_train = 0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() # 记录Iteration损失 iter_loss = loss.item() all_iter_losses.append(iter_loss) iter_indices.append(epoch * len(train_loader) + batch_idx + 1) # 统计训练指标 running_loss += iter_loss _, predicted = output.max(1) total_train += target.size(0) correct_train += predicted.eq(target).sum().item() # 每100批次打印进度 if (batch_idx + 1) % 100 == 0: print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} " f"| 单Batch损失: {iter_loss:.4f}") # 计算 epoch 级指标 epoch_train_loss = running_loss / len(train_loader) epoch_train_acc = 100. * correct_train / total_train # 测试阶段 model.eval() correct_test = 0 total_test = 0 test_loss = 0.0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += criterion(output, target).item() _, predicted = output.max(1) total_test += target.size(0) correct_test += predicted.eq(target).sum().item() epoch_test_loss = test_loss / len(test_loader) epoch_test_acc = 100. * correct_test / total_test # 记录历史数据 train_loss_history.append(epoch_train_loss) test_loss_history.append(epoch_test_loss) train_acc_history.append(epoch_train_acc) test_acc_history.append(epoch_test_acc) # 更新学习率调度器 if scheduler is not None: scheduler.step(epoch_test_loss) # 打印 epoch 结果 print(f"Epoch {epoch+1} 完成 | 训练损失: {epoch_train_loss:.4f} " f"| 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%") # 绘制损失和准确率曲线 plot_iter_losses(all_iter_losses, iter_indices) plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history) return epoch_test_acc # 返回最终测试准确率 # 5. 绘制Iteration损失曲线 def plot_iter_losses(losses, indices): plt.figure(figsize=(10, 4)) plt.plot(indices, losses, 'b-', alpha=0.7) plt.xlabel('Iteration(Batch序号)') plt.ylabel('损失值') plt.title('训练过程中的Iteration损失变化') plt.grid(True) plt.show() # 6. 绘制Epoch级指标曲线 def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss): epochs = range(1, len(train_acc) + 1) plt.figure(figsize=(12, 5)) # 准确率曲线 plt.subplot(1, 2, 1) plt.plot(epochs, train_acc, 'b-', label='训练准确率') plt.plot(epochs, test_acc, 'r-', label='测试准确率') plt.xlabel('Epoch') plt.ylabel('准确率 (%)') plt.title('准确率随Epoch变化') plt.legend() plt.grid(True) # 损失曲线 plt.subplot(1, 2, 2) plt.plot(epochs, train_loss, 'b-', label='训练损失') plt.plot(epochs, test_loss, 'r-', label='测试损失') plt.xlabel('Epoch') plt.ylabel('损失值') plt.title('损失值随Epoch变化') plt.legend() plt.grid(True) plt.tight_layout() plt.show() # 导入ResNet模型 from torchvision.models import resnet18 # 定义ResNet18模型(支持预训练权重加载) def create_resnet18(pretrained=True, num_classes=10): # 加载预训练模型(ImageNet权重) model = resnet18(pretrained=pretrained) # 修改最后一层全连接层,适配CIFAR-10的10分类任务 in_features = model.fc.in_features model.fc = nn.Linear(in_features, num_classes) # 将模型转移到指定设备(CPU/GPU) model = model.to(device) return model # 创建ResNet18模型(加载ImageNet预训练权重,不进行微调) model = create_resnet18(pretrained=True, num_classes=10) model.eval() # 设置为推理模式 # 测试单张图片(示例) from torchvision import utils # 从测试数据集中获取一张图片 dataiter = iter(test_loader) images, labels = dataiter.next() images = images[:1].to(device) # 取第1张图片 # 前向传播 with torch.no_grad(): outputs = model(images) _, predicted = torch.max(outputs.data, 1) # 显示图片和预测结果 plt.imshow(utils.make_grid(images.cpu(), normalize=True).permute(1, 2, 0)) plt.title(f"预测类别: {predicted.item()}") plt.axis('off') plt.show()
-
Day 49 预训练模型
江上鹤.1482025-12-28 17:07
相关推荐
zuozewei2 小时前
7D-AI系列:Transformer 与深度学习核心概念victory04312 小时前
大模型长上下文长度使用窗口注意力表现有下降吗乐迪信息2 小时前
乐迪信息:异物入侵识别算法上线,AI摄像机保障智慧煤矿生产稳定CareyWYR2 小时前
每周AI论文速递(251222-251226)玄同7652 小时前
Python 真零基础入门:从 “什么是编程” 到 LLM Prompt 模板生成虹科网络安全2 小时前
艾体宝洞察 | 生成式AI上线倒计时:Redis如何把“延迟”与“幻觉”挡在生产线之外?Java后端的Ai之路2 小时前
【神经网络基础】-深度学习框架学习指南熬夜敲代码的小N2 小时前
从SEO到GEO:AI时代内容优化的范式革命FakeOccupational2 小时前
【经济学】 基本面数据(Fundamental Data)之 美国劳动力报告&非农就业NFP + ADP + 美国劳动力参与率LFPR