PyTorch:深度学习领域的强大工具

PyTorch 是一个深度学习框架,具有动态计算图和自动微分的功能,广泛应用于各种深度学习任务中。下面是一些 PyTorch 的详细代码介绍,涵盖了从定义模型、数据加载、训练到评估的整个流程:

1. 定义模型

首先,我们定义一个简单的神经网络模型。

python 复制代码
import torch
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

2. 准备数据

接下来,我们准备训练和测试数据。

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms

# 定义数据转换
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 加载训练数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 加载测试数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

3. 定义损失函数和优化器

我们选择交叉熵损失函数和随机梯度下降优化器。

python 复制代码
import torch.optim as optim

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 定义优化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. 训练模型

接下来,我们进行模型的训练。

python 复制代码
# 循环遍历数据集多次
for epoch in range(2):  

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # 获取输入数据
        inputs, labels = data

        # 梯度清零
        optimizer.zero_grad()

        # 正向传播 + 反向传播 + 优化
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 打印统计信息
        running_loss += loss.item()
        if i % 2000 == 1999:    # 每2000个小批量数据打印一次
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

5. 测试模型

最后,我们使用测试数据集评估模型的性能。

python 复制代码
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

以上就是一个完整的 PyTorch 示例,涵盖了模型定义、数据准备、损失函数和优化器的选择、模型训练以及模型评估的过程。通过这些代码,你可以开始构建和训练自己的深度学习模型了。

相关推荐
youcans_4 小时前
【HALCON机器视觉实战】专栏介绍
图像处理·人工智能·计算机视觉·halcon
火山引擎开发者社区4 小时前
火山引擎 veRoCE 获权威认证:IANA 官方为 veRoCE 分配专属 UDP 端口号 4794
人工智能
飘落的数码折腾日记4 小时前
你的AI Agent可能正在“叛变“ | 5类真实威胁与四层防御
人工智能
放羊郎4 小时前
基于ORB-SLAM2算法的优化工作
人工智能·算法·计算机视觉
AI袋鼠帝5 小时前
字节的技术决心,都藏在这个动作里
人工智能
AI袋鼠帝5 小时前
企微又偷偷进化AI,并开始不对劲了..
人工智能
工业机器人销售服务5 小时前
2026 年,探索专业伯朗特机器人的奇妙世界
人工智能·机器人
摆烂大大王5 小时前
AI 日报|2026年5月9日:四部门力推AI与能源双向赋能,AI终端国标出台,中国大模型融资潮涌
人工智能
萑澈5 小时前
编程能力强和多模态模型的模型后训练
人工智能·深度学习·机器学习
LaughingZhu5 小时前
Product Hunt 每日热榜 | 2026-05-08
人工智能·经验分享·深度学习·神经网络·产品运营