-
对inception网络在cifar10上观察精度
-
消融实验:引入残差机制和cbam模块分别进行消融
import torch import torch.nn as nn class ChannelAttention(nn.Module): def __init__(self, in_planes, ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc = nn.Sequential( nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(), nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) ) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.fc(self.avg_pool(x)) max_out = self.fc(self.max_pool(x)) out = avg_out + max_out return self.sigmoid(out) class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() assert kernel_size in (3, 7), 'kernel size must be 3 or 7' padding = 3 if kernel_size == 7 else 1 self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avg_out, max_out], dim=1) x = self.conv1(x) return self.sigmoid(x) class CBAM(nn.Module): def __init__(self, in_planes, ratio=16, kernel_size=7): super(CBAM, self).__init__() self.ca = ChannelAttention(in_planes, ratio) self.sa = SpatialAttention(kernel_size) def forward(self, x): x = x * self.ca(x) x = x * self.sa(x) return x import torch import torch.nn as nn class InceptionModule(nn.Module): def __init__(self, in_channels, n1x1, n3x3_reduce, n3x3, n5x5_reduce, n5x5, pool_proj): super(InceptionModule, self).__init__() # 1x1 conv branch self.b1 = nn.Sequential( nn.Conv2d(in_channels, n1x1, kernel_size=1), nn.BatchNorm2d(n1x1), nn.ReLU(True), ) # 1x1 -> 3x3 conv branch self.b2 = nn.Sequential( nn.Conv2d(in_channels, n3x3_reduce, kernel_size=1), nn.BatchNorm2d(n3x3_reduce), nn.ReLU(True), nn.Conv2d(n3x3_reduce, n3x3, kernel_size=3, padding=1), nn.BatchNorm2d(n3x3), nn.ReLU(True), ) # 1x1 -> 5x5 conv branch self.b3 = nn.Sequential( nn.Conv2d(in_channels, n5x5_reduce, kernel_size=1), nn.BatchNorm2d(n5x5_reduce), nn.ReLU(True), nn.Conv2d(n5x5_reduce, n5x5, kernel_size=5, padding=2), nn.BatchNorm2d(n5x5), nn.ReLU(True), ) # 3x3 pool -> 1x1 conv branch self.b4 = nn.Sequential( nn.MaxPool2d(kernel_size=3, stride=1, padding=1), nn.Conv2d(in_channels, pool_proj, kernel_size=1), nn.BatchNorm2d(pool_proj), nn.ReLU(True), ) def forward(self, x): return torch.cat([self.b1(x), self.b2(x), self.b3(x), self.b4(x)], 1) import torch.nn as nn from .inception import InceptionModule from .cbam import CBAM class InceptionNet(nn.Module): def __init__(self, use_residual=False, use_cbam=False): super(InceptionNet, self).__init__() self.use_residual = use_residual self.use_cbam = use_cbam self.pre_layers = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(True), ) self.a3 = InceptionModule(64, 64, 96, 128, 16, 32, 32) # out: 256 self.b3 = InceptionModule(256, 128, 128, 192, 32, 96, 64) # out: 480 if self.use_cbam: self.cbam1 = CBAM(256) self.cbam2 = CBAM(480) if self.use_residual: self.shortcut1 = nn.Sequential( nn.Conv2d(64, 256, kernel_size=1), nn.BatchNorm2d(256) ) self.shortcut2 = nn.Sequential( nn.Conv2d(256, 480, kernel_size=1), nn.BatchNorm2d(480) ) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.dropout = nn.Dropout(0.4) self.fc = nn.Linear(480, 10) def forward(self, x): x = self.pre_layers(x) # Inception block 1 identity1 = x out1 = self.a3(x) if self.use_cbam: out1 = self.cbam1(out1) if self.use_residual: identity1 = self.shortcut1(identity1) out1 += identity1 out1 = nn.ReLU(True)(out1) out1 = self.maxpool(out1) # Inception block 2 identity2 = out1 out2 = self.b3(out1) if self.use_cbam: out2 = self.cbam2(out2) if self.use_residual: identity2 = self.shortcut2(identity2) out2 += identity2 out2 = nn.ReLU(True)(out2) out2 = self.maxpool(out2) x = self.avgpool(out2) x = x.view(x.size(0), -1) x = self.dropout(x) x = self.fc(x) return x import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader from tqdm import tqdm from models.network import InceptionNet def get_cifar10_loaders(batch_size=128): transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) return trainloader, testloader def train(epoch, net, trainloader, optimizer, criterion, device): net.train() train_loss = 0 correct = 0 total = 0 progress_bar = tqdm(trainloader, desc=f'Epoch {epoch:03d}') for batch_idx, (inputs, targets) in enumerate(progress_bar): inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() train_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() progress_bar.set_postfix(loss=f'{train_loss/(batch_idx+1):.3f}', acc=f'{100.*correct/total:.3f}%') def test(epoch, net, testloader, criterion, device): net.eval() test_loss = 0 correct = 0 total = 0 with torch.no_grad(): progress_bar = tqdm(testloader, desc=f'Test Epoch {epoch:03d}') for batch_idx, (inputs, targets) in enumerate(progress_bar): inputs, targets = inputs.to(device), targets.to(device) outputs = net(inputs) loss = criterion(outputs, targets) test_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() progress_bar.set_postfix(loss=f'{test_loss/(batch_idx+1):.3f}', acc=f'{100.*correct/total:.3f}%') acc = 100.*correct/total return acc def run_experiment(model_name, use_residual, use_cbam, epochs=50): print(f"\n{'='*30}") print(f"Running Experiment: {model_name}") print(f"{'='*30}") device = 'cuda' if torch.cuda.is_available() else 'cpu' # Data trainloader, testloader = get_cifar10_loaders() # Model net = InceptionNet(use_residual=use_residual, use_cbam=use_cbam) net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) best_acc = 0 for epoch in range(epochs): train(epoch, net, trainloader, optimizer, criterion, device) acc = test(epoch, net, testloader, criterion, device) scheduler.step() if acc > best_acc: best_acc = acc print(f"\nBest Test Accuracy for {model_name}: {best_acc:.2f}%\n") return best_acc from trainer import run_experiment if __name__ == '__main__': # 1. 对inception网络在cifar10上观察精度 run_experiment( model_name="Baseline InceptionNet", use_residual=False, use_cbam=False, epochs=50 # 您可以调整训练轮数 ) # 2. 消融实验:引入残差机制 run_experiment( model_name="InceptionNet with Residuals", use_residual=True, use_cbam=False, epochs=50 ) # 3. 消融实验:引入CBAM模块 run_experiment( model_name="InceptionNet with CBAM", use_residual=False, use_cbam=True, epochs=50 ) # 4. (可选) 消融实验:同时引入残差和CBAM run_experiment( model_name="InceptionNet with Residuals and CBAM", use_residual=True, use_cbam=True, epochs=50 )
DAY 54 Inception网络及其思考
冬天给予的预感2025-07-09 6:09
相关推荐
lucky_lyovo5 分钟前
深度学习--tensor(创建、属性)李加号pluuuus7 分钟前
【论文阅读】CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer陈敬雷-充电了么-CEO兼CTO16 分钟前
复杂任务攻坚:多模态大模型推理技术从 CoT 数据到 RL 优化的突破之路Bruce_Liuxiaowei26 分钟前
Netstat高级分析工具:Windows与Linux双系统兼容的精准筛查利器盼小辉丶35 分钟前
TensorFlow深度学习实战——基于自编码器构建句子向量iFulling41 分钟前
【计算机网络】第三章:数据链路层(下)YOLO大师42 分钟前
华为OD机试 2025B卷 - 小明减肥(C++&Python&JAVA&JS&C语言)xiao5kou4chang6kai41 小时前
【Python-GEE】如何利用Landsat时间序列影像通过调和回归方法提取农作物特征并进行分类kaikaile19951 小时前
使用Python进行数据可视化的初学者指南Par@ish1 小时前
【网络安全】恶意 Python 包“psslib”仿冒 passlib,可导致 Windows 系统关闭