一文讲清 nn.Linear 线性变换

我们用 最通俗的语言 + 生活比喻 + 代码示例 + 手动实现,向初学者彻底讲清楚:

torch.nn.Linear 是什么?

✅ 它的原理和作用是什么?

✅ 怎么使用它?


🧩 一句话总结(先记住这个!):

nn.Linear 就是一个"计算器",你给它一组数字,它按固定公式(矩阵乘 + 加法)算出另一组数字 ------ 而且这个公式里的"系数"是可以学习的!


🍎 生活化比喻:学生综合评分器

想象你是一个班主任,要给学生算"综合能力分":

  • 输入:学生的三科成绩 → 语文、数学、英语
  • 你给每科分配一个"权重"(重要性):
    • 语文 × 0.3
    • 数学 × 0.5
    • 英语 × 0.2
  • 再加一个"基础分":+10分
  • 输出:综合能力分

👉 这就是一个 Linear 变换

复制代码
综合分 = 语文×0.3 + 数学×0.5 + 英语×0.2 + 10

nn.Linear 中:

  • 语文, 数学, 英语 → 输入向量 x
  • 0.3, 0.5, 0.2 → 权重 weight
  • 10 → 偏置 bias
  • 综合分 → 输出 y

🧮 数学公式(别怕,很简单!)

ini 复制代码
y = x @ W^T + b

复制代码
输出 = 输入 × 权重转置 + 偏置
  • x:输入,形状 [n](n个特征)
  • W:权重,形状 [m, n](m个输出,n个输入)
  • b:偏置,形状 [m]
  • y:输出,形状 [m]

@ 表示矩阵乘法,^T 表示转置


💻 代码示例(从零开始)

python 复制代码
import torch
import torch.nn as nn

# 创建一个 Linear 层:输入3个特征 → 输出1个分数
linear = nn.Linear(in_features=3, out_features=1)

# 查看参数(初始是随机的)
print("权重 W:", linear.weight)  # [1, 3]
print("偏置 b:", linear.bias)    # [1]

# 输入:一个学生的三科成绩
x = torch.tensor([[80.0, 90.0, 85.0]])  # shape: [1, 3]

# 前向计算
y = linear(x)
print("综合分:", y)  # shape: [1, 1]

🔍 手动验证计算(超重要!)

我们手动计算一遍,看是否和 linear(x) 一致:

python 复制代码
# 手动计算:y = x @ W^T + b
W = linear.weight  # [1, 3]
b = linear.bias    # [1]
x = torch.tensor([[80.0, 90.0, 85.0]])

y_manual = x @ W.T + b  # 注意:W.T 是转置
print("手动计算:", y_manual)
print("框架计算:", y)
print("两者相等:", torch.allclose(y, y_manual))  # True ✅

🎯 作用和意义

1. 特征变换

把输入特征线性组合成新特征 → 降维、升维、特征融合

2. 可学习

权重和偏置是可训练参数 → 模型自动学会最佳"系数"

3. 基础构建块

几乎所有神经网络都由 Linear + 激活函数 堆叠而成

4. 应用场景

  • 几乎无处不在!任何需要将一组特征转换为另一组特征的地方
  • 多层感知机(MLP)的主要组成部分
  • Transformer 中的前馈网络(Feed-Forward Network)
  • 分类器或回归器的最后一层

🧱 在神经网络中的典型用法

python 复制代码
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)  # 输入10维 → 输出20维
        self.fc2 = nn.Linear(20, 5)   # 输入20维 → 输出5维

    def forward(self, x):
        x = torch.relu(self.fc1(x))   # Linear + 激活函数
        x = self.fc2(x)
        return x

model = SimpleNet()
x = torch.randn(5, 10)  # batch=5, input=10
output = model(x)       # [5, 5]

📊 形状变化示例

输入形状 Linear(in, out) 输出形状
[3] Linear(3, 1) [1]
[2, 3] Linear(3, 5) [2, 5]
[4, 7, 3] Linear(3, 2) [4, 7, 2]

Linear 会自动处理前置维度(如 batch, sequence),只对最后一维做变换!


⚙️ 参数初始化

默认使用 Kaiming Uniform(适合 ReLU),也可自定义:

python 复制代码
linear = nn.Linear(3, 1)
nn.init.xavier_uniform_(linear.weight)  # Xavier 初始化
nn.init.zeros_(linear.bias)             # 偏置初始化为0

🚫 常见错误

错误1:输入维度不匹配

python 复制代码
linear = nn.Linear(3, 1)
x = torch.randn(2, 4)  # ❌ 最后一维是4,不是3
y = linear(x)          # RuntimeError!

错误2:忘记加激活函数(纯线性模型能力有限)

python 复制代码
# ❌ 纯线性模型只能拟合线性关系
x = torch.relu(self.fc1(x))  # ✅ 加激活函数,变成非线性!

✅ 总结卡片

项目 说明
中文名 线性层 / 全连接层
公式 y = x @ W^T + b
参数 weight(权重), bias(偏置)
输入形状 [..., in_features]
输出形状 [..., out_features]
典型用法 fc = nn.Linear(in, out)y = fc(x)
必须注意 输入最后一维必须等于 in_features

🧠 记忆口诀:

"Linear 是计算器,输入乘权重加偏置;
形状匹配别忘记,最后一维要对齐;
加上激活变非线,神经网络基石器!"

相关推荐
之歆1 天前
Spring AI入门到实战到原理源码-MCP
java·人工智能·spring
知乎的哥廷根数学学派1 天前
面向可信机械故障诊断的自适应置信度惩罚深度校准算法(Pytorch)
人工智能·pytorch·python·深度学习·算法·机器学习·矩阵
且去填词1 天前
DeepSeek :基于 Schema 推理与自愈机制的智能 ETL
数据仓库·人工智能·python·语言模型·etl·schema·deepseek
待续3011 天前
订阅了 Qoder 之后,我想通过这篇文章分享一些个人使用心得和感受。
人工智能
weixin_397578021 天前
人工智能发展历史
人工智能
强盛小灵通专卖员1 天前
基于深度学习的山体滑坡检测科研辅导:从论文实验到系统落地的完整思路
人工智能·深度学习·sci·小论文·山体滑坡
OidEncoder1 天前
从 “粗放清扫” 到 “毫米级作业”,编码器重塑环卫机器人新能力
人工智能·自动化·智慧城市
Hcoco_me1 天前
大模型面试题61:Flash Attention中online softmax(在线softmax)的实现方式
人工智能·深度学习·自然语言处理·transformer·vllm
阿部多瑞 ABU1 天前
`chenmo` —— 可编程元叙事引擎 V2.3+
linux·人工智能·python·ai写作
极海拾贝1 天前
GeoScene解决方案中心正式上线!
大数据·人工智能·深度学习·arcgis·信息可视化·语言模型·解决方案