我们用 最通俗的语言 + 生活比喻 + 代码示例 + 手动实现,向初学者彻底讲清楚:
✅
torch.nn.Linear
是什么?✅ 它的原理和作用是什么?
✅ 怎么使用它?
🧩 一句话总结(先记住这个!):
nn.Linear
就是一个"计算器",你给它一组数字,它按固定公式(矩阵乘 + 加法)算出另一组数字 ------ 而且这个公式里的"系数"是可以学习的!
🍎 生活化比喻:学生综合评分器
想象你是一个班主任,要给学生算"综合能力分":
- 输入:学生的三科成绩 → 语文、数学、英语
- 你给每科分配一个"权重"(重要性):
- 语文 × 0.3
- 数学 × 0.5
- 英语 × 0.2
- 再加一个"基础分":+10分
- 输出:综合能力分
👉 这就是一个 Linear 变换!
综合分 = 语文×0.3 + 数学×0.5 + 英语×0.2 + 10
在 nn.Linear
中:
语文, 数学, 英语
→ 输入向量x
0.3, 0.5, 0.2
→ 权重weight
10
→ 偏置bias
综合分
→ 输出y
🧮 数学公式(别怕,很简单!)
ini
y = x @ W^T + b
或
输出 = 输入 × 权重转置 + 偏置
x
:输入,形状[n]
(n个特征)W
:权重,形状[m, n]
(m个输出,n个输入)b
:偏置,形状[m]
y
:输出,形状[m]
@
表示矩阵乘法,^T
表示转置
💻 代码示例(从零开始)
python
import torch
import torch.nn as nn
# 创建一个 Linear 层:输入3个特征 → 输出1个分数
linear = nn.Linear(in_features=3, out_features=1)
# 查看参数(初始是随机的)
print("权重 W:", linear.weight) # [1, 3]
print("偏置 b:", linear.bias) # [1]
# 输入:一个学生的三科成绩
x = torch.tensor([[80.0, 90.0, 85.0]]) # shape: [1, 3]
# 前向计算
y = linear(x)
print("综合分:", y) # shape: [1, 1]
🔍 手动验证计算(超重要!)
我们手动计算一遍,看是否和 linear(x)
一致:
python
# 手动计算:y = x @ W^T + b
W = linear.weight # [1, 3]
b = linear.bias # [1]
x = torch.tensor([[80.0, 90.0, 85.0]])
y_manual = x @ W.T + b # 注意:W.T 是转置
print("手动计算:", y_manual)
print("框架计算:", y)
print("两者相等:", torch.allclose(y, y_manual)) # True ✅
🎯 作用和意义
1. 特征变换
把输入特征线性组合成新特征 → 降维、升维、特征融合
2. 可学习
权重和偏置是可训练参数 → 模型自动学会最佳"系数"
3. 基础构建块
几乎所有神经网络都由 Linear + 激活函数
堆叠而成
4. 应用场景
- 几乎无处不在!任何需要将一组特征转换为另一组特征的地方
- 多层感知机(MLP)的主要组成部分
- Transformer 中的前馈网络(Feed-Forward Network)
- 分类器或回归器的最后一层
🧱 在神经网络中的典型用法
python
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20) # 输入10维 → 输出20维
self.fc2 = nn.Linear(20, 5) # 输入20维 → 输出5维
def forward(self, x):
x = torch.relu(self.fc1(x)) # Linear + 激活函数
x = self.fc2(x)
return x
model = SimpleNet()
x = torch.randn(5, 10) # batch=5, input=10
output = model(x) # [5, 5]
📊 形状变化示例
输入形状 | Linear(in, out) | 输出形状 |
---|---|---|
[3] |
Linear(3, 1) | [1] |
[2, 3] |
Linear(3, 5) | [2, 5] |
[4, 7, 3] |
Linear(3, 2) | [4, 7, 2] |
✅
Linear
会自动处理前置维度(如 batch, sequence),只对最后一维做变换!
⚙️ 参数初始化
默认使用 Kaiming Uniform(适合 ReLU),也可自定义:
python
linear = nn.Linear(3, 1)
nn.init.xavier_uniform_(linear.weight) # Xavier 初始化
nn.init.zeros_(linear.bias) # 偏置初始化为0
🚫 常见错误
错误1:输入维度不匹配
python
linear = nn.Linear(3, 1)
x = torch.randn(2, 4) # ❌ 最后一维是4,不是3
y = linear(x) # RuntimeError!
错误2:忘记加激活函数(纯线性模型能力有限)
python
# ❌ 纯线性模型只能拟合线性关系
x = torch.relu(self.fc1(x)) # ✅ 加激活函数,变成非线性!
✅ 总结卡片
项目 | 说明 |
---|---|
中文名 | 线性层 / 全连接层 |
公式 | y = x @ W^T + b |
参数 | weight (权重), bias (偏置) |
输入形状 | [..., in_features] |
输出形状 | [..., out_features] |
典型用法 | fc = nn.Linear(in, out) → y = fc(x) |
必须注意 | 输入最后一维必须等于 in_features |
🧠 记忆口诀:
"Linear 是计算器,输入乘权重加偏置;
形状匹配别忘记,最后一维要对齐;
加上激活变非线,神经网络基石器!"