fastenhancer DPRNN torch 实现

python 复制代码
import math
import typing as tp
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.utils.parametrizations import weight_norm as weight_norm_fn
from torch.nn.utils.parametrize import is_parametrized, remove_parametrizations
from torch import Tensor
from torchaudio.functional import melscale_fbanks

from functional import ONNXSTFT, CompressedSTFT


class StridedConv1d(nn.Conv1d):
    def __init__(
        self, in_channels: int, out_channels: int, kernel_size: int,
        stride: int = 1, padding: int = 0, dilation: int = 1,
        groups: int = 1, bias: bool = True, padding_mode: str = 'zeros',
        device=None, dtype=None
    ):
        assert kernel_size % stride == 0, (
            f'kernel_size k and stride s must satisfy k=(2n+1)s, but '
            f'got k={kernel_size}, s={stride}. Use naive Conv1d instead.'
        )
        assert groups == 1, (
            f'groups must be 1, but got {groups}. '
            f'Use naive Conv1d instead.'
        )
        assert dilation == 1, (
            f'dilation must be 1, but got {dilation}. '
            f'Use naive Conv1d instead.'
        )
        assert padding_mode == 'zeros', (
            f'Only `zeros` padding mode is supported for '
            f'StridedConv1d, but got {padding_mode}. '
            f'Use naive Conv1d instead.'
        )
        self.original_stride = stride
        self.original_padding = padding
        super().__init__(
            in_channels*stride, out_channels, kernel_size//stride,
            stride=1, padding=0, dilation=dilation,
            groups=groups, bias=bias, padding_mode=padding_mode,
            device=device, dtype=dtype
        )

    def forward(self, x: Tensor) -> Tensor:
        """x: [B, Ci, Ti] -> conv1d -> [B, Co, Ti // s]
        <=> x: [B, Ci, Ti] -> reshape to [B, Ci*s, Ti//s] -> conv1d -> [B, Co, Ti//S]"""
        stride = self.original_stride
        padding = self.original_padding
        x = F.pad(x, (padding, padding))
        B, C, T = x.shape
        x = x.view(B, C, T//stride, stride).permute(0, 3, 1, 2).reshape(B, C*stride, T//stride)
        return super().forward(x)


class ScaledConvTranspose1d(nn.ConvTranspose1d):
    def __init__(
        self,
        *args,
        normalize: bool = False,
        exp_scale: bool = False,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.normalize = normalize
        self.exp_scale = exp_scale
        self.scale = nn.Parameter(torch.ones(1))
        if exp_scale:
            self.scale.data.log_()
        self.weight_norm = True

    def remove_weight_reparameterizations(self):
        scale = self.scale.exp() if self.exp_scale else self.scale
        if self.normalize:
            weight = F.normalize(self.weight, dim=(0, 1, 2)).mul_(scale)
        else:
            weight = self.weight * scale
        self.weight.data.copy_(weight)
        self.weight_norm = False

    def forward(self, x: Tensor) -> Tensor:
        if self.weight_norm:
            scale = self.scale.exp() if self.exp_scale else self.scale
            if self.normalize:
                weight = F.normalize(self.weight, dim=(0, 1, 2)).mul_(scale)
            else:
                weight = self.weight * scale
        else:
            weight = self.weight
        return F.conv_transpose1d(
            x, weight, self.bias, stride=self.stride,
            padding=self.padding, output_padding=self.output_padding,
            groups=self.groups, dilation=self.dilation,
        )


def calculate_positional_embedding(channels: int, freq: int) -> Tensor:
    # f0: [1/F, 2/F, ..., 1] * pi
    # c: [1, ..., F-1] -> log-spaced, numel = C//2
    f = torch.arange(1, freq+1, dtype=torch.float32) * (math.pi / freq)
    c = torch.linspace(
        start=math.log(1),
        end=math.log(freq-1),
        steps=channels//2,
        dtype=torch.float32
    ).exp()
    grid = f.view(-1, 1) * c.view(1, -1)            # [F, C//2]
    pe = torch.cat((grid.sin(), grid.cos()), dim=1) # [F, C]
    return pe


class ChannelsLastBatchNorm(nn.BatchNorm1d):
    def forward(self, x: Tensor) -> Tensor:
        """input/output: [T, B, F, C]"""
        T, B, F, C = x.shape
        x = x.view(T*B*F, C, 1)
        return super().forward(x).view(T, B, F, C)


class ChannelsLastSyncBatchNorm(nn.SyncBatchNorm):
    def forward(self, x: Tensor) -> Tensor:
        """input/output: [T, B, F, C]"""
        T, B, F, C = x.shape
        x = x.view(T*B*F, C, 1)
        return super().forward(x).view(T, B, F, C)


class DPRNN(nn.Module):
    def __init__(
        self,
        channels: int,
        freq: int,
        channels_frnn: int,
        eps: float,
        weight_norm: bool,
        pre_norm: bool,
    ) -> None:
        super().__init__()
        self.channels = channels
        self.freq = freq
        self.channels_frnn = channels_frnn
        self.pre_norm = pre_norm

        if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
            BatchNorm = ChannelsLastSyncBatchNorm
        else:
            BatchNorm = ChannelsLastBatchNorm

        self.trnn_pre_norm = BatchNorm(channels, eps, affine=False) if pre_norm else nn.Identity()
        self.trnn = nn.GRU(channels, channels, batch_first=False)
        self.trnn_fc = nn.Linear(channels, channels, bias=False)
        self.trnn_post_norm = BatchNorm(channels, eps)

        self.frnn_pre_norm = BatchNorm(channels, eps, affine=False) if pre_norm else nn.Identity()
        self.frnn = nn.GRU(channels, channels_frnn, bidirectional=True, batch_first=True)
        self.frnn_fc = nn.Linear(channels_frnn*2, channels, bias=False)
        self.frnn_post_norm = BatchNorm(channels, eps)

        self.weight_norm = weight_norm
        if weight_norm:
            self.trnn = weight_norm_fn(self.trnn, name="weight_ih_l0")
            self.trnn = weight_norm_fn(self.trnn, name="weight_hh_l0")
            self.frnn = weight_norm_fn(self.frnn, name="weight_ih_l0")
            self.frnn = weight_norm_fn(self.frnn, name="weight_hh_l0")
            self.frnn = weight_norm_fn(self.frnn, name="weight_ih_l0_reverse")
            self.frnn = weight_norm_fn(self.frnn, name="weight_hh_l0_reverse")

    def remove_weight_reparameterizations(self):
        if self.weight_norm:
            remove_parametrizations(self.trnn, "weight_ih_l0")
            remove_parametrizations(self.trnn, "weight_hh_l0")
            remove_parametrizations(self.frnn, "weight_ih_l0")
            remove_parametrizations(self.frnn, "weight_hh_l0")
            remove_parametrizations(self.frnn, "weight_ih_l0_reverse")
            remove_parametrizations(self.frnn, "weight_hh_l0_reverse")
            self.flatten_parameters()
            self.weight_norm = False

        # Remove Post Norm
        for fc, norm in (
            (self.trnn_fc, self.trnn_post_norm),
            (self.frnn_fc, self.frnn_post_norm),
        ):
            std = norm.running_var.add(norm.eps).sqrt()
            fc.weight.data *= norm.weight.view(-1, 1) / std.view(-1, 1)
            fc.bias = nn.Parameter(norm.bias - norm.running_mean * norm.weight / std)
        self.frnn_post_norm = nn.Identity()
        self.trnn_post_norm = nn.Identity()

        if not self.pre_norm:
            return
        # Remove Pre Norm
        # w @ (x - mean) / std + bias
        # = (w \cdot gamma) @ x + (bias + w @ beta)
        # where gamma = 1/std, beta = -mean/std
        # 1. TGRU
        norm = self.trnn_pre_norm
        std = norm.running_var.add(norm.eps).sqrt()
        beta = -norm.running_mean / std
        w_matmul_beta = (self.trnn.weight_ih_l0.data @ beta.view(-1, 1)).squeeze(1)

        self.trnn.weight_ih_l0.data /= std
        self.trnn.bias_ih_l0.data.add_(w_matmul_beta)
        self.trnn_pre_norm = nn.Identity()

        # 2. FRNN
        norm = self.frnn_pre_norm
        std = norm.running_var.add(norm.eps).sqrt()
        beta = -norm.running_mean / std
        w_matmul_beta = (self.frnn.weight_ih_l0.data @ beta.view(-1, 1)).squeeze(1)

        self.frnn.weight_ih_l0.data /= std
        self.frnn.bias_ih_l0.data.add_(w_matmul_beta)
        self.frnn_pre_norm = nn.Identity()

    def flatten_parameters(self):
        self.trnn.flatten_parameters()
        self.frnn.flatten_parameters()

    def initialize_cache(self, x: Tensor) -> Tensor:
        return x.new_zeros(1, self.freq, self.trnn.hidden_size)

    def forward(self, x: Tensor, h: tp.Optional[Tensor]) -> tp.Tuple[Tensor, Tensor]:
        TIME, BATCH, FREQ, CH = x.shape     # [T, B, F, C]
        x_in = x
        x = self.trnn_pre_norm(x)           # [T, B, F, C]
        x = x.view(TIME, FREQ*BATCH, CH)    # [T, B*F, C]
        x, h = self.trnn(x, h)              # [T, B*F, C]
        x = x.view(TIME, BATCH, FREQ, CH)   # [T, B, F, C]
        x = self.trnn_fc(x)                 # [T, B, F, C]
        x = self.trnn_post_norm(x)          # [T, B, F, C]
        x = x.add_(x_in)                    # [T, B, F, C]

        x_in = x
        x = self.frnn_pre_norm(x)           # [T, B, F, C]
        x = x.view(TIME*BATCH, FREQ, CH)    # [T*B, F, C]
        x, _ = self.frnn(x)                 # [T*B, F, C_FRNN*2]
        x = self.frnn_fc(x)                 # [T*B, F, C]
        x = x.view(TIME, BATCH, FREQ, CH)   # [T, B, F, C]
        x = self.frnn_post_norm(x)          # [T, B, F, C]
        x = x.add_(x_in)                    # [T, B, F, C]
        return x, h


def pre_post_lin(
    freq: int,
    n_filter: int,
    init: tp.Optional[str],
    bias: bool,
    sr: int = 16_000
) -> tp.Tuple[nn.Module, nn.Module]:
    assert init in [None, "mel", "mel_fixed", "linear", "linear_fixed"]
    pre = nn.Linear(freq, n_filter, bias=bias)
    post = nn.Linear(n_filter, freq, bias=bias)

    if init is None:
        return pre, post

    if init in ["mel", "mel_fixed"]:
        def hz_to_mel(freq: float) -> float:
            return 2595.0 * math.log10(1 + freq / 700)

        def mel_to_hz(mel: float) -> float:
            return 700.0 * (math.e ** (mel / 1127) - 1)

        def clip(x: int, a_min: int, a_max: int) -> int:
            return max(min(x, a_max), a_min)

        f_n = sr // 2
        mel_max = hz_to_mel(f_n)
        mel_fb = melscale_fbanks(
            n_freqs=freq,
            f_min=0.0,
            f_max=f_n,
            n_mels=n_filter,
            sample_rate=sr,
            norm='slaney',
            mel_scale='htk'
        ).float().transpose(0, 1) * f_n / freq     # [n_filter, freq]
        zeros = mel_fb.new_zeros(1)
        for idx, mfb in enumerate(mel_fb, start=0):
            if mfb.sum().isclose(zeros):
                idx_f = round(mel_to_hz(idx / n_filter * mel_max) * freq / f_n)
                idx_f = clip(idx_f, 0, freq - 1)
                mfb[idx_f] = 1.0

        mel_fb_inv = torch.linalg.pinv(mel_fb)
        for idx, mfb in enumerate(mel_fb_inv, start=0):
            if mfb.sum().isclose(zeros):
                idx_mel = round(hz_to_mel(idx / freq * f_n) * n_filter / mel_max)
                idx_mel = clip(idx_mel, 0, n_filter - 1)
                mfb[idx_mel] = 1.0
        pre_weight = mel_fb
        post_weight = mel_fb_inv

    elif init in ["linear", "linear_fixed"]:
        f_filter = torch.linspace(0, sr // 2, n_filter)
        delta_f = sr // 2 / n_filter
        f_freqs = torch.linspace(0, sr // 2, freq)
        down = f_filter[1:, None] - f_freqs[None, :]    # [n_filter - 1, freq]
        down = down / delta_f
        down = F.pad(down, (0, 0, 0, 1), value=1.0)     # [n_filter, freq]
        up = f_freqs[None, :] - f_filter[:-1, None]     # [n_filter - 1, freq]
        up = up / delta_f
        up = F.pad(up, (0, 0, 1, 0), value=1.0)         # [n_filter, freq]
        pre_weight = torch.max(up.new_zeros(1), torch.min(down, up))
        post_weight = pre_weight.transpose(0, 1)
        pre_weight = pre_weight / pre_weight.sum(dim=1, keepdim=True)
        post_weight = post_weight / post_weight.sum(dim=1, keepdim=True)

    if init.endswith("_fixed"):
        delattr(pre, "weight")
        delattr(post, "weight")
        pre.register_buffer("weight", pre_weight.contiguous().clone())
        post.register_buffer("weight", post_weight.contiguous().clone())
    else:
        pre.weight.data.copy_(pre_weight)
        post.weight.data.copy_(post_weight)

    return pre, post


@dataclass
class DPRNNConfig:
    num_blocks: int = 3
    channels: int = 32
    channels_frnn: int = 16
    freq: int = 32
    eps: float = 1e-5
    pre_norm: bool = False


class ONNXModel(nn.Module):
    def __init__(
        self,
        channels: int = 64,
        kernel_size: tp.List[int] = [8, 3, 3],
        stride: int = 4,
        dprnn_kwargs: tp.Dict[str, tp.Any] = dict(),
        activation: str = "ReLU",
        activation_kwargs: tp.Dict[str, tp.Any] = dict(inplace=True),
        n_fft: int = 512,
        hop_size: int = 160,                # not used
        win_size: int = 400,                # not used
        window: tp.Optional[str] = "povey", # not used
        stft_normalized: bool = False,      # not used
        mask: tp.Optional[str] = None,
        input_compression: float = 0.25,
        weight_norm: bool = False,
        final_scale: tp.Union[bool, str] = "exp",
        normalize_final_conv: bool = False,
        pre_post_init: tp.Optional[str] = None,
    ):
        assert final_scale in [True, False, "exp"]
        super().__init__()
        self.input_compression = input_compression
        self.stft = self.get_stft(n_fft, hop_size, win_size, window, stft_normalized)
        dprnn_config = DPRNNConfig(**dprnn_kwargs)
        self.dprnn_ch = dprnn_config.channels
        self.dprnn_freq = dprnn_config.freq
        if mask is None:
            self.mask = nn.Identity()
        elif mask == "sigmoid":
            self.mask = nn.Sigmoid()
        elif mask == "tanh":
            self.mask = nn.Tanh()
        else:
            raise RuntimeError(f"model_kwargs.mask={mask} is not supported.")
        self.weight_norm = weight_norm

        if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
            BatchNorm = nn.SyncBatchNorm
        else:
            BatchNorm = nn.BatchNorm1d

        def norm(module):
            if self.weight_norm:
                return weight_norm_fn(module)
            return module

        Act = getattr(nn, activation)

        # Encoder PreNet
        assert kernel_size[0] % stride == 0
        assert (kernel_size[0] - stride) % 2 == 0
        self.enc_pre = nn.Sequential(
            StridedConv1d(  # in_channels = 2 = [real, imag]
                2, channels, kernel_size[0], stride=stride,
                padding=(kernel_size[0] - stride) // 2, bias=False
            ),
            BatchNorm(channels),
            Act(**activation_kwargs),
        )

        # Encoder
        self.encoder = nn.ModuleList()
        for idx in range(1, len(kernel_size)):
            module = nn.Sequential(
                nn.Conv1d(
                    channels, channels, kernel_size[idx], 
                    padding=(kernel_size[idx] - 1) // 2, bias=False
                ),
                BatchNorm(channels),
                Act(**activation_kwargs),
            )
            self.encoder.append(module)

        # DPRNN PreNet
        freq = n_fft // 2 // stride
        pre, post = pre_post_lin(freq, self.dprnn_freq, pre_post_init, bias=False)
        self.dprnn_pre = nn.Sequential(
            pre,
            nn.Conv1d(channels, self.dprnn_ch, 1, bias=False),
            BatchNorm(self.dprnn_ch),
        )

        # DPRNN Blocks
        dprnn_list = []
        for _ in range(dprnn_config.num_blocks):
            block = DPRNN(
                self.dprnn_ch, self.dprnn_freq,
                channels_frnn=dprnn_config.channels_frnn,
                eps=dprnn_config.eps, weight_norm=weight_norm,
                pre_norm=dprnn_config.pre_norm,
            )
            dprnn_list.append(block)
        self.dprnn_block = nn.ModuleList(dprnn_list)

        # DPRNN PostNet
        self.dprnn_post = nn.Sequential(
            post,
            nn.Conv1d(self.dprnn_ch, channels, 1, bias=False),
            BatchNorm(channels),
        )

        # Decoder
        self.decoder = nn.ModuleList()
        for idx in range(len(kernel_size)-1, 0, -1):
            module = nn.Sequential(
                nn.Conv1d(channels*2, channels, 1, bias=False),
                BatchNorm(channels),
                Act(**activation_kwargs),
                nn.Conv1d(
                    channels, channels, kernel_size[idx],
                    padding=(kernel_size[idx] - 1) // 2, bias=False
                ),
                BatchNorm(channels),
                Act(**activation_kwargs),
            )
            self.decoder.append(module)

        # Decoder PostNet
        padding = (kernel_size[0] - stride) // 2
        if final_scale != False:
            # out_channels = 2 = [real, imag] of the mask
            upsample = ScaledConvTranspose1d(
                channels, 2, kernel_size[0], stride=stride,
                padding=padding, bias=True, exp_scale=(final_scale == "exp"),
                normalize=normalize_final_conv
            )
        else:
            upsample = nn.ConvTranspose1d(
                channels, 2, kernel_size[0], stride=stride,
                padding=padding, bias=True
            )
        self.dec_post = nn.Sequential(
            nn.Conv1d(channels*2, channels, 1, bias=False),
            BatchNorm(channels),
            Act(**activation_kwargs),
            upsample
        )

    def get_stft(
        self, n_fft: int, hop_size: int, win_size: int,
        window: str, normalized: bool
    ) -> nn.Module:
        return ONNXSTFT(
            n_fft=n_fft, hop_size=hop_size, win_size=win_size,
            win_type=window, normalized=normalized
        )

    @torch.no_grad()
    def remove_weight_reparameterizations(self):
        """ 1. Remove weight_norm """
        if self.weight_norm:
            with torch.enable_grad():
                # RNNFormer
                for block in self.dprnn_block:
                    block.remove_weight_reparameterizations()

                # Decoder
                self.dec_post[3].remove_weight_reparameterizations()
            self.weight_norm = False

        """ 2. Merge BatchNorm into Conv
        y = (conv(x) - mean) / std * gamma + beta \
          = conv(x) * (gamma / std) + (beta - mean * gamma / std)
        <=> y = conv'(x) where
          W'[c, :, :] = W[c, :, :] * (gamma / std)
          b' = (beta - mean * gamma / std)
        """
        def merge_conv_bn(conv: nn.Module, norm: nn.Module, error_message: str = "") -> nn.Module:
            assert conv.bias is None, error_message
            std = norm.running_var.add(norm.eps).sqrt()
            conv.weight.data *= norm.weight.view(-1, 1, 1) / std.view(-1, 1, 1)
            conv.bias = nn.Parameter(norm.bias - norm.running_mean * norm.weight / std)
            return conv

        # Encoder PreNet
        conv = merge_conv_bn(self.enc_pre[0], self.enc_pre[1], "enc_pre")
        self.enc_pre = nn.Sequential(conv, self.enc_pre[2])

        # Encoder
        new_encoder = nn.ModuleList()
        for idx, module in enumerate(self.encoder):
            conv = merge_conv_bn(module[0], module[1], f"enc.{idx}")
            new_module = nn.Sequential(
                conv,       # Conv-BN Merged
                module[2],  # Activation
            )
            new_encoder.append(new_module)
        self.encoder = new_encoder

        # DPRNN PreNet
        conv = merge_conv_bn(self.dprnn_pre[1], self.dprnn_pre[2], "dprnn_pre")
        self.dprnn_pre = nn.Sequential(
            self.dprnn_pre[0],  # Linear
            conv,               # Conv-BN Merged
        )

        # RNNFormer PostNet
        conv = merge_conv_bn(self.dprnn_post[1], self.dprnn_post[2], "dprnn_post")
        self.dprnn_post = nn.Sequential(
            self.dprnn_post[0], # Linear
            conv,               # Conv-BN Merged
        )

        # Decoder
        new_decoder = nn.ModuleList()
        for idx, module in enumerate(self.decoder):
            conv1 = merge_conv_bn(module[0], module[1], f"dec.{idx}.0")
            conv2 = merge_conv_bn(module[3], module[4], f"dec.{idx}.1")
            new_module = nn.Sequential(
                conv1,      # Conv-BN Merged
                module[2],  # Activation
                conv2,      # Conv-BN Merged
                module[5],  # Activation
            )
            new_decoder.append(new_module)
        self.decoder = new_decoder

        # Decoder PostNet
        conv = merge_conv_bn(self.dec_post[0], self.dec_post[1], "dec_post")
        self.dec_post = nn.Sequential(
            conv,               # Conv-BN Merged
            self.dec_post[2],   # Activation
            self.dec_post[3]    # Transposed Convolution
        )

    def flatten_parameters(self):
        for dprnn in self.dprnn_block:
            dprnn.flatten_parameters()

    def initialize_cache(self, x: Tensor) -> tp.List[Tensor]:
        cache_list = []
        for block in self.dprnn_block:
            cache_list.append(block.initialize_cache(x))
        return cache_list

    def model_forward(self, spec_noisy: Tensor, *args) -> tp.Tuple[Tensor, tp.List[Tensor]]:
        # spec_noisy: [B, F, T, 2]
        cache_in_list = [*args]
        cache_out_list = []
        if len(cache_in_list) == 0:
            cache_in_list = [None for _ in range(len(self.dprnn_block))]
        x = spec_noisy

        B, FREQ, T, _ = x.shape
        x = x.permute(0, 2, 3, 1)       # [B, T, 2, F]
        x = x.reshape(B*T, 2, FREQ)     # [B*T, 2, F]

        # Encoder PreNet
        x = self.enc_pre(x)
        encoder_outs = [x]

        # Encoder
        for module in self.encoder:
            x = module(x)
            encoder_outs.append(x)      # [B*T, C, F']

        # DPRNN
        x = self.dprnn_pre(x)           # [B*T, C, F']
        _, C, _FREQ = x.shape
        x = x.view(B, T, C, _FREQ)      # [B, T, C, F']
        x = x.permute(1, 0, 3, 2)       # [T, B, F', C]
        x = x.contiguous()
        for block, cache_in in zip(self.dprnn_block, cache_in_list):
            x, cache_out = block(x, cache_in)   # [T, B, F', C]
            cache_out_list.append(cache_out)
        x = x.permute(1, 0, 3, 2)       # [B, T, C, F']
        x = x.reshape(B*T, C, _FREQ)    # [B*T, C, F']
        x = self.dprnn_post(x)          # [B*T, C, F']

        # Decoder
        for module in self.decoder:
            x = torch.cat([x, encoder_outs.pop(-1)], dim=1)     # [B*T, 2*C, F]
            x = module(x)                                       # [B*T, C, F]

        # Decoder PostNet
        x = torch.cat([x, encoder_outs.pop(-1)], dim=1)     # [B*T, 2*C, F]
        x = self.dec_post(x)                                # [B*T, 2, F]
        x = x.reshape(B, T, 2, FREQ).permute(0, 3, 1, 2)    # [B, F, T, 2]

        # Mask
        mask = self.mask(x).contiguous()
        return mask, cache_out_list

    def forward(self, spec_noisy: Tensor, *args) -> tp.Tuple[Tensor, tp.List[Tensor]]:
        """ input/output: [B, n_fft//2+1, T_spec, 2]"""
        # Compress
        spec_noisy = spec_noisy[:, :-1, :, :]    # [B, F, T, 2]
        mag = torch.linalg.norm(
            spec_noisy,
            dim=-1,
            keepdim=True
        ).clamp(min=1.0e-5)
        spec_noisy = spec_noisy * mag.pow(self.input_compression - 1.0)

        # Model forward
        mask, cache_out_list = self.model_forward(spec_noisy, *args)
        spec_hat = torch.stack(
            [
                spec_noisy[..., 0] * mask[..., 0] - spec_noisy[..., 1] * mask[..., 1],
                spec_noisy[..., 0] * mask[..., 1] + spec_noisy[..., 1] * mask[..., 0],
            ],
            dim=3
        )

        # Uncompress & pad n_fft//2+1
        mag_compressed = torch.linalg.norm(
            spec_hat,
            dim=-1,
            keepdim=True
        )
        spec_hat = spec_hat * mag_compressed.pow(1.0 / self.input_compression - 1.0)
        spec_hat = F.pad(spec_hat, (0, 0, 0, 0, 0, 1))    # [B, F+1, T, 2]
        return spec_hat, *cache_out_list


class Model(ONNXModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def get_stft(
        self, n_fft: int, hop_size: int, win_size: int,
        window: str, normalized: bool
    ) -> nn.Module:
        return CompressedSTFT(
            n_fft=n_fft, hop_size=hop_size, win_size=win_size,
            win_type=window, normalized=normalized,
            compression=self.input_compression,
            discard_last_freq_bin=True,
        )

    def forward(self, noisy: Tensor) -> tp.Tuple[Tensor, Tensor]:
        # farend/nearend_mic: [B, T_wav]
        spec_noisy = self.stft(noisy)                   # [B, F, T, 2]
        mask, _ = self.model_forward(spec_noisy)        # [B, F, T, 2]
        spec_hat = torch.view_as_complex(spec_noisy) \
            * torch.view_as_complex(mask)       # [B, F, T]
        wav_hat = self.stft.inverse(spec_hat)   # [B, T_wav]
        return wav_hat, torch.view_as_real(spec_hat)


def test():
    x = torch.randn(3, 16_000)
    from utils import get_hparams
    hparams = get_hparams("configs/rnnformer/size/dprnn_b.yaml")
    model = Model(**hparams["model_kwargs"])
    model.flatten_parameters()
    wav_out, spec_out = model(x)
    (wav_out.sum() + spec_out.sum()).backward()
    print(spec_out.shape)

    model.remove_weight_reparameterizations()
    total_params = sum(p.numel() for n, p in model.named_parameters())
    print(f"Number of total parameters: {total_params}")
    # for n, p in model.named_parameters():
    #     print(n, p.shape)


if __name__ == "__main__":
    test()
相关推荐
机器学习之心2 小时前
TCN+SHAP分析深度学习多变量分类预测可解释性分析!Matlab完整代码
深度学习·matlab·分类·多变量分类预测可解释性分析
睡醒了叭2 小时前
目标检测-深度学习-SSD模型项目
人工智能·深度学习·目标检测
冰西瓜6002 小时前
从项目入手机器学习(五)—— 机器学习尝试
人工智能·深度学习·机器学习
Coding茶水间2 小时前
基于深度学习的狗品种检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
开发语言·人工智能·深度学习·yolo·目标检测·机器学习
小鸡吃米…3 小时前
机器学习 - 轮次(Epoch)
人工智能·深度学习·机器学习
Rabbit_QL4 小时前
PyTorch DataLoader `num_workers` 配置指南:从新手到进阶
人工智能·pytorch·python
Loacnasfhia95 小时前
【深度学习】【目标检测】YOLO11-C3k2-Faster-EMA模型实现草莓与番茄成熟度及病害识别系统
人工智能·深度学习·目标检测
阡陌..5 小时前
pytorch模型训练使用多GPU执行报错:Bus error (core dumped)(未解决)
人工智能·pytorch·python
XX風5 小时前
5.1 deep learning introduction
人工智能·深度学习