pytorch+CRNN实现

最近接触了一个仪表盘识别的项目,简单调研以后发现可以用CRNN来做。但是手边缺少仪表盘数据集,就先用ICDAR2013试了一下。

结果遇到了一系列坑。为了不使读者和自己在以后的日子继续遭罪。我把正确的代码发到下面了。

1)超参数请不要调整!!!!CRNN前期训练极其离谱,需要良好的调参,loss才会慢慢下降。

我给出了一个训练曲线,可以看到确实贼几把怪,七拐八拐的。

2)千万不要用百度开源的那个ctc!!!

网络代码:

cpp 复制代码
#crnn.py
import torch.nn as nn
import torch.nn.functional as F

class BidirectionalLSTM(nn.Module):
    # Inputs hidden units Out
    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()

        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)

        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)

        return output


class CRNN(nn.Module):
    #                   32    1   37     256
    def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
        super(CRNN, self).__init__()
        assert imgH % 16 == 0, 'imgH has to be a multiple of 16'

        ks = [3, 3, 3, 3, 3, 3, 2]
        ps = [1, 1, 1, 1, 1, 1, 0]
        ss = [1, 1, 1, 1, 1, 1, 1]
        nm = [64, 128, 256, 256, 512, 512, 512]

        cnn = nn.Sequential()

        def convRelu(i, batchNormalization=False):
            nIn = nc if i == 0 else nm[i - 1]
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        convRelu(0)
        cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64
        convRelu(1)
        cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32
        convRelu(2, True)
        convRelu(3)
        cnn.add_module('pooling{0}'.format(2),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x4x16
        convRelu(4, True)
        convRelu(5)
        cnn.add_module('pooling{0}'.format(3),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x16
        convRelu(6, True)  # 512x1x16

        self.cnn = cnn
        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, nh, nh),
            BidirectionalLSTM(nh, nh, nclass))

    def forward(self, input):
        # conv features
        #print('---forward propagation---')
        conv = self.cnn(input)
        b, c, h, w = conv.size()
        assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2) # b *512 * width
        conv = conv.permute(2, 0, 1)  # [w, b, c]
        output = F.log_softmax(self.rnn(conv), dim=2)
        return output

训练:

cpp 复制代码
#train.py
import os
import torch
import cv2
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import crnn
import time
import re
import matplotlib.pyplot as plt
dic={" ":0,"a":1,"b":2,"c":3,"d":4,"e":5,"f":6,"g":7,"h":8,"i":9,"j":10,"k":11,"l":12,"m":13,"n":14,"o":15,"p":16,"q":17,"r":18,"s":19,"t":20,"u":21,"v":22,"w":23,"x":24,"y":25,"z":26,
     "A":27,"B":28,"C":29,"D":30,"E":31,"F":32,"G":33,"H":34,"I":35,"J":36,"K":37,"L":38,"M":39,"N":40,"O":41,"P":42,"Q":43,"R":44,"S":45,"T":46,"U":47,"V":48,"W":49,"X":50,"Y":51,"Z":52}

STR=" abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
n_class=53
label_sources=r"E:\machine_learning\instrument\icdar_2013\Challenge2_Test_Task1_GT"
image_sources=r"E:\machine_learning\instrument\icdar_2013\Challenge2_Test_Task12_Images"
use_gpu = True
learning_rate = 0.0001
max_epoch = 100
batch_size = 20
# 调整图像大小和归一化操作
class resizeAndNormalize():
    def __init__(self, size, interpolation=cv2.INTER_LINEAR):
        # 注意对于opencv,size的格式是(w,h)
        self.size = size
        self.interpolation = interpolation
        # ToTensor属于类  """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
        self.toTensor = transforms.ToTensor()

    def __call__(self, image):
        # (x,y) 对于opencv来说,图像宽对应x轴,高对应y轴
        image = cv2.resize(image, self.size, interpolation=self.interpolation)
        # 转为tensor的数据结构
        image = self.toTensor(image)
        # 对图像进行归一化操作
        #image = image.sub_(0.5).div_(0.5)
        return image

def load_data(label_folder,image_folder,label_suffix_name=".txt",image_suffix_name=".jpg"):
    image_file,label_file,num_file=[],[],[]
    for parent_folder, _, file_names in os.walk(label_folder):
        # 遍历当前子文件夹中的所有文件
        for file_name in file_names:
            # 只处理图片文件
            # if file_name.endswith(('jpg', 'jpeg', 'png', 'gif')):#提取jpg、jpeg等格式的文件到指定目录
            if file_name.endswith((label_suffix_name)):  # 提取json格式的文件到指定目录
                # 构造源文件路径和目标文件路径
                a,b=file_name.split("gt_")
                c,d=b.split(label_suffix_name)
                image_name=image_folder + "\\" + c + image_suffix_name
                if os.path.exists(image_name):
                    label_name = label_folder + "\\" + file_name
                    txt=open(label_name,'rb')
                    txtl=txt.readlines()
                    for line in range(len(txtl)):
                        image_file.append(image_name)
                        label_file.append(label_name)
                        num_file.append(line)
    return image_file,label_file,num_file

def zl2lable(zl):
    label_list=[]
    for str in zl:
        label_list.append(dic[str])
    return label_list

class NewDataSet(Dataset):
    def __init__(self, label_source,image_source,train=True):
        super(NewDataSet, self).__init__()
        self.image_file,self.label_file,self.num_file= load_data(label_source,image_source)

    def __len__(self):
        return len(self.image_file)

    def __getitem__(self, index):
        txt = open(self.label_file[index], 'rb')
        img=cv2.imread(self.image_file[index],cv2.IMREAD_GRAYSCALE)
        wordL = txt.readlines()
        word=str(wordL[self.num_file[index]])
        pl = re.findall(r'\d+',word)
        zl = re.findall(r"[a-zA-Z]+", word)[1]  #1

        #img tensor
        x1, y1, x2, y2 = pl[:4]
        img= img[int(y1):int(y2),int(x1):int(x2), ]
        (height, width)=img.shape
        # 由于crnn网络输入图像的高为32,故需要resize原始图像的height
        size_height = 32
        # ratio = 32 / float(height)
        size_width =100
        transform = resizeAndNormalize((size_width, size_height))
        # 图像预处理
        imageTensor = transform(img)

        #label tensor
        l = zl2lable(zl)
        labelTensor = torch.IntTensor(l)
        return imageTensor,labelTensor




class CRNNDataSet(Dataset):
    def __init__(self, imageRoot, labelRoot):
        self.image_root = imageRoot
        self.image_dict = self.readfile(labelRoot)
        self.image_name = [fileName for fileName, _ in self.image_dict.items()]

    def __getitem__(self, index):
        image_path = os.path.join(self.image_root, self.image_name[index])
        keys = self.image_dict.get(self.image_name[index])
        label = [int(x) for x in keys]

        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        # if image is None:
        #     return None,None
        (height, width) = image.shape

        # 由于crnn网络输入图像的高为32,故需要resize原始图像的height
        size_height = 32
        ratio = 32 / float(height)
        size_width = int(ratio * width)
        transform = resizeAndNormalize((size_width, size_height))
        # 图像预处理
        image = transform(image)
        # 标签格式转换为IntTensor
        label = torch.IntTensor(label)
        return image, label

    def __len__(self):
        return len(self.image_name)

    def readfile(self, fileName):
        res = []
        with open(fileName, 'r') as f:
            lines = f.readlines()
            for line in lines:
                res.append(line.strip())
        dic = {}
        total = 0
        for line in res:
            part = line.split(' ')
            # 由于会存在训练过程中取图像的时候图像不存在导致异常,所以在初始化的时候就判断图像是否存在
            if not os.path.exists(os.path.join(self.image_root, part[0])):
                print(os.path.join(self.image_root, part[0]))
                total += 1
            else:
                dic[part[0]] = part[1:]
        print(total)

        return dic


trainData =NewDataSet(label_sources,image_sources)

trainLoader = DataLoader(dataset=trainData, batch_size=1, shuffle=True, num_workers=0)




# valData = CRNNDataSet(imageRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\images\\",
#                       labelRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\lables\data_t.txt")
#
# valLoader = DataLoader(dataset=valData, batch_size=1, shuffle=True, num_workers=1)


#
# def decode(preds):
#     pred = []
#     for i in range(len(preds)):
#         if preds[i] != 5989 and ((i == 5989) or (i != 5989 and preds[i] != preds[i - 1])):
#             pred.append(int(preds[i]))
#     return pred
#
#
def toSTR(l):
    str_l=[]
    if isinstance(l, int):
        l=[l]
    for i in range(len(l)):
        str_l.append(STR[l[i]])
    return str_l
def toRES(l):
    new_l=[]
    new_str=' '
    for i in range(len(l)):
        if(l[i]==' '):
            new_str = ' '
            continue
        elif new_str!=l[i]:
            new_l.append(l[i])
            new_str=l[i]
    return new_l

def val(model=torch.load("pytorch-crnn.pth")):
    # 将模式切换为验证评估模式
    loss_func = torch.nn.CTCLoss(blank=0, reduction='mean')
    model.eval()

    test_n=10



    for i, (data, label) in enumerate(trainLoader):
        if(i>test_n):
            break;
        output = model(data.cuda())
        pred_label=output.max(2)[1]
        input_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))
        target_lengths = torch.IntTensor([label.size(1)] * int(label.size(0)))
        # forward(self, log_probs, targets, input_lengths, target_lengths)
        #log_probs = output.log_softmax(2).requires_grad_()
        targets = label.cuda()
        loss = loss_func(output.cpu(), targets.cpu(), input_lengths, target_lengths)

        pred_l=np.array(pred_label.cpu().squeeze()).tolist()
        label_l=np.array(targets.cpu().squeeze()).tolist()
        print(i,":",loss,"pred:",toRES(toSTR(pred_l)),"label_l",toSTR(label_l))




def train():
    model = crnn.CRNN(32, 1, n_class, 256)
    if torch.cuda.is_available() and use_gpu:
        model.cuda()

    loss_func = torch.nn.CTCLoss(blank=0,reduction='mean')
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,betas=(0.9, 0.999))

    lossTotal = 0.0
    k = 0
    printInterval = 100
    start_time = time.time()
    loss_list=[]
    total_list=[]
    for epoch in range(max_epoch):
        n=0
        data_list = []
        label_list = []
        label_len=[]
        for i, (data, label) in enumerate(trainLoader):
            #
            data_list.append(data)
            label_list.append(label)
            label_len.append(label.size(1))
            n=n+1
            if n%batch_size!=0:
                continue
            k=k+1
            data=torch.cat(data_list, dim=0)
            data_list.clear()

            label = torch.cat(label_list, dim=1).squeeze(0)
            label_list.clear()

            target_lengths=torch.tensor(np.array(label_len))
            label_len.clear()
            # 开启训练模式
            model.train()


            if torch.cuda.is_available and use_gpu:
                data = data.cuda()
                loss_func = loss_func.cuda()
                label = label.cuda()

            output = model(data)
            log_probs = output
            # example 建议使用这样,貌似直接把output送进去loss fun也没发现什么问题
            #log_probs = output.log_softmax(2).requires_grad_()
            targets = label.cuda()
            input_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))
            # forward(self, log_probs, targets, input_lengths, target_lengths)
            #targets =torch.zeros(targets.shape)
            loss = loss_func(log_probs.cpu(), targets, input_lengths, target_lengths)/batch_size
            lossTotal += float(loss)
            print("epoch:",epoch,"num:",i,"loss:",float(loss))
            loss_list.append(float(loss))
            if k % printInterval == 0:
                print("[%d/%d] [%d/%d] loss:%f" % (
                    epoch, max_epoch, i + 1, len(trainLoader), lossTotal / printInterval))
                total_list.append( lossTotal / printInterval)
                lossTotal = 0.0
                torch.save(model, 'pytorch-crnn.pth')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    plt.figure()
    plt.plot(loss_list)
    plt.savefig("loss.jpg")

    plt.clf()
    plt.figure()
    plt.plot(total_list)
    plt.savefig("total.jpg")
    end_time = time.time()
    print("takes {}s".format((end_time - start_time)))
    return model

if __name__ == '__main__':
    train()

测试结果如下:

最后给一些参考文献:

https://www.cnblogs.com/azheng333/p/7449515.html

https://blog.csdn.net/wzw12315/article/details/106643182

另外给出数据集和我训练好的模型:

链接:https://pan.baidu.com/s/1-jTA22bLKv2ut_1EJ1WMKA?pwd=jvk8

提取码:jvk8

相关推荐
冷眼看人间恩怨10 分钟前
【话题讨论】AI大模型重塑软件开发:定义、应用、优势与挑战
人工智能·ai编程·软件开发
2401_8830410811 分钟前
新锐品牌电商代运营公司都有哪些?
大数据·人工智能
魔道不误砍柴功16 分钟前
Java 中如何巧妙应用 Function 让方法复用性更强
java·开发语言·python
_.Switch40 分钟前
高级Python自动化运维:容器安全与网络策略的深度解析
运维·网络·python·安全·自动化·devops
AI极客菌1 小时前
Controlnet作者新作IC-light V2:基于FLUX训练,支持处理风格化图像,细节远高于SD1.5。
人工智能·计算机视觉·ai作画·stable diffusion·aigc·flux·人工智能作画
阿_旭1 小时前
一文读懂| 自注意力与交叉注意力机制在计算机视觉中作用与基本原理
人工智能·深度学习·计算机视觉·cross-attention·self-attention
王哈哈^_^1 小时前
【数据集】【YOLO】【目标检测】交通事故识别数据集 8939 张,YOLO道路事故目标检测实战训练教程!
前端·人工智能·深度学习·yolo·目标检测·计算机视觉·pyqt
测开小菜鸟2 小时前
使用python向钉钉群聊发送消息
java·python·钉钉
Power20246662 小时前
NLP论文速读|LongReward:基于AI反馈来提升长上下文大语言模型
人工智能·深度学习·机器学习·自然语言处理·nlp
数据猎手小k2 小时前
AIDOVECL数据集:包含超过15000张AI生成的车辆图像数据集,目的解决旨在解决眼水平分类和定位问题。
人工智能·分类·数据挖掘