基于RNNs(LSTM, GRU)的红点位置检测(pytorch)

文章目录

  • [1 项目背景](#1 项目背景)
  • [2 数据集](#2 数据集)
  • [3 思路](#3 思路)
  • [4 结果](#4 结果)
  • [5 代码](#5 代码)
  • [6 总结](#6 总结)

1 项目背景

需要在图片精确识别三跟红线所在的位置,并输出这三个像素的位置。

其中,每跟红线占据不止一个像素,并且像素颜色也并不是饱和度和亮度极高的红黑配色,每个红线放大后可能是这样的。

而我们的目标是精确输出每个红点的位置,需要精确到像素。也就是说,对于每根红线,模型需要输出橙色箭头所指的像素而不是蓝色箭头所指的像素的位置。

2 数据集

仅仅只是个实验,所以就用代码合成了个数据集,每个数据集规模在15000张图片左右,在没有加入噪音的情况下,每个样本预览如图所示:

加入噪音后,每个样本的预览如下图所示:

图中黑色部分包含比较弱的噪声,并非完全为黑色。

数据集包含两个文件,一个是文件夹,里面包含了jpg压缩的图像数据:

另一个是csv文件,里面包含了每个图像的名字以及3根红线所在的像素的位置。

3 思路

直觉上,一般会想到 CNN 来做这个算法,但是这里我更想采取 RNN 来完成任务。因为在思路上,上面的图片完全可以简化为高度为1,长度为length的图像。此时,如果忽略不计高度,这个图像其实就可以退化为一个长度为 length,特征数为3(RGB)的序列,因此可以采用RNN来抓取数据的序列信息完成这个任务。

站在 NLP 的角度来看,这个就像是一个单词数为length的句子,词向量为3维,任务是去寻找最"奇怪"的三个单词所在的位置。因此想法是用整个序列训练 LSTM 模型,并且取每个length的结果。通过全连接层fc1得到每一个单词的奇怪程度后,通过全连接层f2输出每句话最奇怪的三个单词的位置。

因为 label 是三个位置,所以这里的输出并不能是简单的三个点

这里有两个需要注意的点:

1. LSTM 的结果

LSTM的结果应该取output而不是output[:, -1, :],我一开始认为后者是合理的,因为后者包含了最后一步之前所有步骤的序列信息。但事实上,应该取前者。因为后者是对整个序列信息的一个汇总,而前者才是对每个单词的一个评价。

output[:, -1, :]有点类似于做文本分类任务,去看每个句子的情感倾向啊之类的;而取output对标的任务更像是词性标注,实体抽取。

2. 位置信息的插入

实验的时候一直输不出想要的结果,debug之后发现 LSTM 没什么问题,但是没有办法把lstm的结果映射成红点的位置。因此对 LSTM 的结果再合并上一层位置信息,方便模型计算出位置。

以 LSTM 为例,模型的流程基本可以简化为下面这个流程图:

4 结果

在低噪声低长度的图片里,该模型预测效果尚可,但是仍然不够精确。

在长度较长的图片和低噪声的情况下,模型预测的不精确的问题被进一步放大。

5 代码

LSTM结构:

python 复制代码
class Config(object):
    def __init__(self, device, csv_file, img_dir, width):
        self.device = device
        self.model_name = 'Bi_LSTM'
        self.input_size = 3
        self.hidden_size = 128
        self.num_layers = 3
        self.epoch_number = 200
        self.batch_size = 32
        self.learn_rate = 0.001
        self.csv_file = csv_file
        self.img_dir = img_dir
        self.width = width


class LSTM(nn.Module):
    def __init__(self, config):
        super(LSTM, self).__init__()
        self.input_size = config.input_size
        self.hidden_size = config.hidden_size
        self.num_layers = config.num_layers
        self.device = config.device

        self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,
                            batch_first=True, bidirectional=True, dropout=0.5)

        self.fc1 = nn.Linear(self.hidden_size * 2, 1)
        self.fc2 = nn.Linear(config.width * 2, 3)
        self.sigmoid = nn.Sigmoid()
        self.scale = config.width
        self.device = config.device

    def forward(self, x):
        # LSTM need (batch_size, seq_length, input_size): (batch_size, 3, height, width) -> (batch_size, width, 3)
        x = x.squeeze(2)
        x = x.permute(0, 2, 1)
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        output, _ = self.lstm(x, (h0, c0))
        scores = self.fc1(output).squeeze(-1)  # shape: (batch_size, 270, 1)
        positional_info = torch.arange(self.scale).unsqueeze(0).repeat(scores.shape[0], 1).to(
            x.device)  # (batch_size, seq_length)
        positional_info = positional_info / self.scale
        combined_info = torch.cat((scores, positional_info.float()), dim=-1)  # (batch_size, seq_length * 2)

        # Predict top 3 positions
        predicted_positions = self.fc2(combined_info)
        predicted_positions = self.sigmoid(predicted_positions)
        scaled_predicted_positions = predicted_positions * self.scale
        final_predicted_positions = torch.clamp(scaled_predicted_positions, min=0, max=self.scale - 1)
        return final_predicted_positions

GRU结构

python 复制代码
class GRU(nn.Module):
    def __init__(self, config):
        super(GRU, self).__init__()
        self.input_size = config.input_size
        self.hidden_size = config.hidden_size
        self.num_layers = config.num_layers
        self.device = config.device

        self.gru = nn.GRU(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,
                            batch_first=True, bidirectional=True, dropout=0.5)

        self.fc1 = nn.Linear(self.hidden_size * 2, 1)
        self.fc2 = nn.Sequential(
            nn.Linear(config.width * 2, config.width),
            nn.ReLU(),
            nn.Linear(config.width, 3), # predict 3 points
            nn.Sigmoid(),
        )
        self.scale = config.width
        self.device = config.device

    def forward(self, x):
        # LSTM need (batch_size, seq_length, input_size): (batch_size, 3, height, width) -> (batch_size, width, 3)
        x = x.squeeze(2)
        x = x.permute(0, 2, 1)
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        output, _ = self.gru(x, h0)
        scores = self.fc1(output).squeeze(-1)  # shape: (batch_size, 1080, 1)

        positional_info = torch.arange(self.scale).unsqueeze(0).repeat(scores.shape[0], 1).to(x.device)  # (batch_size, seq_length)
        positional_info = positional_info / self.scale
        combined_info = torch.cat((scores, positional_info.float()), dim=-1)  # (batch_size, seq_length * 2)

        # Predict top 3 positions
        predicted_positions = self.fc2(combined_info)
        # predicted_positions = self.sigmoid(predicted_positions)
        scaled_predicted_positions = predicted_positions * self.scale
        final_predicted_positions = torch.clamp(scaled_predicted_positions, min=0, max=self.scale - 1)
        return final_predicted_positions

6 总结

用RNN检测红点或许是个好主意,但是就目前的结果来,在精确度上还是有很大的问题。

个人的想法是分为三个方向的:

  1. 优化模型,提高精度,加 attention。或者直接换其他模型(暴力 vision transformer?)
  2. 优化判断机制。我个人是认为RNNs的学习能力足以找到几个红点的位置,但是后面几个全连接层 fc1,fc2 可能没有很好地把这些信息转化成位置信息。因此可以优化 fc1,fc2 的设计,加深层数
  3. 处理数据,可能RGB的数据不太适合,考虑把图片分别用 RGB, HSV 多种格式读入后拼接,发挥神经网络处理高维数据的能力。或者加入新设计的特征,比如与前一个像素和后一个像素在R通道的差等等。还可以引入自编码器等方法进一步提纯信息。再或者是优化位置信息的嵌入方式,抛弃简单的位置信息合并,采取位置嵌入。

这几个想法我都会在后面的文章中尝试。

路过的大佬有什么建议 ball ball 在评论区打出来

相关推荐
机器学习之心43 分钟前
LSTM-SVM时序预测 | Matlab基于LSTM-SVM基于长短期记忆神经网络-支持向量机时间序列预测
神经网络·支持向量机·lstm
love you joyfully9 小时前
目标检测与R-CNN——pytorch与paddle实现目标检测与R-CNN
人工智能·pytorch·目标检测·cnn·paddle
这个男人是小帅16 小时前
【AutoDL】通过【SSH远程连接】【vscode】
运维·人工智能·pytorch·vscode·深度学习·ssh
四口鲸鱼爱吃盐16 小时前
Pytorch | 利用MI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
WeeJot嵌入式19 小时前
长短期记忆网络(LSTM):深度学习中的序列数据处理利器
人工智能·深度学习·lstm
沅_Yuan19 小时前
基于LSTM长短期记忆神经网络的多分类预测【MATLAB】
神经网络·分类·lstm
微臣愚钝1 天前
【作业】LSTM
人工智能·机器学习·lstm
qq_273900231 天前
pytorch repeat方法和expand方法的区别
人工智能·pytorch·python
AI程序猿人1 天前
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
人工智能·pytorch·深度学习·自然语言处理·大模型·transformer·llms
四口鲸鱼爱吃盐2 天前
Pytorch | 从零构建Vgg对CIFAR10进行分类
人工智能·pytorch·分类