PyTorch中Linear全连接层

在 PyTorch 中,torch.nn.Linear 是一个实现全连接层(线性变换)的模块,用于神经网络中的线性变换操作。它的数学表达式为:

其中:

  • x是输入数据

  • W是权重矩阵

  • b是偏置项

  • y是输出数据

基本用法

复制代码
import torch
import torch.nn as nn

# 创建一个线性层,输入特征数为5,输出特征数为3
linear_layer = nn.Linear(in_features=5, out_features=3)

# 创建一个随机输入张量(batch_size=2, 特征数=5)
input_tensor = torch.randn(2, 5)

# 前向传播
output = linear_layer(input_tensor)
print(output.shape)  # 输出 torch.Size([2, 3])

主要参数

  1. in_features - 输入特征的数量

  2. out_features - 输出特征的数量

  3. bias - 是否使用偏置项(默认为True)

重要属性

  1. weight - 可学习的权重参数(形状为[out_features, in_features])

  2. bias - 可学习的偏置参数(形状为[out_features])

示例:构建简单神经网络

复制代码
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 20)  # 输入10维,输出20维
        self.fc2 = nn.Linear(20, 2)   # 输入20维,输出2维
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleNet()
input_data = torch.randn(5, 10)  # batch_size=5
output = model(input_data)
print(output.shape)  # torch.Size([5, 2])

初始化权重

复制代码
# 自定义权重初始化
nn.init.xavier_uniform_(linear_layer.weight)
nn.init.zeros_(linear_layer.bias)

# 或者使用PyTorch内置初始化
linear_layer = nn.Linear(5, 3)
torch.nn.init.kaiming_normal_(linear_layer.weight, mode='fan_out')

注意事项

  1. 输入数据的最后一维必须等于in_features

  2. 线性层通常与激活函数配合使用(如ReLU)

  3. 在GPU上使用时,确保数据和模型都在同一设备上。

相关推荐
苏苏susuus9 小时前
深度学习:PyTorch张量基本运算、形状改变、索引操作、升维降维、维度转置、张量拼接
人工智能·pytorch·深度学习
凡人的AI工具箱11 小时前
PyTorch深度学习框架60天进阶学习计划 - 第58天端到端对话系统(一):打造你的专属AI语音助手
人工智能·pytorch·python·深度学习·mcp·a2a
知舟不叙14 小时前
深度学习——基于PyTorch的MNIST手写数字识别详解
人工智能·pytorch·深度学习·手写数字识别
Crabfishhhhh16 小时前
神经网络学习-神经网络简介【Transformer、pytorch、Attention介绍与区别】
pytorch·python·神经网络·学习·transformer
whyeekkk19 小时前
python打卡第52天
pytorch·python·深度学习
猎嘤一号1 天前
使用 PyTorch 和 SwanLab 实时可视化模型训练
人工智能·pytorch·深度学习
福大大架构师每日一题1 天前
pytorch v2.7.1 发布!全面修复关键BUG,性能与稳定性再升级,2025年深度学习利器必备!
pytorch·深度学习·bug
凡人的AI工具箱1 天前
PyTorch深度学习框架60天进阶学习计划-第57天:因果推理模型(二)- 高级算法与深度学习融合
人工智能·pytorch·深度学习·学习·mcp·a2a
四川兔兔1 天前
pytorch 之 nn 库与调试
人工智能·pytorch·python
啊哈哈哈哈哈啊哈哈1 天前
G1周打卡——GAN入门
pytorch·深度学习·生成对抗网络