PyTorch初探:基本函数与案例实践

正文

在熟悉了PyTorch的安装和环境配置后,接下来让我们深入了解PyTorch的基本函数,并通过一个简单的案例来实践这些知识。

1. 基本函数

PyTorch的核心是张量(Tensor),它类似于多维数组,但可以在GPU上运行以加速计算。张量上的操作是构建神经网络层的基础。


以下是PyTorch中一些常用的张量操作函数:

  • torch.tensor(): 创建一个新的张量。
  • torch.ones(), torch.zeros(): 创建全1或全0的张量。
  • torch.randn(): 创建一个具有随机数的张量,这些随机数服从均值为0和标准差为1的正态分布(标准正态分布)。
  • torch.matmul(): 执行矩阵乘法。
  • torch.sum(): 计算张量中所有元素的和。

此外,PyTorch还提供了自动求导机制,这是训练神经网络的关键。通过设置张量的requires_grad属性为True,PyTorch会跟踪对该张量执行的所有操作,以便后续计算梯度。

2. 案例实践:线性回归

为了演示PyTorch的基本用法,我们将实现一个简单的线性回归模型。线性回归是一种预测模型,其中输出是输入的线性组合。

步骤如下:

导入必要的库

python 复制代码
import torch  
import torch.nn as nn  
import torch.optim as optim

准备数据

这里我们使用简单的人工数据来演示。

python 复制代码
# 输入数据  
x_data = torch.tensor([[1.0], [2.0], [3.0]])  
# 输出数据  
y_data = torch.tensor([[2.0], [4.0], [6.0]])

定义模型

线性回归模型可以表示为y = wx + b,其中w是权重,b是偏置。

python 复制代码
class LinearRegressionModel(nn.Module):  
    def __init__(self):  
        super(LinearRegressionModel, self).__init__()  
        self.linear = nn.Linear(1, 1)  # 输入和输出都是1维的  
  
    def forward(self, x):  
        y_pred = self.linear(x)  
        return y_pred  
  
model = LinearRegressionModel()

定义损失函数和优化器

python 复制代码
criterion = nn.MSELoss()  # 均方误差损失  
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降

训练模型

python 复制代码
# 训练周期  
epochs = 100  
  
for epoch in range(epochs):  
    # 前向传播  
    outputs = model(x_data)  
    loss = criterion(outputs, y_data)  
  
    # 反向传播和优化  
    optimizer.zero_grad()  
    loss.backward()  
    optimizer.step()  
  
    # 打印损失  
    if (epoch+1) % 10 == 0:  
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}')  
  
# 测试模型  
with torch.no_grad():  
    prediction = model(torch.tensor([[4.0]]))  
    print(f'Prediction after training: {4.0} => {prediction.item()}')

总结

  • 在这个简单的案例中,我们展示了如何使用PyTorch构建、训练和测试一个基本的线性回归模型。通过这个过程,你应该对PyTorch的基本函数和工作流程有了更深刻的理解在实际应用中,你可能会处理更复杂的模型和数据集,但基本的原理和操作是相似的。
相关推荐
网安CILLE几秒前
Wireshark 抓包实战演示
linux·网络·python·测试工具·web安全·网络安全·wireshark
ThinkPet5 分钟前
【AI】大模型知识入门扫盲以及SpringAi快速入门
java·人工智能·ai·大模型·rag·springai·mcp
汽车仪器仪表相关领域5 分钟前
双组分精准快检,汽修年检利器:MEXA-324M汽车尾气测量仪项目实战全解
大数据·人工智能·功能测试·测试工具·算法·机器学习·压力测试
renhongxia15 分钟前
从文本到仿真:多智能体大型语言模型(LLM)自动化化学工艺设计工作流程
人工智能·语言模型·自动化
王夏奇6 分钟前
python中的基础知识点-1
开发语言·windows·python
叫我辉哥e17 分钟前
新手进阶Python:办公看板集成多数据源+ECharts高级可视化
开发语言·python·echarts
程序员敲代码吗13 分钟前
如何从Python初学者进阶为专家?
jvm·数据库·python
AI工具指南18 分钟前
实测教程:三种主流AI生成PPT工作流详解
人工智能·ppt
DO_Community18 分钟前
技术解码:Character.ai 如何实现大模型实时推理性能 2 倍提升
人工智能·算法·llm·aigc·moe·aiter
Kakaxiii19 分钟前
【2024ACL】Mind Map :知识图谱激发大型语言模型中的思维图谱
人工智能·语言模型·知识图谱