PyTorch中 nn.Linear详解和实战示例

1. nn.Linear 的作用

在 PyTorch 中,torch.nn.Linear 表示一个全连接层(Fully Connected Layer) ,也叫 仿射变换层(Affine Layer)

它的计算公式是:

y=xWT+b y = x W^T + b y=xWT+b

  • 输入 x:形状 [batch_size, in_features]
  • 权重矩阵 W:形状 [out_features, in_features]
  • 偏置 b:形状 [out_features]
  • 输出 y:形状 [batch_size, out_features]

2. 初始化方式

python 复制代码
torch.nn.Linear(
    in_features: int,
    out_features: int,
    bias: bool = True,
    device=None,
    dtype=None
)
  • in_features:输入特征维度
  • out_features:输出特征维度
  • bias :是否使用偏置项(默认 True
  • device/dtype:指定设备与数据类型

3. 参数说明

一个 nn.Linear 层包含两个参数(均可训练):

  1. weight :形状 [out_features, in_features]
  2. bias :形状 [out_features](可选)

初始化时:

  • 权重 weight 默认使用 Kaiming 均匀分布初始化(a=√5)

  • 偏置 bias 默认使用 均匀分布 U(-bound, bound),其中

    bound=1in_features bound = \frac{1}{\sqrt{in\_features}} bound=in_features 1


4. 前向传播公式

假设输入张量 x 形状为 [batch_size, in_features]

output[i]=∑j=1in_featuresx[j]⋅W[i][j]+b[i] \text{output}[i] = \sum_{j=1}^{in\_features} x[j] \cdot W[i][j] + b[i] output[i]=j=1∑in_featuresx[j]⋅W[i][j]+b[i]

即对每个样本进行线性变换。

PyTorch 内部实现是:

python 复制代码
output = input.matmul(weight.T) + bias

5. 反向传播(梯度)

PyTorch 自动求导会自动处理梯度,但核心推导如下:

  • 输入 x ∈ R^{B×I},权重 W ∈ R^{O×I},偏置 b ∈ R^{O}

  • 前向传播:

    Y=XWT+b Y = X W^T + b Y=XWT+b

  • 梯度:

    • 对权重:

      ∂L∂W=∂L∂YTX \frac{\partial L}{\partial W} = \frac{\partial L}{\partial Y}^T X ∂W∂L=∂Y∂LTX

    • 对偏置:

      ∂L∂b=∑samples∂L∂Y \frac{\partial L}{\partial b} = \sum_{samples} \frac{\partial L}{\partial Y} ∂b∂L=samples∑∂Y∂L

    • 对输入:

      ∂L∂X=∂L∂YW \frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} W ∂X∂L=∂Y∂LW


6. 使用示例

python 复制代码
import torch
import torch.nn as nn

# 定义线性层
fc = nn.Linear(in_features=5, out_features=3, bias=True)

# 输入张量 [batch=2, in_features=5]
x = torch.randn(2, 5)

# 前向传播
y = fc(x)
print("Input shape:", x.shape)   # [2, 5]
print("Output shape:", y.shape)  # [2, 3]

# 查看参数
print(fc.weight.shape)  # [3, 5]
print(fc.bias.shape)    # [3]

7. 常见用法

  1. 作为全连接层

    python 复制代码
    model = nn.Sequential(
        nn.Linear(784, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    )
  2. 替代矩阵乘法

    python 复制代码
    W = torch.randn(3, 5)
    b = torch.randn(3)
    x = torch.randn(2, 5)
    
    y1 = x @ W.T + b
    y2 = nn.Linear(5, 3)(x)
    
    print(torch.allclose(y1, y2, atol=1e-6))  # True
  3. 作为嵌入层最后一步投影

    • Transformer 中 decoder 的输出用 nn.Linear 投影到词表大小 vocab_size

8. 源码关键点

PyTorch 源码(torch/nn/modules/linear.py)核心部分:

python 复制代码
def forward(self, input: Tensor) -> Tensor:
    return F.linear(input, self.weight, self.bias)

其中 F.linear 实现就是 input.matmul(weight.T) + bias


9. 常见坑点

  1. 输入维度不对
    nn.Linear 要求输入最后一维是 in_features

    如果输入是 [batch, channels, height, width],要先 flattenpermute

    python 复制代码
    x = torch.randn(32, 3, 28, 28)
    fc = nn.Linear(3*28*28, 100)
    y = fc(x.view(32, -1))  # 展平
  2. 权重转置

    注意公式是 y = x @ W^T,而不是 x @ W

  3. 和卷积的区别

    • nn.Conv2d:局部连接 + 权重共享
    • nn.Linear:全连接,不共享权重

10. 总结

  • nn.Linear = 全连接层 = 仿射变换
  • 参数:weight [out, in]bias [out]
  • 前向公式:y = x @ W^T + b
  • 常用于:MLP、分类器最后一层、Transformer 投影层等
  • 注意输入最后一维要匹配 in_features

11. 综合应用示例

下面是一个完整的综合示例,涵盖以下内容:

定义 nn.Linear
模拟输入数据
前向传播
查看权重与偏置
反向传播 + 梯度查看
x @ W^T + b 对比验证一致性
图示化输入输出形状变化(文字+可视化)


综合示例:手写数字分类 MLP(含 nn.Linear

我们构建一个简单的 2 层感知机,模拟对输入向量进行分类。


Step 1:导入依赖 & 定义模型

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义一个简单的两层感知机(MLP)
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(16, 8)   # 第一层:16 -> 8
        self.fc2 = nn.Linear(8, 4)    # 第二层:8 -> 4(比如分类 4 类)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

Step 2:模拟输入数据

python 复制代码
# 模拟输入数据:batch_size = 3,feature_dim = 16
x = torch.randn(3, 16)

# 实例化模型
model = SimpleMLP()

# 前向传播
output = model(x)

print("输入形状:", x.shape)        # [3, 16]
print("第一层权重形状:", model.fc1.weight.shape)  # [8, 16]
print("第二层权重形状:", model.fc2.weight.shape)  # [4, 8]
print("输出形状:", output.shape)  # [3, 4]

Step 3:手动验证 x @ W^T + bnn.Linear 一致性

python 复制代码
# 取第一层验证
fc = model.fc1
x_input = x

# 手动计算:y = x @ W^T + b
manual_output = x_input @ fc.weight.T + fc.bias

# 与 forward 一致性验证
auto_output = fc(x_input)

print("是否一致:", torch.allclose(manual_output, auto_output, atol=1e-6))

Step 4:反向传播 + 查看梯度

python 复制代码
# 假设一个简单的损失函数
target = torch.tensor([0, 1, 3])     # 假设分类标签
criterion = nn.CrossEntropyLoss()

# 正向计算输出
out = model(x)

# 计算损失
loss = criterion(out, target)

# 反向传播
loss.backward()

# 查看第一层权重梯度
print("fc1 权重梯度形状:", model.fc1.weight.grad.shape)
print("fc1 偏置梯度形状:", model.fc1.bias.grad.shape)

输入输出形状变化总结

层级 输入形状 权重形状 输出形状
输入 [3, 16] -- [3, 16]
fc1 [3, 16] [8, 16] [3, 8]
ReLU [3, 8] -- [3, 8]
fc2 [3, 8] [4, 8] [3, 4]

可视化理解(流程图)

下面这个示意图帮助你直观理解 nn.Linear 是如何做维度映射的:

复制代码
输入张量 x:         [batch_size=3, in_features=16]
        │
        ▼
nn.Linear(16 → 8):  权重 [8, 16],输出 [3, 8]
        │
        ▼
      ReLU 激活
        │
        ▼
nn.Linear(8 → 4):   权重 [4, 8],输出 [3, 4]
        │
        ▼
分类输出 logits:    [batch_size=3, out_features=4]

模型结构和张量流动的视觉图


12.nn.Linear源码关键实现和典型应用

1. nn.Linear 源码关键实现

在 PyTorch 2.0+ 的源码中(torch/nn/modules/linear.py),核心实现非常精简:

python 复制代码
class Linear(Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: Tensor
    bias: Tensor | None

    def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # 权重参数(out_features x in_features)
        self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
        
        # 偏置参数(out_features)
        if bias:
            self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self) -> None:
        # 权重 Kaiming 均匀初始化,偏置 U(-bound, bound)
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.weight, self.bias)

关键点拆解:

  1. 参数存储

    • self.weight[out_features, in_features]
    • self.bias[out_features]
  2. 初始化

    • 权重:Kaiming Uniform(适合 ReLU 激活)
    • 偏置:Uniform(-1/√fan_in, 1/√fan_in)
  3. 前向计算

    python 复制代码
    def linear(input, weight, bias=None):
        return input.matmul(weight.T) + bias
  4. 梯度计算

    PyTorch 自动在 C++/CUDA backend 里定义好了 matmuladd 的梯度传播,不需要 Python 层手写。


2. nn.Linear 在不同架构中的用途

(1) 在 MLP(多层感知机)
  • 作用:核心构建模块,层层映射特征维度。

  • 例子

    python 复制代码
    model = nn.Sequential(
        nn.Linear(784, 256),  # 输入层 (28*28)
        nn.ReLU(),
        nn.Linear(256, 128),  # 隐藏层
        nn.ReLU(),
        nn.Linear(128, 10)    # 输出层 (分类)
    )
  • 解释 :每个 nn.Linear 就是一次仿射变换,把输入映射到新空间。最后一层通常对应分类 logits。


(2) 在 CNN 分类器
  • 作用 :CNN 负责提取空间特征,最后通过 nn.Linear 将卷积特征映射到分类输出空间。

  • 例子(ResNet 中的最后一层):

    python 复制代码
    class CNNClassifier(nn.Module):
        def __init__(self, num_classes=10):
            super().__init__()
            self.conv = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
            self.pool = nn.AdaptiveAvgPool2d((1, 1))
            self.fc   = nn.Linear(64, num_classes)  # 分类层
    
        def forward(self, x):
            x = F.relu(self.conv(x))
            x = self.pool(x)
            x = torch.flatten(x, 1)  # [batch, 64]
            x = self.fc(x)           # [batch, num_classes]
            return x
  • 解释nn.Linear 负责 卷积特征 → 类别预测


(3) 在 Transformer(如 BERT, GPT)

Transformer 内部大量用到 nn.Linear,主要场景有:

a. Attention 的 Q、K、V 投影
python 复制代码
self.q_proj = nn.Linear(d_model, d_k)  # 生成 Query
self.k_proj = nn.Linear(d_model, d_k)  # 生成 Key
self.v_proj = nn.Linear(d_model, d_v)  # 生成 Value
  • 输入:[batch, seq_len, d_model]
  • 输出:[batch, seq_len, d_k]
  • 用于将 embedding 投影到不同子空间。
b. Attention 输出的投影
python 复制代码
self.out_proj = nn.Linear(d_v, d_model)
  • 将多头拼接后的结果映射回 d_model 维度。
c. Feed-Forward 网络(FFN)

Transformer Block 里的 FFN 是:

FFN(x)=Linear(dmodel,dff)→ReLU/GELU→Linear(dff,dmodel) FFN(x) = \text{Linear}(d_{model}, d_{ff}) \to \text{ReLU/GELU} \to \text{Linear}(d_{ff}, d_{model}) FFN(x)=Linear(dmodel,dff)→ReLU/GELU→Linear(dff,dmodel)

例子:

python 复制代码
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
d. BERT 最后一层分类头
  • BERT 用 nn.Linear(hidden_size, vocab_size) 投影到词表维度,得到预测 logits。

  • 例如 Masked LM 任务:

    python 复制代码
    self.cls = nn.Linear(hidden_size, vocab_size)

3. 小结对比表

架构 nn.Linear 用途 输入维度 输出维度
MLP 特征逐层映射 [batch, in_features] [batch, out_features]
CNN 卷积特征 → 分类 [batch, channels] [batch, num_classes]
Transformer (a) Q/K/V 投影 (b) Attention 输出投影 © FFN 映射 (d) 分类/词表投影 [batch, seq, d_model] [batch, seq, d_k/d_ff/d_model/vocab]
BERT Masked LM / NSP 分类头 [batch, hidden_size] [batch, vocab_size]

相关推荐
做一个快乐的小傻瓜1 分钟前
机器学习笔记
人工智能·决策树·机器学习
居然JuRan6 分钟前
MCP:基础概念、快速应用和背后原理
人工智能
不枯石24 分钟前
Python实现RANSAC进行点云直线、平面、曲面、圆、球体和圆柱拟合
python·计算机视觉
1ucency31 分钟前
Dify插件“Database”安装及配置
人工智能
eqwaak043 分钟前
科技信息差(8.26)
大数据·开发语言·人工智能·编辑器
站大爷IP44 分钟前
Python Lambda:从入门到实战的轻量级函数指南
python
深盾安全44 分钟前
Python 装饰器精要
python
念夏沫1 小时前
“华生科技杯”2025年全国青少年龙舟锦标赛在海宁举行
大数据·人工智能·科技
站大爷IP1 小时前
Python爬虫基本原理与HTTP协议详解:从入门到实践
python
2202_756749691 小时前
自然处理语言NLP: 基于双分支 LSTM 的酒店评论情感分析模型构建与实现
人工智能·自然语言处理·lstm