一文讲清 nn.Linear 线性变换

我们用 最通俗的语言 + 生活比喻 + 代码示例 + 手动实现,向初学者彻底讲清楚:

torch.nn.Linear 是什么?

✅ 它的原理和作用是什么?

✅ 怎么使用它?


🧩 一句话总结(先记住这个!):

nn.Linear 就是一个"计算器",你给它一组数字,它按固定公式(矩阵乘 + 加法)算出另一组数字 ------ 而且这个公式里的"系数"是可以学习的!


🍎 生活化比喻:学生综合评分器

想象你是一个班主任,要给学生算"综合能力分":

  • 输入:学生的三科成绩 → 语文、数学、英语
  • 你给每科分配一个"权重"(重要性):
    • 语文 × 0.3
    • 数学 × 0.5
    • 英语 × 0.2
  • 再加一个"基础分":+10分
  • 输出:综合能力分

👉 这就是一个 Linear 变换

复制代码
综合分 = 语文×0.3 + 数学×0.5 + 英语×0.2 + 10

nn.Linear 中:

  • 语文, 数学, 英语 → 输入向量 x
  • 0.3, 0.5, 0.2 → 权重 weight
  • 10 → 偏置 bias
  • 综合分 → 输出 y

🧮 数学公式(别怕,很简单!)

ini 复制代码
y = x @ W^T + b

复制代码
输出 = 输入 × 权重转置 + 偏置
  • x:输入,形状 [n](n个特征)
  • W:权重,形状 [m, n](m个输出,n个输入)
  • b:偏置,形状 [m]
  • y:输出,形状 [m]

@ 表示矩阵乘法,^T 表示转置


💻 代码示例(从零开始)

python 复制代码
import torch
import torch.nn as nn

# 创建一个 Linear 层:输入3个特征 → 输出1个分数
linear = nn.Linear(in_features=3, out_features=1)

# 查看参数(初始是随机的)
print("权重 W:", linear.weight)  # [1, 3]
print("偏置 b:", linear.bias)    # [1]

# 输入:一个学生的三科成绩
x = torch.tensor([[80.0, 90.0, 85.0]])  # shape: [1, 3]

# 前向计算
y = linear(x)
print("综合分:", y)  # shape: [1, 1]

🔍 手动验证计算(超重要!)

我们手动计算一遍,看是否和 linear(x) 一致:

python 复制代码
# 手动计算:y = x @ W^T + b
W = linear.weight  # [1, 3]
b = linear.bias    # [1]
x = torch.tensor([[80.0, 90.0, 85.0]])

y_manual = x @ W.T + b  # 注意:W.T 是转置
print("手动计算:", y_manual)
print("框架计算:", y)
print("两者相等:", torch.allclose(y, y_manual))  # True ✅

🎯 作用和意义

1. 特征变换

把输入特征线性组合成新特征 → 降维、升维、特征融合

2. 可学习

权重和偏置是可训练参数 → 模型自动学会最佳"系数"

3. 基础构建块

几乎所有神经网络都由 Linear + 激活函数 堆叠而成

4. 应用场景

  • 几乎无处不在!任何需要将一组特征转换为另一组特征的地方
  • 多层感知机(MLP)的主要组成部分
  • Transformer 中的前馈网络(Feed-Forward Network)
  • 分类器或回归器的最后一层

🧱 在神经网络中的典型用法

python 复制代码
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)  # 输入10维 → 输出20维
        self.fc2 = nn.Linear(20, 5)   # 输入20维 → 输出5维

    def forward(self, x):
        x = torch.relu(self.fc1(x))   # Linear + 激活函数
        x = self.fc2(x)
        return x

model = SimpleNet()
x = torch.randn(5, 10)  # batch=5, input=10
output = model(x)       # [5, 5]

📊 形状变化示例

输入形状 Linear(in, out) 输出形状
[3] Linear(3, 1) [1]
[2, 3] Linear(3, 5) [2, 5]
[4, 7, 3] Linear(3, 2) [4, 7, 2]

Linear 会自动处理前置维度(如 batch, sequence),只对最后一维做变换!


⚙️ 参数初始化

默认使用 Kaiming Uniform(适合 ReLU),也可自定义:

python 复制代码
linear = nn.Linear(3, 1)
nn.init.xavier_uniform_(linear.weight)  # Xavier 初始化
nn.init.zeros_(linear.bias)             # 偏置初始化为0

🚫 常见错误

错误1:输入维度不匹配

python 复制代码
linear = nn.Linear(3, 1)
x = torch.randn(2, 4)  # ❌ 最后一维是4,不是3
y = linear(x)          # RuntimeError!

错误2:忘记加激活函数(纯线性模型能力有限)

python 复制代码
# ❌ 纯线性模型只能拟合线性关系
x = torch.relu(self.fc1(x))  # ✅ 加激活函数,变成非线性!

✅ 总结卡片

项目 说明
中文名 线性层 / 全连接层
公式 y = x @ W^T + b
参数 weight(权重), bias(偏置)
输入形状 [..., in_features]
输出形状 [..., out_features]
典型用法 fc = nn.Linear(in, out)y = fc(x)
必须注意 输入最后一维必须等于 in_features

🧠 记忆口诀:

"Linear 是计算器,输入乘权重加偏置;
形状匹配别忘记,最后一维要对齐;
加上激活变非线,神经网络基石器!"

相关推荐
你好~每一天9 小时前
2025 中小企业 AI 转型:核心岗技能 “怎么证、怎么用”?
人工智能·百度·数据挖掘·数据分析·职业·转行
飞哥数智坊10 小时前
3B参数差点干翻32B模型,Qwen3 Next 是如何做到的?
人工智能
人工智能技术派10 小时前
Whisper推理源码解读
人工智能·语言模型·whisper·语音识别
编码追梦人11 小时前
AI 重塑行业格局:从金融风控到智能制造的深度实践
人工智能·制造
Lululaurel11 小时前
提示工程深度解析:驾驭大语言模型的艺术与科学
人工智能·ai·aigc·提示词
simon_skywalker11 小时前
第7章 n步时序差分 n步时序差分预测
人工智能·算法·强化学习
唐兴通个人12 小时前
清华大学AI领导力AI时代领导力AI变革领导力培训师培训讲师专家唐兴通讲授数字化转型人工智能组织创新实践领导力国央企国有企业金融运营商制造业
人工智能·数据挖掘
云卓SKYDROID12 小时前
无人机定点派送技术要点与运行方式
人工智能·无人机·航电系统·高科技·云卓科技
码界筑梦坊12 小时前
206-基于深度学习的胸部CT肺癌诊断项目的设计与实现
人工智能·python·深度学习·flask·毕业设计
通往曙光的路上13 小时前
国庆回来的css
人工智能·python·tensorflow