python
复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
python
复制代码
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = DataLoader(testset, batch_size=128, shuffle=False)
python
复制代码
# 原始Inception模块
class Inception(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.branch1x1 = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=1),
nn.ReLU()
)
self.branch3x3 = nn.Sequential(
nn.Conv2d(in_channels, 96, kernel_size=1),
nn.ReLU(),
nn.Conv2d(96, 128, kernel_size=3, padding=1),
nn.ReLU()
)
self.branch5x5 = nn.Sequential(
nn.Conv2d(in_channels, 16, kernel_size=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=5, padding=2),
nn.ReLU()
)
self.branch_pool = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
nn.Conv2d(in_channels, 32, kernel_size=1),
nn.ReLU()
)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3(x)
branch5x5 = self.branch5x5(x)
branch_pool = self.branch_pool(x)
return torch.cat([branch1x1, branch3x3, branch5x5, branch_pool], dim=1)
python
复制代码
# 带残差的Inception模块
class InceptionWithResidual(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.branch1x1 = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=1),
nn.ReLU()
)
self.branch3x3 = nn.Sequential(
nn.Conv2d(in_channels, 96, kernel_size=1),
nn.ReLU(),
nn.Conv2d(96, 128, kernel_size=3, padding=1),
nn.ReLU()
)
self.branch5x5 = nn.Sequential(
nn.Conv2d(in_channels, 16, kernel_size=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=5, padding=2),
nn.ReLU()
)
self.branch_pool = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
nn.Conv2d(in_channels, 32, kernel_size=1),
nn.ReLU()
)
if in_channels != 256:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, 256, kernel_size=1),
nn.BatchNorm2d(256)
)
else:
self.shortcut = nn.Identity()
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3(x)
branch5x5 = self.branch5x5(x)
branch_pool = self.branch_pool(x)
outputs = torch.cat([branch1x1, branch3x3, branch5x5, branch_pool], dim=1)
shortcut = self.shortcut(x)
return F.relu(outputs + shortcut)
python
复制代码
# CBAM注意力模块
class CBAM(nn.Module):
def __init__(self, channels, reduction_ratio=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction_ratio),
nn.ReLU(),
nn.Linear(channels // reduction_ratio, channels)
)
self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3)
def forward(self, x):
# 通道注意力
avg_out = self.fc(self.avg_pool(x).squeeze())
max_out = self.fc(self.max_pool(x).squeeze())
channel_att = torch.sigmoid(avg_out + max_out).unsqueeze(2).unsqueeze(3)
x = x * channel_att
# 空间注意力
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
spatial_att = torch.cat([avg_out, max_out], dim=1)
spatial_att = torch.sigmoid(self.conv(spatial_att))
return x * spatial_att
# 带CBAM的Inception模块
class InceptionWithCBAM(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.branch1x1 = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=1),
nn.ReLU()
)
self.branch3x3 = nn.Sequential(
nn.Conv2d(in_channels, 96, kernel_size=1),
nn.ReLU(),
nn.Conv2d(96, 128, kernel_size=3, padding=1),
nn.ReLU()
)
self.branch5x5 = nn.Sequential(
nn.Conv2d(in_channels, 16, kernel_size=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=5, padding=2),
nn.ReLU()
)
self.branch_pool = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
nn.Conv2d(in_channels, 32, kernel_size=1),
nn.ReLU()
)
self.cbam = CBAM(256)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3(x)
branch5x5 = self.branch5x5(x)
branch_pool = self.branch_pool(x)
outputs = torch.cat([branch1x1, branch3x3, branch5x5, branch_pool], dim=1)
return self.cbam(outputs)
python
复制代码
# 完整网络架构
class InceptionNet(nn.Module):
def __init__(self, num_classes=10, module_type='original'):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
if module_type == 'original':
self.inception1 = Inception(64)
self.inception2 = Inception(256)
elif module_type == 'residual':
self.inception1 = InceptionWithResidual(64)
self.inception2 = InceptionWithResidual(256)
elif module_type == 'cbam':
self.inception1 = InceptionWithCBAM(64)
self.inception2 = InceptionWithCBAM(256)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(256, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.inception1(x)
x = self.inception2(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
python
复制代码
# 训练函数
def train(model, epoch):
model.train()
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99:
print(f'Epoch: {epoch + 1}, Batch: {i + 1}, Loss: {running_loss / 100:.3f}')
running_loss = 0.0
# 测试函数
def test(model):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy on test set: {100 * correct / total:.2f}%')
python
复制代码
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 原始Inception网络
print("训练原始Inception网络:")
model = InceptionNet(module_type='original').to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(5):
train(model, epoch)
test(model)
# 带残差的Inception网络
print("\n训练带残差的Inception网络:")
model = InceptionNet(module_type='residual').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(5):
train(model, epoch)
test(model)
# 带CBAM的Inception网络
print("\n训练带CBAM的Inception网络:")
model = InceptionNet(module_type='cbam').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(5):
train(model, epoch)
test(model)
python
复制代码
训练原始Inception网络:
Epoch: 1, Batch: 100, Loss: 1.982
Epoch: 1, Batch: 200, Loss: 1.718
Epoch: 1, Batch: 300, Loss: 1.602
Accuracy on test set: 43.35%
Epoch: 2, Batch: 100, Loss: 1.475
Epoch: 2, Batch: 200, Loss: 1.405
Epoch: 2, Batch: 300, Loss: 1.371
Accuracy on test set: 53.42%
Epoch: 3, Batch: 100, Loss: 1.279
Epoch: 3, Batch: 200, Loss: 1.239
Epoch: 3, Batch: 300, Loss: 1.197
Accuracy on test set: 59.10%
Epoch: 4, Batch: 100, Loss: 1.130
Epoch: 4, Batch: 200, Loss: 1.118
Epoch: 4, Batch: 300, Loss: 1.084
Accuracy on test set: 60.84%
Epoch: 5, Batch: 100, Loss: 1.061
Epoch: 5, Batch: 200, Loss: 1.015
Epoch: 5, Batch: 300, Loss: 1.005
Accuracy on test set: 59.86%
训练带残差的Inception网络:
Epoch: 1, Batch: 100, Loss: 1.829
Epoch: 1, Batch: 200, Loss: 1.600
Epoch: 1, Batch: 300, Loss: 1.473
Accuracy on test set: 50.87%
Epoch: 2, Batch: 100, Loss: 1.324
Epoch: 2, Batch: 200, Loss: 1.267
Epoch: 2, Batch: 300, Loss: 1.231
Accuracy on test set: 58.51%
Epoch: 3, Batch: 100, Loss: 1.132
Epoch: 3, Batch: 200, Loss: 1.100
Epoch: 3, Batch: 300, Loss: 1.074
Accuracy on test set: 60.79%
Epoch: 4, Batch: 100, Loss: 1.027
Epoch: 4, Batch: 200, Loss: 1.000
Epoch: 4, Batch: 300, Loss: 0.987
Accuracy on test set: 60.19%
Epoch: 5, Batch: 100, Loss: 0.965
Epoch: 5, Batch: 200, Loss: 0.934
Epoch: 5, Batch: 300, Loss: 0.918
Accuracy on test set: 66.30%
训练带CBAM的Inception网络:
Epoch: 1, Batch: 100, Loss: 2.038
Epoch: 1, Batch: 200, Loss: 1.754
Epoch: 1, Batch: 300, Loss: 1.653
Accuracy on test set: 40.46%
Epoch: 2, Batch: 100, Loss: 1.523
Epoch: 2, Batch: 200, Loss: 1.450
Epoch: 2, Batch: 300, Loss: 1.414
Accuracy on test set: 51.94%
Epoch: 3, Batch: 100, Loss: 1.324
Epoch: 3, Batch: 200, Loss: 1.287
Epoch: 3, Batch: 300, Loss: 1.225
Accuracy on test set: 56.31%
Epoch: 4, Batch: 100, Loss: 1.177
Epoch: 4, Batch: 200, Loss: 1.135
Epoch: 4, Batch: 300, Loss: 1.105
Accuracy on test set: 62.34%
Epoch: 5, Batch: 100, Loss: 1.072
Epoch: 5, Batch: 200, Loss: 1.029
Epoch: 5, Batch: 300, Loss: 1.008
Accuracy on test set: 64.95%
@浙大疏锦行