深度学习的疑问--综合【2】:像CNN,GNN,transformer等这些模型都是用于提取特征,然后经过全连接层实现分类的吗?

**总结:**CNN,GNN,transformer等这些模型都是用于提取特征;FC、MLP等用于实现分类,MLP即是多个FC组成的。

是的,从高层次来看,CNN(卷积神经网络)、GNN(图神经网络)、Transformer 等模型的主要作用是提取特征 ,然后通常通过 全连接层(FC,Fully Connected Layer) 或者其他分类头(如 softmax 层)来进行分类或回归任务。

但是,不同模型的特征提取方式、数据结构和任务适用场景有所不同,下面是详细的分析。


1. CNN(卷积神经网络)

适用数据类型: 主要用于图像网格结构数据(如2D/3D数据)。
特征提取方式:

  • CNN 通过卷积层 (Conv)和池化层(Pooling)来提取局部空间特征,并逐层组合高层次特征。

  • 特征提取部分一般由多个卷积层+池化层堆叠而成,形成深层的特征表示。

  • 最后,提取到的高层次特征展平成向量(Flatten),然后通过 **全连接层(FC)** 进行分类或其他任务。

示例结构(用于图像分类):

复制代码
输入(image) -> 卷积层(Conv) -> ReLU -> 池化(Pooling) -> ... -> FC层 -> Softmax -> 分类输出

示例代码(简单 CNN 分类器):

复制代码
import torch.nn as nn

class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(64 * 8 * 8, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes),
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc_layers(x)
        return x

总结:CNN 主要用于图像等数据的局部特征提取,然后通过全连接层分类。


2. GNN(图神经网络)

适用数据类型: 主要用于图结构数据 (如社交网络、分子结构、知识图谱等)。
特征提取方式:

  • GNN 通过图卷积 (如 GCN, GAT, GraphSAGE)等方式聚合邻居节点的特征,从而生成每个节点的高层次表示。

  • 经过多层 GNN 传播后,每个节点的特征包含了多个 hop(层)范围内的邻居信息。

  • 最后,可以使用全连接层(或 MLP,多层感知机)进行分类或回归任务。

示例结构(用于节点分类,如Cora论文分类):

` 注:上图的outputs不是分类输出,是经过GCN之后提取的特征。 `

复制代码
输入(图数据) -> GCN/GAT -> ReLU -> GCN/GAT -> ReLU -> 全连接层(FC) -> Softmax -> 分类输出

示例代码(GCN 进行节点分类):

复制代码
import torch
import torch.nn as nn
import torch_geometric.nn as pyg_nn

class GCN(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats):
        super(GCN, self).__init__()
        self.conv1 = pyg_nn.GCNConv(in_feats, hidden_feats)
        self.conv2 = pyg_nn.GCNConv(hidden_feats, out_feats)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x  # 直接输出用于分类

总结:GNN 主要用于图数据,通过邻居信息传播提取特征,最后可以通过全连接层进行分类或其他任务。


3. Transformer

适用数据类型: 主要用于序列数据(文本、时间序列、DNA序列) ,也扩展到图像和多模态数据
特征提取方式:

  • Transformer 主要依靠 自注意力机制(Self-Attention) 来捕捉输入数据的全局关系,不同于 CNN 仅关注局部信息。

  • 经过多个 Transformer 层后,可以得到上下文感知的序列表示

  • 对于分类任务,通常使用 Transformer 提取特征后,将第一个 token(如 BERT 的 [CLS] token) 或者平均池化后的特征送入 全连接层(FC) 进行分类。

示例结构(用于文本分类,如 BERT):

注:右图为Encoder

复制代码
输入(文本) -> Token Embedding -> Transformer Encoder -> Pooling(CLS或平均) -> FC层 -> Softmax -> 分类输出

示例代码(BERT 进行文本分类):

复制代码
from transformers import BertModel
import torch.nn as nn

class TransformerClassifier(nn.Module):
    def __init__(self, num_classes):
        super(TransformerClassifier, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.fc = nn.Linear(768, num_classes)  # BERT 输出 768 维特征

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        cls_embedding = outputs.last_hidden_state[:, 0, :]  # 取 [CLS] 位置的特征
        logits = self.fc(cls_embedding)
        return logits

总结:Transformer 通过自注意力机制提取特征,通常在最后加上全连接层进行分类。

4. MLP 和 FC

多层感知机(MLP)是一种前馈神经网络 ,它是由多个全连接层(FC)堆叠而成的。MLP的核心思想是利用多层全连接层(FC)构建一个深度网络,并通过非线性激活函数增加模型的表示能力。

  • 结构 :MLP通常包含多个隐藏层(也就是多个全连接层),每一层都会通过激活函数进行非线性变换。MLP的最终目标是通过多层处理,学到输入到输出之间的复杂映射。

  • 特点 :一般来说,MLP包含一个输入层、多个隐藏层和一个输出层。每一层都是全连接层


总结对比

模型类型 适用数据 特征提取方式 分类层
CNN 图像、结构化数据 卷积层提取局部特征,池化层降维 FC + Softmax
GNN 图结构数据(社交网络、分子等) 图卷积聚合邻居信息 FC + Softmax
Transformer 序列数据(文本、时间序列、图像等) 自注意力提取全局特征 FC + Softmax

共同点

  • 这三种模型主要用于特征提取,可以看作是自动提取深层次特征的特征工程器。

  • 经过特征提取后,通常使用 全连接层(FC)+ Softmax 进行分类任务。

不同点

  • CNN 适用于图像数据 ,关注局部特征

  • GNN 适用于图数据 ,依赖邻居信息进行特征聚合

  • Transformer 适用于序列数据(文本、时间序列、图像等) ,关注全局依赖关系


结论

是的,CNN、GNN、Transformer 的主要作用是特征提取,最后通常用全连接层(FC)进行分类。但不同模型适用于不同类型的数据,特征提取方式也不同。实际应用中,也可以用更复杂的分类头(如 MLP、注意力机制等)来替代简单的 FC 层。

相关推荐
大江东去浪淘尽千古风流人物6 小时前
【VLN】VLN仿真与训练三要素 Dataset,Simulators,Benchmarks(2)
深度学习·算法·机器人·概率论·slam
cyyt6 小时前
深度学习周报(2.2~2.8)
人工智能·深度学习
阿杰学AI6 小时前
AI核心知识92——大语言模型之 Self-Attention Mechanism(简洁且通俗易懂版)
人工智能·ai·语言模型·自然语言处理·aigc·transformer·自注意力机制
2401_836235867 小时前
财务报表识别产品:从“数据搬运”到“智能决策”的技术革命
人工智能·科技·深度学习·ocr·生活
holeer7 小时前
【V2.0】王万良《人工智能导论》笔记|《人工智能及其应用》课程教材笔记
神经网络·机器学习·ai·cnn·nlp·知识图谱·智能计算
啊森要自信7 小时前
CANN runtime 深度解析:异构计算架构下运行时组件的性能保障与功能增强实现逻辑
深度学习·架构·transformer·cann
kyle~7 小时前
深度学习---长短期记忆网络LSTM
人工智能·深度学习·lstm
DatGuy7 小时前
Week 36: 量子深度学习入门:辛量子神经网络与物理守恒
人工智能·深度学习·神经网络
肾透侧视攻城狮7 小时前
《解锁计算机视觉:深度解析 PyTorch torchvision 核心与进阶技巧》
人工智能·深度学习·计算机视觉模快·支持的数据集类型·常用变换方法分类·图像分类流程实战·视觉模快高级功能
CoovallyAIHub7 小时前
让本地知识引导AI追踪社区变迁,让AI真正理解社会现象
深度学习·算法·计算机视觉