CNN卷积神经网络Python实现

python 复制代码
import torch
from torch import nn

# ①定义互相关运算
def corr2d(X, K):
    """计算二维互相关运算。"""
    # 获取K的形状 行为h,列为w
    h, w = K.shape
    # 生成全0的矩阵,行为X的行减去h加上1,列为X的列减去w加上1
    Y = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            # 两层循环,相乘,求和
            Y[i, j] = (X[i:i + h, j:j + w] * K).sum()
    # 返回Y
    return Y


# ②实现二维卷积层
class Conv2D(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        # 定义权重
        self.weight = nn.Parameter(torch.rand(kernel_size))
        # 定义偏移
        self.bias = nn.Parameter(torch.zeros(1))

    # 定义正向传播
    def forward(self, x):
        return corr2d(x, self.weight) + self.bias

if __name__ == '__main__':
    # 定义模型
    conv2d = nn.Conv2d(1, 1, kernel_size=(1, 2), bias=False)
    # 定义X
    X = torch.ones((6, 8))
    X[:, 2:6] = 0
    # 定义K
    K = torch.tensor([[1.0, -1.0]])
    # 计算Y
    Y = corr2d(X, K)
    X = X.reshape((1, 1, 6, 8))
    Y = Y.reshape((1, 1, 6, 7))
    # 训练10轮
    for i in range(10):
        # 计算Y_hat
        Y_hat = conv2d(X)
        # 损失
        l = (Y_hat - Y)**2
        # 梯度归零
        conv2d.zero_grad()
        # 后向传播
        l.sum().backward()
        # 优化函数 学习率=3e-2
        conv2d.weight.data[:] -= 3e-2 * conv2d.weight.grad
        if (i + 1) % 2 == 0:
            print(f'batch {i+1}, loss {l.sum():.3f}')
    # 经过10轮学习的权重为
    print(conv2d.weight.data.reshape((1, 2)))

结果

python 复制代码
batch 2, loss 1.463
batch 4, loss 0.358
batch 6, loss 0.106
batch 8, loss 0.037
batch 10, loss 0.014
tensor([[ 1.0066, -0.9830]])

Process finished with exit code 0
相关推荐
AI人工智能+1 天前
炫光活体检测技术:通过光学技术实现高效、安全的身份验证,有效防御多种伪造手段。
人工智能·深度学习·人脸识别·活体检测
多恩Stone1 天前
【3DV 进阶-2】Hunyuan3D2.1 训练代码详细理解下-数据读取流程
人工智能·python·算法·3d·aigc
xiaopengbc1 天前
在 Python 中实现观察者模式的具体步骤是什么?
开发语言·python·观察者模式
Python大数据分析@1 天前
python用selenium怎么规避检测?
开发语言·python·selenium·网络爬虫
ThreeAu.1 天前
Miniconda3搭建Selenium的python虚拟环境全攻略
开发语言·python·selenium·minicoda·python环境配置
偷心伊普西隆1 天前
Python EXCEL 理论探究:格式转换时处理缺失值方法
python·excel
东方佑1 天前
打破常规:“无注意力”神经网络为何依然有效?
人工智能·深度学习·神经网络
Francek Chen1 天前
【深度学习计算机视觉】03:目标检测和边界框
人工智能·pytorch·深度学习·目标检测·计算机视觉·边界框
九章云极AladdinEdu1 天前
AI集群全链路监控:从GPU微架构指标到业务Metric关联
人工智能·pytorch·深度学习·架构·开源·gpu算力