自定义数据集使用框架的线性回归方法对其进行拟合

代码

python 复制代码
import torch
import numpy as np
import torch.nn as nn
 
criterion = nn.MSELoss()
 
data = np.array([[-0.5, 7.7],
                 [1.8, 98.5],
                 [0.9, 57.8],
                 [0.4, 39.2],
                 [-1.4, -15.7],
                 [-1.4, -37.3],
                 [-1.8, -49.1],
                 [1.5, 75.6],
                 [0.4, 34.0],
                 [0.8, 62.3]])
 
x_data = data[:, 0]
y_data = data[:, 1]
 
x_train = torch.tensor(x_data, dtype=torch.float32)
y_train = torch.tensor(y_data, dtype=torch.float32)
 
model = nn.Sequential(nn.Linear(1, 1))
 
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
epoch = 500
for n in range(1,epoch+1):
    y_prd = model(x_train.unsqueeze(1))
    loss = criterion(y_prd.squeeze(1), y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if n % 10 == 0 or n == 1:
        print(f'epoch: {n}, loss: {loss}')

运行结果:

python 复制代码
epoch: 1, loss: 2957.510986328125
epoch: 10, loss: 1785.2095947265625
epoch: 20, loss: 1031.366455078125
epoch: 30, loss: 606.9622192382812
epoch: 40, loss: 366.8883361816406
epoch: 50, loss: 230.34017944335938
epoch: 60, loss: 152.19020080566406
epoch: 70, loss: 107.1496810913086
epoch: 80, loss: 80.98954772949219
epoch: 90, loss: 65.66667175292969
epoch: 100, loss: 56.6099967956543
epoch: 110, loss: 51.205726623535156
epoch: 120, loss: 47.949058532714844
epoch: 130, loss: 45.966827392578125
epoch: 140, loss: 44.74833297729492
epoch: 150, loss: 43.99201583862305
epoch: 160, loss: 43.51824188232422
epoch: 170, loss: 43.218894958496094
epoch: 180, loss: 43.02821731567383
epoch: 190, loss: 42.90589141845703
epoch: 200, loss: 42.82688903808594
epoch: 210, loss: 42.77559280395508
epoch: 220, loss: 42.74211120605469
epoch: 230, loss: 42.72018051147461
epoch: 240, loss: 42.705726623535156
epoch: 250, loss: 42.696205139160156
epoch: 260, loss: 42.68989181518555
epoch: 270, loss: 42.68572235107422
epoch: 280, loss: 42.68293380737305
epoch: 290, loss: 42.68109130859375
epoch: 300, loss: 42.67984390258789
epoch: 310, loss: 42.679046630859375
epoch: 320, loss: 42.678489685058594
epoch: 330, loss: 42.67811584472656
epoch: 340, loss: 42.67786407470703
epoch: 350, loss: 42.67770004272461
epoch: 360, loss: 42.677608489990234
epoch: 370, loss: 42.677520751953125
epoch: 380, loss: 42.67747116088867
epoch: 390, loss: 42.677452087402344
epoch: 400, loss: 42.67742156982422
epoch: 410, loss: 42.67741775512695
epoch: 420, loss: 42.67739486694336
epoch: 430, loss: 42.67738342285156
epoch: 440, loss: 42.67737579345703
epoch: 450, loss: 42.677391052246094
epoch: 460, loss: 42.67737579345703
epoch: 470, loss: 42.67738723754883
epoch: 480, loss: 42.67738342285156
epoch: 490, loss: 42.67738342285156
epoch: 500, loss: 42.67737579345703
相关推荐
清水白石00819 分钟前
深入 Python 的底层世界:从 C 扩展到 ctypes 与 Cython 的本质差异全解析
c语言·python·neo4j
Amelia11111126 分钟前
day49
python
haiyu_y42 分钟前
Day 58 经典时序模型 2(ARIMA / 季节性 / 残差诊断)
人工智能·深度学习·ar
IT=>小脑虎44 分钟前
2026版 Python零基础小白学习知识点【基础版详解】
开发语言·python·学习
我想吃烤肉肉1 小时前
Playwright中page.locator和Selenium中find_element区别
爬虫·python·测试工具·自动化
rabbit_pro1 小时前
Java使用Mybatis-Plus封装动态数据源工具类
java·python·mybatis
Learner1 小时前
Python运算符
开发语言·python
一晌小贪欢1 小时前
Python 精确计算:告别浮点数陷阱,decimal 模块实战指南
开发语言·python·python入门·python3·python小数·python浮点数
空城雀1 小时前
python精通连续剧第一集:简单计算器
服务器·前端·python