超越Sora的开源思路:如何用预训练组件高效训练你的视频扩散模型?(附训练代码)

当我们开始思考3D数据或视频时,一个很自然的想法就是把它们视为一系列2D帧,然后通过简单地把时间作为额外维度来应用同样的模型。

从直觉上看,这种方法似乎可行,但实际上它很快会遇到瓶颈。随着输入变得更高维,模型通常需要变得更大才能表现良好。这导致内存使用增加、计算成本上升,训练也更难稳定,尤其是在GPU资源有限的情况下。

对于像视频生成这样的任务,这个问题变得更加严重。视频中的帧并不是独立的,它们在时间上紧密相连。由于这种强烈的时间依赖性,单一的生成模型很难同时学会物体的外观和运动方式。

在这篇文章中,将介绍隐式流扩散模型,这是一个为应对上述挑战而设计的两阶段框架。第一阶段专注于学习像素级的空间关系,而第二阶段则建模视频帧之间的时间依赖性。

到文章结尾,你将看到LFDM如何用于在MHAD数据集上进行条件式的图像到视频生成,并配有PyTorch的实战代码实现。

隐式流扩散模型的核心思想

正如前文提到的,视频之所以难以建模,是因为空间外观和时间运动高度纠缠。当扩散模型直接在像素空间逐帧生成视频时,结果往往不稳定,导致闪烁、物体形状扭曲或帧间突然不一致。

与其直接生成视频,LFDM采用了一种不同的策略:通过光流来建模运动。在第一阶段,模型被训练来预测源帧和目标帧之间的光流。然后,这个流被用来扭曲源帧,从而生成对应的目标帧。

给定一个源帧 x0,我们可以进一步训练一个扩散模型来生成一系列光流,例如 flow(x0, x1), flow(x0, x2)......通过用这些预测的流来扭曲同一个源帧 x0,就可以合成并组装出连贯的未来帧。

因为所有帧都是由同一张源图像变形生成的,空间结构保持一致,并且运动随时间演化更平滑。这使得扩散模型可以专注于学习运动动态,而非细节外观,让学习问题变得更简单,生成的视频也更稳定。

LFDM背后的另一个关键思想是:光流是在隐式空间而不是直接在像素级应用的。这显著降低了内存使用和计算成本。在我的实验中,源帧被缩放到128×128,并编码成32×32的隐式特征。模型预测一个流场来扭曲这个隐式表示,然后解码器将扭曲后的隐式表示重建为目标帧。

  • 阶段一:隐式流自编码器
  • 训练一个自编码器,将源帧编码到隐式空间,并将经过流扭曲的隐式表示解码回图像。
  • 训练一个流预测器,用于估计源帧和目标帧之间在隐式空间(32×32)的光流。
  • 阶段二:扩散模型
  • 使用第一阶段训练好的流预测器,从一个固定的源帧 x0 生成一系列隐式光流(如上图所示)。
  • 使用扩散模型(或其他生成模型)来建模这个流序列的分布,并生成新的隐式流序列。
  • 光流与反向采样

在计算机视觉中,光流通过为每个像素估计一个2D位移向量,来描述两个连续图像帧 x0 和 x1 之间像素的表观运动。

流扭曲可以用上面的方程表示,其中 p = (x, y) 表示像素坐标,(u, v) 表示光流的水平和垂直分量。这个方程定义了一个像素级的映射,描述了源帧 x0 中的像素如何被移动到目标帧 x1 中,将每个在 (x, y) 的像素映射到 (x + u, y + v)。

根据流动方向,扭曲可以分为前向扭曲和后向扭曲。前向扭曲使用流场将源帧的每个像素映射到目标帧,但有些像素可能会落在有效图像区域之外,导致缺失或未定义的区域(见A图)。而后向扭曲则指定每个目标像素如何从源帧采样,从而避免了这些缺失区域。

在LFDM中,我们选择后向扭曲,因为所有未来帧(x₁, x₂, ...)都是从同一个源帧(x₀)采样得到的,确保了跨时间的空间和外观一致性。此外,后向扭曲避免了目标帧中的缺失或未定义区域。

关节动画的运动表示

对于简单的运动,传统的光流方法(例如经典的或基于学习的方法)通常效果不错。然而,当运动变得更加复杂时------例如,当多个物体独立运动或发生关节式、非刚性的变形时------准确估计光流就变得困难得多。

因此,LFDM采用了基于MRAA的运动表示,它使用结构化的关键点位移来建模运动,并将它们聚合成一个密集的变形场。这种设计使得LFDM即使在复杂的运动场景下,也能以更稳定、更连贯的方式预测光流和遮挡图。

MRAA背后的核心思想相当简单:它不试图一次性建模复杂的整体运动,而是将运动分解成几个较小的局部运动。每个部分被单独建模,然后所有这些局部运动被组合起来形成整体运动。

例如,当一个人将手从低位举过头顶时,多个身体部位------如肩膀、上臂和前臂------会一起运动。MRAA不是直接估计一个单一的、全局的像素级运动场,而是对连续帧之间每个局部组件的相对运动进行建模,然后通过加权聚合将它们组合起来,产生最终的整体运动场(光流)。

  • MRAA的组成部分
  • 区域预测器: 从输入图像中提取一组关键区域,每个区域代表一个局部运动组件。这些区域提供了参与运动的物体部件的结构化和紧凑表示。
  • 流预测器: 通过组合区域级运动,建模源帧和驱动帧之间的相对运动。它将运动聚合成一个密集的变形场以及一个遮挡图。
  • 生成器: 使用预测的流和遮挡图来扭曲源图像,以重建目标帧。通过重建损失,它实现了对有意义的、稳定的运动表示的端到端学习。
  • 背景运动预测器: 估计静态背景的全局运动,例如相机移动。

注:遮挡图指示了由于帧间遮挡而变得不可见的区域,在这些区域,直接从源图像进行扭曲并不可靠。它允许模型在重建过程中降低这些区域的权重或忽略它们,从而防止伪影并提高时间一致性。

区域预测器

区域预测器的目标是学习图像中多个区域的可微分局部表示。它使用一个U-Net来预测每个区域的软热图,以概率方式将像素分配到不同的区域。

从这些热图中,我们计算每个区域的中心和协方差:

基于这些一阶和二阶矩,我们进一步利用PCA推导出每个区域相对于规范坐标系的仿射表示。这些稀疏的、区域级的线索随后被流预测器用来建模密集的光流。

另一个细节是,流预测器在隐式空间运行,并预测一个空间分辨率低得多的光流,而不是直接在像素级。为了保持区域表示与这个隐式流对齐,我们在训练区域预测器时,将输入图像缩放到与隐式特征相同的分辨率。在我的实现中,128×128的图像被缩放到32×32作为模型输入。

ini 复制代码
class RegionPredictor(nn.Module):
    def __init__(self, block_expansion=32, num_regions=10, num_channels=3, \
                 max_features=1024, num_blocks=5, temperature=0.1, scale_factor=0.25, pad=0):
        super().__init__()
        self.temperature = temperature
        self.scale_factor = scale_factor
       
        self.predictor = Hourglass(
            block_expansion,
            in_features=num_channels,
            max_features=max_features,
            num_blocks=num_blocks
        )
        
        self.regions = nn.Conv2d(
            in_channels=self.predictor.out_filters,
            out_channels=num_regions,
            kernel_size=(7, 7),
            padding=pad
        )
        
        # Rescale 128 * 128 frame into 32 * 32
        if self.scale_factor != 1:
            self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
        else:
            self.down = None
    
    def _region_to_pca_params(self, region: torch.Tensor):
        # region: (B, K, H, W) softmax heatmap
        B, K, H, W = region.shape
        grid = make_coordinate_grid((H, W), region.dtype).to(region.device)
        grid = grid.unsqueeze(0).unsqueeze(0)
        
        # mean/shift: (B, K, 2)   
        region_w = region.unsqueeze(-1)               
        mean = (region_w * grid).sum(dim=(2, 3))        
        
        # covariance: (B, K, 2, 2)
        mean_sub = grid - mean.unsqueeze(-2).unsqueeze(-2)
        covar = mean_sub.unsqueeze(-1) * mean_sub.unsqueeze(-2)  
        covar = covar * region.unsqueeze(-1).unsqueeze(-1)       
        covar = covar.sum(dim=(2, 3))
        I = torch.eye(2, device=covar.device, dtype=covar.dtype).view(1, 1, 2, 2)
        covar = covar + 1e-6 * I
        covar = 0.5 * (covar + covar.transpose(-1, -2))
        
        # SVD to get sqrt(covar) as "affine"
        covar_flat = covar.view(-1, 2, 2)           
        U, S, Vh = torch.linalg.svd(covar_flat, full_matrices=False)
        
        # sqrt matrix: U * diag(sqrt(S))
        D = torch.diag_embed(torch.sqrt(torch.clamp(S, min=1e-6))) 
        sqrt = U @ D                                               
        sqrt = sqrt.view(B, K, 2, 2)
        U = U.view(B, K, 2, 2)
        D = D.view(B, K, 2, 2)
        return {"shift": mean, "covar": covar, "affine": sqrt, "u": U, "d": D}
    
    def forward(self, x: torch.Tensor):
        # x: (B, 3, H, W)
        if self.down is not None:
            x = self.down(x)
        
        feature_map = self.predictor(x)
        logits = self.regions(feature_map)  # (B, K, H, W)
        B, K, H, W = logits.shape
        region = logits.view(B, K, -1)
        region = F.softmax(region / self.temperature, dim=2)
        region = region.view(B, K, H, W)
        
        region_params = self._region_to_pca_params(region)
        region_params["heatmap"] = region
        return region_params

像素级流预测器

在区域预测器中,我们获得了每个区域的粗略的、区域级的运动信息。流预测器然后通过以下方程聚合这些区域级线索,合成一个密集的光流场:(其中A: 仿射矩阵,s: 区域中心,x: 像素坐标)

思路很简单。我们首先通过 x - s_drv 移动像素坐标,使得一切都以驱动区域为中心。然后,我们使用驱动帧的逆仿射变换 A_drv^-1 将坐标带入一个规范空间。之后,我们应用源帧的仿射变换 A_src 将其映射回源姿态。最后,我们加上 s_src 得到用于后向扭曲的流网格。

class 复制代码
    def __init__(self, block_expansion=64, num_blocks=5, max_features=1024, \
                 num_regions=10, num_channels=3, scale_factor=0.25):
        super().__init__()
        self.num_regions = num_regions
        self.scale_factor = scale_factor

        in_ch = (num_regions + 1) * (num_channels + 1)
        self.hourglass = Hourglass(
            block_expansion=block_expansion,
            in_features=in_ch,
            max_features=max_features,
            num_blocks=num_blocks,
        )
        self.mask = nn.Conv2d(self.hourglass.out_filters, num_regions + 1, kernel_size=7, padding=3)
        self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=7, padding=3)
        if self.scale_factor != 1:
            self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
    
    def create_heatmap_representations(self, source_image, driving_region_params, source_region_params):
        spatial_size = source_image.shape[2:]  # (h, w)
        covar_d = driving_region_params["covar"]
        covar_s = source_region_params["covar"]
        
        gaussian_driving = region2gaussian(driving_region_params["shift"], covar=covar_d, spatial_size=spatial_size)
        gaussian_source = region2gaussian(source_region_params["shift"], covar=covar_s, spatial_size=spatial_size)
        heatmap = gaussian_driving - gaussian_source  # (B, K, H, W)
        # add background channel (zeros), can consider adding background feature
        zeros = torch.zeros(heatmap.size(0), 1, spatial_size[0], spatial_size[1],
                            device=heatmap.device, dtype=heatmap.dtype)
        heatmap = torch.cat([zeros, heatmap], dim=1)  # (B, K+1, H, W)
        return heatmap.unsqueeze(2)  # (B, K+1, 1, H, W)
    
    def create_sparse_motions(self, source_image, driving_region_params, source_region_params):
        bs, _, h, w = source_image.shape
        identity_grid = make_coordinate_grid((h, w), type=source_region_params["shift"].type())
        identity_grid = identity_grid.view(1, 1, h, w, 2)  # (1,1,H,W,2)
        # region-wise coords centered at driving shift
        coordinate_grid = identity_grid - driving_region_params["shift"].view(bs, self.num_regions, 1, 1, 2)
        
        affine = torch.matmul(source_region_params["affine"], torch.inverse(driving_region_params["affine"]))
        affine = affine * torch.sign(affine[:, :, 0:1, 0:1])
        affine = affine.unsqueeze(-3).unsqueeze(-3)             # (B,K,1,1,2,2)
        affine = affine.repeat(1, 1, h, w, 1, 1)                # (B,K,H,W,2,2)
        
        coordinate_grid = torch.matmul(affine, coordinate_grid.unsqueeze(-1)).squeeze(-1)  # (B,K,H,W,2)
        driving_to_source = coordinate_grid + source_region_params["shift"].view(bs, self.num_regions, 1, 1, 2)
        # background motion is always identity (no bg predictor)
        bg_grid = identity_grid.repeat(bs, 1, 1, 1, 1)  # (B,1,H,W,2)
        return torch.cat([bg_grid, driving_to_source], dim=1)
    
    def create_deformed_source_image(self, source_image, sparse_motions):
        bs, _, h, w = source_image.shape
        source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(bs, self.num_regions + 1, 1, 1, 1, 1)
        source_repeat = source_repeat.view(bs * (self.num_regions + 1), -1, h, w)
        sparse_motions = sparse_motions.view(bs * (self.num_regions + 1), h, w, 2)
        sparse_deformed = F.grid_sample(source_repeat, sparse_motions)
        sparse_deformed = sparse_deformed.view(bs, self.num_regions + 1, -1, h, w)
        return sparse_deformed
    
    def forward(self, source_image, driving_region_params, source_region_params):
        if self.scale_factor != 1:
            source_image = self.down(source_image)
        bs, _, h, w = source_image.shape
        heatmap_representation = self.create_heatmap_representations(source_image, driving_region_params, source_region_params)
        sparse_motion = self.create_sparse_motions(source_image, driving_region_params, source_region_params) 
        deformed_source = self.create_deformed_source_image(source_image, sparse_motion) # (B, K+1, C, H, W)
        
        predictor_input = torch.cat([heatmap_representation, deformed_source], dim=2)    # (B, K+1, 1+C, H, W)
        predictor_input = predictor_input.view(bs, -1, h, w)
        prediction = self.hourglass(predictor_input)
        mask = F.softmax(self.mask(prediction), dim=1).unsqueeze(2)  # (B,K+1,1,H,W)
        
        sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3)         # (B,K+1,2,H,W)
        deformation = (sparse_motion * mask).sum(dim=1).permute(0, 2, 3, 1)  # (B,H,W,2)
        out_dict = {"optical_flow": deformation}
        out_dict["occlusion_map"] = torch.sigmoid(self.occlusion(prediction))
        return out_dict

还使用一个U-Net来预测一个遮挡图。它告诉我们,对于每个像素,哪个区域的运动是可靠的。更重要的是,它让生成器知道哪些部分被遮挡了,无法通过扭曲获得,因此这些区域需要生成器自己来填充。

生成器与训练循环

最后,第一阶段训练的最后一个部分是生成器,它是一个基于自编码器的模型。它首先将源帧 x_s 编码成一个隐式表示 z_s。然后,这个隐式表示被光流扭曲,产生一个变形后的特征。

一个遮挡图 M 被用来混合变形后的隐式表示和原始隐式表示,指示哪些区域应该被信任,哪些需要被填充。解码器然后使用这个混合后的隐式表示来重建目标帧 x_t。

ini 复制代码
class Generator(nn.Module):
    def __init__(self, num_channels=3, num_regions=10, block_expansion=64, max_features=512,
                 num_down_blocks=2, num_bottleneck_blocks=6):   
        super().__init__()
        self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3)) 
        down_blocks = []
        up_blocks = []
        
        for i in range(num_down_blocks):
            in_features  = min(max_features, block_expansion * (2 ** i))
            out_features = min(max_features, block_expansion * (2 ** (i + 1)))
            down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))

        for i in range(num_down_blocks):
            in_features  = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))
            out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))
            up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))

        self.down_blocks = nn.ModuleList(down_blocks)
        self.up_blocks = nn.ModuleList(up_blocks)
        in_features = min(max_features, block_expansion * (2 ** num_down_blocks))
        
        self.bottleneck = nn.Sequential(*[
            ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1))
            for _ in range(num_bottleneck_blocks)
        ])
        self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
        self.num_channels = num_channels
   
    @staticmethod
    def deform_input(inp, grid):
        # grid: (B, H, W, 2) in normalized coords for grid_sample
        b, c, h, w = inp.shape
        gh, gw = grid.shape[1], grid.shape[2]
        if (gh, gw) != (h, w):
            grid = grid.permute(0, 3, 1, 2)  # (B,2,H,W)
            grid = F.interpolate(grid, size=(h, w), mode='bilinear', align_corners=True)
            grid = grid.permute(0, 2, 3, 1)  # (B,H,W,2)
        return F.grid_sample(inp, grid, align_corners=True)
    
    def apply_optical(self, x_skip, x_prev, motion_params):
        if motion_params is None:
            return x_prev if x_prev is not None else x_skip
        x = self.deform_input(x_skip, motion_params['optical_flow'])
        occ = motion_params['occlusion_map']
        if occ.shape[-2:] != x.shape[-2:]:
            occ = F.interpolate(occ, size=x.shape[-2:], mode='bilinear', align_corners=True)

        if x_prev is None:
            return x * occ
        return x * occ + x_prev * (1 - occ)
    
    def forward(self, source_image, motion_params):
        out = self.first(source_image)
        skips = [out]
        for block in self.down_blocks: # Encoder
            out = block(out)
            skips.append(out)
        
        output_dict = {
            "bottle_neck_feat": out,
            "deformed": self.deform_input(source_image, motion_params["optical_flow"]),
            "optical_flow": motion_params["optical_flow"],
        }
        output_dict["occlusion_map"] = motion_params["occlusion_map"]
        out = self.apply_optical(x_skip=out, x_prev=None, motion_params=motion_params)
        out = self.bottleneck(out)
        
        for i, up in enumerate(self.up_blocks): # Decoder
            out = self.apply_optical(
                x_skip=skips[-(i + 1)],
                x_prev=out,
                motion_params=motion_params
            )
            out = up(out)
        out = torch.sigmoid(self.final(out))
        output_dict["prediction"] = out
        return output_dict
    
    def compute_fea(self, source_image):
        out = self.first(source_image)
        for block in self.down_blocks:
            out = block(out)
        return out
  • UTD-MHAD数据集

在这个实现中,我使用UTD-MHAD数据集来训练模型的两个阶段。对于第一阶段,从同一个视频中随机采样两帧,一帧作为源帧,另一帧作为目标帧。对于第二阶段,随机选择一个视频,并提取一个连续的40帧序列作为模型输入。

ini 复制代码
from PIL import Image
from decord import VideoReader, cpu
class UTDDataset(Dataset):
    def __init__(self,
                 root="/content/drive/MyDrive/LFDM_data",
                 image_size=128,
                 mode="stage1",
                 clip_len=40):
        super().__init__()
        assert mode in ["stage1", "stage2"]
        
        self.mode = mode
        self.clip_len = clip_len
        self.video_files = sorted(glob(os.path.join(root, "**", "*.avi"), recursive=True))
        
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
        ])
    
    def __len__(self):
        return len(self.video_files)
    
    def __getitem__(self, idx):
        # Stage 1: random (x0, xt)
        if self.mode == "stage1":
            while True:
                vid_path = random.choice(self.video_files)
                vr = VideoReader(vid_path, ctx=cpu(0))
                n = len(vr)
                if n < 2:
                    continue
                
                i0, i1 = sorted(random.sample(range(n), 2))
                f0 = Image.fromarray(vr[i0].asnumpy())
                f1 = Image.fromarray(vr[i1].asnumpy())
                x0 = self.transform(f0)
                xt = self.transform(f1)
                return x0, xt
        
        # Stage 2: clip + action
        vid_path = self.video_files[idx]
        vr = VideoReader(vid_path, ctx=cpu(0))
        n = len(vr)
        if n == 0:
            return self.__getitem__((idx + 1) % len(self.video_files))
        
        K = self.clip_len
        if n >= K:
            start = random.randint(0, n - K)
            indices = range(start, start + K)
        else:
            indices = sorted(random.choices(range(n), k=K))
        
        frames = []
        for i in indices:
            img = Image.fromarray(vr[i].asnumpy())
            frames.append(self.transform(img))
        
        clip = torch.stack(frames, dim=1)  # (C, K, H, W)
        video_name = os.path.basename(vid_path)
        action_idx = int(video_name.split("_")[0][1:]) - 1  # a01 -> 0
        return clip, action_idx
  • 第一阶段训练循环
  • 从源帧 x0 和目标帧 xt 中提取区域信息。
  • 使用流预测器预测光流和遮挡图。
  • 用预测的运动扭曲源帧,并使用生成器重建目标帧 pred_xt。
  • 测量重建帧与目标帧之间的感知损失。

注:你也可以尝试应用一个等变性损失,这可以加强几何一致性,并提高学习到的运动表示的稳定性。

def 复制代码
    x_vgg = vgg(pred)
    y_vgg = vgg(real)
    loss = 0.0
    for i, w in enumerate(per_weights):
        loss = loss + w * torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
    return loss

region_predictor = RegionPredictor().to(device)
flow_predictor = PixelwiseFlowPredictor().to(device)
generator = Generator().to(device)
total_iters   = 200000
save_interval = 20000

# You may use any pretrained VGG model.
# Here, we use the outputs of the last five layers to compute the perceptual loss.
vgg = Vgg19().to(device)
vgg.eval()  
for p in vgg.parameters():
    p.requires_grad_(False)

dataset = UTDDataset(mode="stage1", image_size=128)
dataloader = DataLoader(
    dataset, 
    batch_size=16, 
    shuffle=True, 
    pin_memory=True, 
    drop_last=True, 
    num_workers=4
)
optimizer = torch.optim.Adam(
    list(region_predictor.parameters()) +
    list(flow_predictor.parameters()) +
    list(generator.parameters()), 
    lr=5e-5, 
    betas=(0.5, 0.999)
)
data_iter = iter(dataloader)
pbar = trange(1, total_iters + 1, desc="Training", dynamic_ncols=True)

for it in pbar:
    try:
        x0, xt = next(data_iter)
    except StopIteration:
        data_iter = iter(dataloader)
        x0, xt = next(data_iter)
    
    x0, xt = x0.to(device, non_blocking=True), xt.to(device, non_blocking=True)
    source_region  = region_predictor(x0)
    driving_region = region_predictor(xt)
    motion = flow_predictor(x0, source_region_params=source_region, driving_region_params=driving_region)             
    generated = generator(x0, motion)
    pred = generated["prediction"]        
    real = xt
    
    perc = perceptual_loss(vgg, pred, real)
    optimizer.zero_grad(set_to_none=True)
    perc.backward()
    optimizer.step()
    
    # ---- save ----
    if it % save_interval == 0:
        ckpt = {
            "it": it,
            "region_predictor": region_predictor.state_dict(),
            "flow_predictor": flow_predictor.state_dict(),
            "generator": generator.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        torch.save(ckpt, f"stage1_ckpt_{it:06d}.pt")

LFDM第二阶段训练

终于到了最后一步。我们现在要做的事情很简单:训练一个生成模型,能够生成合理的隐式流序列,其形状为:

到了这一步,大部分困难的工作已经完成了。使用MRAA风格的方法,我们已经分离了空间结构,并将原始图像压缩到了一个低维隐式空间。因此,模型不再需要关心外观------它只需要学习运动如何随时间演化。

我们现在可以将隐式流序列视为一个3D体积,并使用扩散模型来生成它。当然,扩散模型不是唯一的选择------整流流、Transformer或GANs在这里也能很好地工作。

注:在我的实现中,我使用了"修正流"而不是扩散模型,因为它更容易实现且收敛更快。

python 复制代码
import ...
from einops import rearrange

class TemporalSelfAttention(nn.Module):
    def __init__(self, dim, heads=8, max_frames=64, dropout=0.0):
        super().__init__()
        self.max_frames = max_frames
        self.pos_emb = nn.Embedding(max_frames, dim)
        self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True)
   
    def forward(self, x):
        b, c, f, h, w = x.shape
        if f > self.max_frames:
            raise ValueError(f"num_frames={f} exceeds max_frames={self.max_frames}")
        
        x_seq = rearrange(x, 'b c f h w -> (b h w) f c')   # (BHW, F, C)
        pos = torch.arange(f, device=x.device)
        x_seq = x_seq + self.pos_emb(pos).unsqueeze(0)     # (BHW, F, C)
        out, _ = self.attn(x_seq, x_seq, x_seq, need_weights=False)
        out = rearrange(out, '(b h w) f c -> b c f h w', b=b, h=h, w=w)
        return out

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    
    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

def Upsample(dim, padding_mode="reflect"):
    return nn.Sequential(
        nn.Upsample(scale_factor=(1, 2, 2), mode='nearest'),
        nn.Conv3d(dim, dim, (1, 3, 3), (1, 1, 1), (0, 1, 1), padding_mode=padding_mode)
    )

def Downsample(dim):
    return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))

class NormAttn(nn.Module):
    def __init__(self, norm, attn):
        super().__init__()
        self.norm = norm
        self.attn = attn
    
    def forward(self, x):
        return self.attn(self.norm(x))

class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        if time_emb_dim is None:
            self.mlp = None
        else:
            self.mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_emb_dim, dim_out * 2)
            )
        self.conv1 = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1))
        self.norm1 = nn.GroupNorm(groups, dim_out)
        self.act1  = nn.SiLU()
        self.conv2 = nn.Conv3d(dim_out, dim_out, (1, 3, 3), padding=(0, 1, 1))
        self.norm2 = nn.GroupNorm(groups, dim_out)
        self.act2  = nn.SiLU()
        self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
    
    def forward(self, x, time_emb=None):
        scale = shift = None
        if self.mlp is not None:
            t = self.mlp(time_emb)
            t = rearrange(t, 'b c -> b c 1 1 1')
            scale, shift = t.chunk(2, dim=1)
        
        h = self.conv1(x)
        h = self.norm1(h)
        if scale is not None:
            h = h * (scale + 1) + shift
        h = self.act1(h)
        h = self.conv2(h)
        h = self.norm2(h)
        h = self.act2(h)
        return h + self.res_conv(x)

以上是U-Net模型的一些关键模块。这里,我们添加了一个时间注意力模块,通过关注时间(帧)维度来更好地保持时间一致性。

ini 复制代码
class Unet3D(nn.Module):
    def __init__(self, dim=64, num_classes=28, cond_embed_dim=128, dim_mults=(1, 2, 4, 8),
                 channels=3 + 256, attn_heads=8, max_frames=64, init_kernel_size=7, resnet_groups=8):
        super().__init__()
        self.channels = channels
        init_dim = dim
        init_padding = init_kernel_size // 2
        self.init_conv = nn.Conv3d(
            channels, init_dim,
            kernel_size=(1, init_kernel_size, init_kernel_size),
            padding=(0, init_padding, init_padding),
            padding_mode="zeros"
        )
        
        self.init_norm = nn.GroupNorm(1, init_dim, eps=1e-5)
        self.init_attn = TemporalSelfAttention(init_dim, heads=attn_heads, max_frames=max_frames)
        self.init_attn_res = Residual(lambda x: self.init_attn(self.init_norm(x)))
        dims = [init_dim] + [dim * m for m in dim_mults]
        in_out = list(zip(dims[:-1], dims[1:]))
        num_resolutions = len(in_out)
        
        # time embedding
        time_dim = dim * 4
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )
        
        # class embedding
        self.cond_emb = nn.Embedding(num_classes, cond_embed_dim)
        cond_dim = time_dim + cond_embed_dim
        
        block_klass = partial(ResnetBlock, groups=resnet_groups)
        block_klass_cond = partial(block_klass, time_emb_dim=cond_dim)
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
            norm = nn.GroupNorm(1, dim_out, eps=1e-5)
            attn = TemporalSelfAttention(dim_out, heads=attn_heads, max_frames=max_frames)
            attn_res = Residual(NormAttn(norm, attn))
            
            self.downs.append(nn.ModuleList([
                block_klass_cond(dim_in, dim_out),
                block_klass_cond(dim_out, dim_out),
                attn_res,
                Downsample(dim_out) if not is_last else nn.Identity()
            ]))
        
        mid_dim = dims[-1]
        self.mid_block1 = block_klass_cond(mid_dim, mid_dim)
        norm = nn.GroupNorm(1, mid_dim, eps=1e-5)
        attn = TemporalSelfAttention(mid_dim, heads=attn_heads, max_frames=max_frames)
        self.mid_attn_res = Residual(NormAttn(norm, attn))
        self.mid_block2 = block_klass_cond(mid_dim, mid_dim)
        
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind >= (num_resolutions - 1)
            norm = nn.GroupNorm(1, dim_in, eps=1e-5)
            attn = TemporalSelfAttention(dim_in, heads=attn_heads, max_frames=max_frames)
            attn_res = Residual(NormAttn(norm, attn))
            
            self.ups.append(nn.ModuleList([
                block_klass_cond(dim_out * 2, dim_in),
                block_klass_cond(dim_in, dim_in),
                attn_res,
                Upsample(dim_in, padding_mode="zeros") if not is_last else nn.Identity()
            ]))
        
        self.out = nn.Sequential(
            block_klass(dim * 2, dim),
            nn.Conv3d(dim, 3, 1)
        )
    
    def forward(self, x, time, cond):
        device = x.device
        x = self.init_conv(x)
        r = x.clone()
        x = self.init_attn_res(x)
        t = self.time_mlp(time)
        if cond.dim() == 2 and cond.size(-1) == 1:
            cond = cond.squeeze(-1)
        cond = cond.long().to(device)
        t = torch.cat([t, self.cond_emb(cond)], dim=-1)
        
        h = [] # unet
        for block1, block2, attn_res, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn_res(x)
            h.append(x)
            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn_res(x)
        x = self.mid_block2(x, t)
        
        for block1, block2, attn_res, upsample in self.ups:
            x = torch.cat([x, h.pop()], dim=1)
            x = block1(x, t)
            x = block2(x, t)
            x = attn_res(x)
            x = upsample(x)
        x = torch.cat([x, r], dim=1)
        return self.out(x)

get_latent_flow_residual 函数收集源帧与所有其他帧之间的流。它还输出源帧的特征图,用作U-Net的条件输入。

@torch.no_grad() 复制代码
def get_latent_flow_residual(x, region_predictor, flow_predictor, generator):
    b, _, T, H, W = x.shape
    source_img = x[:, :, 0]  # (B,3,H,W)
    flow_list = []
    conf_list = []
    source_region = region_predictor(source_img)
    for t in range(T):
        driving_img = x[:, :, t]
        driving_region = region_predictor(driving_img)
        flow_motion = flow_predictor(
            source_img,
            source_region_params=source_region,
            driving_region_params=driving_region
        )
        grid = flow_motion["optical_flow"]  
        conf = flow_motion["occlusion_map"] 
        flow_list.append(grid)  # keep (B,H,W,2)
        conf_list.append(conf)
    
    fea = generator.compute_fea(source_img) 
    grid = torch.stack(flow_list, dim=1)  # (B,T,H,W,2)
    conf = torch.stack(conf_list, dim=2)  # (B,1,T,H,W)
    id_grid = make_coordinate_grid((32, 32), grid.dtype).to(grid.device)
    id_grid = id_grid.unsqueeze(0).repeat(b, 1, 1, 1)
    id_grid = id_grid.unsqueeze(1).repeat(1, T, 1, 1, 1) 
    delta = grid - id_grid  
    
    delta = delta.permute(0, 4, 1, 2, 3).contiguous()
    conf = conf * 2 - 1   # conf to [-1,1]
    out = torch.cat([delta, conf], dim=1)  
    # fea repeat on T (B,C,Hf,Wf)->(B,C,T,Hf,Wf)
    fea = fea.unsqueeze(2).repeat(1, 1, T, 1, 1)
    return out, fea

from torchdiffeq import odeint
class RectifiedFlow(nn.Module):
    def __init__(self, unet, p_uncond=0.25, null_label_id=27):
        super().__init__()
        self.unet = unet
        self.p_uncond = p_uncond
        self.null_label_id = null_label_id
    
    def forward(self, x0, fea, label):
        B = x0.size(0)
        device = x0.device
        # Classified Free Guidence
        if (self.null_label_id is not None) and (self.p_uncond > 0):
            drop = torch.rand(B, device=device) < self.p_uncond
            label = label.clone()
            label[drop] = self.null_label_id
        
        t = torch.rand(B, device=device) # (B,)
        t_view = t.view(B, 1, 1, 1, 1)
        x1 = torch.randn_like(x0)
        x_t = (1.0 - t_view) * x1 + t_view * x0
        v_gt = x0 - x1
        
        x_in = torch.cat([x_t, fea], dim=1) # (B, C+C_fea, F, H, W)
        v_pred = self.unet(x_in, t, label)
        loss = F.mse_loss(v_pred, v_gt)
        return loss

@torch.no_grad()
def rf_sample_target(model, fea, label,  device, steps=40):
    # You can use this function to get the latent flow
    # and use the 'decoder' in Generator to recover the video back
    B = label.size(0)
    x0 = torch.randn(B, 3, 40, 32, 32, device=device)
    def func(t, x):
        t_vec = torch.full((B,), float(t), device=device)
        x_in = torch.cat([x, fea], dim=1)
        v = model(x_in, t_vec, label)
        return v
    
    t_span = torch.linspace(0.0, 1.0, steps + 1, device=device)
    x_traj = odeint(func, x0, t_span, method="rk4", rtol=1e-5, atol=1e-5)
    x1 = x_traj[-1]
    x1 = x1.clamp(-1, 1)
    return x1, x0
  • 无分类器引导

对于视频生成,仅仅使用一个标签来控制输出通常是不够的。标签只告诉模型要生成什么动作,但它没有解释这个动作应该如何随时间演化。

因此,模型在训练过程中很容易走捷径,生成一个仅仅"看起来在动"的序列,而没有真正遵循给定标签的语义含义。换句话说,条件信号太弱了,所以模型可以轻易忽略它。

为了解决这个问题,引入无分类器引导来让标签更重要。CFG通过在生成过程中放大条件的影响,明确地加强了条件的约束力,迫使模型遵循标签,而不是退回到通用的、无条件的运动。这对于视频生成尤其重要,因为可能的运动空间比图像大得多。

注:在实践中,CFG非常容易实现。只需添加一个额外的类别来表示空(无条件)条件,并在训练过程中随机丢弃原始标签。

device 复制代码
best_path = "Your pretrain model"
ckpt = torch.load(best_path, map_location=device)
generator.load_state_dict(ckpt["generator"])
region_predictor.load_state_dict(ckpt["region_predictor"])
flow_predictor.load_state_dict(ckpt["flow_predictor"])

unet = Unet3D()
rf_model = RectifiedFlow(unet).to(device)
dataset = UTDDataset(mode='stage2)
optimizer = torch.optim.Adam(list(unet.parameters()), lr=1e-4, betas=(0.9, 0.999))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)

total_iters = 200000
log_interval = 20000
running_loss = 0.0
loss_history = []
iter_history = []

generator.eval()
region_predictor.eval()
flow_predictor.eval()
unet.train()
data_iter = iter(dataloader)
pbar = trange(1, total_iters + 1, desc="Training", dynamic_ncols=True)

for it in pbar:
    try:
        clip, label = next(data_iter)
    except StopIteration:
        data_iter = iter(dataloader)
        clip, label = next(data_iter)
    
    clip, label = clip.to(device), label.to(device)
    B = clip.size(0)
    
    flow, fea = get_latent_flow_residual(clip, region_predictor, flow_predictor, generator)
    loss = rf_model(flow, fea, label)
    optimizer.zero_grad()
    loss.backward()
    
    torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
    optimizer.step()
    ema_update(ema_unet, unet, beta=0.995)
    running_loss += loss.item()
    pbar.set_postfix({"loss": f"{loss.item():.4f}"})
    
    if it % log_interval == 0:
        avg_loss = running_loss / log_interval
        pbar.write(f"[Iter {it}] loss = {avg_loss:.4f}")
        iter_history.append(it)
        loss_history.append(avg_loss)
        running_loss = 0.0

        torch.save(
            {
                "unet": unet.state_dict(),
                "ema_unet": ema_unet.state_dict(),
                "optimizer": optimizer.state_dict(),
                "it": it,
            },
            f"{ckpt_dir}/unet_{it}.pt"
        )
        pbar.write(f"Saved checkpoint at iter {it}")
相关推荐
Coovally AI模型快速验证3 小时前
超越Sora的开源思路:如何用预训练组件高效训练你的视频扩散模型?(附训练代码)
人工智能·算法·yolo·计算机视觉·音视频·无人机
千金裘换酒3 小时前
Leetcode 有效括号 栈
算法·leetcode·职场和发展
空空潍4 小时前
hot100-最小覆盖字串(day12)
数据结构·算法·leetcode
Rui_Freely4 小时前
Vins-Fusion之 相机—IMU在线标定(十一)
人工智能·算法·计算机视觉
yyy(十一月限定版)4 小时前
算法——二分
数据结构·算法
七点半7704 小时前
c++基本内容
开发语言·c++·算法
嵌入式进阶行者4 小时前
【算法】基于滑动窗口的区间问题求解算法与实例:华为OD机考双机位A卷 - 最长的顺子
开发语言·c++·算法
嵌入式进阶行者4 小时前
【算法】用三种解法解决字符串替换问题的实例:华为OD机考双机位A卷 - 密码解密
c++·算法·华为od
罗湖老棍子4 小时前
信使(msner)(信息学奥赛一本通- P1376)四种做法
算法·图论·dijkstra·spfa·floyd·最短路算法
生成论实验室4 小时前
生成论之基:“阴阳”作为元规则的重构与证成——基于《易经》与《道德经》的古典重诠与现代显象
人工智能·科技·神经网络·算法·架构