CNN卷积神经网络Python实现

python 复制代码
import torch
from torch import nn

# ①定义互相关运算
def corr2d(X, K):
    """计算二维互相关运算。"""
    # 获取K的形状 行为h,列为w
    h, w = K.shape
    # 生成全0的矩阵,行为X的行减去h加上1,列为X的列减去w加上1
    Y = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            # 两层循环,相乘,求和
            Y[i, j] = (X[i:i + h, j:j + w] * K).sum()
    # 返回Y
    return Y


# ②实现二维卷积层
class Conv2D(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        # 定义权重
        self.weight = nn.Parameter(torch.rand(kernel_size))
        # 定义偏移
        self.bias = nn.Parameter(torch.zeros(1))

    # 定义正向传播
    def forward(self, x):
        return corr2d(x, self.weight) + self.bias

if __name__ == '__main__':
    # 定义模型
    conv2d = nn.Conv2d(1, 1, kernel_size=(1, 2), bias=False)
    # 定义X
    X = torch.ones((6, 8))
    X[:, 2:6] = 0
    # 定义K
    K = torch.tensor([[1.0, -1.0]])
    # 计算Y
    Y = corr2d(X, K)
    X = X.reshape((1, 1, 6, 8))
    Y = Y.reshape((1, 1, 6, 7))
    # 训练10轮
    for i in range(10):
        # 计算Y_hat
        Y_hat = conv2d(X)
        # 损失
        l = (Y_hat - Y)**2
        # 梯度归零
        conv2d.zero_grad()
        # 后向传播
        l.sum().backward()
        # 优化函数 学习率=3e-2
        conv2d.weight.data[:] -= 3e-2 * conv2d.weight.grad
        if (i + 1) % 2 == 0:
            print(f'batch {i+1}, loss {l.sum():.3f}')
    # 经过10轮学习的权重为
    print(conv2d.weight.data.reshape((1, 2)))

结果

python 复制代码
batch 2, loss 1.463
batch 4, loss 0.358
batch 6, loss 0.106
batch 8, loss 0.037
batch 10, loss 0.014
tensor([[ 1.0066, -0.9830]])

Process finished with exit code 0
相关推荐
weixin_418813872 小时前
Python-可视化学习笔记
笔记·python·学习
Danceful_YJ2 小时前
4.权重衰减(weight decay)
python·深度学习·机器学习
Zonda要好好学习3 小时前
Python入门Day5
python
电商数据girl4 小时前
有哪些常用的自动化工具可以帮助处理电商API接口返回的异常数据?【知识分享】
大数据·分布式·爬虫·python·系统架构
CoooLuckly4 小时前
numpy数据分析知识总结
python·numpy
超龄超能程序猿4 小时前
(六)PS识别:源数据分析- 挖掘图像的 “元语言”技术实现
python·组合模式
amazinging5 小时前
北京-4年功能测试2年空窗-报培训班学测开-第四十四天
python·学习·appium
UrbanJazzerati5 小时前
Xlwings安装报错:Connection timed out & WinError 32?一招解决你的安装难题!
python
Tipriest_5 小时前
Python异常类型介绍
开发语言·python·异常
前端付豪5 小时前
21、用 Python + Pillow 实现「朋友圈海报图生成器」📸(图文合成 + 多模板 + 自动换行)
后端·python