基于矩阵乘积态的生成模型:量子力学与生成任务的结合

🏡作者主页:点击!

🤖编程探索专栏:点击!

⏰️创作时间:2024年12月23日11点02分


神秘男子影,

秘而不宣藏。

泣意深不见,

男子自持重,

子夜独自沉。


文章源链接(有视频):Aspiringcode - 编程抱负 即刻实现传知代码只专注开箱即用的代码https://www.aspiringcode.com/content?id=17212664745068&uid=c77f3951d4d6474f97aa889c88064311

概述

生成模型,通过从数据中学习联合概率分布并据此生成样本,是机器学习和人工智能中的一个重要任务。受量子物理学中概率解释的启发,该文章提出了一种使用矩阵积状态的生成模型,这是一种最初用于描述(特别是一维)纠缠量子态的张量网络。其模型享有类似于密度矩阵重正化群方法的高效学习能力,该方法允许动态调整张量的维度,并提供了一种高效的直接采样方法用于生成任务。本文试图复现该文章的工作,利用该文章的思想,方法去实现MNIST手写数字的生成任务。

  • Han Z-Y, Wang J, Fan H, et al. Unsupervised Generative Modeling Using Matrix Product States[J]. Physical Review X, 2018, 8(3): 031012.

演示效果

方法

量子力学的概率解释自然地建议使用量子态来建模数据分布。假设我们将概率分布编码到一个量子波函数Ψ(v)Ψ(v ) 中,测量会使其坍缩并生成一个结果 v=(v1,v2,...,vN)v =(v 1,v 2,...,vN ),其概率与 ∣Ψ(v)∣2∣Ψ(v)∣2 成正比。受到量子力学生成特性的启发,通过以下方式表示模型概率分布:

其中值得说明的是Ψ(v1,v2,...,vN)Ψ(v 1,v 2,...,v N )和P(v1,v2,...,vN)P(v 1,v 2,...,v N )都是张量(tensor),因为他受制于多个不同的指标v1,v2,...,vNv 1,v 2,...,v N ,此外ZZ 作为配分函数代表的是归一化系数,即Z=∑vi∣Ψ(v1,v2,...,vN)∣2Z =∑v i ∣Ψ(v 1,v 2,...,v N)∣2.

如何对∣Ψ(v1,v2,...,vN)∣2∣Ψ(v 1,v 2,...,vN)∣2进行合适的建模使得模型既不复杂,又在一定程度上能够表示更多不同种类的构型成为现在需要解决的问题。许多已经开发的表示方法和算法可以用于高效的概率建模。在这里,我们使用矩阵积状态(MPS)对波函数进行参数化:

上面的图示意思为,左边是我们需要表示的波函数,线代表它依赖的指标(或者变量),右边则是对应的MPS表示,两个方括号直接的连线代表求和,即将对应的指标(或者变量求和,类似于矩阵的乘积)进行收缩。我们可以看出我们把一个复杂的波函数变成了有限个3指标张量的收缩。

实现

导入训练集(MNIST)

1000张MNIST图像已存储为mnist784_bin_1000.npy。每张图像包含 n=28×28=784n =28×28=784 个像素,每个像素的取值为0或1。每张图像被视为维度为 2n2n 的希尔伯特空间中的一个乘积态。

n = 784 
m = 1000
data = np.load("mnist784_bin_1000.npy").astype(np.int32)
data = data[:m,:]
data = torch.LongTensor(data)\
plt.figure(figsize=(10,2))
imgs = data.cpu().reshape([-1,28,28])
_, ax = plt.subplots(2,10)
for i in range(2): 
    for j in range(10):
        index = i * 2 + j
        if(a >= imgs.shape[0]):
            break
        ax[i][j].imshow(imgs[index,:,:],cmap='bone')
        ax[i][j].set_xticks([])
        ax[i][j].set_yticks([])
plt.show()

这可以让我们观察以下MNIST数据集的样子

定义MPS

现在我们要构造一个初始的MPS, 根据上面的阐述,我们的MPS是由一系列3指标的张量的所构成的,如下所示

解释一下,χχ表示的是最大截断指标,也就是为了控制整个MPS的维度,防止维度灾难;左右两边的11可以认为是一个空指标,即没有任何含义,只是为了在运算过程中方便,实际是没有这个指标也可以;下面的22表示的是这些指标的维度是2,因为我们只是考虑的像素点只有黑白(1,0)两种状态,也因此下面的22的数目需要28∗28=78428∗28=784个。

chi = 30 
mydevice = 'cuda' if torch.cuda.is_available() else torch.device("cpu")
print(mydevice)
data = data.to(mydevice)
bond_dims = [chi for i in range(n-1)]+[1]
tensors= [ torch.randn(bond_dims[i-1],2,bond_dims[i],device=mydevice) for i in range(n)]

我们可以输出从而看到这些张量的输出维度

概率计算

概率计算可以遵循前面的Born公式,即

在这里,带有一个小边(常称之为脚)是一个向量,代表的是对应像素的状态,是一个二维向量,用来表示对应的像素是黑还是白

现在难以计算的是配分函数,即

这个东西,这涉及到张量网络的缩并,在张量网络这个领域中由非常多的缩并方式,一个常用的方法是正交化,即把MPS右边的那些三阶张量全部正交化使得他们收缩刚好是一个单位张量。这个过程如下

通过不断的对左边的张量作用QR分解从而使得左边张量全部正交化(黄色的)。据此我们可以计算出对应的波函数

def getPsi():
    psi = torch.ones([m, 1, 1], device=mydevice)
    for site in range(n):
        selected_tensor = tensors[site][:, data[:, site], :].permute(1, 0, 2)
        psi = torch.matmul(psi, selected_tensor)
    return psi

生成图片

生成图片的过程可以采用条件概率的方法,即先采样一个边缘概率,再从这个边缘概率对应的变量继续采样,重复这个过程即可

核心代码为

def generateSamples(batch):
    n = 784
    samples = torch.zeros([batch, n],device=mydevice)
    for site in range(n - 1):
        orthogonalize(site, True) 
    for s in range(batch):
        vec = torch.ones(1,1,device=mydevice)
        for site in range(n-1, -1, -1):
            vec = (tensors[site].view(-1, bond_dims[site]) @ vec).view(-1, 2)
            p0 = vec[:, 0].norm()**2 / (vec.norm()**2)
            x = (0 if np.random.rand() < p0 else 1)
            vec = vec[:, x]
            samples[s][site] = x
    return samples

训练

因此我们可以根据该模型去训练一个生成模型,损失函数是交叉熵损失即将模型的概率与图片本身分布的概率相拟合。交叉熵损失函数为

除以m就相当于是平均。

ψ′ψ ′和Z′Z′是比较好得到,因为对整体的MPS对各个参数是分立。比如说对第100个张量的里面的参数求梯度,那么其他的张量相对于第100个来说都是常数,这跟对矩阵的乘积求梯度是一样的,并没有参数的过度耦合,因此训练过程中核心代码为

for i in ...
      ....
      gradients[:, i, :] = torch.sum(left_vec.permute(0, 2, 1) @ right_vec.permute(0, 2, 1) / psi, 0) 
gradients = 2.0 * (gradients / m - tensors[site])  
tensors[site] += learning_rate * gradients/gradients.norm()

使用方式

  • jupyter notebook 运行

安装依赖

  • Python 3.11.4
  • torch 2.0.1
相关推荐
没学上了2 小时前
加强版十六章视频读写
opencv·计算机视觉
MUTA️3 小时前
专业版pycharm与服务器连接
人工智能·python·深度学习·计算机视觉·pycharm
四口鲸鱼爱吃盐4 小时前
Pytorch | 利用IE-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python·深度学习·计算机视觉
红色的山茶花5 小时前
YOLOv9-0.1部分代码阅读笔记-loss_tal_dual.py
笔记·深度学习·yolo
呆头鹅AI工作室6 小时前
基于特征工程(pca分析)、小波去噪以及数据增强,同时采用基于注意力机制的BiLSTM、随机森林、ARIMA模型进行序列数据预测
人工智能·深度学习·神经网络·算法·随机森林·回归
huhuhu15326 小时前
第P4周:猴痘病识别
图像处理·python·深度学习·cnn
一勺汤7 小时前
YOLO11改进-注意力-引入自调制特征聚合模块SMFA
人工智能·深度学习·算法·yolo·目标检测·计算机视觉·目标跟踪
白兔12057 小时前
联邦大模型微调
人工智能·深度学习
Stara05119 小时前
基于YOLOV5+Flask安全帽RTSP视频流实时目标检测
python·yolo·目标检测·flask
数据分析能量站9 小时前
神经网络-ResNet
人工智能·深度学习·神经网络