text2sql方法:RESDSQL和DAIL-SQL

之前介绍了text2sql的综述,但是对一些方法的描述不够详细,所以将一些感兴趣的方法思路也整理一下。

RESDSQL

RESDSQL出自2023年2月的论文《RESDSQL: Decoupling Schema Linking and Skeleton Parsing for Text-to-SQL》(github)。它使用seq2seq PLM(pre-trained language model)模型来将自然语言问题翻译成SQL。为了提高准确率,将Schema Linking和Skeleton Parsing解耦,具体做法是先进行Schema Linking使得输入模型encoder的schema只包括与问题最相关的表和列;在生成SQL时,模型decoder先生成SQL骨架再生成实际的SQL查询。

RESDSQL是++R++ anking-enhanced ++E++ ncoding plus a ++S++ keleton-aware ++D++ ecoding framework for Text-to-++SQL++ 的简称,论文图2示意了该方法的思路。

在介绍RESDSQL的实现细节之前,先说明一些符号和概念。将一个关系数据库记为 D \mathcal{D} D,数据库的schema S \mathcal{S} S包括:

  • 数据表集合 T = { t 1 , t 2 , ⋯   , t N } \mathcal{T} = \{ t_1, t_2, \cdots, t_N\} T={t1,t2,⋯,tN}
  • 数据列集合 C = { c 1 1 , ⋯   , c n 1 1 , c 1 2 , ⋯   , c n 2 2 , ⋯   , c 1 N , ⋯   , c n N N } \mathcal{C} = \{ c^1_1, \cdots, c^1_{n_1}, c^2_1, \cdots, c^2_{n_2}, \cdots, c^N_1, \cdots, c^N_{n_N}\} C={c11,⋯,cn11,c12,⋯,cn22,⋯,c1N,⋯,cnNN}。每一个数据列都与一个数据表关联, c n i i c^i_{n_i} cnii是第i个表格的第 n i n_i ni列。
  • 外键关系集合 R = { ( c k i , c h j ) ∣ c k i , c h j ∈ C } \mathcal{R}=\{(c^i_k, c^j_h)|c^i_k, c^j_h \in \mathcal{C} \} R={(cki,chj)∣cki,chj∈C}, 其中的每一对 ( c k i , c h j ) (c^i_k, c^j_h) (cki,chj)表示列 c k i c^i_k cki和 c h j c^j_h chj之间存在外键关系。
  • 用 M = ∑ i = 1 N n i M=\sum^N_{i=1} n_i M=∑i=1Nni表示数据库 D \mathcal{D} D中的所有的列的个数。
  • 在表示一个schema的时候,可以用原始名字(在数据库中表示的名字)或语义名字(原始名字实际代表的语义含义),比如在论文的图1中,"uid"是原始名字,而"airline id"是语义名字,用语义名字相对比原始名字能更清晰地理解列数据的含义。注意有时候语义名字和原始名字是一样的,比如图示中的country。

Ranking-Enhanced Encoder

为了使RESDSQL模型的encoder输入只包括与问题最相关的schema元素,RESDSQL使用了一个cross-encoder模型来对数据表和数据列进行分类,然后基于分类概率来排序并过滤掉不相关的schema元素。这样做一方面可以过滤掉不相关的schema元素,另一方面可以输入seq2seq encoder的schema元素是排过序,模型可以捕捉到潜在的位置信息。

cross-encoder模型的输入 : 将schema元素按照其默认顺序组成一个schema元素序列,并将它和问题拼接起来组成cross-encoder模型的输入 X X X, X = q ∣ t 1 : c 1 1 , ⋯   , c n 1 1 ∣ ⋯ ∣ t N : c 1 N , ⋯   , c n N N X=q|t_1:c^1_1,\cdots, c^1_{n_1}|\cdots|t_N:c_1^N, \cdots, c^N_{n_N} X=q∣t1:c11,⋯,cn11∣⋯∣tN:c1N,⋯,cnNN ,|是分隔符,用来分隔问题和数据表。为了更好地表示数据元素的语义信息,这里使用的是它们的语义名称。

编码模块(Encoding Module) :使用RoBERTa作为cross-encoder模型。因为每一个schema元素可能被模型的分词器分成1个或多个token,比如数据列"airline id"会被分成两个token:"airline" 和"id"。而我们会希望在分类时将每一个schema元素作为整体,论文的解决思路使用了一个包含两层的BiLSTM和一个非线性全连接层作为pooling方法。经过pooling之后,每一个数据表的embedding可表示为 T i ∈ R 1 × d ( 1 ∈ { 1 , ... , N } ) \mathbf{T}_i \in \mathbb{R}^{1\times d} (1 \in \{ 1, \ldots, N\}) Ti∈R1×d(1∈{1,...,N}), 每一个数据列的embedding可表示为 C k i ∈ R 1 × d ( 1 ∈ { 1 , ... , N } , k ∈ { 1 , ... , n i } ) \mathbf{C}_k^i \in \mathbb{R}^{1\times d} (1 \in \{ 1, \ldots, N\}, k\in \{1, \ldots, n_i\}) Cki∈R1×d(1∈{1,...,N},k∈{1,...,ni}), d是隐藏层的大小。

列增强层(Column-Enhanced Layer) :有些问题里只会提到相关列名不会提到表名,比如论文图1中的例子提到了列名"city",但是没有提到表名"airports"。这种表名在问题中缺失的问题可能会影响表分类性能,所以论文作者提出了一个列增强层来将列信息注入到它对应的表embedding中,这样即使问题中只提到列名也能将对应的表给识别出来。列增强层通过multi-head scaled dot-product attention layer和一个特征融合层来实现,设对于第 i i i个表,其所有列信息表示为 C : i ∈ R n i × d \mathbf{C}^i_: \in \mathbb{R}^{n_i \times d} C:i∈Rni×d, 将其通过下式注入到表embeding T i \mathbf{T}_i Ti:
T i C = M u l t i H e a d A t t n ( T i , C : i , C : i , h ) , ( 1 ) T ^ i = N o r m ( T i + T i C ) \begin{aligned} \mathbf{T}i^C &= MultiHeadAttn(\mathbf{T}i, \mathbf{C}^i:, \mathbf{C}^i:, h), \qquad (1) \\ \hat{\mathbf{T}}_i &=Norm(\mathbf{T}_i + \mathbf{T}^C_i) \end{aligned} TiCT^i=MultiHeadAttn(Ti,C:i,C:i,h),(1)=Norm(Ti+TiC)

在上式中 T i \mathbf{T}i Ti作为self-attetion里的query, C : i \mathbf{C}^i: C:i同时作为key和value, h是head的个数, N o r m ( ⋅ ) Norm(\cdot) Norm(⋅) row-wise L 2 L_2 L2归一化函数。通过将原来的表embedding T i \mathbf{T}_i Ti 和列注意力机制表embedding T i C \mathbf{T}_i^C TiC一起获得列增强表embedding T ^ i ∈ R 1 × d \hat{\mathbf{T}}_i \in \mathbb{R}^{1 \times d} T^i∈R1×d。

Cross-Encoder的损失函数 :因为一个SQL查询通常只会包括数据库中的少量数据表和数据列,训练集的标签分布是非常不均匀的。所以论文使用focal loss作为分类损失。cross-encoder的损失函数是多任务学习方式,包括数据表分类损失和数据列分类损失:
L 1 = 1 N ∑ i = 1 N F L ( y i , y i ^ ) + 1 M ∑ i = 1 N ∑ k = 1 n i F L ( y k i , y k i ^ ) \mathcal{L}1 = \frac{1}{N}\sum^{N}{i=1}FL(y_i, \hat{y_i}) + \frac{1}{M}\sum^{N}{i=1}\sum^{n_i}{k=1}FL(y^i_k, \hat{y^i_k}) L1=N1i=1∑NFL(yi,yi^)+M1i=1∑Nk=1∑niFL(yki,yki^)

上式中的FL是focal loss 函数, y i y_i yi是第i个表格的真实标签。 y i = 1 y_i=1 yi=1表示这个表被SQL查询引用了, y i = 0 y_i=0 yi=0表示这个表没有被SQL查询引用。 y k i y^i_k yki是第i个表格的第k列的真实标签。 y k i = 1 y^i_k=1 yki=1表示这一列被SQL查询引用了, y k i = 0 y^i_k=0 yki=0表示这一列没有被SQL查询引用。 y ^ i \hat{y}_i y^i和 y ^ k i \hat{y}^i_k y^ki是预测概率,由基于数据表embedding T ^ i \hat{\mathbf{T}}_i T^i和数据列embedding C k i \mathbf{C}^i_k Cki的两个不同MLP模块估计得到:
y ^ i = σ ( ( T ^ i U 1 t + b 1 t ) U 2 t + b 2 t ) y ^ k i = σ ( ( C k i U 1 c + b 1 c ) U 2 c + b 2 c ) \begin{aligned} \hat{y}_i & =\sigma\left(\left(\hat{\boldsymbol{T}}_i \boldsymbol{U}_1^t+\boldsymbol{b}_1^t\right) \boldsymbol{U}_2^t+\boldsymbol{b}_2^t\right) \\ \hat{y}_k^i & =\sigma\left(\left(\boldsymbol{C}_k^i \boldsymbol{U}_1^c+\boldsymbol{b}_1^c\right) \boldsymbol{U}_2^c+\boldsymbol{b}_2^c\right) \end{aligned} y^iy^ki=σ((T^iU1t+b1t)U2t+b2t)=σ((CkiU1c+b1c)U2c+b2c)

上式中的 U 1 t , U 1 c ∈ R d × w \boldsymbol{U}_1^t, \boldsymbol{U}_1^c \in \mathbb{R}^{d \times w} U1t,U1c∈Rd×w, b 1 t , b 1 c ∈ R w \boldsymbol{b}_1^t, \boldsymbol{b}_1^c \in \mathbb{R}^{ w} b1t,b1c∈Rw, U 2 t , U 2 c ∈ R w × 2 \boldsymbol{U}_2^t, \boldsymbol{U}_2^c \in \mathbb{R}^{w \times 2} U2t,U2c∈Rw×2, b 2 t , b 2 c ∈ R 2 \boldsymbol{b}_2^t, \boldsymbol{b}_2^c \in \mathbb{R}^{2} b2t,b2c∈R2是可训练参数, σ ( ⋅ ) \sigma(\cdot) σ(⋅)表示Softmax函数。

seq2seq模型Encoder的输入准备 :在seq2seq模型推理时,用前面训练好的cross-encoder为每一个schema元素计算概率并排序,只保留数据库中的 top k 1 \text{top}\ k_1 top k1个数据表,并为每个表给保留 top k 2 \text{top}\ k_2 top k2个数据列。 k 1 k_1 k1和 k 2 k_2 k2是两个超参数,太小可能将必须的表或列给排除掉了,太大可能会包含一些不必要的schema元素(论文在后续的实验中对数据集进行统计后取了 k 1 = 4 k_1=4 k1=4, k 2 = 5 k_2=5 k2=5)。如论文图2所示,encoder的输入是由问题、排序后的schema元素,可选的外键关系拼接得到。在schema元素中使用的是原始名称,这样decoder可以直接成输入序列中拷贝所需的schema元素。(另外在论文实验时,还使用了与BRIDGE的方法从数据库中提取可能有用的内容来丰富数据列信息)

Skeleton-Aware Decoder

seq2seq的decoder将SQL生成分为两步:1. 基于问题的语义生成SQL骨架;2.从输入序列中选择所需"数据"(数据表、数据列、数据取值)来填充骨架中的槽位。

目标函数 :为了在不引入额外模块的前提下实现上述分解思想,论文提出了一个新的生成目标,在生成第t个token时不仅依赖于encoder的输出,同时依赖decode在t个时间步之前的输出。也就是不是直接解码得到目标SQL,鼓励解码器先解码出SQL语句的骨架后再解码SQL语句。论文作者认为先解析骨架再解析SQL语句,在每一个解码步骤,SQL生成更容易,因为解码器可以直接从输入序列中拷贝"数据"或者从之前解析的骨架中拷贝SQL关键字:
L 2 = 1 G ∑ i = 1 G p ( l i s , l i ∣ S i ) \mathcal{L}2 = \frac{1}{G} \sum^{G}{i=1} p(l^s_i, l_i |S_i) L2=G1i=1∑Gp(lis,li∣Si)

上式的G是Text2SQL实例数目, S i S_i Si是第i个实例的输入序列(包括问题、排序后的schema元素,可选的外键关系)。 l i l_i li是第i个目标SQL语句, l i s l^s_i lis是从 l i l_i li抽取到的骨架。

SQL归一化:Spider数据集是由有不同标注习惯的标注员创建的,所以标注的最终SQL有不同的风格。为了减少模型学习的难度,将原始SQL语句在训练前进行如下操作的归一化,论文的表1有一个归一化例子。

  • 将所有的关键字和schema元素转小写
  • 在圆括号附近添加空格,将双引号替换成单引号
  • 如果ORDER BY语句后面没有指定排序方式,在后面添加ASC关键词
  • 将所有的AS语句去掉,并将所有表的别名用他们的原始名字替换。

SQL骨架提取:基于归一化的SQL语句,提取只包括SQL关键词和槽位的骨架,即对于一个归一化SQL语句,只保留它的SQL关键字并将其余部分用槽位标识代替。 另外不保留JOIN ON关键词因为它很难从自然语言问题中找到对应的部分。

Execution-Guided SQL Selector:因为在解码时没有约束SQL语法,使用了与以前研究相同的Execution-Guided SQL Selector,即在解码过程中进行beam search并选择结果中的第一个可以执行的SQL作为最终的结果。(beam search 在论文实验中为8)

DAIL-SQL

DAIL SQL出自2023年8月的论文《Text-to-SQL Empowered by Large Language Models: A Benchmark Evaluation》,是一种通过prompt GPT-4来实现Text2SQL的方法。论文先对已有的prompt LLM来实现Text2SQL的提示工程(prompt engineering)进行了调研和比较,然后提出名为DAIL SQL的方法来prompt LLM生成SQL。

论文将text2sql的提示工程分成question representation(prompt里问题和数据库表的表示方式)、example selection(如何选择有效text2sql例子)、example organization(例子在prompt中的组织方式)。

DAIL-SQL的思路的伪代码如论文算法1,下面介绍DAIL-SQL的prompt是如何得到的。

  • question representation:将数据库表用SQL语法语法形式来表示(论文中称其为Code Representation Prompt( CR P \text{CR}_P CRP)),如论文的Listing 4所示意。注意这个表示方式里有将数据表的外键特意标记出来,让LLM在预测"JOIN"操作时更容易。

  • example selection: 在选择text2sql例子时,同时考虑自然语言问题和SQL查询的相似度。

    1. 先将目标问题 q q q和候选数据集 Q \mathcal{Q} Q的每一个问题 q i q_i qi中的领域相关的词(表名、列名、数据取值)都用一个mask token给替换掉(用与RAT-SQL一样的n-gram匹配方法来进行schema linking找到与数据库有关的词。将数据表和列名使用""来替换,将数据取值用""替换)。再计算mask后的 q q q和 q i q_i qi的embedding相似度,并按照它们的相似度大小来对数据集 Q \mathcal{Q} Q倒序排序。
    2. 用一个基础模型基于问题 q q q和数据库 D \mathcal{D} D预测出一个pre-predicted SQL查询 s ′ s^{'} s′(论文使用的基础模型是Graphix),然后使用RESDSQL里的方法对SQL查询提取骨架,比较提取骨架后的 s ′ s^{'} s′和候选数据集 Q \mathcal{Q} Q里的每一个SQL查询 s i s_i si的Jaccard相似度。
    3. 将第二步得到的相似度大于阈值 τ \tau τ的候选集保留。(论文实验里的阈值是0.9,提交到Spider时的阈值是0.85)
    4. 选择top k个例子作为text2sql的例子。(因为之前已经根据问题相似度排序过了,再卡相似度阈值来过滤sql查询例子,并选择top-k,就是同时考虑了自然语言问题和SQL查询的相似度)
  • example organization: text2sql例子的组织方式如论文的Listing 8所示。

  • DAIL-SQL还尝试了self-consistency方法,可以使得执行准确率提高0.4%,但时间效率和token效率都低很多。

相关推荐
bastgia2 小时前
Tokenformer: 下一代Transformer架构
人工智能·机器学习·llm
新智元7 小时前
李飞飞谢赛宁:多模态 LLM「空间大脑」觉醒,惊现世界模型雏形!
人工智能·llm
RWKV元始智能12 小时前
RWKV-7:极先进的大模型架构,长文本能力极强
人工智能·llm
七夜星七夜月1 天前
时间序列预测论文阅读和相关代码库
论文阅读·python·深度学习
zaim11 天前
计算机的错误计算(一百八十七)
人工智能·ai·大模型·llm·错误·正弦/sin·误差/error
张拭心1 天前
Google 提供的 Android 端上大模型组件:MediaPipe LLM 介绍
android·人工智能·llm
带电的小王1 天前
whisper.cpp: Android端测试 -- Android端手机部署音频大模型
android·智能手机·llm·whisper·音频大模型·whisper.cpp
WenBoo-1 天前
HIPT论文阅读
论文阅读
chnyi6_ya1 天前
论文笔记:Buffer of Thoughts: Thought-Augmented Reasoning with Large Language Models
论文阅读·人工智能·语言模型
带电的小王2 天前
whisper.cpp: PC端测试 -- 电脑端部署音频大模型
llm·whisper·音视频·音频大模型