pytorch 实现了使用预训练的VGG16模型对猫和狗的图像进行分类任务

python 复制代码
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms, models\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 数据预处理\n",
    "transform = transforms.Compose([\n",
    "    transforms.RandomResizedCrop(224),# 对图像进行随机的crop以后再resize成固定大小\n",
    "    transforms.RandomRotation(20), # 随机旋转角度\n",
    "    transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转\n",
    "    transforms.ToTensor() \n",
    "])\n",
    " \n",
    "# 读取数据\n",
    "root = 'image'\n",
    "train_dataset = datasets.ImageFolder(root + '/train', transform)\n",
    "test_dataset = datasets.ImageFolder(root + '/test', transform)\n",
    " \n",
    "# 导入数据\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['cat', 'dog']\n",
      "{'cat': 0, 'dog': 1}\n"
     ]
    }
   ],
   "source": [
    "classes = train_dataset.classes\n",
    "classes_index = train_dataset.class_to_idx\n",
    "print(classes)\n",
    "print(classes_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VGG(\n",
      "  (features): Sequential(\n",
      "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (1): ReLU(inplace=True)\n",
      "    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (3): ReLU(inplace=True)\n",
      "    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (6): ReLU(inplace=True)\n",
      "    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (8): ReLU(inplace=True)\n",
      "    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (11): ReLU(inplace=True)\n",
      "    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (13): ReLU(inplace=True)\n",
      "    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (15): ReLU(inplace=True)\n",
      "    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (18): ReLU(inplace=True)\n",
      "    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (20): ReLU(inplace=True)\n",
      "    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (22): ReLU(inplace=True)\n",
      "    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (25): ReLU(inplace=True)\n",
      "    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (27): ReLU(inplace=True)\n",
      "    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (29): ReLU(inplace=True)\n",
      "    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
      "  (classifier): Sequential(\n",
      "    (0): Linear(in_features=25088, out_features=4096, bias=True)\n",
      "    (1): ReLU(inplace=True)\n",
      "    (2): Dropout(p=0.5, inplace=False)\n",
      "    (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
      "    (4): ReLU(inplace=True)\n",
      "    (5): Dropout(p=0.5, inplace=False)\n",
      "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "model = models.vgg16(pretrained = True)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 如果我们想只训练模型的全连接层\n",
    "# for param in model.parameters():\n",
    "#     param.requires_grad = False\n",
    "    \n",
    "# 构建新的全连接层\n",
    "model.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 100),\n",
    "                                       torch.nn.ReLU(),\n",
    "                                       torch.nn.Dropout(p=0.5),\n",
    "                                       torch.nn.Linear(100, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "LR = 0.0001\n",
    "# 定义代价函数\n",
    "entropy_loss = nn.CrossEntropyLoss()\n",
    "# 定义优化器\n",
    "optimizer = optim.SGD(model.parameters(), LR, momentum=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    for i, data in enumerate(train_loader):\n",
    "        # 获得数据和对应的标签\n",
    "        inputs, labels = data\n",
    "        # 获得模型预测结果,(64,10)\n",
    "        out = model(inputs)\n",
    "        # 交叉熵代价函数out(batch,C),labels(batch)\n",
    "        loss = entropy_loss(out, labels)\n",
    "        # 梯度清0\n",
    "        optimizer.zero_grad()\n",
    "        # 计算梯度\n",
    "        loss.backward()\n",
    "        # 修改权值\n",
    "        optimizer.step()\n",
    "\n",
    "\n",
    "def test():\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    for i, data in enumerate(test_loader):\n",
    "        # 获得数据和对应的标签\n",
    "        inputs, labels = data\n",
    "        # 获得模型预测结果\n",
    "        out = model(inputs)\n",
    "        # 获得最大值,以及最大值所在的位置\n",
    "        _, predicted = torch.max(out, 1)\n",
    "        # 预测正确的数量\n",
    "        correct += (predicted == labels).sum()\n",
    "    print(\"Test acc: {0}\".format(correct.item() / len(test_dataset)))\n",
    "    \n",
    "    correct = 0\n",
    "    for i, data in enumerate(train_loader):\n",
    "        # 获得数据和对应的标签\n",
    "        inputs, labels = data\n",
    "        # 获得模型预测结果\n",
    "        out = model(inputs)\n",
    "        # 获得最大值,以及最大值所在的位置\n",
    "        _, predicted = torch.max(out, 1)\n",
    "        # 预测正确的数量\n",
    "        correct += (predicted == labels).sum()\n",
    "    print(\"Train acc: {0}\".format(correct.item() / len(train_dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0\n",
      "Test acc: 0.785\n",
      "Train acc: 0.825\n",
      "epoch: 1\n",
      "Test acc: 0.885\n",
      "Train acc: 0.865\n",
      "epoch: 2\n",
      "Test acc: 0.845\n",
      "Train acc: 0.8675\n",
      "epoch: 3\n",
      "Test acc: 0.945\n",
      "Train acc: 0.885\n",
      "epoch: 4\n",
      "Test acc: 0.89\n",
      "Train acc: 0.8675\n",
      "epoch: 5\n",
      "Test acc: 0.93\n",
      "Train acc: 0.945\n",
      "epoch: 6\n",
      "Test acc: 0.915\n",
      "Train acc: 0.93\n",
      "epoch: 7\n",
      "Test acc: 0.925\n",
      "Train acc: 0.935\n",
      "epoch: 8\n",
      "Test acc: 0.9\n",
      "Train acc: 0.9325\n",
      "epoch: 9\n",
      "Test acc: 0.91\n",
      "Train acc: 0.9425\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(0, 10):\n",
    "    print('epoch:',epoch)\n",
    "    train()\n",
    "    test()\n",
    "    \n",
    "torch.save(model.state_dict(), 'cat_dog_cnn.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}

这是一个Jupyter Notebook文件的内容,主要实现了使用预训练的VGG16模型对猫和狗的图像进行分类任务。以下是对每个部分的详细解释:

1. 导入必要的库

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
  • 导入了torch相关的库用于深度学习操作,包括神经网络定义(nn)、优化器(optim)以及预定义的模型(models)和数据处理工具(datasetstransforms)。还导入了DataLoader用于加载数据。

2. 数据预处理和加载

python 复制代码
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomRotation(20),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])
root = 'image'
train_dataset = datasets.ImageFolder(root + '/train', transform)
test_dataset = datasets.ImageFolder(root + '/test', transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=True)
  • 定义了数据预处理的操作,包括随机裁剪、旋转、水平翻转以及转换为张量。
  • 从指定的image文件夹下的traintest子文件夹中读取图像数据,并应用预处理操作。
  • 使用DataLoader分别创建了训练集和测试集的数据加载器,设置了批量大小为8,并打乱数据顺序。

3. 查看类别信息

python 复制代码
classes = train_dataset.classes
classes_index = train_dataset.class_to_idx
print(classes)
print(classes_index)
  • 获取训练集中的类别名称列表和类别到索引的映射字典,并打印出来。这里显示有两个类别:猫和狗,以及它们对应的索引。

4. 加载预训练模型

python 复制代码
model = models.vgg16(pretrained=True)
print(model)
  • 加载了预训练的VGG16模型,并打印出模型的结构,包括卷积层、池化层和全连接层等信息。

5. 修改模型结构(可选部分)

python 复制代码
# 如果我们想只训练模型的全连接层
# for param in model.parameters():
#     param.requires_grad = False

# 构建新的全连接层
model.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 100),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(100, 2))
  • 这部分代码展示了如何修改模型结构。首先注释掉了冻结所有层的代码,如果需要只训练全连接层,可以取消注释。然后重新定义了模型的全连接层部分,将输出类别改为2(猫和狗)。

6. 定义学习率、损失函数和优化器

python 复制代码
LR = 0.0001
entropy_loss = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), LR, momentum=0.9)
  • 定义了学习率为0.0001,使用交叉熵损失函数,并选择随机梯度下降(SGD)作为优化器,设置了动量为0.9。

7. 定义训练和测试函数

python 复制代码
def train():
  ...
def test():
  ...
  • train函数实现了模型的训练过程,包括获取数据和标签、计算模型输出、计算损失、梯度清零、反向传播和更新权重等步骤。
  • test函数实现了模型在测试集和训练集上的评估过程,计算预测正确的数量,并打印出准确率。

8. 模型训练和保存

python 复制代码
for epoch in range(0, 10):
    print('epoch:', epoch)
    train()
    test()
torch.save(model.state_dict(), 'cat_dog_cnn.pth')
  • 进行10个轮次的训练和测试,每个轮次打印出当前轮次编号,并分别调用traintest函数。
  • 训练完成后,保存模型的权重到cat_dog_cnn.pth文件中。
相关推荐
小陈phd2 小时前
深度学习之目标检测——RCNN
python·深度学习·算法·计算机视觉
牧歌悠悠2 小时前
【深度学习】 零基础介绍卷积神经网络(CNN)
人工智能·深度学习·cnn·深度优先
goomind4 小时前
深度学习实战智能交通计数
深度学习·yolo·计算机视觉·目标跟踪·bytetrack·deepsort·撞线计数
小陈phd4 小时前
深度学习之超分辨率算法——FRCNN
python·深度学习·神经网络
lly_csdn1234 小时前
【图像配准】方法总结
人工智能·python·深度学习·计算机视觉·图像配准
CSBLOG5 小时前
Day27 - 大模型微调,LLaMA搭建
人工智能·深度学习·llama
小嗷犬5 小时前
【论文笔记】CLIP-guided Prototype Modulating for Few-shot Action Recognition
论文阅读·人工智能·深度学习·神经网络·多模态
机器视觉知识推荐、就业指导8 小时前
深度学习OCR与传统OCR对比实验:图像数据集联系博主获取
人工智能·深度学习·ocr
qq7422349849 小时前
从零搭建CBAM、SENet、STN、transformer、mobile_vit、simple_vit、vit模型(Pytorch代码示例)
pytorch·深度学习·transformer