PyTorch:深度学习领域的强大工具

PyTorch 是一个深度学习框架,具有动态计算图和自动微分的功能,广泛应用于各种深度学习任务中。下面是一些 PyTorch 的详细代码介绍,涵盖了从定义模型、数据加载、训练到评估的整个流程:

1. 定义模型

首先,我们定义一个简单的神经网络模型。

python 复制代码
import torch
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

2. 准备数据

接下来,我们准备训练和测试数据。

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms

# 定义数据转换
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 加载训练数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 加载测试数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

3. 定义损失函数和优化器

我们选择交叉熵损失函数和随机梯度下降优化器。

python 复制代码
import torch.optim as optim

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 定义优化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. 训练模型

接下来,我们进行模型的训练。

python 复制代码
# 循环遍历数据集多次
for epoch in range(2):  

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # 获取输入数据
        inputs, labels = data

        # 梯度清零
        optimizer.zero_grad()

        # 正向传播 + 反向传播 + 优化
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 打印统计信息
        running_loss += loss.item()
        if i % 2000 == 1999:    # 每2000个小批量数据打印一次
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

5. 测试模型

最后,我们使用测试数据集评估模型的性能。

python 复制代码
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

以上就是一个完整的 PyTorch 示例,涵盖了模型定义、数据准备、损失函数和优化器的选择、模型训练以及模型评估的过程。通过这些代码,你可以开始构建和训练自己的深度学习模型了。

相关推荐
蹦蹦跳跳真可爱5892 小时前
Python----计算机视觉处理(Opencv:二值化,阈值法,反阈值法,截断阈值法,OTSU阈值法)
人工智能·python·opencv·计算机视觉
袁袁袁袁满3 小时前
Blackbox.Ai体验:AI编程插件如何提升开发效率
人工智能·ai编程·ai插件·chatgpt-4o·deepseek-r1满血版·免费大模型·gemini pro
摸鱼仙人~4 小时前
预训练微调类型分类
人工智能·自然语言处理·分类
申耀的科技观察4 小时前
【观察】拓展大模型应用交付领域“新赛道”,亚信科技为高质量发展“加速度”...
大数据·人工智能·科技
lboyj5 小时前
新能源汽车电控系统的大尺寸PCB需求:猎板PCB的技术突围
大数据·网络·人工智能
HABuo5 小时前
【YOLOv8】YOLOv8改进系列(5)----替换主干网络之EfficientFormerV2
人工智能·深度学习·yolo·目标检测·计算机视觉
訾博ZiBo6 小时前
AI日报 - 2025年3月16日
人工智能
(initial)6 小时前
大型语言模型与强化学习的融合:迈向通用人工智能的新范式——基于基础复现的实验平台构建
人工智能·强化学习
subject625Ruben6 小时前
Matlab多种算法解决未来杯B的多分类问题
人工智能·算法·机器学习·数学建模·matlab·分类·未来杯
不知江月待何人..6 小时前
conda install 和 pip install 的区别
深度学习