"本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!"
一、前言
数字人是目前比较火的一个概念,在直播带货、新闻播报、授课等方面都有应用。目前实现数字人主要有建模和生成两种方式。前者控制数字人嘴型、肢体动作等,但是人物不够真实。后者更专注嘴型的控制,效果会更加真实。
本文将以生成式的数字人为主题,讲解其实现原理,并从零开始实现生成式数字人。
二、唇形同步
2.1 唇形同步任务
数字人中主要面临的问题是唇形同步,即让人物嘴型与输入音频同步。在当前任务中,模型会输入人脸相关信息、音频特征,输出唇形同步后的人脸图像。
这里有两个输入,一个是人脸相关信息。这里并没有非常严格的要求,不过目前普遍的做法是将目标图像的下半部分(嘴唇部分)遮盖(masked_image),同时选取与目标图像不相同的图像(reference_image),最后把masked_image和reference_image作为人脸相关信息输入模型。
其中masked_image提供了人脸上半部分信息,在预测时只需要原样输出即可。而reference_image选取与目标图像不同的图像,主要是为了防止泄露标签。同时提供下半面部的参考图像。
第二个输入是音频特征,这里也没有非常严格的要求。我们可以用MFCC特征、Hubert音频特征、whisper音频特征等。我们只需要注意选取与目标图像位置对齐的音频特征即可。
2.2 已有模型
目前已有许多数字人项目,其思路大致如2.1所描述的。这里以MuseTalk和Wav2Lip为例。
(1)Wav2Lip
Wav2Lip的原始论文发表在2020年,算比较先驱的数字人项目。Wav2Lip人脸信息选取的是2.1中提到的masked_image和reference_image,而音频特征则选取的MFCC。
在Wav2Lip中还添加了SyncNet网络,用于判断图像序列和音频序列是否同步。其架构大致如下:
在开始训练wav2lip模型前,需要先训练好syncnet。训练syncnet时,会采样T张连续图像并随机生成Label,如图Label为1则选取与之对应的音频,否则随机选取音频。然后输入syncnet,训练syncnet中的audio_embedding模型和video_embedding模型。
(2)MuseTalk
MuseTalk是2024年的新模型,其源码中参考了大量Wav2Lip的源码。在结构上两者非常相似,MuseTalk结构如下:
MuseTalk对Wav2Lip做了以下修改:
- 不直接输入图像,而是输入latents变量
- 不使用MFCC特征,而使用whisper特征
- 使用stable diffusion中的UNet作为backbone
- 舍去syncnet
三、准备工作
为了后续方便的处理,我们编写一个common/utils.py文件,该文件中编写一个通用的工具函数。这里介绍一部分代码:
python
import os
import subprocess
def run_command(command, verbose=False):
try:
if verbose:
subprocess.run(command, check=True)
else:
subprocess.run(command, check=True, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
except subprocess.CalledProcessError as e:
print(f"An error occurred: {e}")
except Exception as e:
print(f"An unexpected error occurred: {e}")
def video2audio(video_path, save_path, sr=16000, verbose=False):
audio_path = os.path.join(save_path, Path(video_path).stem + ".wav")
commands = f"ffmpeg -i {video_path} -vn -ar {sr} -ac 1 -y {audio_path}".split()
run_command(commands, verbose)
return audio_path
def resample_video(video_path, save_path, fps=25, verbose=False):
output_path = os.path.join(save_path, os.path.basename(video_path))
commands = f"ffmpeg -i {video_path} -r {fps} {output_path}".split()
run_command(commands, verbose)
return output_path
其中run_command用来执行命令行参数,然后我们用ffmpeg的命令行编写了两个函数,分别是video2audio函数(用于提取视频中的音频)和resample_video函数(用于重采样视频)。
四、数据预处理
我们以Wav2Lip为例,实现数字人代码。首先第一步是数据预处理,这一步主要内容就是提取视频的图像帧以及裁剪出人脸。
第一步我们创建一个Wav2Lip-From-Scratch项目。
4.1 检测人脸
为了简化问题,Wav2Lip会将人脸裁剪出来,只输入面部进行唇形同步。检测人脸的方式有很多,这里我们选择使用wav2lip中的方案,首先下载Wav2Lip的源码:github.com/Rudrabha/Wa...
bash
git clone https://github.com/Rudrabha/Wav2Lip
我们将face_detection文件夹放入Wav2Lip-From-Scratch,并下载s3fd.pth模型放在face_detection\detection\sfd\s3fd.pth。
然后我们就可以识别人脸了,代码如下:
ini
import cv2
import numpy as np
from face_detection import FaceAlignment, LandmarksType
fa = FaceAlignment(
LandmarksType._2D, flip_input=False, device='cuda'
)
image = cv2.imread('00000001.png')
fs = fa.get_detections_for_batch(np.array([image]))
x1, y1, x2, y2 = fs[0]
face = image[y1:y2, x1:x2]
cv2.imshow('face', face)
cv2.waitKey()
cv2.destroyAllWindows()
fa.get_detections_for_batch输入一个batch的图像,输出人脸Bbox数组。下面我们就可以编写数据预处理的代码了。
4.2 preprocess.py
preprocess的作用是对音视频重采样,并检测人脸。我们先看单个视频如何处理,然后再批量处理,代码如下:
python
from pathlib import Path
from typing import List
import cv2
import numpy as np
from tqdm import tqdm
from face_detection import FaceAlignment, LandmarksType
from common.utils import video2audio, resample_video
fa = FaceAlignment(
LandmarksType._2D, flip_input=False, device='cuda'
)
def get_video_fps(video_path):
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
return fps
def generate_frame_batch(video_path: str, batch_size=16) -> List[np.ndarray]:
cap = cv2.VideoCapture(video_path)
batch = []
while True:
ret, frame = cap.read()
if not ret:
cap.release()
break
batch.append(frame)
if len(batch) >= batch_size:
yield batch
batch = []
if len(batch):
yield batch
def process_video_file(video_file, output_path, batch_size=16):
if isinstance(video_file, str):
video_file = Path(video_file)
frames_save_dir = Path(output_path) / video_file.stem
if frames_save_dir.exists():
return
frames_save_dir.mkdir(exist_ok=True, parents=True)
for i, fb in enumerate(generate_frame_batch(str(video_file), batch_size)):
preds = fa.get_detections_for_batch(np.array(fb))
for j, f in enumerate(preds):
if f is None:
continue
x1, y1, x2, y2 = f
cv2.imwrite(str(frames_save_dir / f'{i * batch_size + j}.jpg'), fb[j][y1:y2, x1:x2])
这里有两个函数,第一个是generate_frame_batch,该函数负责读取视频,然后每次返回一个batch。
第二个是process_video_file,该函数利用generate_frame_batch遍历视频,然后逐batch检测人脸,保存到指定目录。执行process_video_file后可以得到人脸图像。
然后是梳理一下整体的操作,我们要做下面几件事情:
- 将视频重采样到25fps
- 提取视频帧,并逐帧检测人脸(process_video_file的操作)
- 提取音频,并重采样到16000
下面我们在process_videos完成上面的操作:
scss
def process_videos(video_dir: str, output_path, batch_size=16):
# 遍历视频
video_list = list(Path(video_dir).glob('*.mp4'))
for video_file in tqdm(video_list, total=len(video_list)):
# 如果视频不是25fps,就重采样视频
if get_video_fps(str(video_file)) != 25:
video_file = Path(resample_video(str(video_file), 'tmp'))
# 检测单个视频的人脸
process_video_file(video_file, output_path, batch_size)
# 提取音频并重采样
video2audio(str(video_file), str(Path(output_path) / video_file.stem))
if __name__ == '__main__':
process_videos('assets/videos', '111', batch_size=2)
这里视频被处理成25fps,音频的sample rate被处理成16000,此时一帧图像对应成640个音频采样点。处理完成后,我们可以得到下面样式的目录结构:
erlang
datasets
video01
-0.jpg
-1.jpg
...
video02
-0.jpg
...
五、数据集
在原始的Wav2Lip中,添加了SyncNet,但是实际测试中,发现SyncNet没有太大作用,在不需要SyncNet的情况下也可以得到比较好的结果。因此我们只创建Wav2Lip网络的数据集。
5.1 提取音频特征
这里我们不直接使用采样点作为特征,而是使用log_mel_spectrogram作为音频特征。这里我们借助openai-whisper模块,编写三个函数,放入common/utils.py:
ini
from whisper import audio
def extract_mel_frames(wav, return_tensor=True):
spec = audio.log_mel_spectrogram(wav)
n_frames = spec.shape[-1] // 4
print('mel的帧数为:', n_frames, spec.shape)
frames = []
for i in range(n_frames):
frames.append(spec[:, i * 4: (i + 1) * 4])
if return_tensor:
return torch.stack(frames)
return np.array(frames)
def get_mel(mels, idx):
mel = torch.zeros((80, 4, 4))
mel_idxes = [idx - 2, idx - 1, idx, idx + 1]
for i, mel_idx in enumerate(mel_idxes):
if mel_idx < 0 or mel_idx > len(mels) - 1:
continue
mel[:, :, i] = mels[mel_idx]
return mel.view((80, 16))
def get_mel_chunks(mels):
n_frames = mels.shape[0]
mel_chunks = torch.zeros((n_frames, 1, 80, 16))
for i in range(n_frames):
mel_chunks[i] = get_mel(mels, i).unsqueeze(0)
return mel_chunks
首先是extract_mel_frames,我们提取了log_mel_spectrogram特征,加入输入形状(16000,)的音频,会返回形状为(80,100)的特征。其中80是特征的维度,100是特征数量。16000是一秒的音频,而视频是25fps,因此一帧的图像对应特征为(80,4)。extract_mel_frames做的就是每4个特征作为一个frame,最后返回音频特征。
get_mel函数则是从提取的mel特征中,选取指定帧的特征。这里选择当前帧的前两帧、当帧、后一帧的特征,我们可以说是带上下文的音频特征。图示如下:
此时得到每一帧图像对应的音频特征形状为(80,16)。最后get_mel_chunks则是为视频所有帧获取音频特征。
5.2 Wav2LipDataset
下面我们看看Dataset类如何编写,先看看整体结构:
python
import random
from pathlib import Path
from typing import List
from collections import namedtuple
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from common.utils import extract_mel_frames, get_mel
IMG_SIZE = 96
VideoItem = namedtuple('VideoItem', ['images', 'audio'])
class Wav2LipDataset(Dataset):
def __init__(self, dataset_dir, split, sync_t=5):
super().__init__()
def get_mel_window(self, mels, start_idx):
pass
def load_all_videos(self):
pass
def get_window(self, vidx, start_idx):
pass
def __getitem__(self, item):
pass
def __len__(self):
pass
if __name__ == '__main__':
ds = Wav2LipDataset('../dst', 'train')
dl = DataLoader(ds, batch_size=16)
for i in dl:
x, m, t = i
print(x.shape, m.shape, t.shape)
break
这里有三个新建的函数,其作用分别如下:
- get_mel_window,获取连续sync_t个音频帧
- get_window,在视频vidx中获取连续sync_t个视频帧
- load_all_videos,加载数据集中的视频路径信息
下面我们一一实现。
(1)init
init的内容非常简单:
ini
def __init__(self, dataset_dir, split, sync_t=5):
super().__init__()
self.dataset_dir = dataset_dir
self.split = split
self.sync_t = sync_t
self.all_videos: List[VideoItem] = []
self.load_all_videos()
这里不过多解释。
(2)load_all_videos
scss
def load_all_videos(self):
dataset_file = Path(self.dataset_dir) / f'{self.split}.txt'
if not dataset_file.exists():
raise FileNotFoundError(f'Missing File {dataset_file}')
with open(dataset_file, 'r', encoding='utf-8') as f:
for video_frame_path in f.readlines():
frame_path_list = list(Path(video_frame_path.strip()).glob("*.jpg"))
if len(frame_path_list) <= 3 * self.sync_t:
continue
frame_path_list = sorted(
frame_path_list,
key=lambda x: int(x.stem)
)
self.all_videos.append(VideoItem(
images=frame_path_list,
audio=next(Path(video_frame_path.strip()).glob('*.wav'))
))
这里做的事情就是变量数据集目录,将单个视频的信息封装成VideoItem对象。该对象保存了图像路径和音频路径的信息。
(3)getitem
这里是数据集的关键部分代码如下:
ini
def __getitem__(self, item):
# 随机选取一个视频
vidx = random.randint(0, len(self.all_videos) - 1)
n_frames = len(self.all_videos[vidx].images)
# 选取图像
image_idx = random.randint(0, n_frames - 1 - self.sync_t)
# 选取与image_idx不一样的图像
wrong_image_idx = random.randint(0, n_frames - 1 - self.sync_t)
while wrong_image_idx == image_idx:
wrong_image_idx = random.randint(0, n_frames - 1 - self.sync_t)
# 加载sync_t个gt_image和wrong_image
gt_window = self.get_window(vidx, image_idx)
ng_window = self.get_window(vidx, wrong_image_idx)
x = np.concatenate([gt_window, ng_window], axis=1)
# 将gt_window的下半脸mask掉
x[:, :3, IMG_SIZE // 2:, :] = 0
x = torch.FloatTensor(x)
# 选取与gt_window对应的音频特征
mels = extract_mel_frames(str(self.all_videos[vidx].audio))
mel_window = self.get_mel_window(mels, image_idx)
return x, mel_window, torch.FloatTensor(gt_window)
在上面我们选取了两次图像,分别是image_idx和wrong_image_idx,两者必须不同。而后我们将image_idx对应的人脸遮罩了下半部分。
这里有三个图像,分别是target_image(image_idx对应的图像),reference_image(wrong_image_idx对应的图像)和masked_image(mask后的target_image)。对照第二节提到的原理,masked_image提供人脸真实的上半部分信息,reference_image提供人脸参考信息,而target_image则是训练目标。
因此我们会将masked_image、reference_image以及音频特征作为输入,target_image作为目标。
(4)get_window和get_mel_window
下面看读取连续视频帧和音频帧的代码:
scss
def get_window(self, vidx, start_idx):
window = []
for idx in range(start_idx, start_idx + self.sync_t):
image_path = str(self.all_videos[vidx].images[idx])
image = cv2.imread(image_path)
image = np.transpose(cv2.resize(image, (IMG_SIZE, IMG_SIZE)) / 255., (2, 0, 1))
window.append(image)
return np.array(window)
def get_mel_window(self, mels, start_idx):
mel_window = torch.zeros((self.sync_t, 1, 80, 16))
for i in range(self.sync_t):
mel = get_mel(mels, start_idx + i)
mel_window[i] = mel.unsqueeze(0)
return mel_window
首先是get_window,这里就是简单从vidx视频中,读取start_idx - start_idx+sync_t这几帧图像。读取后做了resize和除255的操作。
get_mel_window作用与get_window相似,这里我们将音频特征形状转换成(1,80,16)。
(5)len
最后获取长度的函数,代码如下:
python
def __len__(self):
return sum([len(v.images) - self.sync_t for v in self.all_videos])
这里就是简单计算总帧数。
从整体上看,我们数据集返回x的形状为(batch_size,sync_t,6,96,96),因为是两个图像合并,所以通道数是6;mel_window的形状为(batch_size,sync_t,1,80,16);gt_window的形状为(batch_size,sync_t,3,96,96)。
下面我们来测试一下:
ini
if __name__ == '__main__':
ds = Wav2LipDataset('../dst', 'train')
dl = DataLoader(ds, batch_size=16)
for i in dl:
x, m, t = i
print(x.shape, m.shape, t.shape)
break
输出结果如下:
css
torch.Size([16, 5, 6, 96, 96]) torch.Size([16, 5, 1, 80, 16]) torch.Size([16, 5, 3, 96, 96])
到此我们已经实现了一部分代码,在下一篇我们会实现网络结构及训练代码。