基于 RNN(GRU, LSTM)+CNN 的红点位置检测(pytorch)

文章目录

  • [1 项目背景](#1 项目背景)
  • [2 数据集](#2 数据集)
  • [3 思路](#3 思路)
  • [4 实验结果](#4 实验结果)
  • [5 代码](#5 代码)

1 项目背景

需要在图片精确识别三跟红线所在的位置,并输出这三个像素的位置。

其中,每跟红线占据不止一个像素,并且像素颜色也并不是饱和度和亮度极高的红黑配色,每个红线放大后可能是这样的。

而我们的目标是精确输出每个红点的位置,需要精确到像素。也就是说,对于每根红线,模型需要输出橙色箭头所指的像素而不是蓝色箭头所指的像素的位置。

之前尝试过纯 RNN 的实验,也试过在 RNN 前用 CNN,给数据带上卷积的信息。在图片长度为1080、低噪声环境时,对比实验的结果如下:

实验 loss 完全准确的点
GRU 129.6641 1762.0/9000 (20%)
LSTM 249.2053 1267.0/9000 (14%)
CNN+GRU 1419.5781 601.0/9000 (7%)
CNN+LSTM 1166.4599 762.0/9000 (8%)

对的,这个方法甚至起到反效果了。问了做过类似尝试的同事,他表示效果其实跟直接使用 RNN 区别不大。

2 数据集

还是之前那个代码合成的数据集数据集,每个数据集规模在15000张图片左右,在没有加入噪音的情况下,每个样本预览如图所示:

加入噪音后,每个样本的预览如下图所示:

图中黑色部分包含比较弱的噪声,并非完全为黑色。

数据集包含两个文件,一个是文件夹,里面包含了jpg压缩的图像数据:

另一个是csv文件,里面包含了每个图像的名字以及3根红线所在的像素的位置。

3 思路

之前 CNN+RNN 的思路是把 CNN 作为一个特征提取器,RNN 作为决策模型。这次主要是想看看直接用 CNN 做决策会比 RNN 强多少,因为其实 CNN 在这类任务上的优势应该会大很多。也就是说把RNN当作一个特征提取器处理图片数据,再用CNN找到这三个点的位置。按照这个思路,RNN+CNN 的处理流程如下:

然后再在模型上加一点Attention:

4 实验结果

实验 train loss val loss test loss test 完全准确样本 点1平均偏移量 点2平均偏移量 点3平均偏移量
GRU 17.1150 16.2752 233.5694 536.0/4500 (12%) 3.3181 3.0701 3.3957
LSTM 378.7690 47.6191 367.7041 499.0/4500 (11%) 4.2166 3.6437 4.0777
CNN 6.6049 13.6372 231.4501 650.0/4500 (14%) 2.1816 3.0884 3.9680
CNN+RNN 5.3883 6.6833 76.0979 821.0/4500 (18%) 1.8977 2.5229 1.8854
Multi-Head Attention + RNN 174.5019 18.1041 149.0297 645.0/4500 (14%) 2.6598 3.2243 2.4309
RNN+CNN 2.6558 1.7714 28.4280 1318.0/4500 (29%) 1.4926 1.3679 1.5234
RNN+CNN+Attention 6.5938 42.4060 41.9453 1264.0/4500 (28%) 1.5860 1.5557 1.8804

GRU那个妥妥过拟合,CNN 做决策效果确实暴打之前的 RNN,只能说卷积还是适合图像类的任务,RNN 这种针对序列信息的可能效果还是有限。

5 代码

GRU+CNN+Attention

python 复制代码
import torch
import torch.nn as nn

class Config(object):
    def __init__(self, device, csv_file, img_dir, width, input_size):
        self.device = device
        self.model_name = 'GRU_CNN_Attention'
        self.input_size = input_size
        self.hidden_size = 128
        self.num_layers = 2
        self.epoch_number = 150
        self.batch_size = 32
        self.learn_rate = 0.0002
        self.csv_file = csv_file
        self.img_dir = img_dir
        self.width = width

class GRU_CNN(nn.Module):
    def __init__(self, config):
        super(GRU_CNN, self).__init__()
        self.hidden_size = config.hidden_size
        self.num_layers = config.num_layers
        self.device = config.device
        self.sequence_length = config.width
        self.channels = config.input_size

        self.gru = nn.GRU(input_size=self.channels, hidden_size=self.hidden_size, num_layers=self.num_layers,
                          batch_first=True, bidirectional=True, dropout=0.6)

        self.attention = nn.MultiheadAttention(embed_dim=2 * self.hidden_size, num_heads=4, batch_first=True)

        self.fc = nn.Linear(2 * self.hidden_size, 4)

        self.conv1 = nn.Conv2d(4 + self.channels, 32, kernel_size=(1, 3), stride=1, padding=(0, 1))
        self.se1 = SEAttention(32)
        self.relu = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(1, 3), stride=1, padding=(0, 1))
        self.se2 = SEAttention(64)
        self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
        self.conv3 = nn.Conv2d(64, 128, kernel_size=(1, 3), stride=1, padding=(0, 1))
        self.se3 = SEAttention(128)
        self.pool3 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
        self.fc1 = nn.Linear(128 * (self.sequence_length // 8), 128)
        self.fc2 = nn.Linear(128, 3)

    def forward(self, x):
        rnn_x = x.squeeze(2).permute(0, 2, 1)
        # x = x + self.pos_encoding[:, :x.size(1), :].to(x.device)
        h0 = torch.zeros(self.num_layers * 2, rnn_x.size(0), self.hidden_size).to(x.device)
        gru_output, _ = self.gru(rnn_x, h0) # batch_size, sequence_length, 2 * hidden_size
        context_vector, _ = self.attention(gru_output, gru_output, gru_output) # batch_size, sequence_length, 2 * hidden_size
        gru_output_fc = self.fc(context_vector)  # batch_size, sequence_length, 3
        gru_output_fc = gru_output_fc.transpose(1, 2).unsqueeze(2)  # batch_size, 3, 1, sequence_length

        x = torch.cat((x, gru_output_fc), dim=1)
        x = self.pool1(self.se1(self.relu(self.conv1(x))))
        x = self.pool2(self.se2(self.relu(self.conv2(x))))
        x = self.pool3(self.se3(self.relu(self.conv3(x))))
        x = x.view(-1, 128 * (self.sequence_length // 8))
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class SEAttention(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

GRU+CNN

python 复制代码
import torch
import torch.nn as nn

class Config(object):
    def __init__(self, device, csv_file, img_dir, width, input_size):
        self.device = device
        self.model_name = 'GRU_CNN'
        self.input_size = input_size
        self.hidden_size = 128
        self.num_layers = 2
        self.epoch_number = 100
        self.batch_size = 32
        self.learn_rate = 0.001
        self.csv_file = csv_file
        self.img_dir = img_dir
        self.width = width

class GRU_CNN(nn.Module):
    def __init__(self, config):
        super(GRU_CNN, self).__init__()
        self.hidden_size = config.hidden_size
        self.num_layers = config.num_layers
        self.device = config.device
        self.sequence_length = config.width
        self.channels = config.input_size

        self.gru = nn.GRU(input_size=self.channels, hidden_size=self.hidden_size, num_layers=self.num_layers,
                          batch_first=True, bidirectional=True, dropout=0.6)

        self.fc = nn.Linear(2 * self.hidden_size, 3)

        self.conv1 = nn.Conv2d(3 + self.channels, 32, kernel_size=(1, 3), stride=1, padding=(0, 1))
        self.relu = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(1, 3), stride=1, padding=(0, 1))
        self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
        self.conv3 = nn.Conv2d(64, 128, kernel_size=(1, 3), stride=1, padding=(0, 1))
        self.pool3 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
        self.fc1 = nn.Linear(128 * (self.sequence_length // 8), 128)
        self.fc2 = nn.Linear(128, 3)

    def forward(self, x):
        rnn_x = x.squeeze(2).permute(0, 2, 1)
        # x = x + self.pos_encoding[:, :x.size(1), :].to(x.device)
        h0 = torch.zeros(self.num_layers * 2, rnn_x.size(0), self.hidden_size).to(x.device)
        gru_output, _ = self.gru(rnn_x, h0) # batch_size, sequence_length, 2 * hidden_size
        gru_output_fc = self.fc(gru_output)  # batch_size, sequence_length, 3
        gru_output_fc = gru_output_fc.transpose(1, 2).unsqueeze(2)  # batch_size, 3, 1, sequence_length

        x = torch.cat((x, gru_output_fc), dim=1)

        x = self.pool1(self.relu(self.conv1(x)))
        x = self.pool2(self.relu(self.conv2(x)))
        x = self.pool3(self.relu(self.conv3(x)))
        x = x.view(-1, 128 * (self.sequence_length // 8))
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
相关推荐
Seeklike3 小时前
12.04 深度学习-用CNN做图像分类+训练可视化
深度学习·分类·cnn
pzx_0013 小时前
【时间序列预测】基于Pytorch实现CNN_LSTM算法
人工智能·pytorch·python·算法·cnn·lstm
被制作时长两年半的个人练习生10 小时前
【pytorch】pytorch的缓存策略——计算机分层理论的另一大例证
人工智能·pytorch·python
曼城周杰伦16 小时前
自然语言处理:第七十二章 微软推出超越自己GraphRAG的LazyGraphRAG
人工智能·pytorch·神经网络·microsoft·自然语言处理·nlp
骑猪玩狗17 小时前
第N9周:seq2seq翻译实战-Pytorch复现-小白版
人工智能·pytorch·python
不灭蚊香18 小时前
YOLOv2 (You Only Look Once Version 2)
深度学习·神经网络·yolo·目标检测·计算机视觉·cnn
算力魔方AIPC18 小时前
PyTorch 2.5.1: Bugs修复版发布
人工智能·pytorch·python
Joyner201818 小时前
pytorch中有哪些归一化的方式?
人工智能·pytorch·python
Niuguangshuo19 小时前
PyTorch 实现动态输入
人工智能·pytorch·python
总有一天你的谜底会解开19 小时前
pytorch加载预训练权重失败
人工智能·pytorch·python