前言
本文同步发布于MindSpore社区,欢迎加入MindSpore社区,一同探索更多可能!

WaveNet 是一种典型的自回归生成模型,能够直接对原始波形建模。相比"先提特征再生成"的路线,WaveNet 直接在采样点级别学习概率分布,因此在音频生成/语音合成中能获得更自然的效果。本实践基于 MindSpore复现一个简化版 WaveNet,用少量音频数据训练后,生成一段新的音乐/音频片段。
第一步:打开实验平台
首先,我们需要登录实验平台,找到对应的实训项目。

点击"打开 Jupyter 在线编程"后,选择合适的运行环境。本案例推荐使用 Ascend-snt9b 环境,镜像选择包含 mindspore 的版本(如 python3.9-ms2.7.1-cann8.3.RC1)。

等待环境启动完成后,我们就可以开始编写/运行代码了。

第二步:环境准备(MindSpore + 音频依赖)
1)检查 MindSpore 版本
bash
pip show mindspore

2)安装音频相关依赖(用于读写 wav、μ-law 量化)
bash
pip install librosa
pip install soundfile
pip install nnmnkwii
代码中会用到:
librosa:加载音频soundfile:保存音频nnmnkwii.preprocessing:μ-law 量化与反量化

第三步:数据预处理
原始 16-bit 音频每个采样点有 65536 个可能值,如果把预测当成分类,Softmax 输出会非常大、训练难度也高。实践中常用经典的 μ-law companding,把波形压缩并量化到 256 类(μ=255),将任务变成 256 类分类问题,显著降低参数规模与训练难度。

第四步:构建训练数据集
如图所示,WaveNet在推理时 ,我们根据前n个时刻的样本预测当前时刻的样本值(即网络的输入序列长度n为网络的感受野),然后我们将当前时刻的预测值也作为n个输入中的一个输入网络中,预测下一时刻的样本点。
而在训练时 ,我们只训练网络根据n个输入预测第 n + 1 n+1 n+1个值。为了提高效率,我们通常设定网络一次性预测长度为 o o o的输出,根据一个预测样本对应网络感受野大小的样本的输入,网络的输入长度应为 n + o − 1 n+o-1 n+o−1。
我们首先调用generate_dataset方法将原始音频文件进行μ率压缩及量化得到用于网络训练和推理的数据集。
python
def generate_dataset(file_location, out_file, sampling_rate=16000, mono=True):
audio_files = Path(file_location).glob("*")
processed_files = []
for idx, file_wav in enumerate(audio_files):
audio, _ = librosa.load(str(file_wav), sr=sampling_rate, mono=mono)
if idx == 0:
# 分割数据集中的第一个音频,将其十分之一长度作为预测音乐的开头样本(独立同分布)
pred_head = audio[:len(audio) // 10]
audio = audio[len(audio) // 10:]
sf.write("pred_head.wav", pred_head, 16000, subtype='PCM_24')
sf.write("train_audio.wav", audio, 16000, subtype='PCM_24')
np.savez("pred_head.npz", pre.mulaw_quantize(pred_head, 256))
wav_quantized = pre.mulaw_quantize(audio, 256)
print("generated from audio file: " + str(file_wav.name))
processed_files.append(wav_quantized)
np.savez(out_file, *processed_files)
class WaveDataset:
def __init__(self,
dataset_file,
receptive_feild,
output_length,
classes=256,
sampling_rate=16000,
mono=True,
wave_location=None):
self.dataset_file = dataset_file
self.receptive_feild = receptive_feild
self.output_length = output_length
self.item_length = receptive_feild + output_length
self.classes = classes
if not Path(dataset_file).exists():
print(f"{dataset_file} not found, generating dataset file ...")
generate_dataset(wave_location, dataset_file, sampling_rate, mono)
print(f"datset file {dataset_file} generated.")
self.data = []
dataset = np.load(self.dataset_file, mmap_mode="r")
print(f"dataset file {self.dataset_file} loaded")
for i in range(len(dataset)):
self.data.append(dataset["arr_" + str(i)])
self.index_seg = []
self.collect_segment_index()
def collect_segment_index(self):
for i, audio in enumerate(self.data):
len_audio = len(audio)
N_seg = (len_audio - self.item_length) // self.output_length
for j in range(N_seg):
pos_start = j * self.output_length
pos_end = pos_start + self.item_length
if pos_end < len_audio:
self.index_seg.append((i, pos_start, pos_end))
self.index_seg.append((i, len_audio - self.item_length, len_audio))
def __len__(self):
return len(self.index_seg)
def get_onehot(self, slice):
onehot_ = np.eye(self.classes)[slice]
return onehot_.transpose()
def __getitem__(self, index):
num_audio, pos_start, pos_end = self.index_seg[index]
data_slice = self.data[num_audio][pos_start: pos_end]
onehot = self.get_onehot(data_slice[: -1])
target = data_slice[-self.item_length + 1:]
return onehot.astype(np.float32), target.astype(np.int32)
def create_dataset(dataset_file, receptive_feild, output_length, batch_size, classes=256, sampling_rate=16000, mono=True, wave_location=None):
dataset = ds.GeneratorDataset(WaveDataset(dataset_file=dataset_file,
receptive_feild=receptive_feild,
output_length=output_length,
classes=classes,
sampling_rate=sampling_rate,
mono=mono,
wave_location=wave_location),
["inputs", "targets"], num_parallel_workers=4, shuffle=True)
return dataset.batch(batch_size=batch_size)
第五步:搭建 WaveNet 网络
1)残差单元
在残差单元的实际实现中,除了第一层的扩张卷积,数据流在门控激活单元每个分支的子激活函数之前也同样经过了扩张卷积的处理,一个残差单元中的所有扩张卷积具有相同的扩张系数,具体实现如下:
python
from mindspore import nn,mint
import math
class ResidualConv1dGLU(nn.Cell):
"""Residual dilated conv1d with gated activation units"""
def __init__(self, residual_channels=None, gate_channels=None, kernel_size=None, skip_out_channels=None, bias=True,
dropout=1 - 0.95, dilation=1, cin_channels=-1, gin_channels=-1, padding=None, causal=True):
super(ResidualConv1dGLU, self).__init__()
self.dropout = dropout
self.dropout_op = mint.nn.Dropout(p=self.dropout)
padding = (kernel_size - 1) * dilation
self.conv = mint.nn.Conv1d(residual_channels, gate_channels, kernel_size,
padding=padding, dilation=dilation, bias=bias)
gate_out_channels = gate_channels // 2
self.conv1x1_out = mint.nn.Conv1d(gate_out_channels, residual_channels, kernel_size=1,
padding=0, dilation=1, bias=True)
self.conv1x1_skip = mint.nn.Conv1d(gate_out_channels, skip_out_channels, kernel_size=1,
padding=0, dilation=1, bias=True)
self.factor = math.sqrt(0.5)
def construct(self, x):
residual = x
x = self.dropout_op(x)
x = self.conv(x)
# remove future time steps
x = x[:, :, :residual.shape[-1]]
a, b = mint.chunk(x, chunks=2, dim=1)
x = mint.mul(mint.tanh(a), mint.sigmoid(b))
s = self.conv1x1_skip(x)
x = self.conv1x1_out(x)
x = mint.mul(mint.add(x, residual), self.factor)
return x, s
2)WaveNet
在WaveNet的实际实现中,我们将整个网络划分为多个块,每个块由若干层残差单元组成。在每个块中,膨胀系数从1开始以2的倍数递增。如以4层残差单元为一个块,由2个块组成的WaveNet的膨胀系数应为[1, 2, 4, 8, 1, 2, 4, 8]。WaveNet的具体实现如下:
python
from mindspore.ops import operations as P
import math
class WaveNet(nn.Cell):
def __init__(self, out_channels=256, layers=20, blocks=2,
residual_channels=512,
gate_channels=512,
skip_out_channels=512,
kernel_size=3, dropout=1 - 0.95):
super().__init__()
self.out_channels = out_channels
print(f"network info: \n\tlayers: {layers}\n\tblocks:{blocks}")
assert layers % blocks == 0
self.layers_per_block = layers // blocks # 24 / 4 = 6
self.first_conv = mint.nn.Conv1d(out_channels, residual_channels, kernel_size=1)
conv_layers = []
for layer in range(layers):
dilation = 2 ** (layer % self.layers_per_block) # 1, 2, 4, 8, 16, 32
conv = ResidualConv1dGLU(
residual_channels, gate_channels,
kernel_size=kernel_size,
skip_out_channels=skip_out_channels,
bias=True,
dropout=dropout,
dilation=dilation)
conv_layers.append(conv)
self.conv_layers = nn.CellList(conv_layers)
self.last_conv_layers = nn.CellList([
mint.nn.ReLU(),
mint.nn.Conv1d(skip_out_channels, skip_out_channels, kernel_size=1),
mint.nn.ReLU(),
mint.nn.Conv1d(skip_out_channels, out_channels, kernel_size=1)])
self.factor = math.sqrt(1.0 / len(self.conv_layers)) # sqrt( 1 / 24)
self.receptive_field = 1
for _ in range(blocks):
additional_scope = 2
for _ in range(self.layers_per_block):
self.receptive_field += additional_scope
additional_scope *= 2
print("receptive filed: ", self.receptive_field)
def construct(self, x, softmax=False):
B, _, T = x.shape
x = self.first_conv(x)
skips = None
for f in self.conv_layers:
x, hidden = f(x) # x=[B, 128, 10240], hidden=[B, 128, 10240]
if skips is None:
skips = hidden
else:
skips = mint.add(skips, hidden)
skips = mint.mul(skips, self.factor)
x = skips # x=[B, 128, 10240]
for f in self.last_conv_layers:
x = f(x) # x=[B, 2, 10240]
if softmax:
x = mint.softmax(x, dim=1)
return x
第六步:模型训练
在模型训练中,一个完整的训练过程(step)需要实现以下三步:
- 正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)
- 反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients
- 参数优化:将梯度更新到参数上
MindSpore使用函数式自动微分机制,因此针对上述步骤需要实现:
- 正向计算函数定义
- 通过函数变换获得梯度计算函数
- 训练函数定义,执行正向计算、反向传播和参数优化
python
from mindspore import ops
from mindspore.amp import all_finite
def train_loop(model, dataset, loss_fn, optimizer, logger):
# Define forward function
def forward_fn(data, label):
logits = model(data)
loss = loss_fn(logits, label)
return loss, logits
# Get gradient function
grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
# Define function of one-step training
def train_step(data, label):
(loss, logits), grads = grad_fn(data, label)
if all_finite(grads):
optimizer(grads)
return loss
size = dataset.get_dataset_size()
model.set_train()
dataloader = dataset.create_tuple_iterator()
loss_sum, iter_sum = 0., 0
for batch, (data, label) in enumerate(dataloader):
loss = train_step(data, label)
logger.append(loss.asnumpy())
if batch % 20 == 0:
print("loss: {:>.3f}\t\t {:>4d}/{:>4d}".format(loss.asnumpy(), batch, size))
实例化数据集、网络模型、优化器并开始训练。
python
# 训练模型
import numpy as np
dataset_file = "./dataset.npz"
wave_location = "./dataset/"
# 超参数设定
layers = 24
blocks = 4
residual_channels = 512
skip_channels = 512
gate_channels = 512
output_length = 64
classes = 256
epochs = 20
batch_size = 32
learning_rate = 0.001
model = WaveNet(out_channels=256, layers=layers, blocks=blocks, residual_channels=residual_channels, skip_out_channels=skip_channels, gate_channels=gate_channels)
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.Adam(model.trainable_params(), learning_rate=learning_rate)
print("dataset loading ...")
dataset = create_dataset(dataset_file=dataset_file, wave_location=wave_location, batch_size=batch_size, classes=classes, output_length=output_length, receptive_feild=model.receptive_field, sampling_rate=16000, mono=True)
print("dataset loaded.\n\tdataset size: {:d}\n\tbatch size: {:d}".format(dataset.get_dataset_size(), dataset.get_batch_size()))
# 开始训练
for t in range(epochs):
print(f"Epoch {t+1}\n---------------------------------------------------------------------------------------")
loss_recoder = []
train_loop(model, dataset, loss_fn, optimizer, loss_recoder)
np.save("loss.npy", loss_recoder)
ms.save_checkpoint(model, "./wavenet_{:d}.ckpt".format(t))
第七步:结果展示
运行下面代码,描绘损失与训练迭代关系图:从图中可以看到,标志着生成数据分布与真实数据分布的交叉熵损失随着迭代次数逐渐下降,网络能被正常优化
python
import matplotlib.pyplot as plt
loss = np.load("loss.npy")
plt.figure(figsize=(10, 5))
plt.title("Cross Entropy Loss During Training")
plt.plot(loss, label="WaveNet", color='blue')
plt.xlabel("")
plt.ylabel("loss")
plt.legend()
plt.show()
第八步:音乐生成
在这一部分中,我们加载训练好的WaveNet模型,并截取一小段与训练数据同分布的音频文件作为网络的预测的历史信息,并根据这些历史信息预测新的音频样本。
python
import random
from tqdm import tqdm
def gen_music(model, gen_time_length, head_location, head_length=1024):
'''
gen_time_length: 要生成的时间长度,单位: 分钟
head_location :预测头文件路径
head_length :用于预测音频的头文件样本点个数,通常设置为网络的感受野
'''
head_file = np.load(head_location, mmap_mode="r")["arr_0"]
random_start = random.randint(0, len(head_file) - head_length)
head = head_file[random_start: random_start + head_length]
total_length = int(gen_time_length * 16000 * 60)
for _ in tqdm(range(total_length), ncols=60):
# pred = pred_one(model, head[-1024:]).asnumpy()
current_input = head[-head_length:]
pred = pred_one(model, current_input).asnumpy()
head = np.append(head, pred)
return head
def pred_one(model, x):
# 构造 onehot: [Batch, Channels, Time]
onehot = np.eye(256)[x].transpose()
# 转换类型
input_tensor = ms.Tensor(onehot).astype(ms.float32)
input_tensor = mint.unsqueeze(input_tensor, 0)
# 推理
pred = model(input_tensor)
pred_sample = pred[0, :, -1]
return mint.argmax(pred_sample)
model = WaveNet(out_channels=256, layers=24, blocks=4)
ms.load_checkpoint("wavenet_1.ckpt", model)
model.set_train(False)
output = gen_music(model, gen_time_length=1/6, head_location="./pred_head.npz") # 生成一个10s(1/6分钟)的片段
output = pre.inv_mulaw_quantize(output, 256)
sf.write("gen.wav", output, 16000, subtype='PCM_24')
print("generated")
小结
本实践用 MindSpore 复现了一个端到端的 WaveNet 音频生成流程:从 μ-law 量化降低分类难度,到扩张卷积 + 门控残差结构建模长时依赖,再到逐点自回归生成与反量化写回 wav。完成后你会得到:
- 可训练的 WaveNet 模型(ckpt)
- loss 收敛曲线(用于验证训练有效)
- 生成的音频文件
gen.wav