基于RNNs(LSTM, GRU)的红点位置检测(pytorch)

文章目录

  • [1 项目背景](#1 项目背景)
  • [2 数据集](#2 数据集)
  • [3 思路](#3 思路)
  • [4 结果](#4 结果)
  • [5 代码](#5 代码)
  • [6 总结](#6 总结)

1 项目背景

需要在图片精确识别三跟红线所在的位置,并输出这三个像素的位置。

其中,每跟红线占据不止一个像素,并且像素颜色也并不是饱和度和亮度极高的红黑配色,每个红线放大后可能是这样的。

而我们的目标是精确输出每个红点的位置,需要精确到像素。也就是说,对于每根红线,模型需要输出橙色箭头所指的像素而不是蓝色箭头所指的像素的位置。

2 数据集

仅仅只是个实验,所以就用代码合成了个数据集,每个数据集规模在15000张图片左右,在没有加入噪音的情况下,每个样本预览如图所示:

加入噪音后,每个样本的预览如下图所示:

图中黑色部分包含比较弱的噪声,并非完全为黑色。

数据集包含两个文件,一个是文件夹,里面包含了jpg压缩的图像数据:

另一个是csv文件,里面包含了每个图像的名字以及3根红线所在的像素的位置。

3 思路

直觉上,一般会想到 CNN 来做这个算法,但是这里我更想采取 RNN 来完成任务。因为在思路上,上面的图片完全可以简化为高度为1,长度为length的图像。此时,如果忽略不计高度,这个图像其实就可以退化为一个长度为 length,特征数为3(RGB)的序列,因此可以采用RNN来抓取数据的序列信息完成这个任务。

站在 NLP 的角度来看,这个就像是一个单词数为length的句子,词向量为3维,任务是去寻找最"奇怪"的三个单词所在的位置。因此想法是用整个序列训练 LSTM 模型,并且取每个length的结果。通过全连接层fc1得到每一个单词的奇怪程度后,通过全连接层f2输出每句话最奇怪的三个单词的位置。

因为 label 是三个位置,所以这里的输出并不能是简单的三个点

这里有两个需要注意的点:

1. LSTM 的结果

LSTM的结果应该取output而不是output[:, -1, :],我一开始认为后者是合理的,因为后者包含了最后一步之前所有步骤的序列信息。但事实上,应该取前者。因为后者是对整个序列信息的一个汇总,而前者才是对每个单词的一个评价。

output[:, -1, :]有点类似于做文本分类任务,去看每个句子的情感倾向啊之类的;而取output对标的任务更像是词性标注,实体抽取。

2. 位置信息的插入

实验的时候一直输不出想要的结果,debug之后发现 LSTM 没什么问题,但是没有办法把lstm的结果映射成红点的位置。因此对 LSTM 的结果再合并上一层位置信息,方便模型计算出位置。

以 LSTM 为例,模型的流程基本可以简化为下面这个流程图:

4 结果

在低噪声低长度的图片里,该模型预测效果尚可,但是仍然不够精确。

在长度较长的图片和低噪声的情况下,模型预测的不精确的问题被进一步放大。

5 代码

LSTM结构:

python 复制代码
class Config(object):
    def __init__(self, device, csv_file, img_dir, width):
        self.device = device
        self.model_name = 'Bi_LSTM'
        self.input_size = 3
        self.hidden_size = 128
        self.num_layers = 3
        self.epoch_number = 200
        self.batch_size = 32
        self.learn_rate = 0.001
        self.csv_file = csv_file
        self.img_dir = img_dir
        self.width = width


class LSTM(nn.Module):
    def __init__(self, config):
        super(LSTM, self).__init__()
        self.input_size = config.input_size
        self.hidden_size = config.hidden_size
        self.num_layers = config.num_layers
        self.device = config.device

        self.lstm = nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,
                            batch_first=True, bidirectional=True, dropout=0.5)

        self.fc1 = nn.Linear(self.hidden_size * 2, 1)
        self.fc2 = nn.Linear(config.width * 2, 3)
        self.sigmoid = nn.Sigmoid()
        self.scale = config.width
        self.device = config.device

    def forward(self, x):
        # LSTM need (batch_size, seq_length, input_size): (batch_size, 3, height, width) -> (batch_size, width, 3)
        x = x.squeeze(2)
        x = x.permute(0, 2, 1)
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        output, _ = self.lstm(x, (h0, c0))
        scores = self.fc1(output).squeeze(-1)  # shape: (batch_size, 270, 1)
        positional_info = torch.arange(self.scale).unsqueeze(0).repeat(scores.shape[0], 1).to(
            x.device)  # (batch_size, seq_length)
        positional_info = positional_info / self.scale
        combined_info = torch.cat((scores, positional_info.float()), dim=-1)  # (batch_size, seq_length * 2)

        # Predict top 3 positions
        predicted_positions = self.fc2(combined_info)
        predicted_positions = self.sigmoid(predicted_positions)
        scaled_predicted_positions = predicted_positions * self.scale
        final_predicted_positions = torch.clamp(scaled_predicted_positions, min=0, max=self.scale - 1)
        return final_predicted_positions

GRU结构

python 复制代码
class GRU(nn.Module):
    def __init__(self, config):
        super(GRU, self).__init__()
        self.input_size = config.input_size
        self.hidden_size = config.hidden_size
        self.num_layers = config.num_layers
        self.device = config.device

        self.gru = nn.GRU(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers,
                            batch_first=True, bidirectional=True, dropout=0.5)

        self.fc1 = nn.Linear(self.hidden_size * 2, 1)
        self.fc2 = nn.Sequential(
            nn.Linear(config.width * 2, config.width),
            nn.ReLU(),
            nn.Linear(config.width, 3), # predict 3 points
            nn.Sigmoid(),
        )
        self.scale = config.width
        self.device = config.device

    def forward(self, x):
        # LSTM need (batch_size, seq_length, input_size): (batch_size, 3, height, width) -> (batch_size, width, 3)
        x = x.squeeze(2)
        x = x.permute(0, 2, 1)
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        output, _ = self.gru(x, h0)
        scores = self.fc1(output).squeeze(-1)  # shape: (batch_size, 1080, 1)

        positional_info = torch.arange(self.scale).unsqueeze(0).repeat(scores.shape[0], 1).to(x.device)  # (batch_size, seq_length)
        positional_info = positional_info / self.scale
        combined_info = torch.cat((scores, positional_info.float()), dim=-1)  # (batch_size, seq_length * 2)

        # Predict top 3 positions
        predicted_positions = self.fc2(combined_info)
        # predicted_positions = self.sigmoid(predicted_positions)
        scaled_predicted_positions = predicted_positions * self.scale
        final_predicted_positions = torch.clamp(scaled_predicted_positions, min=0, max=self.scale - 1)
        return final_predicted_positions

6 总结

用RNN检测红点或许是个好主意,但是就目前的结果来,在精确度上还是有很大的问题。

个人的想法是分为三个方向的:

  1. 优化模型,提高精度,加 attention。或者直接换其他模型(暴力 vision transformer?)
  2. 优化判断机制。我个人是认为RNNs的学习能力足以找到几个红点的位置,但是后面几个全连接层 fc1,fc2 可能没有很好地把这些信息转化成位置信息。因此可以优化 fc1,fc2 的设计,加深层数
  3. 处理数据,可能RGB的数据不太适合,考虑把图片分别用 RGB, HSV 多种格式读入后拼接,发挥神经网络处理高维数据的能力。或者加入新设计的特征,比如与前一个像素和后一个像素在R通道的差等等。还可以引入自编码器等方法进一步提纯信息。再或者是优化位置信息的嵌入方式,抛弃简单的位置信息合并,采取位置嵌入。

这几个想法我都会在后面的文章中尝试。

路过的大佬有什么建议 ball ball 在评论区打出来

相关推荐
小锋学长生活大爆炸23 分钟前
【教程】Cupy、Numpy、Torch互相转换
pytorch·numpy·cupy
拓端研究室TRL3 小时前
Python注意力机制Attention下CNN-LSTM-ARIMA混合模型预测中国银行股票价格|附数据代码...
开发语言·人工智能·python·cnn·lstm
qq_273900235 小时前
torch.stack 张量维度的变化
人工智能·pytorch·深度学习
啊文师兄7 小时前
使用 Pytorch 搭建视频车流量检测资源(基于YOLO)
人工智能·pytorch·yolo
使者大牙7 小时前
【LLM学习笔记】第三篇:模型微调及LoRA介绍(附PyTorch实例)
人工智能·pytorch·python·深度学习
scdifsn7 小时前
动手学深度学习10.1. 注意力提示-笔记&练习(PyTorch)
pytorch·笔记·深度学习·注意力机制·注意力提示
小毕超8 小时前
基于 PyTorch 从零手搓一个GPT Transformer 对话大模型
pytorch·gpt·transformer
YRr YRr20 小时前
ubuntu20.04 解决Pytorch默认安装CPU版本的问题
人工智能·pytorch·python
代码猪猪傻瓜coding21 小时前
pytorch torch.tile用法
人工智能·pytorch·python