读3dsr代码②训练

train_dada

首先初始化权重

python 复制代码
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

他的训练数据是imagenet的rgb,然后利用Perlin 噪声来模拟深度图像

generate_perlin_noise

  1. 初始化旋转变换:

    • rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))]) 创建一个旋转变换,旋转角度在 -90 到 90 度之间。这个变换后续会应用到 Perlin 噪声上,增加噪声图的多样性。
  2. 确定 Perlin 噪声的尺度:

    • perlin_scalexperlin_scaley 分别代表在 x 和 y 方向上的 Perlin 噪声尺度。这些尺度是随机选取的(默认是2的0~6次幂),用于控制噪声的粗糙度或细腻度。
  3. 生成 Perlin 噪声:

    • perlin_noise = rand_perlin_2d_np((resize_shape[0], resize_shape[1]), res=(perlin_scalex, perlin_scaley)) 生成一个二维的 Perlin 噪声图。
      resize_shape=256,256

    rand_perlin_2d_np

    1. 定义 Perlin 噪声的分辨率和形状:

        delta = (res[0] / shape[0], res[1] / shape[1])
        d = (shape[0] // res[0], shape[1] // res[1])
      

      delta 计算 Perlin 噪声网格的间隔
      d 计算 Perlin 每个细胞内重复的次数,以适应最终生成的噪声图像的形状。

    2. 创建网格和梯度:

      • grid 创建一个在 [0, 1) 区间内均匀分布的二维网格,用于计算 Perlin 噪声。
      • grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 创建一个二维数组,其范围在 [0, 1) 内。这是通过 np.mgrid 实现的,该函数生成一个密集的多维"网格",其中每个维度的步长由 delta 确定。
      • delta 的计算基于所需的噪声图像的分辨率和噪声的尺度,确保 grid 覆盖了 [0, 1) 的范围,且分布均匀。
      • 使用 % 1 是为了确保所有值都在 [0, 1) 的范围内,这是因为 Perlin 噪声是在一个规范化的空间内生成的,其中每个点的值应该位于 [0, 1) 的区间内。
      • angles 生成随机角度,用于确定每个网格点的梯度向量。
      • gradients 根据 angles 生成的二维单位向量表示梯度。
      • tt 是在整个噪声图像区域重复 gradients,以确保每个网格细胞内部有相同的梯度向量。
    3. 计算噪声值:

      • dot 函数计算网格点与梯度向量的点积。
      • n00, n10, n01, n11 计算四个角的点积值。
      • t 使用预定义的渐变函数 fade 对网格坐标进行调整,以实现平滑过渡。
      • 最终的噪声值通过插值函数 lerp_np (一个预定义的线性函数)和上述点积值结合,生成整个噪声图。
  4. 应用旋转变换:

    • perlin_noise = rot(image=perlin_noise) 对生成的 Perlin 噪声图应用旋转变换。
  5. 设置阈值并应用:

    • threshold = torch.rand(1).numpy()[0] * beta + beta 计算阈值,用于后续的二值化处理。
    • perlin_thr = np.where(np.abs(perlin_noise) > threshold, np.ones_like(perlin_noise), np.zeros_like(perlin_noise)) 根据阈值二值化处理 Perlin 噪声,生成阈值化后的噪声图 perlin_thr
  6. 生成归一化 Perlin 噪声:

    • norm_perlin = np.where(np.abs(perlin_noise) > threshold, perlin_noise, np.zeros_like(perlin_noise)) 生成归一化的 Perlin 噪声图 norm_perlin,在噪声值低于阈值的地方设为0。

函数最终返回 norm_perlin(归一化 Perlin 噪声)->perlin_norm、perlin_thr(阈值化 Perlin 噪声)、原始的 perlin_noise 和使用的阈值 threshold->p_thr。

随机缩放噪声:

生成一个 [0, 1] 范围内的随机数 beta

image = beta * perlin_noise:使用这个随机数 beta 对 Perlin 噪声进行缩放,模拟不同深度的变化。

随机平移噪声:

生成另一个 [0, 1] 范围内的随机数 beta2。

image = image + (beta2 * (1 - beta)):将缩放后的噪声图进一步平移,以增加深度图的变化性。

裁剪和调整深度图:

image = np.clip(image, 0.0, 1.0) 确保深度值在 [0, 1] 范围内。

image = np.expand_dims(image, 2) 增加一个维度,使图像从二维变为三维。

image = np.transpose(image, (2, 0, 1)) 调整深度图的维度顺序,适配 PyTorch 的要求。

所以这个深度图是通过模拟而非直接从现实世界的深度传感器获取,和rgb也没有半点关系,完全随机出来的

训练过程中,VectorQuantizerEMA的两个实例里多了几步

首先补充说明一下几个类内初始化变量

python 复制代码
self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
这种初始化定义方式是为了准备在后续过程中使用指数移动平均(EMA)来更新聚类大小
register_buffer是nn.Module的一个方法
用于注册一个不需要梯度的缓冲区
这是因为_ema_cluster_size不是模型的参数(不需要学习)
但它是模型的一部分
并且在模型的训练过程中会更新
通过注册为缓冲区
确保在模型保存和加载时_ema_cluster_size也会被保存和加载。
和requires_grad=False的区别可能在于register_buffer的变量一定可以被加载或保存

self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
self._ema_w.data.normal_()

self._decay = decay
self._epsilon = epsilon
python 复制代码
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
从获得quantized 1,48,48,256开始
# Use EMA to update the embedding vectors
if self.training:
    self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                             (1 - self._decay) * torch.sum(encodings, 0)
	这步更新是反映每个嵌入向量当前被选中的频率
	这是通过将当前的_ema_cluster_size乘以衰减因子_decay
	然后加上新的观察值(由encodings的求和得到)来实现的。
	
    # Laplace smoothing of the cluster size
    n = torch.sum(self._ema_cluster_size.data)
    self._ema_cluster_size = (
            (self._ema_cluster_size + self._epsilon)
            / (n + self._num_embeddings * self._epsilon) * n)
    对_ema_cluster_size应用拉普拉斯平滑(加一个小的常数_epsilon)
    以避免任何聚类大小变为零
    这个操作确保了即使某些聚类在当前批次中未被观察到(即其聚类大小为零)
    它们的大小也会被设置为一个小的非零值
    这有助于增加数值稳定性。

    dw = torch.matmul(encodings.t(), flat_input)
    self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)

    self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
    更新嵌入向量,确保了嵌入向量的更新考虑了不同嵌入被选择的频率

# Loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs)

这种基于EMA的更新策略有助于平滑训练过程中嵌入向量的更新,使模型在训练时更稳定,同时减少由于少数嵌入向量频繁更新而导致的过拟合风险。

最后输出就要前仨

输出loss_b, loss_t, recon_out

python 复制代码
loss_vq = loss_b + loss_t

recon_depth = recon_out[:,:1,:,:]
recon_rgb = recon_out[:,1:,:,:]

# Using L1 loss may work better and lead to improved reconstructions
#l2_recon_d_loss = torch.mean((depth_image - recon_depth)**2)
l2_recon_d_loss = torch.mean(torch.abs(depth_image - recon_depth))
#l2_recon_rgb_loss = torch.mean((rgb_image - recon_rgb)**2)
l2_recon_rgb_loss = torch.mean(torch.abs(rgb_image - recon_rgb))
#l2_recon_loss = torch.mean((model_in - recon_out)**2)
l2_recon_loss = torch.mean(torch.abs(model_in - recon_out))
recon_loss = l2_recon_loss + loss_vq
loss = recon_loss

相当于一个mse损失,一个l1范数损失

train_dsr

训练集相比测试集:

  1. 初始化时通过遍历整个训练集,直接定义全局的im_max和im_min,但是其实后面根本没有用上(???这岂不是浪费启动时间)
  2. 没有读取gt(毕竟还是无监督)
  3. plane_mask和生成的perlin噪声相乘得到msk作为getitem的anomaly_mask。
    随机取一个0~1之间的no_anomaly值,如果>0.5,那么anomaly_mask全为0
  4. 最后rgb和深度图分别随机旋转-15, 15之间的角度

训练过程:

  1. 从拼接深度和rgb得到in_image开始,到sub_res_model_hi之前为止,这段过程是不计算梯度的(其实也就是DiscreteLatentModelGroups这部分,但是中间加了些步骤,所以通过model.xxx的形式实现)
  2. ①得到quantized_t之后先经过下面的generate_fake_anomalies_joined函数得到anomaly_embedding_lo,然后分别和quantized_t经过upsample_t得到up_quantized_t和up_quantized_t_real,二者都是16,256,96,96。
    然后这二者再分别依次和enc_b拼接、经过_pre_vq_conv_bot、经过_vq_vae_bot,最终分别得到quantized_b和quantized_b_real,二者都是16,256,96,96。

generate_fake_anomalies_joined

输入zt是16,256,48,48->features,quantized_t是16,256,48,48->embeddings,embedder_lo._embedding.weight是2048, 256->memory_torch_original->memory_torch,anomaly_mask是16,1,384,384->mask,anomaly_strength_lo = (torch.rand(in_image.shape[0]) * 0.90 + 0.10).cuda()(16,)->strength

异常嵌入生成过程详细分析:

  • 初始化随机嵌入矩阵: random_embeddings 初始化为一个16,2304,256的零张量。这个张量用于存储每个位置随机选择的嵌入。

  • features重塑得到inputs 16,48,48,256

  • 遍历每个样本: 对于 embeddings 中的每个样本 k,执行以下操作:

    • 重塑inputs[k]得到flat_input 2304,256
    • 计算距离: 对于给定的样本 k,计算它的每个特征向量 flat_inputmemory_torch 中所有嵌入向量之间的距离distances_b。
    • 确定替换的嵌入数量: 根据 strength[k]->percentage_vectors 这么一个随机的比重确定每个样本要替换的嵌入数量 topk
    • 选择嵌入: 根据距离选择 topk 个最近的嵌入values,并从这些嵌入中随机选择一个来替换原始嵌入,形成新的嵌入 random_embeddings[k]
  • 调整随机嵌入矩阵形状:random_embeddings 重塑并转置,使其形状与输入嵌入 embeddings 一致。16,48,48,256

  • 可选的嵌入打乱: 以随机概率使用 shuffle_patches 函数打乱嵌入块,增加异常嵌入的多样性。

    • 图像分割成块:
      使用 unfold 操作将每个图像分割成大小为 patch_size 的块,不考虑重叠,无填充。结果 u 包含了图像中所有块的展平形式。
    • 打乱块的顺序:
      u 中的每个图像的块进行随机打乱。这里使用了列表推导式和 torch.randperm 来为每个图像生成一个随机索引,以此来打乱块的顺序。
    • 重构图像:
      使用 fold 操作将打乱后的块重新组合成图像。

    返回重新组合后的图像,它的形状与输入相同,但每个图像内部的块已经被随机打乱。

  • 应用异常掩码:

    • 调整掩码大小: 使用 max_pool2dmask 进行下采样,使其空间维度与 embeddings 相匹配,得到anomaly_mask16,1,48,48
    • 合并嵌入: 通过异常掩码将随机异常嵌入和原始嵌入线性组合,生成最终的异常嵌入 anomaly_embedding

函数返回 anomaly_embedding->anomaly_embedding_lo,它在原始特征中引入了模拟的异常,可用于训练模型进行异常检测。16,256,48,48


第二次的输入是zb,quantized_b,embedder_hi._embedding.weight,anomaly_mask,anomaly_strength_hi= (torch.rand(in_image.shape[0]) * 0.90 + 0.10).cuda()

  1. ②然后再把quantized_b和quantized_b_real分别输入给generate_fake_anomalies_joined,分别输出anomaly_embedding和anomaly_embedding_hi_usebot
python 复制代码
use_both = torch.randint(0, 2,(in_image.shape[0],1,1,1)).cuda().float()
use_lo = torch.randint(0, 2,(in_image.shape[0],1,1,1)).cuda().float()
use_hi = (1 - use_lo)
anomaly_embedding_hi_usebot = generate_fake_anomalies_joined(zb_real,
                                                    quantized_b_real,
                                                    embedder_hi._embedding.weight,
                                                    anomaly_mask, strength=anomaly_strength_hi)
anomaly_embedding_lo_usebot = quantized_t
anomaly_embedding_hi_usetop = quantized_b_real
anomaly_embedding_lo_usetop = anomaly_embedding_lo
anomaly_embedding_hi_not_both =  use_hi * anomaly_embedding_hi_usebot + use_lo * anomaly_embedding_hi_usetop
anomaly_embedding_lo_not_both =  use_hi * anomaly_embedding_lo_usebot + use_lo * anomaly_embedding_lo_usetop
anomaly_embedding_hi = (anomaly_embedding * use_both + anomaly_embedding_hi_not_both * (1.0 - use_both)).detach().clone()
anomaly_embedding_lo = (anomaly_embedding_lo * use_both + anomaly_embedding_lo_not_both * (1.0 - use_both)).detach().clone()

anomaly_embedding_hi_copy = anomaly_embedding_hi.clone()
anomaly_embedding_lo_copy = anomaly_embedding_lo.clone()
  1. 开始计算梯度后,下面的部分也有变动
python 复制代码
recon_feat_hi, recon_embeddings_hi, _ = sub_res_model_hi(anomaly_embedding_hi_copy, embedder_hi)
recon_feat_lo, recon_embeddings_lo, _ = sub_res_model_lo(anomaly_embedding_lo_copy, embedder_lo)
这里之前分别输入的是embeddings_hi和embeddings_lo
recon_feat_xx也就是unet部分的输出output

# Reconstruct the image from the anomalous features with the general appearance decoder
up_quantized_anomaly_t = model.upsample_t(anomaly_embedding_lo)
quant_join_anomaly = torch.cat((up_quantized_anomaly_t, anomaly_embedding_hi), dim=1)
recon_image_general = model._decoder_b(quant_join_anomaly)
虽然model.upsample_t和model._decoder_b在torch.no_grad()之外,
但是model并没有计入优化器
所以就算计算了梯度也不会更新它的权重

# Reconstruct the image from the reconstructed features
# with the object-specific image reconstruction module
up_quantized_recon_t = model.upsample_t(recon_embeddings_lo)
quant_join = torch.cat((up_quantized_recon_t, recon_embeddings_hi), dim=1)
recon_image_recon = model_decode(quant_join)

out_mask = decoder_seg(recon_image_recon,recon_image_general)
out_mask_sm = torch.softmax(out_mask, dim=1)

# Calculate losses
loss_feat_hi = torch.nn.functional.mse_loss(recon_feat_hi, quantized_b_real.detach())
loss_feat_lo = torch.nn.functional.mse_loss(recon_feat_lo, quantized_t.detach())
loss_l2_recon_img = torch.nn.functional.mse_loss(in_image, recon_image_recon)
total_recon_loss = loss_feat_lo + loss_feat_hi + loss_l2_recon_img*10


# Resize the ground truth anomaly map to closely match the augmented features
down_ratio_x_hi = int(anomaly_mask.shape[3] / quantized_b.shape[3])
anomaly_mask_hi = torch.nn.functional.max_pool2d(anomaly_mask,
                                                (down_ratio_x_hi, down_ratio_x_hi)).float()
anomaly_mask_hi = torch.nn.functional.interpolate(anomaly_mask_hi, scale_factor=down_ratio_x_hi)
down_ratio_x_lo = int(anomaly_mask.shape[3] / quantized_t.shape[3])
anomaly_mask_lo = torch.nn.functional.max_pool2d(anomaly_mask,
                                                (down_ratio_x_lo, down_ratio_x_lo)).float()
anomaly_mask_lo = torch.nn.functional.interpolate(anomaly_mask_lo, scale_factor=down_ratio_x_lo)
anomaly_mask = anomaly_mask_lo * use_both + (
          anomaly_mask_lo * use_lo + anomaly_mask_hi * use_hi) * (1.0 - use_both)

#anomaly_mask = anomaly_mask * anomaly_type_sum
# Calculate the segmentation loss
segment_loss = loss_focal(out_mask_sm, anomaly_mask)
Focal Loss主要用于解决类别不平衡的问题
在像素级的异常检测任务中,
"类别不平衡"问题通常是指异常区域像素与正常区域像素之间的不平衡。
它通过减少对易分类对象的关注(通过降低它们的损失贡献)
来提高模型对困难或少见类别的关注度。
Focal Loss 的公式是-1 * alpha * (1 - pt)^gamma * log(pt)
其中 pt 是模型对正确类别的预测概率
alpha 和 gamma 是调节损失贡献的超参数。

l1_mask_loss = torch.mean(torch.abs(out_mask_sm - torch.cat((1.0 - anomaly_mask, anomaly_mask), dim=1)))
如果模型的预测在某些像素点上极度不准确(即预测值与真实值之间的差异很大),
L1 损失不会像平方差损失(L2 损失)那样对这些错误赋予过高的权重,
从而避免让模型过度适应那些极端的误差。

segment_loss = segment_loss + l1_mask_loss 
# L1 is different than in the paper but may improve results in some cases
相关推荐
IE066 分钟前
深度学习系列76:流式tts的一个简单实现
人工智能·深度学习
GIS数据转换器11 分钟前
城市生命线安全保障:技术应用与策略创新
大数据·人工智能·安全·3d·智慧城市
无须logic ᭄14 分钟前
CrypTen项目实践
python·机器学习·密码学·同态加密
m0_743106464 小时前
【论文笔记】MV-DUSt3R+:两秒重建一个3D场景
论文阅读·深度学习·计算机视觉·3d·几何学
m0_743106464 小时前
【论文笔记】TranSplat:深度refine的camera-required可泛化稀疏方法
论文阅读·深度学习·计算机视觉·3d·几何学
Coovally AI模型快速验证7 小时前
MMYOLO:打破单一模式限制,多模态目标检测的革命性突破!
人工智能·算法·yolo·目标检测·机器学习·计算机视觉·目标跟踪
AI浩8 小时前
【面试总结】FFN(前馈神经网络)在Transformer模型中先升维再降维的原因
人工智能·深度学习·计算机视觉·transformer
orion-orion9 小时前
贝叶斯机器学习:高斯分布及其共轭先验
机器学习·统计学习
IE0610 小时前
深度学习系列75:sql大模型工具vanna
深度学习
不惑_10 小时前
深度学习 · 手撕 DeepLearning4J ,用Java实现手写数字识别 (附UI效果展示)
java·深度学习·ui