python
复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os
# 设置随机种子,保证结果可复现
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 创建保存图像的目录
os.makedirs('visualizations', exist_ok=True)
# 数据加载和预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
# 由于显存限制,增大batch_size可能会导致显存不足,因此选择适中的batch_size
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
# 定义简化版的VSSM模型
class VSSM(nn.Module):
def __init__(self, input_size=784, hidden_size=32, state_size=16, output_size=10):
super(VSSM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.state_size = state_size
self.output_size = output_size
# 编码器网络 - 将输入映射到隐状态分布
self.encoder = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU()
)
# 变分推断网络 - 生成隐状态的均值和方差
self.fc_mu = nn.Linear(hidden_size, state_size)
self.fc_logvar = nn.Linear(hidden_size, state_size)
# 状态转移网络 - 预测下一个隐状态
self.transition = nn.Sequential(
nn.Linear(state_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, state_size)
)
# 解码器网络 - 从隐状态重构输入
self.decoder = nn.Sequential(
nn.Linear(state_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, input_size)
)
# 分类器网络 - 从隐状态预测类别
self.classifier = nn.Sequential(
nn.Linear(state_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.2), # 添加Dropout减少过拟合
nn.Linear(hidden_size, output_size)
)
def encode(self, x):
# x: [batch_size, input_size]
h = self.encoder(x)
mu = self.fc_mu(h) # [batch_size, state_size]
logvar = self.fc_logvar(h) # [batch_size, state_size]
return mu, logvar
def reparameterize(self, mu, logvar):
# 重参数化技巧,实现隐变量的随机采样
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std # [batch_size, state_size]
def decode(self, z):
# z: [batch_size, state_size]
return self.decoder(z) # [batch_size, input_size]
def classify(self, z):
# z: [batch_size, state_size]
return self.classifier(z) # [batch_size, output_size]
def forward(self, x):
# x: [batch_size, 1, 28, 28]
batch_size = x.size(0)
x_flat = x.view(batch_size, -1) # [batch_size, 784]
# 编码并采样隐状态
mu, logvar = self.encode(x_flat)
z = self.reparameterize(mu, logvar)
# 状态转移
z_next = self.transition(z)
# 解码和分类
recon_flat = self.decode(z_next)
pred = self.classify(z)
return recon_flat, pred, mu, logvar, z, x_flat
# 定义VSSM损失函数
def vssm_loss(recon_x, x, pred, target, mu, logvar, lambda_kl=0.1, lambda_cls=1.0):
# 重构损失 - 衡量重构图像与原始图像的差异
recon_loss = F.mse_loss(recon_x, x.view(x.size(0), -1), reduction='sum')
# KL散度 - 衡量隐变量分布与标准正态分布的差异
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# 分类损失
cls_loss = F.cross_entropy(pred, target, reduction='sum')
# 计算总损失
batch_size = x.size(0)
total_loss = (recon_loss + lambda_kl * kl_loss + lambda_cls * cls_loss) / batch_size
return total_loss, recon_loss.item()/batch_size, kl_loss.item()/batch_size, cls_loss.item()/batch_size
# 绘制损失曲线的函数
def pltLoss(train_losses, test_losses, epochs):
plt.figure(figsize=(10, 5))
plt.plot(range(1, epochs+1), train_losses, 'b-', label='Training Loss')
plt.plot(range(1, epochs+1), test_losses, 'r-', label='Test Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Test Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curve.png')
plt.close()
# 可视化测试样本及其预测结果的函数
def plotTest(model, test_loader, device, epoch):
model.eval()
best_sample = None
best_confidence = -1
best_info = None
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
# 前向传播获取中间结果
recon_flat, pred, mu, logvar, z, x_flat = model(data)
# 计算预测置信度
confidence = F.softmax(pred, dim=1).max(dim=1)[0]
# 找到置信度最高的样本
max_idx = confidence.argmin().item()
if confidence[max_idx] > best_confidence:
best_confidence = confidence[max_idx].item()
best_sample = {
'input': data[max_idx].cpu(),
'recon': recon_flat[max_idx].cpu().view(1, 28, 28),
'target': target[max_idx].cpu().item(),
'pred': pred[max_idx].argmax().cpu().item(),
'confidence': best_confidence,
'mu': mu[max_idx].cpu().numpy(),
'logvar': logvar[max_idx].cpu().numpy(),
'z': z[max_idx].cpu().numpy(),
'pred_dist': F.softmax(pred[max_idx], dim=0).cpu().numpy()
}
# 释放不再需要的张量以节省显存
del data, target, recon_flat, pred, mu, logvar, z, x_flat, confidence, max_idx
torch.cuda.empty_cache()
if best_sample is not None:
# 创建可视化
plt.figure(figsize=(12, 8))
# 1. 原始输入图像
plt.subplot(2, 3, 1)
plt.title(f'Input Image (True: {best_sample["target"]})')
plt.imshow(best_sample['input'].squeeze().numpy(), cmap='gray')
plt.axis('off')
# 2. 重构图像
plt.subplot(2, 3, 2)
plt.title(f'Reconstructed Image')
plt.imshow(best_sample['recon'].squeeze().numpy(), cmap='gray')
plt.axis('off')
# 3. 隐变量均值
plt.subplot(2, 3, 3)
plt.title('Latent Mean (μ)')
plt.bar(range(len(best_sample['mu'])), best_sample['mu'])
plt.xlabel('Dimension')
plt.ylabel('Value')
# 4. 隐变量方差
plt.subplot(2, 3, 4)
plt.title('Latent Log Variance (log σ²)')
plt.bar(range(len(best_sample['logvar'])), best_sample['logvar'])
plt.xlabel('Dimension')
plt.ylabel('Value')
# 5. 采样的隐变量
plt.subplot(2, 3, 5)
plt.title('Sampled Latent Variable (z)')
plt.bar(range(len(best_sample['z'])), best_sample['z'])
plt.xlabel('Dimension')
plt.ylabel('Value')
# 6. 预测分布
plt.subplot(2, 3, 6)
plt.title(f'Prediction Distribution (Pred: {best_sample["pred"]}, Conf: {best_sample["confidence"]:.4f})')
plt.bar(range(10), best_sample['pred_dist'])
plt.xticks(range(10))
plt.xlabel('Class')
plt.ylabel('Probability')
plt.tight_layout()
plt.savefig(f'visualizations/epoch_{epoch}_best_sample.png')
plt.close()
# 初始化模型、优化器和学习率调度器
model = VSSM().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True)
# 训练函数
def train(model, train_loader, optimizer, epoch, device):
model.train()
train_loss = 0
train_recon_loss = 0
train_kl_loss = 0
train_cls_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# 前向传播 - 接收所有6个返回值
recon, pred, mu, logvar, z, x_flat = model(data)
# 计算损失
loss, recon_loss, kl_loss, cls_loss = vssm_loss(recon, data, pred, target, mu, logvar)
# 反向传播和优化
loss.backward()
optimizer.step()
# 累加损失
train_loss += loss.item()
train_recon_loss += recon_loss
train_kl_loss += kl_loss
train_cls_loss += cls_loss
# 释放不再需要的张量以节省显存
# del data, target, recon, pred, mu, logvar, z, x_flat, loss, recon_loss, kl_loss, cls_loss
# torch.cuda.empty_cache()
# 打印训练进度
if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
# 计算平均损失
avg_loss = train_loss / len(train_loader)
avg_recon_loss = train_recon_loss / len(train_loader)
avg_kl_loss = train_kl_loss / len(train_loader)
avg_cls_loss = train_cls_loss / len(train_loader)
print(f'Epoch: {epoch} Average training loss: {avg_loss:.4f} '
f'(Recon: {avg_recon_loss:.4f}, KL: {avg_kl_loss:.4f}, Cls: {avg_cls_loss:.4f})')
return avg_loss
# 测试函数
def test(model, test_loader, device):
model.eval()
test_loss = 0
test_recon_loss = 0
test_kl_loss = 0
test_cls_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
# 前向传播 - 接收所有6个返回值
recon, pred, mu, logvar, z, x_flat = model(data)
# 计算损失
loss, recon_loss, kl_loss, cls_loss = vssm_loss(recon, data, pred, target, mu, logvar)
# 累加损失
test_loss += loss.item()
test_recon_loss += recon_loss
test_kl_loss += kl_loss
test_cls_loss += cls_loss
# 计算分类准确率
pred_class = pred.argmax(dim=1, keepdim=True)
correct += pred_class.eq(target.view_as(pred_class)).sum().item()
# # 释放不再需要的张量以节省显存
# del data, target, recon, pred, mu, logvar, z, x_flat, loss, recon_loss, kl_loss, cls_loss, pred_class
# torch.cuda.empty_cache()
# 计算平均损失和准确率
avg_loss = test_loss / len(test_loader)
avg_recon_loss = test_recon_loss / len(test_loader)
avg_kl_loss = test_kl_loss / len(test_loader)
avg_cls_loss = test_cls_loss / len(test_loader)
accuracy = 100. * correct / len(test_loader.dataset)
print(f'Average test loss: {avg_loss:.4f} '
f'(Recon: {avg_recon_loss:.4f}, KL: {avg_kl_loss:.4f}, Cls: {avg_cls_loss:.4f})')
print(f'Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')
return avg_loss, accuracy
# 主训练循环
epochs = 10
train_losses = []
test_losses = []
best_accuracy = 0.0
for epoch in range(1, epochs + 1):
print(f'\nEpoch {epoch}/{epochs}')
# 训练一个epoch
train_loss = train(model, train_loader, optimizer, epoch, device)
train_losses.append(train_loss)
# 测试模型
test_loss, accuracy = test(model, test_loader, device)
test_losses.append(test_loss)
# 可视化最佳样本
plotTest(model, test_loader, device, epoch)
# 学习率调整
scheduler.step(test_loss)
# 保存最佳模型
if accuracy > best_accuracy:
best_accuracy = accuracy
torch.save(model.state_dict(), 'best_model.pth')
print(f'Best model saved with accuracy: {accuracy:.2f}%')
# 绘制损失曲线
pltLoss(train_losses, test_losses, epoch)
# 释放不再需要的张量以节省显存
torch.cuda.empty_cache()
print(f'\nTraining completed. Best accuracy: {best_accuracy:.2f}%')