PyTorch:深度学习领域的强大工具

PyTorch 是一个深度学习框架,具有动态计算图和自动微分的功能,广泛应用于各种深度学习任务中。下面是一些 PyTorch 的详细代码介绍,涵盖了从定义模型、数据加载、训练到评估的整个流程:

1. 定义模型

首先,我们定义一个简单的神经网络模型。

python 复制代码
import torch
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

2. 准备数据

接下来,我们准备训练和测试数据。

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms

# 定义数据转换
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 加载训练数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 加载测试数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

3. 定义损失函数和优化器

我们选择交叉熵损失函数和随机梯度下降优化器。

python 复制代码
import torch.optim as optim

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 定义优化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. 训练模型

接下来,我们进行模型的训练。

python 复制代码
# 循环遍历数据集多次
for epoch in range(2):  

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # 获取输入数据
        inputs, labels = data

        # 梯度清零
        optimizer.zero_grad()

        # 正向传播 + 反向传播 + 优化
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 打印统计信息
        running_loss += loss.item()
        if i % 2000 == 1999:    # 每2000个小批量数据打印一次
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

5. 测试模型

最后,我们使用测试数据集评估模型的性能。

python 复制代码
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

以上就是一个完整的 PyTorch 示例,涵盖了模型定义、数据准备、损失函数和优化器的选择、模型训练以及模型评估的过程。通过这些代码,你可以开始构建和训练自己的深度学习模型了。

相关推荐
极限实验室18 分钟前
Coco AI 实战(一):Coco Server Linux 平台部署
人工智能
杨过过儿29 分钟前
【学习笔记】4.1 什么是 LLM
人工智能
巴伦是只猫38 分钟前
【机器学习笔记Ⅰ】13 正则化代价函数
人工智能·笔记·机器学习
伍哥的传说44 分钟前
React 各颜色转换方法、颜色值换算工具HEX、RGB/RGBA、HSL/HSLA、HSV、CMYK
深度学习·神经网络·react.js
大千AI助手1 小时前
DTW模版匹配:弹性对齐的时间序列相似度度量算法
人工智能·算法·机器学习·数据挖掘·模版匹配·dtw模版匹配
AI生存日记1 小时前
百度文心大模型 4.5 系列全面开源 英特尔同步支持端侧部署
人工智能·百度·开源·open ai大模型
LCG元1 小时前
自动驾驶感知模块的多模态数据融合:时序同步与空间对齐的框架解析
人工智能·机器学习·自动驾驶
why技术1 小时前
Stack Overflow,轰然倒下!
前端·人工智能·后端
超龄超能程序猿2 小时前
(三)PS识别:基于噪声分析PS识别的技术实现
图像处理·人工智能·计算机视觉