Pytorch入门实例

数据集是受教育年限和收入,如下图

代码如下

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch.nn as nn
import torch.optim as optim

data = pd.read_csv('./Income.csv')

X = torch.from_numpy(data.Education.values.reshape(-1,1).astype(np.float32))
Y = torch.from_numpy(data.Income.values.reshape(-1,1).astype(np.float32))

learning_rate = 0.0001
model = nn.Linear(1,1)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),lr=learning_rate)

for epoch in range(50):
    for x,y in zip(X,Y):
        output = model(x)
        loss = loss_fn(output,y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

plt.scatter(data.Education,data.Income)
plt.plot(X.numpy(),model(X).detach().numpy(),c='r')
plt.xlabel('Education')
plt.ylabel('Income')
plt.show()

输出如下

相关推荐
刺猬的温驯4 分钟前
Flow Matching 训练的输入分布问题:从 VAE Latent 统计性质到归一化工程实践——以 VoxFlash-TTS 为例
人工智能·语音合成·tts
机器之心8 分钟前
近80年后,埃尔德什经典「拉姆齐数下界」,被三位中国学者首次指数级改进
人工智能·openai
机器之心11 分钟前
Nvidia都在点赞的LoopWM世界模型,竟然来自一家中国初创FaceMind?
人工智能·openai
美团技术团队1 小时前
LongCat 开源 VitaBench 2.0:长期动态智能体基准新标杆
人工智能·算法
moMo1 小时前
从“你好”到 1024 维坐标:大模型怎么识字
人工智能
ShallWeL1 小时前
【机器学习】(2)—— 线性回归:损失函数
人工智能·机器学习
美团技术团队2 小时前
ICML 2026 | 美团技术团队学术论文精选
人工智能
moMo2 小时前
你的每一次对话,都是第一次
人工智能
不加辣椒2 小时前
第13章 检索增强提示工程
人工智能
小爷毛毛_卓寿杰2 小时前
我把 397B 的「Agentic 大脑」塞进了 Xinference,一键部署 Nex-N2
人工智能·架构·github