Inception【代码详解Pytorch version】

"""Contains the Inception V3 model, which is used for inference ONLY.

This file is mostly borrowed from torchvision/models/inception.py.

Inception model is widely used to compute FID or IS metric for evaluating

generative models. However, the pre-trained models from torchvision is slightly

different from the TensorFlow version

http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz

which is used by the official FID implementation

https://github.com/bioinf-jku/TTUR

In particular:

(1) The number of classes in TensorFlow model is 1008 instead of 1000.

(2) The avg_pool() layers in TensorFlow model does not include the padded zero.

(3) The last Inception E Block in TensorFlow model use max_pool() instead of

avg_pool().

Hence, to align the evaluation results with those from TensorFlow

implementation, we modified the inception model to support both versions. Please

use align_tf argument to control the version.

"""

import warnings

import torch

import torch.nn as nn

import torch.nn.functional as F

import torch.distributed as dist

from utils.misc import download_url

all = 'InceptionModel'

_MODEL_URL_SHA256 = {

This model is provided by torchvision, which is ported from TensorFlow.

'torchvision_official': (

'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',

'1a9a5a14f40645a370184bd54f4e8e631351e71399112b43ad0294a79da290c8' # hash sha256

),

复制代码
# This model is provided by https://github.com/mseitzer/pytorch-fid
'tf_inception_v3': (
    'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth',
    '6726825d0af5f729cebd5821db510b11b1cfad8faad88a03f1befd49fb9129b2'  # hash sha256
)

}

class InceptionModel(object):

"""Defines the Inception (V3) model.

复制代码
This is a static class, which is used to avoid this model to be built
repeatedly. Consequently, this model is particularly used for inference,
like computing FID. If training is required, please use the model from
`torchvision.models` or implement by yourself.

NOTE: The pre-trained model assumes the inputs to be with `RGB` channel
order and pixel range [-1, 1], and will also resize the images to shape
[299, 299] automatically. If your input is normalized by subtracting
(0.485, 0.456, 0.406) and dividing (0.229, 0.224, 0.225), please use
`transform_input` in the `forward()` function to un-normalize it.
"""
models = dict()

@staticmethod
def build_model(align_tf=True):
    """Builds the model and load pre-trained weights.

    If `align_tf` is set as True, the model will predict 1008 classes, and
    the pre-trained weight from `https://github.com/mseitzer/pytorch-fid`
    will be loaded. Otherwise, the model will predict 1000 classes, and will
    load the model from `torchvision`.

    The built model supports following arguments when forwarding:

    - transform_input: Whether to transform the input back to pixel range
        (-1, 1). Please disable this argument if your input is already with
        pixel range (-1, 1). (default: False)
    - output_logits: Whether to output the categorical logits instead of
        features. (default: False)
    - remove_logits_bias: Whether to remove the bias when computing the
        logits. The official implementation removes the bias by default.
        Please refer to
        `https://github.com/openai/improved-gan/blob/master/inception_score/model.py`.
        (default: False)
    - output_predictions: Whether to output the final predictions, i.e.,
        `softmax(logits)`. (default: False)
    """
    if align_tf:
        num_classes = 1008
        model_source = 'tf_inception_v3'
    else:
        num_classes = 1000
        model_source = 'torchvision_official'

    fingerprint = model_source

    if fingerprint not in InceptionModel.models:
        # Build model.
        model = Inception3(num_classes=num_classes,
                           aux_logits=False,
                           init_weights=False,
                           align_tf=align_tf)

        # Download pre-trained weights.
        if dist.is_initialized() and dist.get_rank() != 0:
            dist.barrier()  # Download by chief.

        url, sha256 = _MODEL_URL_SHA256[model_source]
        filename = f'inception_model_{model_source}_{sha256}.pth'
        model_path, hash_check = download_url(url,
                                              filename=filename,
                                              sha256=sha256)
        state_dict = torch.load(model_path, map_location='cpu')
        if hash_check is False:
            warnings.warn(f'Hash check failed! The remote file from URL '
                          f'`{url}` may be changed, or the downloading is '
                          f'interrupted. The loaded inception model may '
                          f'have unexpected behavior.')

        if dist.is_initialized() and dist.get_rank() == 0:
            dist.barrier()  # Wait for other replicas.

        # Load weights.
        model.load_state_dict(state_dict, strict=False)
        del state_dict

        # For inference only.
        model.eval().requires_grad_(False).cuda()
        InceptionModel.models[fingerprint] = model

    return InceptionModel.models[fingerprint]

class Inception3(nn.Module):

复制代码
def __init__(self, num_classes=1000, aux_logits=True, inception_blocks=None,
             init_weights=True, align_tf=True):
    super(Inception3, self).__init__()
    if inception_blocks is None:
        inception_blocks = [
            BasicConv2d, InceptionA, InceptionB, InceptionC,
            InceptionD, InceptionE, InceptionAux
        ]
    assert len(inception_blocks) == 7
    conv_block = inception_blocks[0]
    inception_a = inception_blocks[1]
    inception_b = inception_blocks[2]
    inception_c = inception_blocks[3]
    inception_d = inception_blocks[4]
    inception_e = inception_blocks[5]
    inception_aux = inception_blocks[6]

    self.aux_logits = aux_logits
    self.align_tf = align_tf
    self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
    self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
    self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
    self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
    self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
    self.Mixed_5b = inception_a(192, pool_features=32, align_tf=self.align_tf)
    self.Mixed_5c = inception_a(256, pool_features=64, align_tf=self.align_tf)
    self.Mixed_5d = inception_a(288, pool_features=64, align_tf=self.align_tf)
    self.Mixed_6a = inception_b(288)
    self.Mixed_6b = inception_c(768, channels_7x7=128, align_tf=self.align_tf)
    self.Mixed_6c = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
    self.Mixed_6d = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
    self.Mixed_6e = inception_c(768, channels_7x7=192, align_tf=self.align_tf)
    if aux_logits:
        self.AuxLogits = inception_aux(768, num_classes)
    self.Mixed_7a = inception_d(768)
    self.Mixed_7b = inception_e(1280, align_tf=self.align_tf)
    self.Mixed_7c = inception_e(2048, use_max_pool=self.align_tf)
    self.fc = nn.Linear(2048, num_classes)
    if init_weights:
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                import scipy.stats as stats
                stddev = m.stddev if hasattr(m, 'stddev') else 0.1
                X = stats.truncnorm(-2, 2, scale=stddev)
                values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
                values = values.view(m.weight.size())
                with torch.no_grad():
                    m.weight.copy_(values)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

@staticmethod
def _transform_input(x, transform_input=False):
    if transform_input:
        x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
        x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
        x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
        x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
    return x

def _forward(self,
             x,
             output_logits=False,
             remove_logits_bias=False,
             output_predictions=False):
    # Upsample if necessary.
    if x.shape[2] != 299 or x.shape[3] != 299:
        if self.align_tf:
            theta = torch.eye(2, 3).to(x)
            theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 299
            theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 299
            theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1)
            grid = F.affine_grid(theta,
                                 size=(x.shape[0], x.shape[1], 299, 299),
                                 align_corners=False)
            x = F.grid_sample(x, grid,
                              mode='bilinear',
                              padding_mode='border',
                              align_corners=False)
        else:
            x = F.interpolate(
                x, size=(299, 299), mode='bilinear', align_corners=False)
    if x.shape[1] == 1:
        x = x.repeat((1, 3, 1, 1))

    if self.align_tf:
        x = (x * 127.5 + 127.5 - 128) / 128

    # N x 3 x 299 x 299
    x = self.Conv2d_1a_3x3(x)
    # N x 32 x 149 x 149
    x = self.Conv2d_2a_3x3(x)
    # N x 32 x 147 x 147
    x = self.Conv2d_2b_3x3(x)
    # N x 64 x 147 x 147
    x = F.max_pool2d(x, kernel_size=3, stride=2)
    # N x 64 x 73 x 73
    x = self.Conv2d_3b_1x1(x)
    # N x 80 x 73 x 73
    x = self.Conv2d_4a_3x3(x)
    # N x 192 x 71 x 71
    x = F.max_pool2d(x, kernel_size=3, stride=2)
    # N x 192 x 35 x 35
    x = self.Mixed_5b(x)
    # N x 256 x 35 x 35
    x = self.Mixed_5c(x)
    # N x 288 x 35 x 35
    x = self.Mixed_5d(x)
    # N x 288 x 35 x 35
    x = self.Mixed_6a(x)
    # N x 768 x 17 x 17
    x = self.Mixed_6b(x)
    # N x 768 x 17 x 17
    x = self.Mixed_6c(x)
    # N x 768 x 17 x 17
    x = self.Mixed_6d(x)
    # N x 768 x 17 x 17
    x = self.Mixed_6e(x)
    # N x 768 x 17 x 17
    if self.training and self.aux_logits:
        aux = self.AuxLogits(x)
    else:
        aux = None
    # N x 768 x 17 x 17
    x = self.Mixed_7a(x)
    # N x 1280 x 8 x 8
    x = self.Mixed_7b(x)
    # N x 2048 x 8 x 8
    x = self.Mixed_7c(x)
    # N x 2048 x 8 x 8
    # Adaptive average pooling
    x = F.adaptive_avg_pool2d(x, (1, 1))
    # N x 2048 x 1 x 1
    x = F.dropout(x, training=self.training)
    # N x 2048 x 1 x 1
    x = torch.flatten(x, 1)
    # N x 2048
    if output_logits or output_predictions:
        x = self.fc(x)
        # N x 1000 (num_classes)
        if remove_logits_bias:
            x = x - self.fc.bias.view(1, -1)
        if output_predictions:
            x = F.softmax(x, dim=1)
    return x, aux

def forward(self,
            x,
            transform_input=False,
            output_logits=False,
            remove_logits_bias=False,
            output_predictions=False):
    x = self._transform_input(x, transform_input)
    x, aux = self._forward(
        x, output_logits, remove_logits_bias, output_predictions)
    if self.training and self.aux_logits:
        return x, aux
    else:
        return x

class InceptionA(nn.Module):

复制代码
def __init__(self, in_channels, pool_features, conv_block=None, align_tf=False):
    super(InceptionA, self).__init__()
    if conv_block is None:
        conv_block = BasicConv2d
    self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)

    self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
    self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)

    self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
    self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
    self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)

    self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
    self.pool_include_padding = not align_tf

def _forward(self, x):
    branch1x1 = self.branch1x1(x)

    branch5x5 = self.branch5x5_1(x)
    branch5x5 = self.branch5x5_2(branch5x5)

    branch3x3dbl = self.branch3x3dbl_1(x)
    branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
    branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

    branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                               count_include_pad=self.pool_include_padding)
    branch_pool = self.branch_pool(branch_pool)

    outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
    return outputs

def forward(self, x):
    outputs = self._forward(x)
    return torch.cat(outputs, 1)

class InceptionB(nn.Module):

复制代码
def __init__(self, in_channels, conv_block=None):
    super(InceptionB, self).__init__()
    if conv_block is None:
        conv_block = BasicConv2d
    self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)

    self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
    self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
    self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)

def _forward(self, x):
    branch3x3 = self.branch3x3(x)

    branch3x3dbl = self.branch3x3dbl_1(x)
    branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
    branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

    branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)

    outputs = [branch3x3, branch3x3dbl, branch_pool]
    return outputs

def forward(self, x):
    outputs = self._forward(x)
    return torch.cat(outputs, 1)

class InceptionC(nn.Module):

复制代码
def __init__(self, in_channels, channels_7x7, conv_block=None, align_tf=False):
    super(InceptionC, self).__init__()
    if conv_block is None:
        conv_block = BasicConv2d
    self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)

    c7 = channels_7x7
    self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
    self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
    self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))

    self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
    self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
    self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
    self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
    self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))

    self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
    self.pool_include_padding = not align_tf

def _forward(self, x):
    branch1x1 = self.branch1x1(x)

    branch7x7 = self.branch7x7_1(x)
    branch7x7 = self.branch7x7_2(branch7x7)
    branch7x7 = self.branch7x7_3(branch7x7)

    branch7x7dbl = self.branch7x7dbl_1(x)
    branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
    branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
    branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
    branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

    branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                               count_include_pad=self.pool_include_padding)
    branch_pool = self.branch_pool(branch_pool)

    outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
    return outputs

def forward(self, x):
    outputs = self._forward(x)
    return torch.cat(outputs, 1)

class InceptionD(nn.Module):

复制代码
def __init__(self, in_channels, conv_block=None):
    super(InceptionD, self).__init__()
    if conv_block is None:
        conv_block = BasicConv2d
    self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
    self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)

    self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
    self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
    self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
    self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)

def _forward(self, x):
    branch3x3 = self.branch3x3_1(x)
    branch3x3 = self.branch3x3_2(branch3x3)

    branch7x7x3 = self.branch7x7x3_1(x)
    branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
    branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
    branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

    branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
    outputs = [branch3x3, branch7x7x3, branch_pool]
    return outputs

def forward(self, x):
    outputs = self._forward(x)
    return torch.cat(outputs, 1)

class InceptionE(nn.Module):

复制代码
def __init__(self, in_channels, conv_block=None, align_tf=False, use_max_pool=False):
    super(InceptionE, self).__init__()
    if conv_block is None:
        conv_block = BasicConv2d
    self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)

    self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
    self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
    self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))

    self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
    self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
    self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
    self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))

    self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
    self.pool_include_padding = not align_tf
    self.use_max_pool = use_max_pool

def _forward(self, x):
    branch1x1 = self.branch1x1(x)

    branch3x3 = self.branch3x3_1(x)
    branch3x3 = [
        self.branch3x3_2a(branch3x3),
        self.branch3x3_2b(branch3x3),
    ]
    branch3x3 = torch.cat(branch3x3, 1)

    branch3x3dbl = self.branch3x3dbl_1(x)
    branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
    branch3x3dbl = [
        self.branch3x3dbl_3a(branch3x3dbl),
        self.branch3x3dbl_3b(branch3x3dbl),
    ]
    branch3x3dbl = torch.cat(branch3x3dbl, 1)

    if self.use_max_pool:
        branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
    else:
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                   count_include_pad=self.pool_include_padding)
    branch_pool = self.branch_pool(branch_pool)

    outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
    return outputs

def forward(self, x):
    outputs = self._forward(x)
    return torch.cat(outputs, 1)

class InceptionAux(nn.Module):

复制代码
def __init__(self, in_channels, num_classes, conv_block=None):
    super(InceptionAux, self).__init__()
    if conv_block is None:
        conv_block = BasicConv2d
    self.conv0 = conv_block(in_channels, 128, kernel_size=1)
    self.conv1 = conv_block(128, 768, kernel_size=5)
    self.conv1.stddev = 0.01
    self.fc = nn.Linear(768, num_classes)
    self.fc.stddev = 0.001

def forward(self, x):
    # N x 768 x 17 x 17
    x = F.avg_pool2d(x, kernel_size=5, stride=3)
    # N x 768 x 5 x 5
    x = self.conv0(x)
    # N x 128 x 5 x 5
    x = self.conv1(x)
    # N x 768 x 1 x 1
    # Adaptive average pooling
    x = F.adaptive_avg_pool2d(x, (1, 1))
    # N x 768 x 1 x 1
    x = torch.flatten(x, 1)
    # N x 768
    x = self.fc(x)
    # N x 1000
    return x

class BasicConv2d(nn.Module):

复制代码
def __init__(self, in_channels, out_channels, **kwargs):
    super(BasicConv2d, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
    self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    return F.relu(x, inplace=True
相关推荐
TYUT_xiaoming几秒前
yolo模型训练
人工智能·python·yolo
2301_780789666 分钟前
零信任架构中,身份感知防火墙(IAFW)的部署要点与最佳实践
linux·运维·服务器·人工智能·tcp/ip·架构
MicroTech20257 分钟前
业绩披露|微算法科技(MLGO)2025年净利润1.27亿元
大数据·人工智能·科技
百度Geek说7 分钟前
Superpowers:给 Claude Code 装上“工程大脑”
人工智能
AGIPlayer7 分钟前
没有生态的大模型不算前沿
大数据·人工智能·物联网
lulu121654407812 分钟前
OpenRouter Fusion 多模型融合架构深度拆解:预算级模型组团打平 Fable 5,多模型协作才是 AGI 的正确打开方式?
java·人工智能·架构·ai编程·agi
恋猫de小郭18 分钟前
Redis 作者反驳「中国模型之所以强,是因为通过 API 蒸馏了美国模型」
前端·人工智能·ai编程
林间码客28 分钟前
04 ROC曲线与AUC:从零开始手动计算
大数据·人工智能·算法
codexu31 分钟前
NoteGen 里一条记录如何变成 Markdown
人工智能
澄旭39 分钟前
拆解一个成熟 Skill,看懂 Skill 到底该怎么写
人工智能