从0开始深度学习(10)——softmax的简洁实现

同样的,本章将使用torch自带的API简洁的实现softmax回归

1 读取数据

使用自带的DataLoader

python 复制代码
import torch
from torch import nn,optim
import torchvision
from torch.utils import data
from torchvision import transforms,datasets
from torch.utils.data import DataLoader

# 定义超参数
batch_size = 256
learning_rate = 0.01
epochs = 5

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),  
    transforms.Normalize((0.5,), (0.5,))  # 标准化到[-1, 1]区间,加快计算
])

# 加载Fashion-MNIST数据集
train_dataset = datasets.FashionMNIST(root='D:/DL_Data/', train=True, download=False, transform=transform)
test_dataset = datasets.FashionMNIST(root='D:/DL_Data/', train=False, download=False, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

2 定义模型,初始化参数

使用torch自带的nn模型,输入层用Flatten(),因为要把2828的展开成一维,输出层用Linear,前面我们说过,全连接层可以看作线性模型,也符合softmax的特征,输入是784,因为2828展开后是784,输出是10,因为有10和可能预测到的类别

python 复制代码
# 定义模型
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784,10)
)
# 初始化参数
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

3 定义损失函数和优化器

使用torch自带的

python 复制代码
# 损失函数与优化器
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失,因为它包含了softmax
optimizer = optim.SGD(net.parameters(), lr=learning_rate)

4 训练

python 复制代码
# 训练模型
for epoch in range(epochs):
    net.train()
    running_loss = 0.0
    running_corrects = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = net(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        # 计算正确率
        _, preds = torch.max(output, 1)
        running_loss += loss.item() * data.size(0)
        running_corrects += torch.sum(preds == target.data)

        if batch_idx % 10 == 0:# 每训练10步输出一次loss和acc
            epoch_loss = running_loss / ((batch_idx + 1) * batch_size)
            epoch_acc = running_corrects.double() / ((batch_idx + 1) * batch_size)
            print(f'Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}')

    # 输出每个epoch的平均损失和正确率
    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)
    print(f'Epoch [{epoch+1}/{epochs}] Summary - Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}')

5 预测

python 复制代码
# 定义 Fashion-MNIST 标签的文本描述
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

# 预测并显示结果
def predict(net, test_iter, n=6):
    for X, y in test_iter:
        break  # 只取一个批次的数据
    trues = get_fashion_mnist_labels(y)
    preds = get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [true + '\n' + pred for true, pred in zip(trues, preds)]
    n = min(n, X.shape[0])
    fig, axs = plt.subplots(1, n, figsize=(12, 3))
    for i in range(n):
        axs[i].imshow(X[i].permute(1, 2, 0).squeeze().numpy(), cmap='gray')
        axs[i].set_title(titles[i])
        axs[i].axis('off')
    plt.show()

# 调用预测函数
predict(net, test_iter, n=10)
相关推荐
Daphnis_z几秒前
大模型应用编排工具Dify之常用编排组件
人工智能·chatgpt·prompt
yuanbenshidiaos1 小时前
【大数据】机器学习----------强化学习机器学习阶段尾声
人工智能·机器学习
盼小辉丶6 小时前
TensorFlow深度学习实战——情感分析模型
深度学习·神经网络·tensorflow
好评笔记6 小时前
AIGC视频生成模型:Stability AI的SVD(Stable Video Diffusion)模型
论文阅读·人工智能·深度学习·机器学习·计算机视觉·面试·aigc
算家云6 小时前
TangoFlux 本地部署实用教程:开启无限音频创意脑洞
人工智能·aigc·模型搭建·算家云、·应用社区·tangoflux
AI街潜水的八角7 小时前
工业缺陷检测实战——基于深度学习YOLOv10神经网络PCB缺陷检测系统
pytorch·深度学习·yolo
叫我:松哥8 小时前
基于Python django的音乐用户偏好分析及可视化系统设计与实现
人工智能·后端·python·mysql·数据分析·django
熊文豪8 小时前
深入解析人工智能中的协同过滤算法及其在推荐系统中的应用与优化
人工智能·算法
Vol火山9 小时前
AI引领工业制造智能化革命:机器视觉与时序数据预测的双重驱动
人工智能·制造
tuan_zhang9 小时前
第17章 安全培训筑牢梦想根基
人工智能·安全·工业软件·太空探索·战略欺骗·算法攻坚