基于pytorch.nn模块实现线性模型

课程:b站up 跟李沐学AI 的系列课程 《动手学深度学习》

笔记:本笔记基于以上课程所写,有疑问的地方,可评论区留言,或自行观看b站视频

完整代码

python 复制代码
import torch
from torch import nn, optim
from torch.nn import Sequential, Linear
from torch.utils import data

# ========== 数据生成与模型定义部分 ==========

# 生成真实参数(用于模拟数据生成)
true_w = torch.tensor([2, -3.4])  # 真实权重向量(2维输入)
true_b = 4.2  # 真实偏置项


# 定义合成数据生成函数(模拟真实数据分布)
def synthetic_data(w, b, num_examples):
    """生成线性关系数据集 y = Xw + b + 高斯噪声

    参数:
    w -- 真实权重向量 (2,)
    b -- 真实偏置 (标量)
    num_examples -- 样本数量

    返回:
    X -- 特征矩阵 (num_examples, 2)
    y -- 标签向量 (num_examples, 1)
    """
    # 生成特征矩阵:从标准正态分布采样(均值0,方差1)
    X = torch.normal(0, 1, (num_examples, len(w)))
    # 计算线性关系结果
    y = torch.matmul(X, w) + b
    # 添加高斯噪声(模拟测量误差,标准差0.01)
    y += torch.normal(0, 0.01, y.shape)
    # 调整标签形状为列向量
    return X, y.reshape((-1, 1))


# 生成1000个样本的合成数据集
features, labels = synthetic_data(true_w, true_b, 1000)
print("features shape : ", features.shape)  # (1000, 2)
print("labels shape : ", labels.shape)  # (1000, 1)
print('features:', features[:3], '\nlabel:', labels[:3])

# 创建数据加载器(批量处理)
batch_size = 10
# 将特征和标签组合为数据集,打乱顺序后分批加载
data_iter = data.DataLoader(
    dataset=data.TensorDataset(features, labels),  # 数据集包装
    batch_size=batch_size,  # 每批10个样本
    shuffle=True  # 训练前打乱顺序
)

# ========== 模型定义与初始化 ==========

# 定义线性回归模型(使用Sequential容器)
net = Sequential(Linear(2, 1))  # 输入2维,输出1维的线性层

# 参数初始化(重要!影响模型收敛)
# 权重初始化:均值0,标准差0.01的正态分布
net[0].weight.data.normal_(0, 0.01)
# 偏置初始化:全零
net[0].bias.data.fill_(0)

# ========== 训练配置 ==========

# 超参数设置
lr = 0.03  # 学习率(控制参数更新步长)
num_epochs = 3  # 训练轮数(遍历整个数据集的次数)

# 损失函数:均方误差(MSE)
loss = nn.MSELoss()

# 优化器:随机梯度下降(SGD)
optimizer = optim.SGD(
    params=net.parameters(),  # 需优化的参数
    lr=lr  # 学习率
)

# ========== 训练循环 ==========

for epoch in range(num_epochs):
    # 遍历每个批次的数据
    for X, y in data_iter:
        # 前向传播:计算预测值
        y_hat = net(X)
        # 计算当前批次的损失
        l = loss(y_hat, y)

        # 反向传播准备
        optimizer.zero_grad()  # 清空梯度(防止梯度累积)
        l.backward()  # 自动计算梯度

        # 参数更新
        optimizer.step()  # 根据梯度更新参数

    # 每轮结束后计算整个数据集的损失(验证效果)
    with torch.no_grad():  # 禁用梯度计算(节省内存)
        # 计算所有样本的预测值
        y_hat_all = net(features)
        # 计算平均损失
        total_loss = loss(y_hat_all, labels)
        print(f'epoch {epoch + 1}, loss {total_loss.item():.4f}')

# ========== 结果验证 ==========

# 提取训练得到的参数
train_w = net[0].weight.data  # 权重参数
train_b = net[0].bias.data  # 偏置参数

# 打印训练结果与真实参数对比
print("训练得到的权重:", train_w, "\n真实权重:", true_w)
print("训练得到的偏置:", train_b, "\n真实偏置:", true_b)

输出结果

bash 复制代码
features shape :  torch.Size([1000, 2])
labels shape :  torch.Size([1000, 1])
features: tensor([[-0.7123,  0.7392],
        [ 0.9638, -1.7860],
        [ 1.6612, -0.6634]]) 
label: tensor([[ 0.2535],
        [12.2078],
        [ 9.7877]])
epoch 1, loss 0.0002
epoch 2, loss 0.0001
epoch 3, loss 0.0001
训练得到的权重: tensor([[ 1.9998, -3.4003]]) 
真实权重: tensor([ 2.0000, -3.4000])
训练得到的偏置: tensor([4.2002]) 
真实偏置: 4.2

Process finished with exit code 0
相关推荐
数据知道14 分钟前
将英文PDF文件完整地翻译成中文的4类方式
人工智能·学习·自然语言处理·pdf·机器翻译
大千AI助手15 分钟前
RAGFoundry:面向检索增强生成的模块化增强框架
人工智能·大模型·llm·微调·rag·检索·ragfoundry
dxnb2219 分钟前
Datawhale+AI夏令营_让AI读懂财报PDF task2深入赛题笔记
人工智能·笔记·pdf
那就摆吧20 分钟前
AI赋能6G网络安全研究:智能威胁检测与自动化防御
人工智能·web安全·ai·自动化·6g
产品经理独孤虾21 分钟前
流程优化点识别与分析:从混沌到清晰的产品体验突破法
人工智能·产品经理·需求分析·产品设计·提示词工程·deepseek·业务流程优化
2501_9247474532 分钟前
强光干扰下误报率↓82%!陌讯多模态算法在睡岗检测的落地优化
人工智能·深度学习·算法·目标检测·计算机视觉
加速财经1 小时前
WEEX参与欧洲两场重要Web3线下活动,助力社区协作与技术交流
人工智能·web3
Aousdu1 小时前
算法_python_学习记录_01
python·学习·算法
说私域1 小时前
基于开源AI大模型、AI智能名片与S2B2C商城小程序的零售智能化升级路径研究
人工智能·小程序·开源
AI科技分享1 小时前
仅需8W,无人机巡检系统落地 AI 低空智慧城市!可源码交付
人工智能·无人机·智慧城市