PyTorch实现二维卷积与边缘检测:从原理到实战

本文通过PyTorch实现二维互相关运算、自定义卷积层,并演示如何通过卷积核检测图像边缘。同时,我们将训练一个卷积核参数,使其能够从数据中学习边缘特征。


1. 二维互相关运算的实现

互相关运算(Cross-Correlation)是卷积操作的基础。以下代码实现了二维互相关运算:

python 复制代码
import torch
from torch import nn

def corr2d(x, k):
    h, w = k.shape
    y = torch.zeros((x.shape[0] - h + 1, x.shape[1] - w + 1))
    for i in range(y.shape[0]):
        for j in range(y.shape[1]):
            y[i, j] = (x[i:i+h, j:j+w] * k).sum()  # 逐元素相乘后求和
    return y

验证输出

输入矩阵和卷积核如下,输出结果为互相关运算后的张量:

python 复制代码
x = torch.tensor([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])
k = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
print(corr2d(x, k))

输出

bash 复制代码
tensor([[19., 25.],
        [37., 43.]])

2. 自定义二维卷积层

通过继承nn.Module实现一个自定义卷积层,包含可学习的权重和偏置:

python 复制代码
class Conv2D(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(kernel_size))
        self.bias = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        return corr2d(x, self.weight) + self.bias

3. 边缘检测应用

3.1 构造输入图像

创建一个6x8的矩阵,中间4列为黑色(值为0),两侧为白色(值为1):

python 复制代码
x = torch.ones(6, 8)
x[:, 2:6] = 0
print(x)

输出

bash 复制代码
tensor([[1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.]])

3.2 定义卷积核

使用卷积核[[1, -1]]检测垂直边缘:

python 复制代码
k = torch.tensor([[1.0, -1.0]])
y = corr2d(x, k)
print(y)

输出

bash 复制代码
tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])
  • 结果解释

    输出中1表示从白到黑的边缘,-1表示从黑到白的边缘。

3.3 水平边缘检测

若将输入矩阵转置,原卷积核无法检测水平边缘:

python 复制代码
print(corr2d(x.T, k))

输出:全零矩阵(无法检测到水平边缘)

bash 复制代码
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        ...])

4. 学习卷积核参数

使用PyTorch内置的nn.Conv2d,通过梯度下降学习卷积核参数:

python 复制代码
# 定义模型
conv2d = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(1, 2), bias=False)

# 调整输入输出形状
x = x.reshape((1, 1, 6, 8))  # (batch_size, channels, height, width)
y = y.reshape((1, 1, 6, 7))

# 训练过程
for i in range(10):
    y_hat = conv2d(x)
    loss = (y_hat - y).pow(2)
    conv2d.zero_grad()
    loss.sum().backward()
    conv2d.weight.data[:] -= 3e-2 * conv2d.weight.grad  # 更新权重
    if (i+1) % 2 == 0:
        print(f'batch{i+1}, loss{loss.sum():.3f}')

输出

bash 复制代码
batch2, loss5.270
batch4, loss0.884
batch6, loss0.148
batch8, loss0.025
batch10, loss0.004

4.1 查看学习后的卷积核

训练后的权重接近理想值[1, -1]

python 复制代码
print(conv2d.weight.data.reshape((1, 2)))

输出

bash 复制代码
tensor([[ 0.9883, -0.9878]])

5. 总结

  1. 互相关运算:通过逐窗口计算实现基础的卷积操作。

  2. 边缘检测:方向特定的卷积核可提取图像边缘特征。

  3. 参数学习:利用梯度下降可自动学习卷积核参数,无需手动设计。

完整代码已验证,读者可自行调整输入或卷积核探索更多效果。


提示 :实际项目中建议使用PyTorch内置的高效卷积层(如nn.Conv2d),而非手动实现,以充分利用GPU加速。

相关推荐
芯盾时代20 分钟前
数据出境的安全合规思考
大数据·人工智能·安全·网络安全·信息与通信
Sylvan Ding38 分钟前
PyTorch Lightning实战 - 训练 MNIST 数据集
人工智能·pytorch·python·lightning
大白技术控40 分钟前
浙江大学 deepseek 公开课 第三季 第3期 - 陈喜群 教授 (附PPT下载) by 突破信息差
人工智能·互联网·deepseek·deepseek公开课·浙大deepseek公开课课件·deepseek公开课ppt·人工智能大模型
Silence4Allen44 分钟前
大模型微调指南之 LLaMA-Factory 篇:一键启动LLaMA系列模型高效微调
人工智能·大模型·微调·llama-factory
江鸟19981 小时前
AI日报 · 2025年05月11日|传闻 OpenAI 考虑推出 ChatGPT “永久”订阅模式
人工智能·gpt·ai·chatgpt·github
weifont1 小时前
Ai大模型训练从零到1第一节(共81节)
人工智能
kyle~1 小时前
C++匿名函数
开发语言·c++·人工智能
水银嘻嘻1 小时前
web 自动化之 Unittest 应用:报告&装饰器&断言
前端·python·自动化
知来者逆1 小时前
AI 在模仿历史语言方面面临挑战:大型语言模型在生成历史风格文本时的困境与研究进展
人工智能·深度学习·语言模型·自然语言处理·chatgpt
攻城狮7号1 小时前
Python爬虫第20节-使用 Selenium 爬取小米商城空调商品
开发语言·数据库·爬虫·python·selenium