PyTorch:深度学习领域的强大工具

PyTorch 是一个深度学习框架,具有动态计算图和自动微分的功能,广泛应用于各种深度学习任务中。下面是一些 PyTorch 的详细代码介绍,涵盖了从定义模型、数据加载、训练到评估的整个流程:

1. 定义模型

首先,我们定义一个简单的神经网络模型。

python 复制代码
import torch
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

2. 准备数据

接下来,我们准备训练和测试数据。

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms

# 定义数据转换
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 加载训练数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 加载测试数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

3. 定义损失函数和优化器

我们选择交叉熵损失函数和随机梯度下降优化器。

python 复制代码
import torch.optim as optim

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 定义优化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. 训练模型

接下来,我们进行模型的训练。

python 复制代码
# 循环遍历数据集多次
for epoch in range(2):  

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # 获取输入数据
        inputs, labels = data

        # 梯度清零
        optimizer.zero_grad()

        # 正向传播 + 反向传播 + 优化
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 打印统计信息
        running_loss += loss.item()
        if i % 2000 == 1999:    # 每2000个小批量数据打印一次
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

5. 测试模型

最后,我们使用测试数据集评估模型的性能。

python 复制代码
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

以上就是一个完整的 PyTorch 示例,涵盖了模型定义、数据准备、损失函数和优化器的选择、模型训练以及模型评估的过程。通过这些代码,你可以开始构建和训练自己的深度学习模型了。

相关推荐
牛客企业服务2 分钟前
2025年AI面试深度测评:3款主流工具实战对比
人工智能·面试·职场和发展
北京耐用通信4 分钟前
唤醒沉睡的“钢铁手臂”:耐达讯自动化PROFINET转DeviceNet网关如何让老旧焊接机器人融入智能产线
人工智能·物联网·网络协议·自动化·信息与通信
延凡科技4 分钟前
延凡 APM 应用性能管理系统:AI+eBPF 驱动全栈智能可观测
大数据·人工智能·科技·能源
free-elcmacom4 分钟前
机器学习高阶教程<4>因果机器学习:因果推断、可解释AI与科学发现的新革命
人工智能·python·机器学习·因果机器学习
SACKings5 分钟前
神经元是什么?在深度学习中的数学表达是什么?
人工智能·深度学习
蓝卓工业操作系统5 分钟前
开源赋能全球智造,蓝卓Open supOS捐赠仪式开启产业协作新纪元
人工智能·开源·工业互联网·工厂操作系统·蓝卓
Mintopia6 分钟前
🌍 技术向善:WebAIGC如何通过技术设计规避负面影响?
人工智能·aigc
淮北4948 分钟前
图神经网络与pytorch
人工智能·pytorch·神经网络
模型启动机9 分钟前
微软确认:Windows 11 AI 智能体访问用户文件前会先请求许可
人工智能·microsoft·ai·大模型
Coovally AI模型快速验证9 分钟前
复杂工业场景如何实现3D实例与部件一体化分割?多视角贝叶斯融合的分层图像引导框
人工智能·深度学习·计算机视觉·3d·语言模型·机器人