用Pytorch实现线性回归(Linear Regression with Pytorch)

使用pytorch写神经网络的第一步就是需要准备好数据集,设计模型(用于计算y_hat(y的预测值)),构造损失函数和优化器(使用PyTorch API),写训练周期(前馈(算loss)+反馈(算梯度)+更新(更新权重))

一:准备数据

现在使用mini-batch的方式,X和Y为3x1(可以变,但是x和y要相同)的矩阵形式。

从代码中也可以看出来,x和y都是3x1的矩阵。

二:设计模型(构造计算图)

此处使用了一个仿射模型(在pytorch中叫做线性单元)

在我们设计的例子中,我们需要设置权重w的数值,和偏置量b。

那w和b的形状(几x几的矩阵),是由y_hat和x来共同确定。

之后将y_hat和y放入loss函数中进行计算,得出loss的值(一定是一个标量)。

看下模型设计的代码:

python 复制代码
#需要继承自module ,因为module中有很多方法我们需要使用
class LinearModel(torch.nn.Module):
    def __init__(self): #构造函数 在初始化对象时默认调用的函数
        super(LinearModel,self).__init__() #super调用父类的构造
        self.linear = torch.nn.Linear(1,1) #构造一个对象 linear Unit中的w和b(linear来自父类,可以自动反向传播)
    
    def forward(self,x): #前馈需要进行的计算 发现没有backword模块,因为Module中自动根据计算图实现backword过程
        y_pred = self.linear(x)
        return y_pred

model = LinearModel() #实例化 在之后既可以使用model(x)将x传入forword中的x,求得y_pred

其中torch.nn.Linear 的使用方法如下

三:构造loss和optimizer

此处我们使用MSEloss,需要的参事时y_hat和y,就可以求出loss。

代码如下:

python 复制代码
criterion = torch.nn.MSELoss(size_average=False)

我们使用SGD优化器(不会构建计算图),代码如下

python 复制代码
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

四:训练过程

python 复制代码
for epoch in range(100):
    y_pred = model(x_data)  #先计算出y_hat
    loss = criterion(y_pred,y_data) #再计算出loss
    print(epoch,loss.item()) 
    
    optimizer.zero_grad()#在反馈前将梯度清0
    loss.backward()#反馈
    optimizer.step()#更新

最后打印一些相关内容

python 复制代码
# w b
print('w=',model.linear.weight.item())
print('b=',model.linear.weight.item())

#Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred=',y_test.data)

发现当range为1000时,已经达到了我们的预期。

五:整体流程

相关推荐
隔窗听雨眠7 分钟前
原生一体化多模态大模型技术研究:从拼接到统一的架构革命
人工智能·架构
羊羊小栈17 分钟前
Uplift营销供应链协同决策系统(基于Uplift因果推断与运筹优化算法)
前端·人工智能·算法·毕业设计·大作业
苏州邦恩精密20 分钟前
江苏三维扫描仪厂家如何选择合适的工业测量方案?
人工智能·科技·机器学习·3d·自动化·制造
humors22120 分钟前
100种社会实践
人工智能·程序人生
保卫大狮兄25 分钟前
什么是WBS项目管理?WBS有哪些核心功能?
大数据·人工智能
标书畅畅行27 分钟前
钛投标:全流程企业级AI标书解决方案,重构投标数字化生产力
大数据·人工智能
叫我:松哥34 分钟前
基于深度卷积神经网络的水果图片分类算法设计与实现,有ResNet50的迁移学习模型,准确率达95%
人工智能·python·神经网络·机器学习·分类·cnn·迁移学习
大囚长35 分钟前
大模型API的上下文缓存(Contextual Cache)
人工智能·缓存
无心水35 分钟前
【Hermes:团队、企业、生态与边界】47、Hermes 在 CI/CD 中的完整 DevOps 流水线:从 PR 审查到自动部署,让 Agent 接管你的发布流程
运维·人工智能·devops·openclaw·养龙虾·hermes·honcho
名不经传的养虾人40 分钟前
从0到1:企业级AI项目迭代日记 Vol.44|功能建好,和功能接通,是两件完全不同的事
人工智能·架构·agent·ai编程·企业ai