回归实战(小白版本)

一.完整代码

python 复制代码
import torch
import matplotlib.pylab as plt#画图
import random #产生随机数

def create_data(w,b,data_num):#生成数据,w系数,b截距,data_num表样本数量
     x=torch.normal(0,1,(data_num,len(w)))#注:系数个数必须等于特征数
     y=torch.matmul(x,w)+b#表矩阵相乘

     noise=torch.normal(0,0.01,(y.shape))#噪声要加到y上,生成的数据与y一样的维度
     y+=noise
     return x,y

num=500

true_w=torch.tensor([8.1,2,2,4])#四个特征系数
true_b=torch.tensor(1.1)#是一个标量张量,对应线性模型的截距

X,Y=create_data(true_w,true_b,num)

#X[:,3]表取所有行,取第四列,1表示散点大小
plt.scatter(X[:,2],Y,1)#画一个散点图,展示第四个特征与y的关系
plt.show()



# 不知道w,b来推测其值
#label特征对应的标签集(比如线性回归中的 y 值,是模型预测的目标)
def data_provider(data,label,batchsize):#每次访问函数,就会提供一批数据(即一组一组计算)
     length=len(label)
     indices=list(range(length))#样本索引

     for each in range(0,length,batchsize):#每次循环的步长为batchsize
          get_indices=indices[each:each+batchsize]#当前批次索引
          get_data=data[get_indices]
          get_label=label[get_indices]

          yield get_data,get_label#有存档点的return 暂停点

batchsize=16
# for batch_x,batch_y in data_provider(X,Y,batchsize):
#      print(batch_x,batch_y)
#      break


#定义模型
def fun(x,w,b):
     pred_y=torch.matmul(x,w)+b#预测值
     return pred_y

#损失函数
def maeLoss(pre_y,y):
     return torch.sum(abs(pre_y-y))/len(y)

#优化函数
def sgd(paras,lr):#随机梯度下降,更新参数
     with torch.no_grad():#属于这句代码的部分,不计算梯度
          for para in paras:
                  para-=para.grad*lr#往损失函数减小的方向移动
                  para.grad.zero_()#使用过的梯度,归0


lr=0.03
w_0=torch.normal(0,0.01,true_w.shape,requires_grad=True)
b_0=torch.tensor(0.01,requires_grad=True)
print(w_0,b_0)


epochs=50#训练的轮数

for epoch in range(epochs):
     data_loss=0
     for batch_x,batch_y in data_provider(X,Y,batchsize):
          pred_y=fun(batch_x,w_0,b_0)
          loss=maeLoss(pred_y,batch_y)
          loss.backward()
          sgd([w_0,b_0],lr)
          data_loss+=loss

     print("epoch %03d: loss:%.6f"%(epoch,data_loss))

print("真实的函数值是",true_w,true_b)
print("训练得到的函数值是",w_0,b_0)


#只能看某一列的y值图
#第一列
idx=0
plt.plot(X[:,idx].detach().numpy(),X[:,idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy())
plt.scatter(X[:,idx],Y,1)
plt.show()

二.具体细节(有疑惑的部分)

python 复制代码
#优化函数
def sgd(paras,lr):#随机梯度下降,更新参数
     with torch.no_grad():#属于这句代码的部分,不计算梯度重点
          for para in paras:
                  para-=para.grad*lr#往损失函数减小的方向移动
                  para.grad.zero_()#使用过的梯度,归0

首先就是分批进行优化参数w,b;每一轮末尾要将使用过的梯度归0,防止梯度累积影响下一轮数据的参数优化更新;

但不是分批的进行吗 那么每次不是一直在覆盖之前存储的w吗?

  • 不是丢失信息 ,而是在改进参数
  • 每次批处理都让w变得更接近"正确答案"(每一批都在优化参数)
相关推荐
华玥作者14 小时前
[特殊字符] VitePress 对接 Algolia AI 问答(DocSearch + AI Search)完整实战(下)
前端·人工智能·ai
AAD5558889914 小时前
YOLO11-EfficientRepBiPAN载重汽车轮胎热成像检测与分类_3
人工智能·分类·数据挖掘
王建文go14 小时前
RAG(宠物健康AI)
人工智能·宠物·rag
ALINX技术博客14 小时前
【202601芯动态】全球 FPGA 异构热潮,ALINX 高性能异构新品预告
人工智能·fpga开发·gpu算力·fpga
易营宝14 小时前
多语言网站建设避坑指南:既要“数据同步”,又能“按市场个性化”,别踩这 5 个坑
大数据·人工智能
fanstuck14 小时前
从0到提交,如何用 ChatGPT 全流程参与建模比赛的
大数据·数学建模·语言模型·chatgpt·数据挖掘
春日见14 小时前
vscode代码无法跳转
大数据·人工智能·深度学习·elasticsearch·搜索引擎
Drgfd15 小时前
真智能 vs 伪智能:天选 WE H7 Lite 用 AI 人脸识别 + 呼吸灯带,重新定义智能化充电桩
人工智能·智能充电桩·家用充电桩·充电桩推荐
萤丰信息15 小时前
AI 筑基・生态共荣:智慧园区的价值重构与未来新途
大数据·运维·人工智能·科技·智慧城市·智慧园区
盖雅工场15 小时前
排班+成本双管控,餐饮零售精细化运营破局
人工智能·零售餐饮·ai智能排班