用Pytorch实现线性回归(Linear Regression with Pytorch)

使用pytorch写神经网络的第一步就是需要准备好数据集,设计模型(用于计算y_hat(y的预测值)),构造损失函数和优化器(使用PyTorch API),写训练周期(前馈(算loss)+反馈(算梯度)+更新(更新权重))

一:准备数据

现在使用mini-batch的方式,X和Y为3x1(可以变,但是x和y要相同)的矩阵形式。

从代码中也可以看出来,x和y都是3x1的矩阵。

二:设计模型(构造计算图)

此处使用了一个仿射模型(在pytorch中叫做线性单元)

在我们设计的例子中,我们需要设置权重w的数值,和偏置量b。

那w和b的形状(几x几的矩阵),是由y_hat和x来共同确定。

之后将y_hat和y放入loss函数中进行计算,得出loss的值(一定是一个标量)。

看下模型设计的代码:

python 复制代码
#需要继承自module ,因为module中有很多方法我们需要使用
class LinearModel(torch.nn.Module):
    def __init__(self): #构造函数 在初始化对象时默认调用的函数
        super(LinearModel,self).__init__() #super调用父类的构造
        self.linear = torch.nn.Linear(1,1) #构造一个对象 linear Unit中的w和b(linear来自父类,可以自动反向传播)
    
    def forward(self,x): #前馈需要进行的计算 发现没有backword模块,因为Module中自动根据计算图实现backword过程
        y_pred = self.linear(x)
        return y_pred

model = LinearModel() #实例化 在之后既可以使用model(x)将x传入forword中的x,求得y_pred

其中torch.nn.Linear 的使用方法如下

三:构造loss和optimizer

此处我们使用MSEloss,需要的参事时y_hat和y,就可以求出loss。

代码如下:

python 复制代码
criterion = torch.nn.MSELoss(size_average=False)

我们使用SGD优化器(不会构建计算图),代码如下

python 复制代码
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

四:训练过程

python 复制代码
for epoch in range(100):
    y_pred = model(x_data)  #先计算出y_hat
    loss = criterion(y_pred,y_data) #再计算出loss
    print(epoch,loss.item()) 
    
    optimizer.zero_grad()#在反馈前将梯度清0
    loss.backward()#反馈
    optimizer.step()#更新

最后打印一些相关内容

python 复制代码
# w b
print('w=',model.linear.weight.item())
print('b=',model.linear.weight.item())

#Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred=',y_test.data)

发现当range为1000时,已经达到了我们的预期。

五:整体流程

相关推荐
渡我白衣17 分钟前
多路转接之epoll:理论篇
人工智能·神经网络·网络协议·tcp/ip·自然语言处理·信息与通信·tcpdump
明月照山海-18 分钟前
机器学习周报二十八
人工智能·机器学习
weixin_437497776 小时前
读书笔记:Context Engineering 2.0 (上)
人工智能·nlp
喝拿铁写前端6 小时前
前端开发者使用 AI 的能力层级——从表面使用到工程化能力的真正分水岭
前端·人工智能·程序员
goodfat6 小时前
Win11如何关闭自动更新 Win11暂停系统更新的设置方法【教程】
人工智能·禁止windows更新·win11优化工具
北京领雁科技7 小时前
领雁科技反洗钱案例白皮书暨人工智能在反洗钱系统中的深度应用
人工智能·科技·安全
落叶,听雪7 小时前
河南建站系统哪个好
大数据·人工智能·python
清月电子7 小时前
杰理AC109N系列AC1082 AC1074 AC1090 芯片停产替代及资料说明
人工智能·单片机·嵌入式硬件·物联网
Dev7z7 小时前
非线性MPC在自动驾驶路径跟踪与避障控制中的应用及Matlab实现
人工智能·matlab·自动驾驶
七月shi人7 小时前
AI浪潮下,前端路在何方
前端·人工智能·ai编程