自定义数据集使用框架的线性回归方法对其进行拟合

代码

python 复制代码
import torch
import numpy as np
import torch.nn as nn
 
criterion = nn.MSELoss()
 
data = np.array([[-0.5, 7.7],
                 [1.8, 98.5],
                 [0.9, 57.8],
                 [0.4, 39.2],
                 [-1.4, -15.7],
                 [-1.4, -37.3],
                 [-1.8, -49.1],
                 [1.5, 75.6],
                 [0.4, 34.0],
                 [0.8, 62.3]])
 
x_data = data[:, 0]
y_data = data[:, 1]
 
x_train = torch.tensor(x_data, dtype=torch.float32)
y_train = torch.tensor(y_data, dtype=torch.float32)
 
model = nn.Sequential(nn.Linear(1, 1))
 
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
epoch = 500
for n in range(1,epoch+1):
    y_prd = model(x_train.unsqueeze(1))
    loss = criterion(y_prd.squeeze(1), y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if n % 10 == 0 or n == 1:
        print(f'epoch: {n}, loss: {loss}')

运行结果:

python 复制代码
epoch: 1, loss: 2957.510986328125
epoch: 10, loss: 1785.2095947265625
epoch: 20, loss: 1031.366455078125
epoch: 30, loss: 606.9622192382812
epoch: 40, loss: 366.8883361816406
epoch: 50, loss: 230.34017944335938
epoch: 60, loss: 152.19020080566406
epoch: 70, loss: 107.1496810913086
epoch: 80, loss: 80.98954772949219
epoch: 90, loss: 65.66667175292969
epoch: 100, loss: 56.6099967956543
epoch: 110, loss: 51.205726623535156
epoch: 120, loss: 47.949058532714844
epoch: 130, loss: 45.966827392578125
epoch: 140, loss: 44.74833297729492
epoch: 150, loss: 43.99201583862305
epoch: 160, loss: 43.51824188232422
epoch: 170, loss: 43.218894958496094
epoch: 180, loss: 43.02821731567383
epoch: 190, loss: 42.90589141845703
epoch: 200, loss: 42.82688903808594
epoch: 210, loss: 42.77559280395508
epoch: 220, loss: 42.74211120605469
epoch: 230, loss: 42.72018051147461
epoch: 240, loss: 42.705726623535156
epoch: 250, loss: 42.696205139160156
epoch: 260, loss: 42.68989181518555
epoch: 270, loss: 42.68572235107422
epoch: 280, loss: 42.68293380737305
epoch: 290, loss: 42.68109130859375
epoch: 300, loss: 42.67984390258789
epoch: 310, loss: 42.679046630859375
epoch: 320, loss: 42.678489685058594
epoch: 330, loss: 42.67811584472656
epoch: 340, loss: 42.67786407470703
epoch: 350, loss: 42.67770004272461
epoch: 360, loss: 42.677608489990234
epoch: 370, loss: 42.677520751953125
epoch: 380, loss: 42.67747116088867
epoch: 390, loss: 42.677452087402344
epoch: 400, loss: 42.67742156982422
epoch: 410, loss: 42.67741775512695
epoch: 420, loss: 42.67739486694336
epoch: 430, loss: 42.67738342285156
epoch: 440, loss: 42.67737579345703
epoch: 450, loss: 42.677391052246094
epoch: 460, loss: 42.67737579345703
epoch: 470, loss: 42.67738723754883
epoch: 480, loss: 42.67738342285156
epoch: 490, loss: 42.67738342285156
epoch: 500, loss: 42.67737579345703
相关推荐
用户83562907805118 小时前
从手动编辑到代码生成:Python 助你高效创建 Word 文档
后端·python
c8i18 小时前
python中类的基本结构、特殊属性于MRO理解
python
隐语SecretFlow18 小时前
国人自研开源隐私计算框架SecretFlow,深度拆解框架及使用【开发者必看】
深度学习
liwulin050618 小时前
【ESP32-CAM】HELLO WORLD
python
Doris_202319 小时前
Python条件判断语句 if、elif 、else
前端·后端·python
Doris_202319 小时前
Python 模式匹配match case
前端·后端·python
Billy_Zuo19 小时前
人工智能深度学习——卷积神经网络(CNN)
人工智能·深度学习·cnn
这里有鱼汤19 小时前
Python量化实盘踩坑指南:分钟K线没处理好,小心直接亏钱!
后端·python·程序员
羊羊小栈19 小时前
基于「YOLO目标检测 + 多模态AI分析」的遥感影像目标检测分析系统(vue+flask+数据集+模型训练)
人工智能·深度学习·yolo·目标检测·毕业设计·大作业
l12345sy19 小时前
Day24_【深度学习—广播机制】
人工智能·pytorch·深度学习·广播机制