CNN卷积神经网络Python实现

python 复制代码
import torch
from torch import nn

# ①定义互相关运算
def corr2d(X, K):
    """计算二维互相关运算。"""
    # 获取K的形状 行为h,列为w
    h, w = K.shape
    # 生成全0的矩阵,行为X的行减去h加上1,列为X的列减去w加上1
    Y = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            # 两层循环,相乘,求和
            Y[i, j] = (X[i:i + h, j:j + w] * K).sum()
    # 返回Y
    return Y


# ②实现二维卷积层
class Conv2D(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        # 定义权重
        self.weight = nn.Parameter(torch.rand(kernel_size))
        # 定义偏移
        self.bias = nn.Parameter(torch.zeros(1))

    # 定义正向传播
    def forward(self, x):
        return corr2d(x, self.weight) + self.bias

if __name__ == '__main__':
    # 定义模型
    conv2d = nn.Conv2d(1, 1, kernel_size=(1, 2), bias=False)
    # 定义X
    X = torch.ones((6, 8))
    X[:, 2:6] = 0
    # 定义K
    K = torch.tensor([[1.0, -1.0]])
    # 计算Y
    Y = corr2d(X, K)
    X = X.reshape((1, 1, 6, 8))
    Y = Y.reshape((1, 1, 6, 7))
    # 训练10轮
    for i in range(10):
        # 计算Y_hat
        Y_hat = conv2d(X)
        # 损失
        l = (Y_hat - Y)**2
        # 梯度归零
        conv2d.zero_grad()
        # 后向传播
        l.sum().backward()
        # 优化函数 学习率=3e-2
        conv2d.weight.data[:] -= 3e-2 * conv2d.weight.grad
        if (i + 1) % 2 == 0:
            print(f'batch {i+1}, loss {l.sum():.3f}')
    # 经过10轮学习的权重为
    print(conv2d.weight.data.reshape((1, 2)))

结果

python 复制代码
batch 2, loss 1.463
batch 4, loss 0.358
batch 6, loss 0.106
batch 8, loss 0.037
batch 10, loss 0.014
tensor([[ 1.0066, -0.9830]])

Process finished with exit code 0
相关推荐
Ulyanov4 分钟前
三维战场可视化核心原理(一):从坐标系到运动控制的全景指南
开发语言·前端·python·pyvista·gui开发
SNAKEpc1213812 分钟前
PyQtGraph应用(一):常用图表图形绘制
python·qt·pyqt
CSND74015 分钟前
anaconda 安装库,终端手动指定下载源
python
0思必得015 分钟前
[Web自动化] 爬虫基础
运维·爬虫·python·selenium·自动化·html
放飞自我的Coder18 分钟前
【Python 异步编程学习手册】
python
HyperAI超神经18 分钟前
在线教程丨微软开源3D生成模型TRELLIS.2,3秒生成高分辨率的全纹理资产
人工智能·深度学习·机器学习·3d
IT阳晨。18 分钟前
【CNN卷积神经网络(吴恩达)】目标检测学习笔记
深度学习·目标检测·cnn
ycydynq19 分钟前
django 数据库 多表操作
数据库·python·django
m0_5494166619 分钟前
自动化与脚本
jvm·数据库·python
gsgbgxp19 分钟前
安装库是优先用conda还是pip
深度学习·ubuntu·conda·pip