1.Pytorch模型应用(线性与非线性预测)

学习 PyTorch 时如果只听术语解释(比如"张量是多维数组""自动求导用于反向传播"),确实容易懵,也记不住。

下面我用 一个完整的、超简单的例子,从零开始,带你一步步写一个模型,并在每一步告诉你:

  • 你在做什么

  • 为什么要做这个

  • 对应哪个术语

用 PyTorch 线性预测「学生是否通过考试」

  • 输入:学习时间(小时) → 比如 [2.0, 5.0, 1.0, ...]

  • 输出:是否通过(0 = 挂科,1 = 通过) → 比如 [0, 1, 0, ...]

第一步 张量准备

python 复制代码
import torch

# 学习时间(小时)
hours = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
# 是否通过(0 或 1)
passed = [0, 0, 0, 1, 1, 1, 1]

# 把列表变成 PyTorch 的张量(Tensor)
x = torch.tensor(hours).unsqueeze(1)  # 变成 [[1.0], [2.0], ..., [7.0]] → 形状 [7, 1]
y = torch.tensor(passed).float()      # 标签必须是 float(因为后面用 BCELoss)
  • 把普通数字列表变成 PyTorch 能处理的格式。
  • 模型只能处理张量(Tensor),不能直接处理 Python 列表。
  • 张量(Tensor)→ 就是"PyTorch 专用的数组",支持 GPU、自动求导等。

第二步 模型定义

python 复制代码
import torch.nn as nn

class PassPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        # 定义一个"线性层":输入1个数(学习时间),输出1个数(通过分数)
        self.linear = nn.Linear(in_features=1, out_features=1)

    def forward(self, x):
        # 前向传播:输入x → 经过线性层 → 输出一个数
        return torch.sigmoid(self.linear(x))  # 用 sigmoid 把输出变成 0~1 的概率

model = PassPredictor()
  • 创建了一个"预测器",内部只有一层:y = w * x + b,再套个 sigmoid 变成概率。
  • 模型需要有"可学习的参数"(w 和 b),并且要能从输入得到输出。
  • nn.Module → 所有 PyTorch 模型都要继承它,里面写 __init__(定义结构)和 forward(定义计算流程)。

第三步 训练工具设计

python 复制代码
# 损失函数:衡量预测和真实答案的差距(这里用二分类交叉熵)
loss_fn = nn.BCELoss()

# 优化器:负责更新模型参数(这里用 Adam,一种聪明的更新方法)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
  • 损失函数:告诉模型"你错得多离谱"
  • 优化器:根据错误,自动调整 w 和 b,让下次错得少一点

第四步 模型训练

python 复制代码
for epoch in range(100):  # 训练 100 轮
    # 1. 前向计算:输入x,得到预测
    y_pred = model(x)  # 形状 [7, 1]

    # 2. 计算损失
    loss = loss_fn(y_pred.squeeze(), y)  # squeeze 把 [7,1] 变成 [7]

    # 3. 清空旧梯度(非常重要!)
    optimizer.zero_grad()

    # 4. 反向传播:计算每个参数该往哪边改(自动求导!)
    loss.backward()

    # 5. 更新参数(w 和 b)
    optimizer.step()

    if epoch % 20 == 0:
        print(f"第 {epoch} 轮,损失 = {loss.item():.4f}")

loss.backward自动求导(Autograd):PyTorch 自动算出 w 和 b 的梯度(该增大还是减小)

第五步 模型预测

python 复制代码
# 预测:学习 2.5 小时,能通过吗?
test_input = torch.tensor([[2.5]])
prob = model(test_input)
print(f"学习 2.5 小时,通过概率 = {prob.item():.2%}")

向训练好的模型输入待预测值,得到真实预测值。

非线性应用

在上面的内容中,学习时间越久,通过率越高。但是这很不合理,比如一个人熬夜学习,那他的成绩也会下降,所以我们将数据这么改

python 复制代码
# 学习时间(小时)
hours = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
# 是否通过(0 或 1)
passed = [0, 0, 0, 1, 1, 1, 1, 0, 0, 0]

再用之前的线性模型预测就没有效果。

这本质上是在拟合一条 S型曲线(因为加了 sigmoid),但底子还是直线!

  • 它只能表示:"x 越大,y 越大" 或 "x 越大,y 越小"

  • 它无法表示"先升后降"(即非单调关系)

隐藏层

所以我们需要弄非线性模型,使用隐藏层。

python 复制代码
class PassPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        # 输入1 → 隐藏层10个神经元 → 输出1
        self.fc1 = nn.Linear(1, 10)   # 第一层
        self.fc2 = nn.Linear(10, 1)   # 第二层
        self.relu = nn.ReLU()         # 非线性激活函数

    def forward(self, x):
        x = self.relu(self.fc1(x))    # 先线性变换,再加非线性
        x = torch.sigmoid(self.fc2(x)) # 输出概率
        return x
  • ReLU 激活函数让模型能"弯曲"
  • 隐藏层(10个神经元)提供了拟合复杂形状的能力
  • 这种结构叫 多层感知机(MLP),能逼近任意连续函数(万能近似定理)

大模型(如 LLaMA、Qwen)内部有成千上万个非线性层(Transformer + 激活函数),所以它能学极其复杂的模式。

其它的非线性激活函数

激活函数 公式 图像形状 输出范围
Tanh tanh⁡(x)=ex−e−xex+e−xtanh(x)=ex+e−xex−e−x​ S型,对称于原点 (-1, 1)
SiLU(又叫 Swish) SiLU(x)=x⋅σ(x)=x⋅11+e−xSiLU(x)=x⋅σ(x)=x⋅1+e−x1​ 平滑、右升左降,不对称 (-0.278, ∞)

你可以把 SiLU 理解为:"带权重的 ReLU" ------ 负数部分不直接砍成 0,而是保留一点小尾巴。

nn.Tanh()用于早期 RNN、需要输出对称的场景。在 0 附近梯度 ≈ 1,但两端会梯度消失(→0)

nn.SiLU()用于现代大模型 (如 Vision Transformer、LLaMA、Qwen)。梯度更平滑,不易消失,尤其在负区仍有微弱梯度

一般来说,我们更推荐用SiLU,这是经验之谈。

相关推荐
会飞的老朱1 小时前
医药集团数智化转型,智能综合管理平台激活集团管理新效能
大数据·人工智能·oa协同办公
聆风吟º3 小时前
CANN runtime 实战指南:异构计算场景中运行时组件的部署、调优与扩展技巧
人工智能·神经网络·cann·异构计算
寻星探路3 小时前
【深度长文】万字攻克网络原理:从 HTTP 报文解构到 HTTPS 终极加密逻辑
java·开发语言·网络·python·http·ai·https
Codebee5 小时前
能力中心 (Agent SkillCenter):开启AI技能管理新时代
人工智能
聆风吟º6 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
uesowys6 小时前
Apache Spark算法开发指导-One-vs-Rest classifier
人工智能·算法·spark
AI_56786 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子6 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
ValhallaCoder6 小时前
hot100-二叉树I
数据结构·python·算法·二叉树
智驱力人工智能6 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算