基于扩散模型的表单插补

任务介绍

表单数据中的缺失值可能会影响到后续的数据建模、统计分析以及决策制定等环节。表单插补任务是指在表单数据中,当某些字段存在缺失值 时,使用方法对缺失值进行填补 。填补后得到的完整数据可以提高分析结果的准确性和可靠性,避免因缺失值而导致的偏差。

常见表单插补方法

  1. 均值插补
    • 对于数值型字段,可以使用该字段的均值来填补缺失值。例如,在一组学生成绩数据中,如果某个学生的数学成绩缺失,可以用所有学生数学成绩的平均值来进行插补。
  2. 中位数插补
    • 同样适用于数值型字段。当数据中存在极端值时,中位数插补可能比均值插补更稳健。
  3. 众数插补
    • 对于分类型字段,可以使用该字段的众数进行插补。比如在一组关于颜色喜好的调查数据中,如果某些人的喜好颜色缺失,可以用出现次数最多的颜色作为插补值。
  4. 回归插补
    • 利用其他相关字段建立回归模型,预测缺失值。例如,根据房屋的面积、位置等因素建立房价的回归模型,当某个房屋的房价缺失时,可以通过该模型进行预测插补。

模型介绍

CSDI模型

CSDI(conditional score-based diffusion model)模型[1]是由斯坦福大学学者提出应用于时间序列数据插补,插补流程如图1所示。

图1 时间序列插补流程

其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> c o co </math>co和 <math xmlns="http://www.w3.org/1998/Math/MathML"> t a ta </math>ta分是conditionaltarget的缩写。

扩散模型就不多介绍了,可以看我之前关于扩散模型的记录。因为插补数据对应的标签是缺失的,作者提出了一种自监督的训练流程,如图2所示。

图2 自监督训练流程

粗糙来讲,就是在已有观测数据上做插补。具体来讲,训练,验证的时候,绿色表示观测数据,作者在观测数据中取一部分当插补目标,其他当作条件值;随后对插补目标进行加噪,用 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ \epsilon_{\theta} </math>ϵθ预测噪声,还原目标值。测试的时候,在真正缺失数据上做插补。

由于是序列数据,作者分别使用了时间Transformer Layer和特征Transformer Layer来学习特征,如图3所示。

图3 注意力架构

其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K是特征数, <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L是时间步数, <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C是通道数。

CSDI_T模型

CSDI_T模型[2]是由东京大学学者提出用于处理表单数据缺失值。与CSDI模型非常类似,作者将其拓展到表单数据,并且兼容类别变量。

由于模型与CSDI非常类型,就不介绍了。作者提出了三种方法处理类别变量,分别是独热编码,模拟比特编码和特征标记化(Tokenization)。其中,二位的模拟比特可以编码四种状态,优点是编码长度短,缺点是复杂。说说特征标记化吧,源码如下所示。

python 复制代码
# partially stole from https://github.com/Yura52/tabular-dl-revisiting-models/blob/main/bin/ft_transformer.py
class Tokenizer(nn.Module):
    def __init__(self, d_numerical: int, categories: ty.Optional[ty.List[int]], d_token: int, bias: bool,)-> None:
        super().__init__()
        d_bias= d_numerical+ len(categories)
        # 9是指该特征有9个取值
        # [0]+ [9, 16, 8, 15, 6, 5, 2, 42]
        # [  0,   9,  25,  33,  48,  54,  59,  61, 103]
        category_offsets= torch.tensor([0]+ categories[:-1]).cumsum(0)
        # 8
        self.d_token= d_token
        self.register_buffer("category_offsets", category_offsets)
        self.category_embeddings= nn.Embedding(sum(categories)+ 1, self.d_token)
        self.category_embeddings.weight.requires_grad= False
        nn_init.kaiming_uniform_(self.category_embeddings.weight, a= math.sqrt(5))
        self.weight= nn.Parameter(Tensor(d_numerical, self.d_token))
        self.weight.requires_grad= False
        self.bias= nn.Parameter(Tensor(d_bias, self.d_token)) if bias else None
        nn_init.kaiming_uniform_(self.weight, a= math.sqrt(5))
        if self.bias is not None:
            nn_init.kaiming_uniform_(self.bias, a= math.sqrt(5))
            self.bias.requires_grad= False
    @property
    def n_tokens(self)-> int:
        return len(self.weight)+ (0 if self.category_offsets is None else len(self.category_offsets))
    # x_num, (B, 8, 6); x_cat, (B, 8, 9). 
    def forward(self, x_num: Tensor, x_cat: ty.Optional[Tensor])-> Tensor:
        x_some= x_num if x_cat is None else x_cat
        x_cat= x_cat.type(torch.int32)
        assert x_some is not None
        # self.weight.T.shape, (8, 6); x_num, (B, 1, 6);
        # 数值变量直接embedding
        x= self.weight.T* x_num
        # x_cat, (B, 1, 9).
        if x_cat is not None:
            x= x[:, np.newaxis, :, :]
            x= x.permute(0, 1, 3, 2)
            # x_cat, (B, 1, 9); self.category_offsets[None], (1, 9)
            # x, torch.cat([(B, 1, 6, 8), (B, 1, 9, 8)])-> (B, 1, 15, 8)
            # 确定分类变量嵌入的范围,再进行嵌入
            x= torch.cat([x, self.category_embeddings(x_cat + self.category_offsets[None])], dim= 2,)
        if self.bias is not None:
            x= x+ self.bias[None]
        return x
    # d_numerical, 6
    def recover(self, Batch, d_numerical):
        # 128, 120, 1
        B, L, K= Batch.shape
        L_new= int(L/ self.d_token)
        # Batch, (B, 15, 8)
        Batch= Batch.reshape(B, L_new, self.d_token)
        Batch= Batch- self.bias
        Batch_numerical= Batch[:, :d_numerical, :]
        # 对数值变量嵌入进行还原, (B, 9, 8), 除以嵌入取平均
        Batch_numerical= Batch_numerical/ self.weight
        # Batch_numerical, (B, 9, 8)-> (B, 9)
        Batch_numerical= torch.mean(Batch_numerical, 2, keepdim= False)
        # (B, 9, 8)
        Batch_cat = Batch[:, d_numerical:, :]
        # (B, 9)
        new_Batch_cat = torch.zeros([Batch_cat.shape[0], Batch_cat.shape[1]])
        # for each cata feature
        for i in range(Batch_cat.shape[1]):
            # token_start, 1, 10, 26, 34, 49, 55, 60, 62, 104.
            token_start = self.category_offsets[i] + 1
            if i == Batch_cat.shape[1] - 1:
                token_end = self.category_embeddings.weight.shape[0] - 1
            else:
                token_end = self.category_offsets[i + 1]
            # emb_vec, (9, 8).
            emb_vec = self.category_embeddings.weight[token_start : token_end + 1, :]
            # 量化操作(根据嵌入间的距离,为学到的嵌入分配固定嵌入)
            for j in range(Batch_cat.shape[0]):
                distance = torch.norm(emb_vec - Batch_cat[j, i, :], dim=1)
                nearest = torch.argmin(distance)
                new_Batch_cat[j, i] = nearest + 1
            new_Batch_cat = new_Batch_cat.to(Batch_numerical.device)
        return torch.cat([Batch_numerical, new_Batch_cat], dim=1)

其中, d_numericalcategoriesd_tokenbias分别是数值变量的个数,各个类别特征取值数量,词元的嵌入维度和是否使用偏执项。思想是编码时,数值变量直接乘以嵌入,各个分类变量的取值对应不同嵌入范围(一个分类变量)中的一个嵌入(一个取值)。解码时,数值变量直接除以嵌入,计算分类变量嵌入与嵌入范围内所有嵌入(该变量下所有取值)的距离,根据距离指派对应分类变量值

Breast数据集介绍

我们使用的是威斯康星州乳腺癌数据集(Wisconsin Breast Cancer Dataset),总共600多条记录,10个特征。

  1. "Sample code number":这是样本的编号,用于唯一标识每个数据点,在数据分析过程中主要起到索引的作用,方便对特定样本进行追踪和管理。
  2. "Clump Thickness":细胞团的厚度,取值范围为 1 - 10。较厚的细胞团可能暗示着异常的细胞生长,与肿瘤的发展可能相关。
  3. "Uniformity of Cell Size":细胞大小的均匀性,同样取值 1 - 10。均匀的细胞大小通常是正常组织的特征,而不均匀的细胞大小可能是肿瘤组织的表现之一。
  4. "Uniformity of Cell Shape":细胞形状的均匀性,1 - 10 的取值范围。正常组织的细胞形状通常较为一致,肿瘤组织可能会出现细胞形状的多样性。
  5. "Marginal Adhesion":边缘粘连程度,1 - 10 的等级划分。边缘粘连情况可以反映肿瘤细胞与周围组织的关系,对于判断肿瘤的侵袭性有一定的参考价值。
  6. "Single Epithelial Cell Size":单个上皮细胞的大小,1 - 10 的取值。上皮细胞的大小变化可能与肿瘤的类型和发展阶段有关。
  7. "Bare Nuclei":裸核情况,1 - 10 的范围。裸核的出现可能暗示着细胞的异常状态,对肿瘤的诊断有一定的提示作用。
  8. "Bland Chromatin":染色质的平淡程度,1 - 10 的等级。染色质的特征可以反映细胞的生理状态,异常的染色质平淡程度可能与肿瘤相关。
  9. "Normal Nucleoli":正常核仁的情况,1 - 10 的取值。核仁的形态和数量变化可以为肿瘤的判断提供线索。
  10. "Mitoses":有丝分裂的情况,1 - 10 的范围。有丝分裂的活跃程度与肿瘤的生长速度和恶性程度密切相关。

实验结果

下面给出的是在UCI breast数据集上的损失曲线和RMSEMAE曲线。具体实验结果可以去看[2].

图4 训练和验证上的损失曲线

图5 验证集上的RMSE和MAE曲线

完整代码

完整代码见我的github.

参考资料

[1] [2107.03502] CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation

[2] [2210.17128] Diffusion models for missing value imputation in tabular data

[3] Breast Cancer Wisconsin (Original) - UCI Machine Learning Repository

相关推荐
池央10 分钟前
丹摩征文活动 | 搭建 CogVideoX-2b详细教程:用短短6秒展现创作魅力
人工智能
九年义务漏网鲨鱼11 分钟前
【人脸伪造检测后门攻击】 Exploring Frequency Adversarial Attacks for Face Forgery Detection
论文阅读·python·算法·aigc
_OLi_17 分钟前
力扣 LeetCode 977. 有序数组的平方(Day1:数组)
数据结构·算法·leetcode
QYR市场调研22 分钟前
5G时代的关键元件:射频微波MLCCs市场前景广阔
人工智能
AI小白日记32 分钟前
深入探索AutoDL平台:深度学习GPU算力最佳选择
人工智能·深度学习·gpu算力
励志成为嵌入式工程师32 分钟前
c语言选择排序
c语言·算法·排序算法
風清掦36 分钟前
C/C++每日一练:编写一个查找子串的位置函数
c语言·c++·算法
一尘之中43 分钟前
元宇宙及其技术
人工智能
电子手信1 小时前
AI知识库在行业应用中的未来趋势与案例分析
大数据·人工智能·自然语言处理·数据挖掘
A charmer1 小时前
算法每日双题精讲——滑动窗口(最大连续1的个数 III,将 x 减到 0 的最小操作数)
c++·算法·leetcode