PyTorch 揭秘 :构建MNIST数据集

👋 今天我们继续来聊聊PyTorch,这个在深度学习领域火得一塌糊涂的开源机器学习库。PyTorch以其灵活性和直观的操作被广大研究人员和开发者所青睐。

火种一:PyTorch的简洁性

对于初学者来说,PyTorch的简洁易懂是它的一大卖点。看这段代码:

python

ini 复制代码
import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的线性模型
model = nn.Linear(in_features=1, out_features=1)

# 损失函数和优化器
loss_function = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 假设我们有一些训练数据
x_train = torch.tensor([[1.], [2.], [3.]])
y_train = torch.tensor([[2.], [4.], [6.]])

# 训练模型
for epoch in range(100):
    model.train()
    optimizer.zero_grad() # 清零梯度
    y_predicted = model(x_train)
    loss = loss_function(y_predicted, y_train)
    loss.backward() # 反向传播
    optimizer.step() # 更新参数

    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

不需要复杂的配置,我们就搭建好了一个能进行训练的线性回归模型。这种直观的操作使得PyTorch非常适合快速原型开发和研究。

火种二:动态计算图的强大

PyTorch使用动态计算图(Dynamic Computation Graph),也就是说,图的构建是在代码运行时动态进行的,这允许你进行更为直观的模型构建和调试。

这让PyTorch在处理可变长度的输入,如不同长度的文本序列或时间序列数据时,显得游刃有余。动态图的特性也使得在网络中嵌入复杂的控制流成为可能,比如循环和条件语句,这些都是静态图难以做到的。

火种三:社区支持和生态系统

PyTorch背后有着强大的社区支持。从论坛到GitHub,从学术研究到工业应用,无数的开发者和研究者都在为之贡献代码,分享经验和见解。

另外,PyTorch有着丰富的生态系统。无论是高级抽象库如torchvision用于图像处理,torchaudio为音频分析,还是与其他库的无缝对接,如ONNX用于模型导出,PyTorch都让深度学习工程师的工作变得更加简单。

火种四:实践举例

看一个实际的例子,如何用PyTorch来构建一个卷积神经网络(CNN)来识别手写数字,也就是著名的MNIST数据集:

python

ini 复制代码
import torch.optim as optim
import torch.nn as nn

# 我们继续为模型定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

# 训练过程
num_epochs = 5
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        # 梯度置零
        optimizer.zero_grad()

        # 正向传播以及损失计算
        outputs = net(inputs)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:    # 每100个批次打印一次统计信息
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')

# 保存模型参数
torch.save(net.state_dict(), 'mnist_cnn.pth')

这段代码完成了训练循环,包括损失计算、反向传播和网络参数的优化。

每100个batch打印一次训练过程中的平均损失,方便我们观察模型学习的情况。

将训练好的模型参数保存到文件中,便于后续的评估或者继续训练。

小结

PyTorch 以其简洁性、强大的动态计算图和活跃的社区支持让学习和研发都变得轻松。我们还通过构建一个CNN模型来识别MNIST数据集中的手写数字,讲述了整个模型的设计、训练和评估过程。

希望你能有所收获~~

相关推荐
中杯可乐多加冰3 小时前
RAG 深度实践系列(七):从“能用”到“好用”——RAG 系统优化与效果评估
人工智能·大模型·llm·大语言模型·rag·检索增强生成
山顶夕景11 小时前
【LLM】大模型数据清洗&合成&增强方法
大模型·llm·训练数据
tiger11913 小时前
FPGA 在大模型推理中的应用
人工智能·llm·fpga·大模型推理
AndrewHZ13 小时前
【AI黑话日日新】什么是大模型的test-time scaling?
人工智能·深度学习·大模型·llm·推理加速·测试时缩放
GPUStack15 小时前
vLLM、SGLang 融资背后,AI 推理正在走向系统化与治理
大模型·llm·vllm·模型推理·sglang·高性能推理
Tadas-Gao17 小时前
大模型幻觉治理新范式:SCA与[PAUSE]注入技术的深度解析与创新设计
人工智能·深度学习·机器学习·架构·大模型·llm
猿小羽17 小时前
基于 Spring AI 与 Streamable HTTP 构建 MCP Server 实践
java·llm·spring ai·mcp·streamable http
AndrewHZ18 小时前
【AI黑话日日新】什么是隐式CoT?
人工智能·深度学习·算法·llm·cot·复杂推理
一个处女座的程序猿1 天前
CV之VLM之LLM-OCR:《DeepSeek-OCR 2: Visual Causal Flow》翻译与解读
llm·ocr·cv·vlm
dawdo2222 天前
自己动手从头开始编写LLM推理引擎(9)-KV缓存实现和优化
缓存·llm·transformer·qwen·kv cache