超越CNN:GCN如何重塑图像处理

目录

写在前面

一、GCN处理图像的优势

二、构建数据集

三、定义模型

四、训练代码

五、推理代码

六、总结


写在前面

GCN 用于图像处理时,并没有 CNN 中 "固定形状、滑动遍历" 的卷积核,但存在承担 "特征变换" 功能的权重矩阵,其作用与 CNN 卷积核的 "参数化特征提取" 本质相通,只是适配图结构的操作形式不同。

下面我将用GCN完成一个简单的图像分类任务,这项任务的核心是数据处理------构建图数据。

要了解GCN的基础知识,戳这里:一图看懂图卷积网络GCN

要了解GCN处理图像的计算过程,戳这里:图卷积网络GCN:图像理解的新视角

一、GCN处理图像的优势

那GCN比CNN有哪些优势呢?举个例子,假设你要识别猫的图片。

CNN 会扫描整张图 → 识别猫 → 识别椅子 → 组合信息,它看到的是一个规则的像素网格。优点 是对规则、结构一致的图像(例如方形照片)非常高效。缺点是它只能处理"规则网格",也就是固定排列的像素点。

GCN 把图像看作一个图(Graph) ,节点(Node)可以是每个像素、每个超像素(superpixel)或者图像里的某些"关键区域"。边(Edge)代表这些节点之间的关系,比如:像素之间颜色相似度、空间距离、是否属于同一物体等。然后,GCN 在图上传播信息------每个节点都会根据它的邻居更新自己的特征。

GCN把"猫的身体各部分""椅子的腿""背景"等区域当作节点;根据它们之间的关系(比如"猫在椅子上")建立边;让这些节点相互传递信息;最终得出一个更"结构化"的理解。GCN 知道:"猫"和"椅子"不是孤立存在的,它们之间有关系。这类"结构关系"是 CNN 不容易直接捕捉的。

二、SLIC 算法

下面步骤会用到SLIC 算法,这里简单介绍一下。

SLIC(Simple Linear Iterative Clustering) 是一种常用的 超像素分割算法

它的作用是把一张图像切成一堆颜色相近、空间相邻的小块区域(称为"超像素")。这些超像素比像素更"聪明"------每个块大致代表图像中的一个局部区域(比如一块天空、一片草地、一只眼睛)。它能让后续算法(如 GCN、目标检测、分割)更高效地处理图像结构。

SLIC 的核心思想很简单:在颜色空间和空间位置上,把相似的像素聚成一类。它本质上是 在五维空间中做 K-means 聚类

这 5 个维度是:

  • 三个颜色维度(通常是 Lab 空间中的 L, a, b);

  • 两个位置维度(像素的 x, y 坐标)。

所以每个像素都可以表示为一个五维向量:(L,a,b,x,y)

计算步骤:

1.初始化聚类中心

  • 把图像均匀分成若干个格子;

  • 在每个格子的中心挑一个像素作为初始聚类中心。

2.定义距离度量(颜色+空间)

对每个像素,计算它到聚类中心的"距离":

其中:

  • ​:颜色差(Lab空间)

  • ​:空间距离

  • S:超像素的期望大小

  • m:平衡系数(控制颜色 vs 空间的重要性)

当 m小时,更注重颜色一致;当 m大时,更注重空间连续。

3.分配像素

每个像素根据距离 D 选择最近的中心归类。

4.更新聚类中心

对每个聚类,重新计算平均的颜色和坐标,然后更新中心点。

5.迭代

重复分配和更新,直到聚类中心稳定。

6.后处理(可选)

去掉孤立的小块,保证区域连通。

直观理解(打个比方)

想象你在画布上撒满了彩色小珠子:SLIC 就像在画布上放很多小"吸铁石",每个吸铁石会吸引周围 颜色相近 的小珠子;经过几轮吸附和调整,画布就被自然地分成了一些 颜色块 ------ 这就是超像素。

四、构建数据集

这里是任务的核心------生成"超像素"或者"关键区域"。我们可以使用目标检测或者分割模型先识别出猫和椅子等物体作为"超像素",然后再建立联系;这里我们简化问题,使用基于颜色的SLIC 算法来生成"超像素"。

通常,一个图像对应一个 Data 对象(来自 torch_geometric.data.Data)。我们可以把多个图像封装进 Dataset。构建数据集代码:

python 复制代码
from torch_geometric.data import Dataset, DataLoader
from skimage.segmentation import slic
from skimage.color import rgb2lab
import numpy as np
from PIL import Image
import os
import torch

class CatDogGraphDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.samples = []
        for label, cls in enumerate(['cat', 'dog']):
            folder = os.path.join(root_dir, cls)
            for img_name in os.listdir(folder):
                if img_name.endswith('.jpg'):
                    self.samples.append((os.path.join(folder, img_name), label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB").resize((64, 64))
        img_np = np.array(img)

        # 生成超像素
        segments = slic(img_np, n_segments=75, compactness=10)
        num_nodes = segments.max() + 1

        # 节点特征
        lab_img = rgb2lab(img_np)
        node_features = []
        for i in range(num_nodes):
            mask = segments == i
            mean_color = lab_img[mask].mean(axis=0)
            node_features.append(mean_color)
        x = torch.tensor(node_features, dtype=torch.float)

        # 构造边(相邻的超像素)
        edges = set()
        for i in range(63):
            for j in range(63):
                if segments[i, j] != segments[i, j + 1]:
                    edges.add((segments[i, j], segments[i, j + 1]))
                if segments[i, j] != segments[i + 1, j]:
                    edges.add((segments[i, j], segments[i + 1, j]))
        if len(edges) == 0:
            edges.add((0, 0))  # 避免空图
        edge_index = torch.tensor(list(zip(*edges)), dtype=torch.long)

        return Data(x=x, edge_index=edge_index, y=torch.tensor([label], dtype=torch.long))s定义模型

模型很简单,由两层GCN组成,每层都之后是ReLU操作,然后经过global_mean_pool,最后送入全连接输出结果:

python 复制代码
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class SimpleGCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)  # 聚合所有节点特征 → 图级别表示
        x = self.fc(x)
        return x

五、训练代码

使用交叉熵损失,Adam优化器:

python 复制代码
from torch_geometric.loader import DataLoader
import torch.optim as optim

# 加载数据
train_dataset = CatDogGraphDataset('data/cats_and_dogs/train')
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# 初始化模型
model = SimpleGCN(input_dim=3, hidden_dim=32, output_dim=2)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

# 训练循环
for epoch in range(10):
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")

六、推理代码

推理时,流程一样,只是不计算梯度。

python 复制代码
test_dataset = CatDogGraphDataset('data/cats_and_dogs/test')
test_loader = DataLoader(test_dataset, batch_size=1)

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in test_loader:
        out = model(batch.x, batch.edge_index, batch.batch)
        pred = out.argmax(dim=1)
        correct += int((pred == batch.y).sum())
        total += batch.y.size(0)

print(f"Test Accuracy: {correct / total:.2%}")

七、总结

一句话总结:

CNN:在像素网格上卷积提取局部特征;

GCN:在区域关系图上传递并融合结构信息。

对于"猫坐在椅子上"这种有结构关系的图像,GCN 能更好地理解语义。

GCN用于图像处理就介绍到这里。

关注不迷路(*^▽^*),暴富入口==》 https://bbs.csdn.net/topics/619691583

相关推荐
康语智能4 小时前
科技赋能成长,小康AI家庭医生守护童真
人工智能·科技
WLJT1231231235 小时前
科技赋能塞上农业:宁夏从黄土地到绿硅谷的蝶变
大数据·人工智能·科技
StarPrayers.5 小时前
旅行商问题(TSP)(2)(heuristics.py)(TSP 的两种贪心启发式算法实现)
前端·人工智能·python·算法·pycharm·启发式算法
koo3645 小时前
李宏毅机器学习笔记21
人工智能·笔记·机器学习
Bony-5 小时前
奶茶销售数据分析
人工智能·数据挖掘·数据分析·lstm
山烛5 小时前
YOLO v1:目标检测领域的单阶段革命之作
人工智能·yolo·目标检测·计算机视觉·yolov1
华仔AI智能体6 小时前
Qwen3(通义千问3)、OpenAI GPT-5、DeepSeek 3.2、豆包最新模型(Doubao 4.0)通用模型能力对比
人工智能·python·语言模型·agent·智能体
大千AI助手6 小时前
高斯隐马尔可夫模型:原理与应用详解
人工智能·高斯·hmm·高斯隐马尔可夫模型·ghmm·马尔科夫模型·混合高斯模型