CNN卷积神经网络Python实现

python 复制代码
import torch
from torch import nn

# ①定义互相关运算
def corr2d(X, K):
    """计算二维互相关运算。"""
    # 获取K的形状 行为h,列为w
    h, w = K.shape
    # 生成全0的矩阵,行为X的行减去h加上1,列为X的列减去w加上1
    Y = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            # 两层循环,相乘,求和
            Y[i, j] = (X[i:i + h, j:j + w] * K).sum()
    # 返回Y
    return Y


# ②实现二维卷积层
class Conv2D(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        # 定义权重
        self.weight = nn.Parameter(torch.rand(kernel_size))
        # 定义偏移
        self.bias = nn.Parameter(torch.zeros(1))

    # 定义正向传播
    def forward(self, x):
        return corr2d(x, self.weight) + self.bias

if __name__ == '__main__':
    # 定义模型
    conv2d = nn.Conv2d(1, 1, kernel_size=(1, 2), bias=False)
    # 定义X
    X = torch.ones((6, 8))
    X[:, 2:6] = 0
    # 定义K
    K = torch.tensor([[1.0, -1.0]])
    # 计算Y
    Y = corr2d(X, K)
    X = X.reshape((1, 1, 6, 8))
    Y = Y.reshape((1, 1, 6, 7))
    # 训练10轮
    for i in range(10):
        # 计算Y_hat
        Y_hat = conv2d(X)
        # 损失
        l = (Y_hat - Y)**2
        # 梯度归零
        conv2d.zero_grad()
        # 后向传播
        l.sum().backward()
        # 优化函数 学习率=3e-2
        conv2d.weight.data[:] -= 3e-2 * conv2d.weight.grad
        if (i + 1) % 2 == 0:
            print(f'batch {i+1}, loss {l.sum():.3f}')
    # 经过10轮学习的权重为
    print(conv2d.weight.data.reshape((1, 2)))

结果

python 复制代码
batch 2, loss 1.463
batch 4, loss 0.358
batch 6, loss 0.106
batch 8, loss 0.037
batch 10, loss 0.014
tensor([[ 1.0066, -0.9830]])

Process finished with exit code 0
相关推荐
Ray Liang14 小时前
用六边形架构与整洁架构对比是伪命题?
java·python·c#·架构设计
AI攻城狮14 小时前
如何给 AI Agent 做"断舍离":OpenClaw Session 自动清理实践
python
千寻girling14 小时前
一份不可多得的 《 Python 》语言教程
人工智能·后端·python
AI攻城狮17 小时前
用 Playwright 实现博客一键发布到稀土掘金
python·自动化运维
曲幽18 小时前
FastAPI分布式系统实战:拆解分布式系统中常见问题及解决方案
redis·python·fastapi·web·httpx·lock·asyncio
孟健1 天前
Karpathy 用 200 行纯 Python 从零实现 GPT:代码逐行解析
python
码路飞1 天前
写了个 AI 聊天页面,被 5 种流式格式折腾了一整天 😭
javascript·python
曲幽2 天前
FastAPI压力测试实战:Locust模拟真实用户并发及优化建议
python·fastapi·web·locust·asyncio·test·uvicorn·workers
敏编程2 天前
一天一个Python库:jsonschema - JSON 数据验证利器
python