PyTorch 常用方法总结

PyTorch 常用方法总结

🔹 构造 / 创建

torch.xxxtorch.tensor(), torch.from_numpy(), torch.as_tensor(), torch.arange(), torch.linspace(), torch.logspace(), torch.zeros(), torch.ones(), torch.full(), torch.eye(), torch.empty(), torch.rand(), torch.randint(), torch.randn(), torch.normal(), torch.uniform(), torch.zeros_like(), torch.ones_like(), torch.empty_like(), torch.rand_like()

🔹 随机数种子

torch.xxxtorch.manual_seed(), torch.initial_seed(), torch.get_rng_state(), torch.set_rng_state()
torch.cuda.xxxtorch.cuda.manual_seed(), torch.cuda.manual_seed_all()
torch.Generator

🔹 张量操作 / 形状变换

tensor.xxxtensor.reshape(), tensor.view(), tensor.squeeze(), tensor.unsqueeze(), tensor.transpose(), tensor.permute(), tensor.repeat(), tensor.expand(), tensor.flatten(), tensor.unflatten(), tensor.flip(), tensor.roll(), tensor.narrow(), tensor.select()
torch.xxxtorch.cat(), torch.stack(), torch.chunk(), torch.split(), torch.rot90()

🔹 数学运算

tensor.xxxtensor.add(), tensor.sub(), tensor.mul(), tensor.div(), tensor.pow(), tensor.exp(), tensor.log(), tensor.log10(), tensor.log2(), tensor.sqrt(), tensor.sin(), tensor.cos(), tensor.tan(), tensor.abs(), tensor.ceil(), tensor.floor(), tensor.round(), tensor.sum(), tensor.prod(), tensor.cumsum(), tensor.cumprod(), tensor.clamp(), tensor.clip(), tensor.sigmoid(), tensor.tanh()
torch.xxxtorch.add(), torch.sub(), torch.mul(), torch.div(), torch.pow(), torch.exp(), torch.log(), torch.sqrt(), torch.sin(), torch.cos(), torch.tan(), torch.abs(), torch.sum(), torch.prod(), torch.cumsum(), torch.cumprod(), torch.clamp(), torch.sigmoid(), torch.tanh()

🔹 统计

tensor.xxxtensor.mean(), tensor.median(), tensor.mode(), tensor.std(), tensor.var(), tensor.min(), tensor.max(), tensor.argmin(), tensor.argmax(), tensor.unique(), tensor.topk(), tensor.sort(), tensor.argsort(), tensor.quantile(), tensor.kthvalue()
torch.xxxtorch.mean(), torch.median(), torch.mode(), torch.std(), torch.var(), torch.min(), torch.max(), torch.argmin(), torch.argmax(), torch.unique(), torch.topk(), torch.sort(), torch.argsort(), torch.quantile(), torch.kthvalue()

🔹 索引 / 条件

tensor.xxxtensor.nonzero(), tensor.masked_select(), tensor.masked_fill(), tensor.any(), tensor.all()
torch.xxxtorch.where(), torch.nonzero(), torch.index_select(), torch.masked_select(), torch.gather(), torch.scatter(), torch.take(), torch.masked_fill(), torch.any(), torch.all()

🔹 比较运算

tensor.xxxtensor.eq(), tensor.ne(), tensor.lt(), tensor.le(), tensor.gt(), tensor.ge(), tensor.equal(), tensor.isnan(), tensor.isinf(), tensor.isfinite()
torch.xxxtorch.eq(), torch.ne(), torch.lt(), torch.le(), torch.gt(), torch.ge(), torch.equal(), torch.allclose(), torch.isnan(), torch.isinf(), torch.isfinite()

🔹 线性代数

tensor.xxxtensor.matmul(), tensor.mm(), tensor.bmm(), tensor.dot(), tensor.cross(), tensor.norm()
torch.xxxtorch.matmul(), torch.mm(), torch.bmm(), torch.dot(), torch.cross(), torch.norm()
torch.linalg.xxxtorch.linalg.det(), torch.linalg.inv(), torch.linalg.pinv(), torch.linalg.matrix_rank(), torch.linalg.solve(), torch.linalg.eig(), torch.linalg.svd(), torch.linalg.qr(), torch.linalg.cholesky(), torch.linalg.lstsq(), torch.linalg.norm()

🔹 自动求导 (autograd)

tensor属性tensor.requires_grad, tensor.grad
tensor方法tensor.backward(), tensor.detach()
torch.xxxtorch.no_grad(), torch.enable_grad(), torch.set_grad_enabled()
torch.autograd.grad_mode

🔹 神经网络模块 (torch.nn)

基础层

nn.Linear, nn.Bilinear, nn.Identity

卷积层

nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.Unfold, nn.Fold

池化层

nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveMaxPool2d, nn.AdaptiveAvgPool2d, nn.GlobalMaxPool2d

循环层

nn.RNN, nn.LSTM, nn.GRU, nn.RNNCell, nn.LSTMCell, nn.GRUCell

正则化层

nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm, nn.GroupNorm, nn.InstanceNorm2d, nn.Dropout, nn.Dropout2d, nn.Dropout3d

激活函数

nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.ELU, nn.SELU, nn.Sigmoid, nn.Tanh, nn.Softmax, nn.LogSoftmax, nn.Softmin, nn.GELU, nn.Swish, nn.Mish

激活函数 (torch.nn.functional)

F.xxxF.relu(), F.leaky_relu(), F.elu(), F.selu(), F.sigmoid(), F.tanh(), F.softmax(), F.log_softmax(), F.softmin(), F.gelu(), F.swish(), F.mish()

损失函数

nn.MSELoss, nn.L1Loss, nn.CrossEntropyLoss, nn.NLLLoss, nn.BCELoss, nn.BCEWithLogitsLoss, nn.KLDivLoss, nn.SmoothL1Loss, nn.PoissonNLLLoss, nn.CosineEmbeddingLoss, nn.MarginRankingLoss

损失函数 (torch.nn.functional)

F.xxxF.mse_loss(), F.l1_loss(), F.cross_entropy(), F.nll_loss(), F.binary_cross_entropy(), F.binary_cross_entropy_with_logits(), F.kl_div(), F.smooth_l1_loss(), F.cosine_embedding_loss()

容器

nn.Sequential, nn.ModuleList, nn.ModuleDict, nn.ParameterList, nn.ParameterDict

其他常用函数 (torch.nn.functional)

F.xxxF.conv2d(), F.max_pool2d(), F.avg_pool2d(), F.dropout(), F.batch_norm(), F.layer_norm(), F.interpolate(), F.pad(), F.normalize(), F.embedding(), F.one_hot()

🔹 优化器 (torch.optim)

optim.SGD, optim.Adam, optim.AdamW, optim.RMSprop, optim.Adagrad, optim.Adadelta, optim.Adamax, optim.NAdam, optim.RAdam, optim.LBFGS

优化器方法

optimizer.xxxoptimizer.step(), optimizer.zero_grad(), optimizer.state_dict(), optimizer.load_state_dict()

优化器参数

参数lr, momentum, weight_decay, betas, eps, amsgrad

🔹 学习率调度器 (torch.optim.lr_scheduler)

lr_scheduler.StepLR, lr_scheduler.MultiStepLR, lr_scheduler.ExponentialLR, lr_scheduler.CosineAnnealingLR, lr_scheduler.ReduceLROnPlateau, lr_scheduler.CyclicLR, lr_scheduler.OneCycleLR, lr_scheduler.LambdaLR, lr_scheduler.CosineAnnealingWarmRestarts

调度器方法

scheduler.xxxscheduler.step(), scheduler.get_last_lr(), scheduler.state_dict(), scheduler.load_state_dict()

调度器参数

参数step_size, gamma, milestones, T_max, eta_min, factor, patience, base_lr, max_lr

🔹 数据处理 (torch.utils.data)

数据集类

data.Dataset, data.TensorDataset, data.ConcatDataset, data.Subset

数据加载器

data.DataLoader
参数batch_size, shuffle, num_workers, pin_memory, drop_last, sampler

采样器

data.Sampler, data.RandomSampler, data.SequentialSampler, data.WeightedRandomSampler, data.SubsetRandomSampler

数据分割

函数data.random_split()

🔹 数据变换 (torchvision.transforms)

transforms.Compose, transforms.ToTensor, transforms.Normalize, transforms.Resize, transforms.CenterCrop, transforms.RandomCrop, transforms.RandomHorizontalFlip, transforms.RandomVerticalFlip, transforms.RandomRotation, transforms.ColorJitter, transforms.RandomResizedCrop, transforms.Pad, transforms.Lambda

变换参数

参数size, mean, std, scale, ratio, degrees, brightness, contrast, saturation, hue

🔹 CUDA / 设备

设备相关

torch.xxxtorch.device(), torch.cuda.is_available(), torch.cuda.device_count(), torch.cuda.get_device_name(), torch.cuda.current_device(), torch.cuda.empty_cache(), torch.cuda.memory_allocated(), torch.cuda.memory_reserved()

张量设备操作

tensor.xxxtensor.to(), tensor.cpu(), tensor.cuda(), tensor.half(), tensor.float(), tensor.double(), tensor.int(), tensor.long(), tensor.bool()

设备参数

参数device='cuda', device='cpu', non_blocking=True, memory_format

🔹 保存 / 加载

torch.xxxtorch.save(), torch.load(), torch.jit.save(), torch.jit.load(), torch.jit.script(), torch.jit.trace()

模型状态

model.xxxmodel.state_dict(), model.load_state_dict(), model.parameters(), model.named_parameters(), model.modules(), model.named_modules()

保存参数

参数map_location, strict=False, weights_only=True

🔹 分布式训练

nn.DataParallel, nn.parallel.DistributedDataParallel
函数torch.distributed.init_process_group(), torch.distributed.barrier(), torch.distributed.all_reduce(), torch.distributed.broadcast()

🔹 模型模式控制

model.xxxmodel.train(), model.eval(), model.requires_grad_()
参数training=True/False

🔹 初始化方法 (torch.nn.init)

init.xxxinit.xavier_uniform_(), init.xavier_normal_(), init.kaiming_uniform_(), init.kaiming_normal_(), init.normal_(), init.uniform_(), init.constant_(), init.zeros_(), init.ones_()

🔹 工具函数

torch.xxxtorch.__version__, torch.set_printoptions(), torch.get_default_dtype(), torch.set_default_dtype(), torch.set_num_threads(), torch.get_num_threads()

打印选项参数

参数precision, threshold, edgeitems, linewidth, profile, sci_mode


📋 标准模型构建框架

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class StandardModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout=0.1):
        super(StandardModel, self).__init__()
        # 网络层定义
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        # 正则化
        self.dropout = nn.Dropout(dropout)
        self.batch_norm = nn.BatchNorm1d(hidden_size)
        
    def forward(self, x):
        # x.size() -> [batch_size, input_size]
        x = F.relu(self.fc1(x))           # 使用 F.relu() 函数式激活
        x = self.batch_norm(x)            # 批归一化
        x = self.dropout(x)               # dropout正则化
        x = F.relu(self.fc2(x))           # 第二层 + 激活
        x = self.dropout(x)               # dropout
        x = self.fc3(x)                   # 输出层(无激活)
        return x

# 实例化模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = StandardModel(input_size=784, hidden_size=256, output_size=10).to(device)

# 模型初始化
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)    # 权重初始化
        m.bias.data.fill_(0.01)                    # 偏置初始化

model.apply(init_weights)

# 查看模型信息
print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
print(f"可训练参数: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

🚀 标准训练框架

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR

def train_model(model, train_loader, val_loader, epochs=100, device='cpu'):
    # 损失函数和优化器
    criterion = nn.CrossEntropyLoss()                              # 损失函数
    optimizer = optim.Adam(model.parameters(),                     # 优化器
                          lr=0.001, 
                          weight_decay=1e-4,
                          betas=(0.9, 0.999))
    scheduler = StepLR(optimizer,                                  # 学习率调度器
                      step_size=30, 
                      gamma=0.1)
    
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        # ============ 训练阶段 ============
        model.train()                           # 设置为训练模式
        train_loss = 0.0
        train_correct = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)     # 数据移至设备
            
            optimizer.zero_grad()               # 清零梯度
            output = model(data)                # 前向传播
            loss = criterion(output, target)    # 计算损失
            loss.backward()                     # 反向传播
            optimizer.step()                    # 更新参数
            
            # 统计
            train_loss += loss.item()
            pred = output.argmax(dim=1)         # 获取预测类别
            train_correct += pred.eq(target).sum().item()
        
        # ============ 验证阶段 ============
        model.eval()                            # 设置为评估模式
        val_loss = 0.0
        val_correct = 0
        
        with torch.no_grad():                   # 禁用梯度计算
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                val_loss += criterion(output, target).item()
                pred = output.argmax(dim=1)
                val_correct += pred.eq(target).sum().item()
        
        # ============ 计算指标 ============
        train_acc = 100. * train_correct / len(train_loader.dataset)
        val_acc = 100. * val_correct / len(val_loader.dataset)
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        
        print(f'Epoch {epoch+1}/{epochs}:')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print(f'Learning Rate: {scheduler.get_last_lr()[0]:.6f}')
        
        # ============ 保存最佳模型 ============
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),         # 模型参数
                'optimizer_state_dict': optimizer.state_dict(), # 优化器状态
                'scheduler_state_dict': scheduler.state_dict(), # 调度器状态
                'loss': val_loss,
                'accuracy': val_acc,
            }, 'best_model.pth')
        
        scheduler.step()                        # 更新学习率
        print('-' * 50)

# ============ 模型恢复 ============
def load_checkpoint(model, optimizer, scheduler, filename):
    """加载训练检查点"""
    checkpoint = torch.load(filename, map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return epoch, loss

# ============ 预测函数 ============
def predict(model, data_loader, device):
    """模型预测"""
    model.eval()                            # 评估模式
    predictions = []
    probabilities = []
    
    with torch.no_grad():                   # 禁用梯度
        for data, _ in data_loader:
            data = data.to(device)
            output = model(data)            # 前向传播
            prob = F.softmax(output, dim=1) # 转换为概率
            pred = output.argmax(dim=1)     # 获取预测类别
            
            predictions.extend(pred.cpu().numpy())
            probabilities.extend(prob.cpu().numpy())
    
    return predictions, probabilities

# ============ 模型评估 ============
def evaluate_model(model, test_loader, device):
    """评估模型性能"""
    model.eval()
    test_loss = 0
    correct = 0
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    
    test_loss /= len(test_loader)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    print(f'Test Loss: {test_loss:.4f}')
    print(f'Test Accuracy: {accuracy:.2f}%')
    
    return test_loss, accuracy