Pytorch:模型线性回归

目录

一、模拟线性回归

二、线性回归问题如何选择机器学习还是神经网络

三、实操建议

四、简单应用示例


一、神经网络模拟线性回归

代码如下:

python 复制代码
# 导入相关模块
import torch
from torch.utils.data import TensorDataset  # 构造数据集对象
from torch.utils.data import DataLoader  # 数据加载器
from torch import nn  # nn模块中有平方损失函数和假设函数
from torch import optim  # optim模块中有优化器函数
from sklearn.datasets import make_regression  # 创建线性回归模型数据集
import matplotlib.pyplot as plt  # 可视化

plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号


# 1. 定义函数, 创建线性回归样本数据.
def create_dataset():
    # 1. 创建数据集对象.
    x, y, coef = make_regression(
        n_samples=100,  # 100条样本(100个样本点)
        n_features=1,  # 1个特征(1个特征点)
        noise=10,  # 噪声, 噪声越大, 样本点越散, 噪声越小, 样本点越集中
        coef=True,  # 是否返回系数, 默认为False, 返回值为None
        bias=14.5,  # 偏置
        random_state=3  # 随机种子, 随机种子相同, 输出数据相同
    )

    # 2. 把上述的数据, 封装成 张量对象.
    x = torch.tensor(x, dtype=torch.float32)
    y = torch.tensor(y, dtype=torch.float32)

    # 3. 返回结果.
    return x, y, coef


# 2. 定义函数, 表示模型训练.
def train(x, y, coef):
    # 1. 创建数据集对象. 把 tensor -> 数据集对象 -> 数据加载器.
    dataset = TensorDataset(x, y)
    # 2. 创建数据加载器对象.
    # 参1: 数据集对象, 参2: 批次大小, 参3: 是否打乱数据(训练集打乱, 测试集不打乱)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
    # 3. 创建初始的 线性回归模型.
    # 参1: 输入特征维度, 参2: 输出特征维度.
    model = nn.Linear(1, 1)
    # 4. 创建损失函数对象.
    criterion = nn.MSELoss()
    # 5. 创建优化器对象.
    # 参1: 模型参数, 参2: 学习率.
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    # 6. 具体的训练过程.=
    # 6.1 定义变量, 分别表示: 训练轮数, 每轮的(平均)损失值, 训练总损失值, 训练的样本数.
    epochs, loss_list, total_loss, total_sample = 100, [], 0.0, 0
    # 6.2 开始训练, 按轮训练.
    for epoch in range(epochs):  # epoch的值: 0, 1, 2...99
        # 6.3 每轮是分 批次 训练的, 所以从 数据加载器中 获取 批次数据.
        for train_x, train_y in dataloader:  # 7批(16, 16, 16, 16, 16, 16, 4)
            # 6.4 模型预测.
            y_pred = model(train_x)
            # 6.5 计算(每批的平均)损失值.
            loss = criterion(y_pred, train_y.reshape(-1, 1))  # -1 自动计算.
            # 6.6 计算总损失 和 样本(批次)数
            total_loss += loss.item()
            total_sample += 1
            # 6.7 梯度清零 + 反向传播 + 梯度更新.
            optimizer.zero_grad()  # 梯度清零
            loss.backward()  # 反向传播, 计算梯度
            optimizer.step()  # 梯度更新

        # 6.8 把本轮的(平均)损失值, 添加到列表中.
        loss_list.append(total_loss / total_sample)
        print(f'轮数: {epoch + 1}, 平均损失值: {total_loss / total_sample}')

    # 7. 打印(最终的)训练结果.
    print(f'{epochs} 轮的平均损失分别为: {loss_list}')
    print(f'模型参数, 权重: {model.weight}, 偏置: {model.bias}')

    # 8. 绘制损失曲线.
    #                 100轮         每轮的平均损失值
    plt.plot(range(epochs), loss_list)
    plt.title('损失值曲线变化图')
    plt.grid()      # 绘制网格线
    plt.show()

    # 9. 绘制预测值和真实值的关系.
    # 9.1 绘制样本点分布情况.
    plt.scatter(x, y)
    # 9.2 绘制训练模型的预测值.
    # x: 100个样本点的特征.
    y_pred = torch.tensor(data = [v * model.weight + model.bias for v in x])
    # 9.3 计算真实值.
    y_true = torch.tensor(data = [v * coef + 14.5 for v in x])
    # 9.4 绘制预测值 和 真实值的 折线图.
    plt.plot(x, y_pred, color='red', label='预测值')
    plt.plot(x, y_true, color='green', label='真实值')
    # 9.5 图例, 网格.
    plt.legend()
    plt.grid()
    # 9.6 显示图像.
    plt.show()


    plt.show()


# 3. 测试.
if __name__ == '__main__':
    # 3.1 创建数据集.
    x, y, coef = create_dataset()
    # print(f'x: {x}, y: {y}, coef: {coef}')

    # 3.2 模型训练.
    train(x, y, coef)

二、线性回归问题应该选择机器学习还是神经网络

先给结论:

标准线性回归任务:优先用传统机器学习线性回归,完全没必要用神经网络版线性回归

为什么常规任务选「机器学习线性回归」

1. 优点:传统机器学习线性回归

速度极快:最小二乘法一步闭式解,不用迭代训练

可解释性极强:每个权重 w 代表特征影响力、正负相关性、系数大小一目了然

超简单、参数少、不易过拟合

结果稳定、可复现,无随机初始化扰动

不用调参:不需要学习率、epoch、batch、优化器

2. 缺点:神经网络做线性回归

❌ 杀鸡用牛刀,代码更复杂、训练慢

❌ 有随机初始化,结果每次略有差异

❌ 超参数一堆要调(学习率、迭代次数)

完全丢失可解释性

❌ 额外算力消耗,纯冗余设计

只有下面这些场景,才换神经网络 / 深度学习回归:

  1. 数据非线性强 普通线性拟合不动,需要加:隐藏层 + 激活函数(ReLU/Sigmoid),变成深度非线性回归
  2. 特征维度极高图像、文本、序列数据(几千 / 上万维特征),传统模型扛不住
  3. 后续要衔接深度学习链路比如:先回归、再接分类 / 预测,整套流水线都是 DL 框架
  4. 海量大数据需要 mini-batch 分批训练、GPU 加速

只要你是:表格结构化数据、想要线性拟合、看特征系数 → 坚决用传统 ML 线性回归

三、实操建议

场景 1:普通表格数据、预测连续值(房价、销量、产量)

👉 用:sklearn.linear_model.LinearRegression / 岭回归 / Lasso

场景 2:数据有非线性、曲线关系

👉 用:

  • 传统:多项式回归、随机森林、XGBoost 回归
  • 深度:多层神经网络回归(加隐藏层)

场景 3:图像 / 时序 / 文本回归

👉 用:深度学习神经网络

四、简单应用示例

1、机器学习

python 复制代码
import numpy as np
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt

# 构造模拟数据
np.random.seed(0)
X = np.linspace(0, 10, 100).reshape(-1, 1)  # 特征
y = 2.5 * X + 5 + np.random.randn(100, 1)    # y = 2.5x + 5 + 噪声

# 建模+训练
lr = LinearRegression()
lr.fit(X, y)

# 查看参数(极强可解释)
print("权重w =", lr.coef_[0][0])
print("偏置b =", lr.intercept_[0])

# 预测
y_pred = lr.predict(X)

# 画图
plt.scatter(X, y, label="真实数据")
plt.plot(X, y_pred, "r", lw=2, label="线性回归拟合")
plt.legend()
plt.show()

2、神经网络

代码如下:

python 复制代码
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# 1. 构造同样数据
np.random.seed(0)
torch.manual_seed(0)
X_np = np.linspace(0, 10, 100).reshape(-1, 1)
y_np = 2.5 * X_np + 5 + np.random.randn(100, 1)

# 转张量
X = torch.tensor(X_np, dtype=torch.float32)
y = torch.tensor(y_np, dtype=torch.float32)

# 2. 定义 单层线性网络
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # 输入1维,输出1维,无激活
        self.linear = nn.Linear(in_features=1, out_features=1)
        
    def forward(self, x):
        return self.linear(x)

model = Net()

# 3. 损失函数 & 优化器
loss_fn = nn.MSELoss()
opt = torch.optim.SGD(model.parameters(), lr=0.01)

# 4. 迭代训练
epochs = 1000
for epoch in range(epochs):
    y_pred = model(X)
    loss = loss_fn(y_pred, y)
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    if (epoch+1) % 200 == 0:
        print(f"epoch:{epoch+1}, loss:{loss.item():.4f}")

# 查看神经网络学到的 w 和 b
w, b = model.linear.weight.item(), model.linear.bias.item()
print(f"\n神经网络学到:w={w:.3f}, b={b:.3f}")

# 预测绘图
y_pred_all = model(X).detach().numpy()
plt.scatter(X_np, y_np, label="真实数据")
plt.plot(X_np, y_pred_all, "r", lw=2, label="NN线性回归拟合")
plt.legend()
plt.show()

只要加一层隐藏层 + 激活函数,就从「线性」变成「非线性回归」。

相关推荐
执于代码2 小时前
python 环境知多少
开发语言·python
itzixiao2 小时前
L1-054 福到了(15 分)[java][python]
java·python·算法
斯维赤2 小时前
Python学习超简单第十一弹:邮件发送
开发语言·python·学习
qq_372154232 小时前
如何配置表中某列的排序权重_全文索引配置与权重分配
jvm·数据库·python
还是阿落呀2 小时前
如何判断一个年份是否为闰年?
python
2501_914245933 小时前
CSS如何使用-nth-of-type精确选择列表项_通过元素类型限制提升样式健壮性
jvm·数据库·python
overmind3 小时前
oeasy Python 124 序列_字符串_string_str
开发语言·python
吕源林3 小时前
Golang如何做本地缓存加速_Golang本地缓存教程【核心】
jvm·数据库·python
_深海凉_3 小时前
LeetCode热题100-26. 删除有序数组中的重复项
python·算法·leetcode