LSTM 模型“实现”整数相加运算

文章简介

本文主要介绍了使用 LSTM 模型完成简单的两个整数相加的运算。

数据准备

为了满足模型训练的需要,应该准备 50000 条样本,每个样本包含 query 字符串ans 字符串,如下所示:

makefile 复制代码
query:52+758
ans: 810

我们这里限定了加法运算的两个整数都是 1-999 的任意一个整数,所以 query 的长度最长为 7 (两个最大的三位数和一个加号组成的字符串长度),ans 的长度最长为 4 ,如果长度不足,则在后面用空格补齐。关键代码如下:

scss 复制代码
while len(questions) < TRAINING_SIZE:
    f = lambda: int("".join(np.random.choice(list("1234567890")) for _ in range(np.random.randint(1, DIGITS + 1))))
    a, b = f(), f()
    ...
    q = "{}+{}".format(a, b)
    query = q + " " * (MAXLEN - len(q))
    ans = str(a + b)
    ans += " " * (DIGITS + 1 - len(ans))
    questions.append(query)
    expected.append(ans)

样本创建好之后,需要对样本进行向量化处理,也就是将每个字符都转换成对应的 one-hot 表示,因为每个样本的 query 长度为 7 ,字符集合长度为 12 ,所以每个 query 改成 [7,12] 的 one-hot 向量;每个样本的 ans 长度为 4 ,字符集合长度为 12 ,所以每个 ans 改成 [4,12] 的 one-hot 向量。关键代码如下:

scss 复制代码
x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool)
y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool)
for i, sentence in enumerate(questions):
    x[i] = ctable.encode(sentence, MAXLEN)
for i, sentence in enumerate(expected):
    y[i] = ctable.encode(sentence, DIGITS + 1)

模型搭建

模型结构很简单,主要使用了 LSTM 层、RepeatVector 层、 Dense 层,都是基础知识,不做过多解释,编译模型时候设置损失函数为categorical_crossentropy,优化器为 adam 优化器,评估指标为准确率accuracy 关键代码如下:

ini 复制代码
model = keras.Sequential()
model.add(layers.LSTM(128, input_shape=(MAXLEN, len(chars))))
model.add(layers.RepeatVector(DIGITS + 1))
model.add(layers.LSTM(128, return_sequences=True))
model.add(layers.Dense(len(chars), activation="softmax"))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

模型训练

选取 90% 的样本为训练集, 10% 的样本为测试集,下面是模型训练的日志打印:

yaml 复制代码
Iter 1
1407/1407 [==============================] - 11s 6ms/step - loss: 1.7796 - accuracy: 0.3499 - val_loss: 1.5788 - val_accuracy: 0.4065
Iter 2
1407/1407 [==============================] - 9s 6ms/step - loss: 1.3928 - accuracy: 0.4762 - val_loss: 1.2489 - val_accuracy: 0.5346
...
Iter 28
1407/1407 [==============================] - 9s 6ms/step - loss: 0.0205 - accuracy: 0.9944 - val_loss: 0.0257 - val_accuracy: 0.9917
Iter 29
1407/1407 [==============================] - 9s 6ms/step - loss: 0.0256 - accuracy: 0.9926 - val_loss: 0.0747 - val_accuracy: 0.9827

效果展示

下面展示了 10 条样本结果,预测正确的有 表示,预测错误的有 表示,可以看出来结果基本正确,最终的验证集准确率能达到 0.9827 。

css 复制代码
Q 537+65  A 602  ☑ 602 
Q 0+998   A 998  ☑ 998 
Q 50+691  A 741  ☑ 741 
Q 104+773 A 877  ☑ 877 
Q 21+84   A 105  ☑ 105 
Q 318+882 A 1200 ☑ 1200
Q 850+90  A 940  ☑ 940 
Q 96+11   A 107  ☒ 907 
Q 1+144   A 145  ☑ 145 
Q 809+4   A 813  ☑ 813 

参考

github.com/wangdayaya/...

相关推荐
Dekesas969518 小时前
【深度学习】基于Faster R-CNN的黄瓜幼苗智能识别与定位系统,农业AI新突破
人工智能·深度学习·r语言
哥布林学者20 小时前
吴恩达深度学习课程四:计算机视觉 第二周:经典网络结构 (三)1×1卷积与Inception网络
深度学习·ai
鼾声鼾语20 小时前
matlab的ros2发布的消息,局域网内其他设备收不到情况吗?但是matlab可以订阅其他局域网的ros2发布的消息(问题总结)
开发语言·人工智能·深度学习·算法·matlab·isaaclab
【建模先锋】1 天前
特征提取+概率神经网络 PNN 的轴承信号故障诊断模型
人工智能·深度学习·神经网络·信号处理·故障诊断·概率神经网络·特征提取
轲轲011 天前
Week02 深度学习基本原理
人工智能·深度学习
smile_Iris1 天前
Day 40 复习日
人工智能·深度学习·机器学习
深度学习实战训练营1 天前
TransUNet:Transformer 成为医学图像分割的强大编码器,Transformer 编码器 + U-Net 解码器-k学长深度学习专栏
人工智能·深度学习·transformer
火山kim1 天前
经典论文研读报告:DAGGER (Dataset Aggregation)
人工智能·深度学习·机器学习
Coding茶水间1 天前
基于深度学习的水果检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·机器学习·计算机视觉
studytosky1 天前
深度学习理论与实战:反向传播、参数初始化与优化算法全解析
人工智能·python·深度学习·算法·分类·matplotlib