NLP(4)--实现一个线性层

前言

仅记录学习过程,有问题欢迎讨论

感觉全连接层就像一个中间层转换数据的形态的,或者说预处理数据?

代码

里面有两个部分,一部分是自己实现的,一部分是利用模块的方法实现的。

java 复制代码
import torch
import torch.nn as nn
import numpy as np

"""
numpy手动实现模拟一个线性层
"""


# 搭建一个2层的神经网络模型
# 每层都是线性层
# 继承 nn.Module
class TorchModel(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2):
        super(TorchModel, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size1) # 5*3
        self.layer2 = nn.Linear(hidden_size1, hidden_size2)

    def forward(self, x):
        x = self.layer1(x)
        # 第二层输出就是y
        y_pred = self.layer2(x)
        return y_pred


# 自定义模型
class DiyModel:
    def __init__(self, w1, b1, w2, b2):
        self.w1 = w1
        self.b1 = b1
        self.w2 = w2
        self.b2 = b2

    def forward(self, x):
        # 点 * w1.T 是转置 2*3 * 3*5 === 2*5
        hidden = np.dot(x, self.w1.T) + self.b1  # 1*5
        # 2*5 * 5*2 ===  2* 2
        y_pred = np.dot(hidden, self.w2.T) + self.b2  # 1*2
        return y_pred


# 随便准备一个网络输入 2*3
x = np.array([[3.1, 1.3, 1.2],
              [2.1, 1.3, 13]])
# 建立torch 参数是隐藏层的维度
torch_model = TorchModel(3, 5, 2)
# 字典 包含了w b 因为隐藏层的函数为: y = w*x + b
print(torch_model.state_dict())
print("-----------")
#打印模型权重,权重为随机初始化
torch_model_w1 = torch_model.state_dict()["layer1.weight"].numpy()
torch_model_b1 = torch_model.state_dict()["layer1.bias"].numpy()
torch_model_w2 = torch_model.state_dict()["layer2.weight"].numpy()
torch_model_b2 = torch_model.state_dict()["layer2.bias"].numpy()
print(torch_model_w1, "torch w1 权重")
print(torch_model_b1, "torch b1 权重")
print("-----------")
print(torch_model_w2, "torch w2 权重")
print(torch_model_b2, "torch b2 权重")
print("-----------")

# 预测
torch_x = torch.FloatTensor(x)
y_pred = torch_model.forward(torch_x)
# 2*2 的矩阵
print("torch模型预测结果:", y_pred)


# #把torch模型权重拿过来自己实现计算过程
diy_model = DiyModel(torch_model_w1, torch_model_b1, torch_model_w2, torch_model_b2)
# #用自己的模型来预测
y_pred_diy = diy_model.forward(np.array(x))
print("diy模型预测结果:", y_pred_diy)
相关推荐
陈苏同学8 分钟前
MPC控制器从入门到进阶(小车动态避障变道仿真 - Python)
人工智能·python·机器学习·数学建模·机器人·自动驾驶
mahuifa12 分钟前
python实现usb热插拔检测(linux)
linux·服务器·python
努力毕业的小土博^_^29 分钟前
【深度学习|学习笔记】 Generalized additive model广义可加模型(GAM)详解,附代码
人工智能·笔记·深度学习·神经网络·学习
MyhEhud40 分钟前
kotlin @JvmStatic注解的作用和使用场景
开发语言·python·kotlin
狐凄1 小时前
Python实例题:pygame开发打飞机游戏
python·游戏·pygame
漫谈网络1 小时前
Telnet 类图解析
python·自动化·netdevops·telnetlib·网络自动化运维
小小鱼儿小小林1 小时前
用AI制作黑神话悟空质感教程,3D西游记裸眼效果,西游人物跳出书本
人工智能·3d·ai画图
浪淘沙jkp1 小时前
AI大模型学习二十、利用Dify+deepseekR1 使用知识库搭建初中英语学习智能客服机器人
人工智能·llm·embedding·agent·知识库·dify·deepseek
农夫山泉2号2 小时前
【python】—conda新建python3.11的环境报错
python·conda·python3.11
AndrewHZ3 小时前
【图像处理基石】什么是油画感?
图像处理·人工智能·算法·图像压缩·视频处理·超分辨率·去噪算法