DAY 54 Inception网络及其思考

  1. 对inception网络在cifar10上观察精度

  2. 消融实验:引入残差机制和cbam模块分别进行消融

    复制代码
    import torch
    import torch.nn as nn
    
    class ChannelAttention(nn.Module):
        def __init__(self, in_planes, ratio=16):
            super(ChannelAttention, self).__init__()
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
            self.max_pool = nn.AdaptiveMaxPool2d(1)
    
            self.fc = nn.Sequential(
                nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
                nn.ReLU(),
                nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
            )
            self.sigmoid = nn.Sigmoid()
    
        def forward(self, x):
            avg_out = self.fc(self.avg_pool(x))
            max_out = self.fc(self.max_pool(x))
            out = avg_out + max_out
            return self.sigmoid(out)
    
    class SpatialAttention(nn.Module):
        def __init__(self, kernel_size=7):
            super(SpatialAttention, self).__init__()
            assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
            padding = 3 if kernel_size == 7 else 1
    
            self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
            self.sigmoid = nn.Sigmoid()
    
        def forward(self, x):
            avg_out = torch.mean(x, dim=1, keepdim=True)
            max_out, _ = torch.max(x, dim=1, keepdim=True)
            x = torch.cat([avg_out, max_out], dim=1)
            x = self.conv1(x)
            return self.sigmoid(x)
    
    class CBAM(nn.Module):
        def __init__(self, in_planes, ratio=16, kernel_size=7):
            super(CBAM, self).__init__()
            self.ca = ChannelAttention(in_planes, ratio)
            self.sa = SpatialAttention(kernel_size)
    
        def forward(self, x):
            x = x * self.ca(x)
            x = x * self.sa(x)
            return x
    import torch
    import torch.nn as nn
    
    class InceptionModule(nn.Module):
        def __init__(self, in_channels, n1x1, n3x3_reduce, n3x3, n5x5_reduce, n5x5, pool_proj):
            super(InceptionModule, self).__init__()
    
            # 1x1 conv branch
            self.b1 = nn.Sequential(
                nn.Conv2d(in_channels, n1x1, kernel_size=1),
                nn.BatchNorm2d(n1x1),
                nn.ReLU(True),
            )
    
            # 1x1 -> 3x3 conv branch
            self.b2 = nn.Sequential(
                nn.Conv2d(in_channels, n3x3_reduce, kernel_size=1),
                nn.BatchNorm2d(n3x3_reduce),
                nn.ReLU(True),
                nn.Conv2d(n3x3_reduce, n3x3, kernel_size=3, padding=1),
                nn.BatchNorm2d(n3x3),
                nn.ReLU(True),
            )
    
            # 1x1 -> 5x5 conv branch
            self.b3 = nn.Sequential(
                nn.Conv2d(in_channels, n5x5_reduce, kernel_size=1),
                nn.BatchNorm2d(n5x5_reduce),
                nn.ReLU(True),
                nn.Conv2d(n5x5_reduce, n5x5, kernel_size=5, padding=2),
                nn.BatchNorm2d(n5x5),
                nn.ReLU(True),
            )
    
            # 3x3 pool -> 1x1 conv branch
            self.b4 = nn.Sequential(
                nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
                nn.Conv2d(in_channels, pool_proj, kernel_size=1),
                nn.BatchNorm2d(pool_proj),
                nn.ReLU(True),
            )
    
        def forward(self, x):
            return torch.cat([self.b1(x), self.b2(x), self.b3(x), self.b4(x)], 1)
    import torch.nn as nn
    from .inception import InceptionModule
    from .cbam import CBAM
    
    class InceptionNet(nn.Module):
        def __init__(self, use_residual=False, use_cbam=False):
            super(InceptionNet, self).__init__()
            self.use_residual = use_residual
            self.use_cbam = use_cbam
    
            self.pre_layers = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(True),
            )
    
            self.a3 = InceptionModule(64, 64, 96, 128, 16, 32, 32) # out: 256
            self.b3 = InceptionModule(256, 128, 128, 192, 32, 96, 64) # out: 480
    
            if self.use_cbam:
                self.cbam1 = CBAM(256)
                self.cbam2 = CBAM(480)
    
            if self.use_residual:
                self.shortcut1 = nn.Sequential(
                    nn.Conv2d(64, 256, kernel_size=1),
                    nn.BatchNorm2d(256)
                )
                self.shortcut2 = nn.Sequential(
                    nn.Conv2d(256, 480, kernel_size=1),
                    nn.BatchNorm2d(480)
                )
    
            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.dropout = nn.Dropout(0.4)
            self.fc = nn.Linear(480, 10)
    
        def forward(self, x):
            x = self.pre_layers(x)
            
            # Inception block 1
            identity1 = x
            out1 = self.a3(x)
            if self.use_cbam:
                out1 = self.cbam1(out1)
            if self.use_residual:
                identity1 = self.shortcut1(identity1)
                out1 += identity1
            out1 = nn.ReLU(True)(out1)
            out1 = self.maxpool(out1)
    
            # Inception block 2
            identity2 = out1
            out2 = self.b3(out1)
            if self.use_cbam:
                out2 = self.cbam2(out2)
            if self.use_residual:
                identity2 = self.shortcut2(identity2)
                out2 += identity2
            out2 = nn.ReLU(True)(out2)
            out2 = self.maxpool(out2)
    
            x = self.avgpool(out2)
            x = x.view(x.size(0), -1)
            x = self.dropout(x)
            x = self.fc(x)
            return x
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision
    import torchvision.transforms as transforms
    from torch.utils.data import DataLoader
    from tqdm import tqdm
    
    from models.network import InceptionNet
    
    def get_cifar10_loaders(batch_size=128):
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
        trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
        testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
        
        return trainloader, testloader
    
    def train(epoch, net, trainloader, optimizer, criterion, device):
        net.train()
        train_loss = 0
        correct = 0
        total = 0
        
        progress_bar = tqdm(trainloader, desc=f'Epoch {epoch:03d}')
        for batch_idx, (inputs, targets) in enumerate(progress_bar):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
    
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
            progress_bar.set_postfix(loss=f'{train_loss/(batch_idx+1):.3f}', acc=f'{100.*correct/total:.3f}%')
    
    def test(epoch, net, testloader, criterion, device):
        net.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            progress_bar = tqdm(testloader, desc=f'Test Epoch {epoch:03d}')
            for batch_idx, (inputs, targets) in enumerate(progress_bar):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)
    
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                
                progress_bar.set_postfix(loss=f'{test_loss/(batch_idx+1):.3f}', acc=f'{100.*correct/total:.3f}%')
        
        acc = 100.*correct/total
        return acc
    
    def run_experiment(model_name, use_residual, use_cbam, epochs=50):
        print(f"\n{'='*30}")
        print(f"Running Experiment: {model_name}")
        print(f"{'='*30}")
    
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        # Data
        trainloader, testloader = get_cifar10_loaders()
        
        # Model
        net = InceptionNet(use_residual=use_residual, use_cbam=use_cbam)
        net = net.to(device)
        if device == 'cuda':
            net = torch.nn.DataParallel(net)
    
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
        best_acc = 0
        for epoch in range(epochs):
            train(epoch, net, trainloader, optimizer, criterion, device)
            acc = test(epoch, net, testloader, criterion, device)
            scheduler.step()
            
            if acc > best_acc:
                best_acc = acc
    
        print(f"\nBest Test Accuracy for {model_name}: {best_acc:.2f}%\n")
        return best_acc
    from trainer import run_experiment
    
    if __name__ == '__main__':
        # 1. 对inception网络在cifar10上观察精度
        run_experiment(
            model_name="Baseline InceptionNet", 
            use_residual=False, 
            use_cbam=False,
            epochs=50 # 您可以调整训练轮数
        )
    
        # 2. 消融实验:引入残差机制
        run_experiment(
            model_name="InceptionNet with Residuals", 
            use_residual=True, 
            use_cbam=False,
            epochs=50
        )
    
        # 3. 消融实验:引入CBAM模块
        run_experiment(
            model_name="InceptionNet with CBAM", 
            use_residual=False, 
            use_cbam=True,
            epochs=50
        )
        
        # 4. (可选) 消融实验:同时引入残差和CBAM
        run_experiment(
            model_name="InceptionNet with Residuals and CBAM",
            use_residual=True,
            use_cbam=True,
            epochs=50
        )
相关推荐
极客范儿1 分钟前
新华三H3CNE网络工程师认证—Telnet
网络·ssh·telnet
Σdoughty2 分钟前
HCIP---MGRE实验
网络
ZY小袁4 分钟前
MGRE综合实验
服务器·网络·笔记·网络安全·学习方法·信息与通信·p2p
呆头鹅AI工作室7 分钟前
[2025CVPR-图象分类方向]CATANet:用于轻量级图像超分辨率的高效内容感知标记聚合
图像处理·人工智能·深度学习·目标检测·机器学习·计算机视觉·分类
向左转, 向右走ˉ14 分钟前
为什么分类任务偏爱交叉熵?MSE 为何折戟?
人工智能·深度学习·算法·机器学习·分类·数据挖掘
zzywxc7872 小时前
编程算法在金融、医疗、教育、制造业的落地应用。
人工智能·深度学习·算法·机器学习·金融·架构·开源
conkl3 小时前
构建 P2P 网络与分布式下载系统:从底层原理到安装和功能实现
linux·运维·网络·分布式·网络协议·算法·p2p
笙囧同学5 小时前
基于大数据技术的疾病预警系统:从数据预处理到机器学习的完整实践(后附下载链接)
大数据·网络·机器学习
盖雅工场7 小时前
零工合规挑战:盖雅以智能安全体系重构企业用工风控
网络·安全·重构
仰望星空的凡人7 小时前
【JS逆向基础】数据库之MongoDB
javascript·数据库·python·mongodb