深度学习任务分类与示例(一)


文章目录

  • 一、深度学习图像任务的完整分类
    • [1. 感知层(Low-level Vision)](#1. 感知层(Low-level Vision))
    • [2. 理解层(Middle-level Vision)](#2. 理解层(Middle-level Vision))
    • [3. 高级理解层(High-level Vision)](#3. 高级理解层(High-level Vision))
    • [4. 决策层(Decision-making)](#4. 决策层(Decision-making))
  • 二、图像任务的关系与层次结构
    • [1. 任务层次关系](#1. 任务层次关系)
    • [2. 技术依赖关系](#2. 技术依赖关系)
    • [3. 模型架构演进关系](#3. 模型架构演进关系)
  • 三、图像任务的输入输出机制与示例
    • [1. 图像分类](#1. 图像分类)
    • [2. 目标检测](#2. 目标检测)
    • [3. 语义分割](#3. 语义分割)
    • [4. 实例分割](#4. 实例分割)
    • [5. 图像生成](#5. 图像生成)

一、深度学习图像任务的完整分类

根据计算机视觉任务的层次结构和功能特性,我们可以将深度学习图像任务分为四大层次,每个层次包含多个具体任务:

1. 感知层(Low-level Vision)

感知层任务关注图像的基础处理和特征提取,解决图像质量改善和基础表示问题:

  • 图像重建:修复破损或缺失部分的图像(如医学图像修复)
  • 图像去噪:去除图像中的噪声,提高图像质量
  • 图像超分辨率:从低分辨率图像生成高分辨率图像
  • 图像配准:将不同来源或时间的图像对齐到同一坐标系
  • 图像压缩:以较少的比特表示原像素矩阵,减少存储和传输开销
  • 图像隐写:在图像中隐藏秘密信息,如LSB技术

2. 理解层(Middle-level Vision)

理解层任务关注图像内容的识别和定位,建立对图像的初步认知:

  • 图像分类:判断图像的整体类别(如识别手写数字0-9)
  • 目标检测:在图像中定位并识别多个物体
  • 语义分割:为图像中每个像素分配语义类别标签
  • 实例分割:区分图像中同一类别的不同实例
  • 目标跟踪:在视频序列中持续追踪特定目标
  • 图像生成:基于模型生成新的图像内容(如GAN、VAE)
  • 图像翻译:将图像从一个域转换到另一个域(如CycleGAN)
  • 图像风格迁移:将一张图像的风格应用到另一张图像上

3. 高级理解层(High-level Vision)

高级理解层任务关注图像内容的深度理解和推理:

  • 视觉关系检测:识别图像中物体之间的关系或相互作用
  • 视觉问答(VQA):根据图像和问题生成文本答案
  • 图像描述生成:为图像生成自然语言描述
  • 3D视觉任务
    • 3D目标检测
    • 3D场景重建(如NeRF)
    • 3D位姿估计(如手部姿态估计)
    • 3D视觉定位(如点云中的物体定位)
  • 医学影像特有任务
    • 病灶分割(如U-Net用于肿瘤分割)
    • 医学图像重建(如CT去伪影)
    • 医学图像配准

4. 决策层(Decision-making)

决策层任务关注基于视觉理解的决策和行动规划:

  • 自动驾驶感知:环境理解与路径规划
  • 机器人视觉:物体抓取与场景交互
  • 视频分析
    • 视频分类
    • 视频目标检测
    • 行为识别
    • 视频超分辨率
  • 多模态任务融合:结合视觉、文本、语音等多模态信息进行推理

二、图像任务的关系与层次结构

1. 任务层次关系

计算机视觉任务形成了清晰的层次结构,从低级到高级逐步提升认知能力:

图像像素 → 感知层 → 理解层 → 高级理解层 → 决策层

感知层 是视觉处理的基础,通过改进图像质量(如去噪、超分辨率)和提取基础特征,为后续任务提供良好输入。理解层 在此基础上进行物体识别和定位,如分类、检测和分割任务。高级理解层 则利用理解层的结果,进行更复杂的场景理解、关系推理和3D重建。决策层将视觉理解与其他模态信息结合,做出最终决策。

2. 技术依赖关系

许多高级任务直接依赖于基础任务的输出:

  • 实例分割 依赖于目标检测,在检测框基础上生成像素级掩码
  • 视觉关系检测 需要目标检测的结果作为输入,识别物体间关系
  • 3D位姿估计 通常需要目标检测实例分割的2D位置信息
  • 视觉问答 需要结合目标检测语义分割和自然语言处理
  • 视频目标检测图像目标检测的时序扩展,需处理连续帧序列

3. 模型架构演进关系

深度学习模型架构也呈现从基础到高级的演进关系:

  • 全连接网络(如用于MNIST分类)是最简单的网络结构,将图像展平为一维向量处理
  • 卷积神经网络(CNN)(如ResNet、VGG)是图像处理的核心架构,通过局部感受野提取空间特征
  • Transformer架构(如DETR、Swin Transformer)提供了全局依赖建模能力,用于复杂场景理解
  • 扩散模型(如Stable Diffusion)和**神经辐射场(NeRF)**代表了最新的生成和3D重建技术

三、图像任务的输入输出机制与示例

1. 图像分类

任务定义:将图像分为不同的类别,如识别手写数字0-9。

简单示例:全连接网络处理MNIST数据集

python 复制代码
import torch
import torch.nn as nn

class SimpleClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 128)  # 输入层:784维 → 隐藏层:128维
        self.relu = nn.ReLU()               # ReLU激活函数
        self.fc2 = nn.Linear(128, 10)      # 输出层:128维 → 10个类别

    def forward(self, x):
        x = x.view(-1, 28*28)  # 将28x28图像展平为784维向量
        x = self.relu(self.fc1(x))  # 隐藏层前向传播
        x = self.fc2(x)               # 输出层前向传播
        return x

示例输入输出
image = torch.randn(1, 1, 28, 28)  # MNIST图像形状为(1,28,28)
model = SimpleClassifier()
output = model(image)  # 输出形状为(1,10),表示每个数字的概率
print(output.shape)   # torch.Size([1, 10])

**流行算法示例**:ResNet(残差网络)

import torchvision.models as models

加载预训练的ResNet-50模型
model = models.resnet50(pretrained=True)

示例输入输出
image = torch.randn(1, 3, 224, 224)  # 彩色图像形状为(3,224,224)
output = model(image)  # 输出形状为(1,1000),表示ImageNet数据集的1000个类别的概率
print(output.shape)   # torch.Size([1, 1000])

2. 目标检测

任务定义:在图像中定位并识别多个物体,输出边界框及其类别标签。

简单示例:目标检测的基本结构

python 复制代码
import torch
import torch.nn as nn

class SimpleYOLO(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2),
            nn.ReLU()
        )
        self.detection_head = nn.Conv2d(32, num_classes*5, kernel_size=1)  # 5个输出:x,y,w,h,confidence

    def forward(self, x):
        features = self.backbone(x)  # 提取特征
        predictions = self.detection_head(features)  # 预测边界框和类别
        return predictions

3. 语义分割

任务定义:为图像中每个像素分配语义类别标签,如将图像中的每个像素分为"人"、"车"或"背景"等类别。

简单示例:UNet的基本结构

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """两次连续的卷积层(卷积+BN+ReLU)"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=2):
        """
        参数:
            in_channels: 输入通道数 (RGB=3, 灰度=1)
            out_channels: 输出通道数 (等于类别数)
        """
        super().__init__()
        
        # 编码器 (下采样路径)
        self.encoder1 = DoubleConv(in_channels, 64)    # 1/1
        self.encoder2 = DoubleConv(64, 128)            # 1/2
        self.encoder3 = DoubleConv(128, 256)           # 1/4
        self.encoder4 = DoubleConv(256, 512)           # 1/8
        
        # 下采样
        self.pool = nn.MaxPool2d(2)
        
        # 底部
        self.bottleneck = DoubleConv(512, 1024)       # 1/16
        
        # 解码器 (上采样路径)
        self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.decoder4 = DoubleConv(1024, 512)         # 1024=上采样512+跳跃连接512
        
        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.decoder3 = DoubleConv(512, 256)          # 512=上采样256+跳跃连接256
        
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.decoder2 = DoubleConv(256, 128)          # 256=上采样128+跳跃连接128
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.decoder1 = DoubleConv(128, 64)           # 128=上采样64+跳跃连接64
        
        # 输出层
        self.output = nn.Conv2d(64, out_channels, 1)  # 1x1卷积,将特征映射到类别数
    
    def forward(self, x):
        """前向传播过程"""
        # 保存编码器的特征图用于跳跃连接
        e1 = self.encoder1(x)      # 通道: 3->64
        e2 = self.encoder2(self.pool(e1))  # 下采样,通道: 64->128
        e3 = self.encoder3(self.pool(e2))  # 下采样,通道: 128->256
        e4 = self.encoder4(self.pool(e3))  # 下采样,通道: 256->512
        
        # 瓶颈层
        b = self.bottleneck(self.pool(e4))  # 下采样,通道: 512->1024
        
        # 解码器,通过跳跃连接融合特征
        d4 = self.upconv4(b)  # 上采样,1024->512
        d4 = torch.cat([e4, d4], dim=1)  # 跳跃连接:拼接e4和上采样结果
        d4 = self.decoder4(d4)  # 512+512=1024 -> 512
        
        d3 = self.upconv3(d4)  # 上采样,512->256
        d3 = torch.cat([e3, d3], dim=1)  # 跳跃连接
        d3 = self.decoder3(d3)  # 256+256=512 -> 256
        
        d2 = self.upconv2(d3)  # 上采样,256->128
        d2 = torch.cat([e2, d2], dim=1)  # 跳跃连接
        d2 = self.decoder2(d2)  # 128+128=256 -> 128
        
        d1 = self.upconv1(d2)  # 上采样,128->64
        d1 = torch.cat([e1, d1], dim=1)  # 跳跃连接
        d1 = self.decoder1(d1)  # 64+64=128 -> 64
        
        # 输出
        output = self.output(d1)  # 64 -> 类别数
        
        return output

参考文章

U-Net呈现为一个经典的"U"形编码器-解码器结构。左侧为编码器(收缩路径),它像是一个特征提取器,通过卷积和下采样逐步压缩图像尺寸、增加通道数,以捕获图像的深层语义信息。右侧为解码器(扩展路径),其过程相反,通过上采样和卷积逐步恢复图像的空间细节。连接左右两侧同层级特征的灰色跳跃连接是U-Net设计的精髓,它能将编码器捕捉到的精细纹理和位置信息直接传递给解码器,从而在恢复大尺寸图像时,能同时利用浅层的细节与深层的语义,实现精准的像素级定位。

图中不同颜色的箭头代表了核心操作:

蓝紫箭头 (3×3卷积 + ReLU):是特征提取的基本单元。以最左上角输入为例,一张572×572的图像,经过第一层"3×3卷积→ReLU"后,由于卷积操作的特性,尺寸会略微缩小,得到570×570的特征图;再经过一次相同操作,得到568×568的特征图。此过程不改变通道数的本质,而是通过多组卷积核来生成新的特征表达。

红色箭头 (2×2最大池化/下采样):在每提取两次特征后,它会将特征图尺寸减半(如从568×568变为284×284),此举在扩大后续卷积感受野的同时,实现了数据压缩和一定程度的平移不变性。

绿色箭头 (2×2反卷积/上采样):位于解码器,功能是进行"解码"以恢复分辨率。例如,将196×196的低分辨率特征图,通过2×2的反卷积核进行上采样,尺寸放大一倍至392×392。

灰色箭头 (裁剪与复制):即跳跃连接的具体实现。由于编码器每次卷积会导致特征图边界丢失,其输出尺寸(如568×568)通常大于解码器对应层上采样后的尺寸(如392×392)。因此,需要将编码器输出的特征图裁剪中心部分,再复制到解码器,与上采样结果在通道维度上进行拼接(如上例中,392×392×64 + 392×392×64 = 392×392×128),从而融合多尺度信息。

青蓝箭头(1×1卷积):位于网络最末端。它不改变特征图的空间尺寸,但负责将高维特征通道(如64通道)映射到目标类别数(如2通道),其输出388×388×2的每个空间位置都对应输入图像一个像素点的类别概率分布,最终实现语义分割。

4. 实例分割

任务定义:不仅识别图像中的物体类别,还需区分同一类别的不同实例。

简单示例:结合目标检测和语义分割的简单实例分割模型

python 复制代码
import torch
import torch.nn as nn
import numpy as np
import cv2
from scipy import ndimage
from skimage import measure, morphology
import matplotlib.pyplot as plt

class SimpleInstanceSegmentation(nn.Module):
    """
    基于距离变换的简单实例分割模型
    1. 语义分割:分离前景(细胞)和背景
    2. 距离变换:计算每个前景像素到最近背景像素的距离
    3. 分水岭分割:分离相互接触的细胞
    """
    
    def __init__(self, in_channels=3, out_channels=1):
        super().__init__()
        
        # 简单的编码器-解码器结构
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 16, 3, padding=1),  # 输入: 3通道RGB
            nn.ReLU(),
            nn.MaxPool2d(2),  # 下采样
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 2, stride=2),  # 上采样
            nn.ReLU(),
            nn.ConvTranspose2d(16, 8, 2, stride=2),    # 上采样
            nn.ReLU(),
            nn.Conv2d(8, out_channels, 1)  # 1x1卷积,输出分割掩码
        )
        
        # 可选:距离变换头
        self.distance_head = nn.Conv2d(32, 1, 1)  # 预测距离变换图
        
    def forward(self, x):
        """前向传播:语义分割"""
        features = self.encoder(x)
        semantic_mask = torch.sigmoid(self.decoder(features))  # 二值掩码
        
        # 可选:同时预测距离变换
        distance_map = self.distance_head(features)
        
        return semantic_mask, distance_map
    
    def separate_instances(self, semantic_mask, min_size=50):
        """
        将语义分割结果分离为各个实例
        参数:
            semantic_mask: 语义分割掩码 (H, W),值为0或1
            min_size: 最小实例大小(过滤小区域)
        返回:
            instance_mask: 实例掩码 (H, W),每个实例有唯一ID
        """
        if isinstance(semantic_mask, torch.Tensor):
            semantic_mask = semantic_mask.detach().cpu().numpy()
        
        # 确保是二值图像
        binary_mask = (semantic_mask > 0.5).astype(np.uint8)
        
        # 步骤1: 距离变换
        # 计算每个前景像素到最近背景像素的欧几里得距离
        distance = ndimage.distance_transform_edt(binary_mask)
        
        # 步骤2: 寻找局部最大值(细胞中心)
        # 使用最大值滤波找到局部最大值
        local_max = self._find_local_maxima(distance, min_distance=10)
        
        # 标记局部最大值为种子点
        markers, num_seeds = ndimage.label(local_max)
        
        # 步骤3: 分水岭分割
        # 将距离图的负值作为地形高度
        elevation_map = -distance
        
        # 应用分水岭算法
        segmentation = morphology.watershed(
            elevation_map, 
            markers, 
            mask=binary_mask
        )
        
        # 步骤4: 过滤小区域
        filtered_segmentation = np.zeros_like(segmentation)
        for region in measure.regionprops(segmentation):
            if region.area >= min_size:
                filtered_segmentation[segmentation == region.label] = region.label
        
        return filtered_segmentation
    
    def _find_local_maxima(self, distance, min_distance=5):
        """寻找局部最大值点"""
        # 使用最大值滤波
        size = 2 * min_distance + 1
        kernel = np.ones((size, size), dtype=bool)
        local_max = ndimage.maximum_filter(distance, footprint=kernel) == distance
        
        # 只在前景区域寻找
        local_max[distance == 0] = 0
        
        return local_max

代码实现了一个简化的实例分割模型,其编码器-解码器主干网络与图中描述的U-Net核心思想一脉相承。具体来看:

  1. 编码器(左侧收缩路径):代码中的self.encoder依次执行"蓝紫箭头"所示的3×3卷积+ReLU操作,以及"红色箭头"所示的2×2最大池化/下采样,这与图中左上角描述的流程(572×572 → 卷积 → 568×568 → 池化 → 284×284)完全对应,目的都是逐步提取深层语义特征。

  2. 解码器(右侧扩展路径):代码中的self.decoder通过"绿色箭头"所示的2×2转置卷积(反卷积) 进行上采样,逐步恢复特征图的空间分辨率,最终通过"青蓝箭头"所示的1×1卷积输出每个像素的语义类别(前景/背景),对应图中最右侧将特征映射到类别空间的过程。

  3. 关键差异与实例分割实现:在获得二值语义掩码后,通过距离变换找到细胞中心,再使用分水岭算法(以距离图的负值为地形)将从中心扩散的"水域"分离,从而将连成一片的前景分割成具有独立ID的多个实例,完成了从"像素是什么"到"像素属于哪个物体"的实例分割任务。

5. 图像生成

任务定义 :生成符合某些条件的图像,如从噪声中生成真实图像。### (1)GAN网络

GAN(生成对抗网络)的完整工作流程与核心对抗思想。从最左侧的"noise"(噪声) 开始,这通常是随机生成的一组数字(如100维高斯噪声向量),它作为生成器G(Generator) 的"种子"或创意来源。生成器(红色框G) 的核心任务是一个"造假者":它接收这片混沌的噪声,并通过一个复杂的神经网络(通常是全连接层和转置卷积层)进行学习与变换,最终"无中生有"地输出一批"Fake Sample"(假样本),如图中叠放的生成图像。

与此同时,从数据集来的"Real Sample"(真实样本) 被提供在流程上方。真假两路样本在此汇合,共同输入给"检验员"------判别器D(Discriminator,蓝色框)。判别器的职责是一个"鉴定家":它需要仔细观察和分析输入样本的特征,并通过另一个神经网络(通常是卷积网络)进行判断,最终输出一个"predict label"(预测标签),即一个概率值(比如0.8表示"很可能为真",0.1表示"很可能是假")。

GAN的精髓在于图中箭头所描绘的、贯穿始终的动态对抗博弈

  1. 生成器G的目标是:不断学习,使自己生成的"假样本"越来越逼真,直到能"以假乱真",成功骗过判别器D(让D对假样本也输出高的真实概率)。
  2. 判别器D的目标是:不断进化自己的鉴定能力,越来越精准地区分真假样本。

在训练初期,生成器水平很差,判别器很容易识破。但随着训练进行,判别器给出的反馈(预测标签)会作为关键信号,通过反向传播"指导"生成器改进。这个过程循环往复,如同"造假工艺"与"鉴定技术"在相互比拼中不断升级。最终,理想状态下双方会达到一个平衡点(纳什均衡),此时生成器能产生足以乱真的高质量样本,而判别器则难以判断(真假概率都接近0.5)。因此,整个系统的输入是随机噪声和真实数据,而输出则是一个能够从噪声中生成逼真数据的、训练好的生成器模型。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# ==================== 1. 生成器 (G Generator) ====================
# 对应图中的红色虚线框"G Generator"
class Generator(nn.Module):
    """
    生成器:从噪声生成假样本
    输入: 噪声向量 (noise) → 输出: 假样本 (Fake Sample)
    对应图中: noise → 红色箭头 → G Generator → 黑色箭头 → Fake Sample
    """
    def __init__(self, noise_dim=100, img_channels=1, img_size=28):
        super(Generator, self).__init__()
        self.noise_dim = noise_dim
        
        # 将噪声向量逐步上采样为图像
        self.model = nn.Sequential(
            # 第一层: 全连接层,将噪声向量展开
            nn.Linear(noise_dim, 128),
            nn.ReLU(),
            
            # 第二层: 进一步扩展特征
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            
            # 第三层: 扩展到与图像尺寸匹配
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            
            # 输出层: 映射到图像空间
            nn.Linear(512, img_channels * img_size * img_size),
            nn.Tanh()  # 输出值在[-1, 1]之间,对应归一化的图像
        )
        
        self.img_channels = img_channels
        self.img_size = img_size
    
    def forward(self, noise):
        """
        前向传播: 从噪声生成假样本
        对应图中流程: noise → G → Fake Sample
        """
        # 输入: noise (batch_size, noise_dim)
        batch_size = noise.shape[0]
        
        # 通过生成器网络
        output = self.model(noise)  # 形状: (batch_size, img_channels*img_size*img_size)
        
        # 重塑为图像格式
        fake_images = output.view(batch_size, self.img_channels, self.img_size, self.img_size)
        
        return fake_images  # 输出: Fake Sample

# ==================== 2. 判别器 (D Discriminator) ====================
# 对应图中的蓝色虚线框"D Discriminator"
class Discriminator(nn.Module):
    """
    判别器:判断输入样本是真实还是虚假
    输入: 真实样本(Real Sample) 或 假样本(Fake Sample)
    输出: 预测标签(predict label) - 样本为真的概率
    对应图中: Real Sample/Fake Sample → 黑色箭头 → D Discriminator → 黑色箭头 → predict label
    """
    def __init__(self, img_channels=1, img_size=28):
        super(Discriminator, self).__init__()
        
        # 将图像展平为向量
        self.input_dim = img_channels * img_size * img_size
        
        # 判别器网络结构
        self.model = nn.Sequential(
            # 第一层: 处理展平的图像
            nn.Linear(self.input_dim, 512),
            nn.LeakyReLU(0.2),  # 使用LeakyReLU防止梯度消失
            
            # 第二层: 提取特征
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            
            # 第三层: 进一步提取特征
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            
            # 输出层: 输出一个概率值
            nn.Linear(128, 1),
            nn.Sigmoid()  # 输出0-1之间的概率,对应"predict label"
        )
    
    def forward(self, img):
        """
        前向传播: 判断图像是真实还是生成
        对应图中流程: Real Sample/Fake Sample → D → predict label
        """
        # 输入: img (batch_size, img_channels, img_size, img_size)
        batch_size = img.shape[0]
        
        # 将图像展平为向量
        flattened = img.view(batch_size, -1)  # 形状: (batch_size, img_channels*img_size*img_size)
        
        # 通过判别器网络
        validity = self.model(flattened)  # 形状: (batch_size, 1)
        
        return validity  # 输出: predict label (概率值)

# ==================== 3. 完整的GAN模型 ====================
class GAN(nn.Module):
    """
    完整的GAN模型,整合生成器和判别器
    实现图中所示的完整数据流
    """
    def __init__(self, noise_dim=100, img_channels=1, img_size=28):
        super(GAN, self).__init__()
        self.noise_dim = noise_dim
        self.generator = Generator(noise_dim, img_channels, img_size)
        self.discriminator = Discriminator(img_channels, img_size)
    
    def generate_fake_samples(self, batch_size, device='cpu'):
        """
        生成假样本
        对应图中: noise → G → Fake Sample
        """
        # 1. 生成随机噪声 (对应图中"noise")
        noise = torch.randn(batch_size, self.noise_dim).to(device)
        
        # 2. 通过生成器生成假样本 (对应图中G Generator)
        fake_samples = self.generator(noise)  # 输出Fake Sample
        
        return fake_samples
    
    def discriminate(self, samples):
        """
        判别样本
        对应图中: Real Sample/Fake Sample → D → predict label
        """
        # 通过判别器得到预测标签
        predictions = self.discriminator(samples)  # 输出predict label
        
        return predictions

# ==================== 4. 训练过程示例 ====================
def train_gan_step_by_step():
    """
    逐步演示GAN的训练过程,对应图中完整的数据流
    """
    # 设置参数
    batch_size = 64
    noise_dim = 100
    img_size = 28
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 初始化模型
    gan = GAN(noise_dim=noise_dim, img_channels=1, img_size=img_size)
    gan.generator.to(device)
    gan.discriminator.to(device)
    
    # 定义损失函数和优化器
    adversarial_loss = nn.BCELoss()  # 二分类交叉熵损失
    optimizer_G = optim.Adam(gan.generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(gan.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    # 模拟一批真实样本 (对应图中"Real Sample")
    real_samples = torch.randn(batch_size, 1, img_size, img_size).to(device)
    real_labels = torch.ones(batch_size, 1).to(device)  # 真实样本标签为1
    
    print("=" * 60)
    print("GAN训练步骤 - 对应图中的数据流")
    print("=" * 60)
    
    # 第一步: 生成假样本 (对应图中上半部分流程)
    print("\n1. 生成假样本 (对应图中: noise → G → Fake Sample):")
    
    # 生成随机噪声 (对应图中"noise"的五个黑点)
    noise = torch.randn(batch_size, noise_dim).to(device)
    print(f"   输入噪声形状: {noise.shape}  (对应图中的'noise')")
    
    # 通过生成器生成假样本
    fake_samples = gan.generator(noise)
    print(f"   生成假样本形状: {fake_samples.shape}  (对应图中的'Fake Sample')")
    
    # 第二步: 判别真实样本 (对应图中右上部分流程)
    print("\n2. 判别真实样本 (对应图中: Real Sample → D → predict label):")
    
    # 判别器处理真实样本
    real_validity = gan.discriminator(real_samples)
    print(f"   真实样本预测: {real_validity[:5].squeeze().detach().cpu().numpy()}  (对应图中的'predict label')")
    
    # 第三步: 判别假样本 (对应图中右下部分流程)
    print("\n3. 判别假样本 (对应图中: Fake Sample → D → predict label):")
    
    # 判别器处理假样本
    fake_validity = gan.discriminator(fake_samples.detach())
    print(f"   假样本预测: {fake_validity[:5].squeeze().detach().cpu().numpy()}  (对应图中的'predict label')")
    
    # 第四步: 训练判别器
    print("\n4. 训练判别器:")
    optimizer_D.zero_grad()
    
    # 计算真实样本的损失
    real_loss = adversarial_loss(real_validity, real_labels)
    print(f"   真实样本损失: {real_loss.item():.4f}")
    
    # 计算假样本的损失
    fake_labels = torch.zeros(batch_size, 1).to(device)
    fake_loss = adversarial_loss(fake_validity, fake_labels)
    print(f"   假样本损失: {fake_loss.item():.4f}")
    
    # 总判别器损失
    d_loss = real_loss + fake_loss
    d_loss.backward()
    optimizer_D.step()
    print(f"   判别器总损失: {d_loss.item():.4f}")
    
    # 第五步: 训练生成器
    print("\n5. 训练生成器 (对应完整流程: noise → G → Fake Sample → D → predict label):")
    optimizer_G.zero_grad()
    
    # 重新生成假样本
    fake_samples = gan.generator(noise)
    
    # 判别器对生成样本的预测
    validity = gan.discriminator(fake_samples)
    
    # 生成器希望判别器将假样本判断为真
    g_loss = adversarial_loss(validity, real_labels)
    g_loss.backward()
    optimizer_G.step()
    
    print(f"   生成器损失: {g_loss.item():.4f} (目标: 使判别器输出接近1)")
    print(f"   最终流程完成: noise → G → Fake Sample → D → predict label = {validity.mean().item():.4f}")
    
    return gan
相关推荐
一条闲鱼_mytube2 小时前
智能体设计模式(二)反思-工具使用-规划
网络·人工智能·设计模式
m0_748254662 小时前
CSS AI 编程
前端·css·人工智能
愚公搬代码2 小时前
【愚公系列】《AI+直播营销》030-主播的选拔和人设设计(选拔匹配的主播)
人工智能
三不原则2 小时前
故障案例:告警风暴处理,用 AI 实现告警聚合与降噪
人工智能
这张生成的图像能检测吗2 小时前
(论文速读)GNS:学习用图网络模拟复杂物理
人工智能·图神经网络·物理模型
童话名剑2 小时前
神经风格迁移(吴恩达深度学习笔记)
深度学习·机器学习·计算机视觉·特征检测·神经风格迁移
HySpark2 小时前
基于语音转文字与语义分析的智能语音识别技术
人工智能·语音识别
Coder_Boy_2 小时前
基于SpringAI的在线考试系统-考试模块前端页面交互设计及优化
java·数据库·人工智能·spring boot
Maddie_Mo2 小时前
智能体设计模式 第一章:提示链
人工智能·python·语言模型·rag