手动构建线性回归(PyTorch)

python 复制代码
import torch
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
import random
#1.构建数据
#构建数据集
def create_dataset():
    x,y,coef=make_regression(n_samples=100,
                             n_features=1,
                             random_state=0,
                             noise=10,
                             coef=True,
                             bias=14.5)
    #将构建数据转换为张量类型
    x=torch.tensor(x)
    y=torch.tensor(y)
    return x,y

#构建数据加载器
def data_loader(x,y, batch_size):
    #计算下样本的数量
    data_len = len(y)
    #构建数据索引
    data_index=list(range(data_len))

    random.shuffle(data_index)
    #计算总的batch数量
    batch_number=data_len//batch_size
    for idx in range(batch_number):
        start=idx+batch_size
        end=start+batch_size
        batch_train_x=x[start:end]
        batch_train_y=y[start:end]
        yield batch_train_x,batch_train_y

def test01():
    x,y=create_dataset()
    plt.scatter(x,y)
    plt.show()

    for x,y in data_loader(x,y,batch_size=10):
        print(y)
#2.假设函数、损失函数、优化方法
#损失函数:平均损失
#优化方法:梯度下降
#假设函数
w=torch.tensor(0.1,requires_grad=True,dtype=torch.float64)
b=torch.tensor(0.1,requires_grad=True,dtype=torch.float64)



def linear_regression(x):
    return w*x+b

#损失函数
def square_loss(y_pred,y_true):
    return torch.square(y_pred - y_true)

#优化方法
def sqd(lr=1e-2):
    #除以16是使用的是批次样本的平均梯度
    w.data=w.data-lr*w.grad.data/16
    b.data=b.data-lr*b.grad.data/16
    

if __name__ == '__main__':
    test01()
相关推荐
TGITCIC14 分钟前
AI Agent竞争进入下半场:模型只是入场券,系统架构决定胜负
人工智能·ai产品经理·ai产品·ai落地·大模型架构·ai架构·大模型产品
斐夷所非2 小时前
人工智能 AI. 机器学习 ML. 深度学习 DL. 神经网络 NN 的区别与联系
人工智能
Funny_AI_LAB4 小时前
OpenAI DevDay 2025:ChatGPT 进化为平台,开启 AI 应用新纪元
人工智能·ai·语言模型·chatgpt
深瞳智检4 小时前
YOLO算法原理详解系列 第002期-YOLOv2 算法原理详解
人工智能·算法·yolo·目标检测·计算机视觉·目标跟踪
深眸财经5 小时前
机器人再冲港交所,优艾智合能否破行业困局?
人工智能·机器人
小宁爱Python5 小时前
从零搭建 RAG 智能问答系统1:基于 LlamaIndex 与 Chainlit实现最简单的聊天助手
人工智能·后端·python
新知图书6 小时前
Encoder-Decoder架构的模型简介
人工智能·架构·ai agent·智能体·大模型应用开发·大模型应用
大模型真好玩6 小时前
低代码Agent开发框架使用指南(一)—主流开发框架对比介绍
人工智能·低代码·agent
tzc_fly6 小时前
AI作为操作系统已经不能阻挡了,尽管它还没来
人工智能·chatgpt
PKNLP6 小时前
深度学习之神经网络1(Neural Network)
人工智能·深度学习·神经网络