Pytorch入门实例

数据集是受教育年限和收入,如下图

代码如下

python 复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch.nn as nn
import torch.optim as optim

data = pd.read_csv('./Income.csv')

X = torch.from_numpy(data.Education.values.reshape(-1,1).astype(np.float32))
Y = torch.from_numpy(data.Income.values.reshape(-1,1).astype(np.float32))

learning_rate = 0.0001
model = nn.Linear(1,1)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),lr=learning_rate)

for epoch in range(50):
    for x,y in zip(X,Y):
        output = model(x)
        loss = loss_fn(output,y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

plt.scatter(data.Education,data.Income)
plt.plot(X.numpy(),model(X).detach().numpy(),c='r')
plt.xlabel('Education')
plt.ylabel('Income')
plt.show()

输出如下

相关推荐
weixin_42499936几秒前
Python yield关键字原理_生成器执行机制
jvm·数据库·python
2301_803875611 分钟前
Golang如何做分布式ID生成_Golang雪花算法教程【详解】
jvm·数据库·python
YJlio3 分钟前
4月14日热点新闻解读:从金融数据到平台治理,一文看懂今天最值得关注的6个信号
java·前端·人工智能·金融·eclipse·电脑·eixv3
weixin_408717774 分钟前
实现鼠标滚轮在容器滚动到底部后无缝传递至页面的平滑过渡
jvm·数据库·python
薛定猫AI4 分钟前
【技术干货】OpenAI Codex 重大更新:从代码补全工具到全流程智能开发平台
运维·人工智能
gc_22994 分钟前
学习python使用Ultralytics的YOLO26进行旋转框检测的基本用法
python·ultralytics·yolo26·旋转框检测
格林威6 分钟前
工业相机“心跳”监测脚本(C# 版) 支持海康 / Basler / 堡盟工业相机
开发语言·人工智能·数码相机·opencv·计算机视觉·c#·视觉检测
404号扳手6 分钟前
03大模型核心原理
人工智能·llm
沪漂阿龙在努力8 分钟前
深度拆解LangChain Chains与LCEL:从Runnable到生产级AI工作流
人工智能
浩安9 分钟前
【Python网络编程】02_面向对象的三大特征
python