DAY34 GPU 训练与类的 call 方法

1. CPU 性能查看
  • 核心指标
    • 架构代际:如 Intel 13 代 / 14 代、AMD Zen 3/Zen 4,代际越新通常能效比越高。
    • 核心数 / 线程数:核心数决定并行计算能力,线程数(超线程技术)可提升多任务处理效率。
  • 查看方式
    • Windows:任务管理器 → 性能 → CPU;或使用 wmic cpu get caption,deviceid,maxclockspeed,currentclockspeed,numberofcores,numberoflogicalprocessors 命令。
    • Linux:lscpucat /proc/cpuinfo

2. GPU 性能查看
  • 核心指标
    • 显存:决定可加载的模型与数据规模,单位为 GB(如 8GB/24GB)。
    • 级别:消费级(RTX 4090)、专业级(RTX A6000)、数据中心级(A100/H100)。
    • 架构代际:如 NVIDIA Ampere(30 系)、Ada Lovelace(40 系)、Hopper(H100),代际更新会带来算力与能效提升。
  • 查看方式
    • NVIDIA GPU:nvidia-smi 命令(Linux/Windows),可查看显存占用、温度、功耗等。
    • 图形界面:NVIDIA Control Panel 或 GPU-Z。

3. GPU 训练方法
  • 核心逻辑:将模型与数据移动到 GPU 设备上进行加速计算。

PyTorch 示例

python 复制代码
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 模型移动到GPU
model = Model().to(device)
# 数据移动到GPU
inputs = inputs.to(device)
labels = labels.to(device)
  • 关键注意:模型与数据必须在同一设备(CPU/GPU)上,否则会报错。

4. 类的 call 方法与前向传播
  • 原理 :在 PyTorch 中,nn.Module 类实现了 __call__ 方法,该方法会自动调用 forward() 函数。

  • 代码表现

    python 复制代码
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(10, 5)
        def forward(self, x):
            # 等价于 self.fc1.__call__(x)
            x = self.fc1(x)
            return x
  • 原因self.fc1nn.Linear 实例,继承自 nn.Module,因此可像函数一样直接调用 self.fc1(x),底层会执行其 forward 逻辑。


💡 实用技巧

训练过程中,可在命令行输入 nvidia-smi 实时查看显存占用、GPU 利用率等信息,帮助排查显存不足或性能瓶颈问题。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

# ===================== 1. 设备配置(自动使用GPU,没有则用CPU)=====================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 查看GPU信息(如果有)
if torch.cuda.is_available():
    print(f"GPU名称: {torch.cuda.get_device_name(0)}")
    print(f"可用显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# ===================== 2. 构建简易数据集(仅演示,替换成你的数据)=====================
# 生成随机数据:1000个样本,输入维度20,输出分类10类
x = torch.randn(1000, 20)
y = torch.randint(0, 10, (1000,))
dataset = TensorDataset(x, y)
# 数据加载器
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# ===================== 3. 定义模型 =====================
class SimpleNet(nn.Module):
    def __init__(self, input_dim=20, hidden_dim=64, num_classes=10):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    # 前向传播(__call__会自动调用这个方法)
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# 初始化模型 + 迁移到GPU/CPU
model = SimpleNet().to(device)

# ===================== 4. 损失函数 + 优化器 =====================
criterion = nn.CrossEntropyLoss()  # 分类任务
optimizer = optim.Adam(model.parameters(), lr=0.001)

# ===================== 5. 训练循环 =====================
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()  # 训练模式
    total_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in loader:
        # 【关键】数据迁移到设备
        inputs = inputs.to(device)
        labels = labels.to(device)

        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # 反向传播 + 优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 统计
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    avg_loss = total_loss / len(loader)
    acc = 100 * correct / total
    return avg_loss, acc

# ===================== 6. 开始训练 =====================
epochs = 10
best_acc = 0.0

for epoch in range(epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # 打印日志
    print(f"Epoch [{epoch+1}/{epochs}] | Loss: {train_loss:.4f} | Acc: {train_acc:.2f}%")

    # 保存准确率最高的模型
    if train_acc > best_acc:
        best_acc = train_acc
        torch.save(model.state_dict(), "best_model.pth")
        print(f"✅ 已保存最优模型,最佳准确率: {best_acc:.2f}%")

print("\n训练完成!")
print(f"最终最佳准确率: {best_acc:.2f}%")

核心 GPU 使用要点(必看)

  1. 统一设备 模型 (model.to(device)) + 数据 (inputs.to(device)) 必须在同一设备,否则报错。
  2. 无需手动改代码有 GPU 自动用 GPU,没有自动用 CPU,跨平台通用。
  3. 查看 GPU 状态 命令行输入:nvidia-smi,实时看显存、利用率。
如何替换成你的任务
  1. 数据集替换成你的图片 / 文本数据;
  2. 模型替换成你的网络结构;
  3. 调整batch_sizelrepochs超参数即可。
总结
  • 模板自动适配 GPU/CPU,零修改直接运行;
  • 核心就两步:模型.to (device)、数据.to (device);
  • 包含训练、统计、保存最优模型全套功能,直接用于项目。

@浙大疏锦行

相关推荐
u01091476040 分钟前
CSS组件库如何快速扩展_通过Sass @extend继承基础布局
jvm·数据库·python
baidu_3409988244 分钟前
Golang怎么用go-noescape优化性能_Golang如何使用编译器指令控制逃逸分析行为【进阶】
jvm·数据库·python
m0_6784854544 分钟前
如何利用虚拟 DOM 实现无痕刷新?基于 VNode 对比的状态保持技巧
jvm·数据库·python
不吃香菜学java1 小时前
Redis的java客户端
java·开发语言·spring boot·redis·缓存
qq_342295821 小时前
CSS如何实现透明背景效果_通过RGBA色彩模式控制透明度
jvm·数据库·python
TechWayfarer1 小时前
知乎/微博的IP属地显示为什么偶尔错误?用IP归属地查询平台自检工具3步验证
网络·python·网络协议·tcp/ip·网络安全
Greyson11 小时前
CSS如何处理超长文本换行问题_结合word-wrap属性
jvm·数据库·python
justjinji1 小时前
如何批量更新SQL数据表_使用UPDATE JOIN语法提升效率
jvm·数据库·python
小江的记录本1 小时前
【网络安全】《网络安全常见攻击与防御》(附:《六大攻击核心特性横向对比表》)
java·网络·人工智能·后端·python·安全·web安全
贵沫末1 小时前
python——打包自己的库并安装
开发语言·windows·python