【深度学习】复习温故而知新1

数据集ref:https://raw.githubusercontent.com/justinge/pic-go-for-xbotgo/master/Income1.csv

复制代码
X = torch.from_numpy(data.Education.values.reshape(-1,1).astype(np.float32))
Y = torch.from_numpy(data.Income.values.reshape(-1,1).astype(np.float32))

y_pred = model(x) # 预测
loss = loss_fn(y, y_pred) # 计算损失
opt.zero_grad() # 梯度清零
loss.backward() # 反向传播
opt.step() # 下一次

总的来说

python 复制代码
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

data = pd.read_csv("/home/data_for_ai_justin/01learn/dataset/Income1.csv")
X = torch.from_numpy(data.Education.values.reshape(-1,1).astype(np.float32))
Y = torch.from_numpy(data.Income.values.reshape(-1,1).astype(np.float32))
loss_fn = nn.MSELoss()
int_features, out_features = 1, 1
model = nn.Linear(int_features,out_features)
opt = torch.optim.SGD(
    model.parameters(),
    lr=0.0001
)
for i in range(5000):
    for x,y in zip(X,Y):
        y_pred = model(x)
        loss = loss_fn(y, y_pred)
        opt.zero_grad()
        loss.backward()
        opt.step()

plt.scatter(data.Education, data.Income)
plt.plot(X.numpy(), model(X).data.numpy(),c='r')
plt.savefig("1.jpg")

分解写法

python 复制代码
df = pd.read_csv("/home/data_for_ai_justin/01learn/dataset/Income1.csv")
X=torch.from_numpy(df.Education.values.reshape(-1,1).astype(np.float32))
Y=torch.from_numpy(df.Income.values.reshape(-1,1).astype(np.float32))
w = torch.randn(1,requires_grad=True)
b = torch.zeros(1,requires_grad=True)
learning_rate = 0.001
for i in range(5000):
    for x,y in zip(X,Y):
        y_pred = torch.matmul(x,w) + b
        loss=(y-y_pred).pow(2).mean()
        if w.grad is not None:
            w.grad.data.zero_()
        if b.grad is not None:
            b.grad.data.zero_()
        loss.backward()
        with torch.no_grad():
            w.data -= w.grad.data * learning_rate
            b.data -= b.grad.data * learning_rate
plt.scatter(df.Education, df.Income)
plt.xlabel = "x"
plt.ylabel = "y"
plt.plot(X.numpy(), (torch.matmul(X,w) + b).data.numpy(),c='r')
plt.show()
        
相关推荐
Jamence几秒前
多模态大语言模型arxiv论文略读(三十九)
人工智能·语言模型·自然语言处理
ai大模型木子25 分钟前
嵌入模型(Embedding Models)原理详解:从Word2Vec到BERT的技术演进
人工智能·自然语言处理·bert·embedding·word2vec·ai大模型·大模型资料
普if加的帕2 小时前
java Springboot使用扣子Coze实现实时音频对话智能客服
java·开发语言·人工智能·spring boot·实时音视频·智能客服
KoiC2 小时前
Dify接入RAGFlow无返回结果
人工智能·ai应用
lilye663 小时前
精益数据分析(20/126):解析经典数据分析框架,助力创业增长
大数据·人工智能·数据分析
盈达科技3 小时前
盈达科技:登顶GEO优化全球制高点,以AICC定义AI时代内容智能优化新标杆
大数据·人工智能
安冬的码畜日常3 小时前
【AI 加持下的 Python 编程实战 2_10】DIY 拓展:从扫雷小游戏开发再探问题分解与 AI 代码调试能力(中)
开发语言·前端·人工智能·ai·扫雷游戏·ai辅助编程·辅助编程
古希腊掌管学习的神3 小时前
[LangGraph教程]LangGraph04——支持人机协作的聊天机器人
人工智能·语言模型·chatgpt·机器人·agent
FIT2CLOUD飞致云3 小时前
问答页面支持拖拽和复制粘贴文件,MaxKB企业级AI助手v1.10.6 LTS版本发布
人工智能·开源
起个破名想半天了3 小时前
计算机视觉cv入门之答题卡自动批阅
人工智能·opencv·计算机视觉