Extending Pretrained Module for MetricGAN+


I have a trained metricgan model and want it to use it as a pretrained model which I can access directly from the package itself. In order to do so, I cloned the repo and extended the pretrained module. The problem is that I could not figure out how to set the parameters of the hyperparams.yaml file. The output of the model has 3 model outputs - generator, discriminator and model. I tried to use the other enhancement module (SpectralMaskEnhancement) as a reference. You can find the class definition inside the interface.py and the hyperyaml.yaml file for model description below.

"""Defines interfaces for simple inference with pretrained models

import torch
import torchaudio
from types import SimpleNamespace
from torch.nn import SyncBatchNorm
from torch.nn import DataParallel as DP
from hyperpyyaml import load_hyperpyyaml
from speechbrain.pretrained.fetching import fetch
from speechbrain.dataio.preprocess import AudioNormalizer
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from speechbrain.utils.data_utils import split_path
from speechbrain.utils.distributed import run_on_main

class Pretrained:
    """Takes a trained model and makes predictions on new data.

    This is a base class which handles some common boilerplate.
    It intentionally has an interface similar to ``Brain`` - these base
    classes handle similar things.

    Subclasses of Pretrained should implement the actual logic of how
    the pretrained system runs, and add methods with descriptive names
    (e.g. transcribe_file() for ASR).


class MetricGANPlus(Pretrained):
    """ MetricGANPlus model.
    The class can be used to run on single .wav files.

    >>> from speechbrain.pretrained import MetricGANPlus
    >>> tmpdir =  getfixture("tmpdir")
    >>> metricgan = MetricGANPlus.from_hparams(
    source = ""
    >>> noisy, fs = torchaudio.load(samples/noisy.wav")
    >>> metricgan = metricgan.enhance_batch(noisy)
    HPARAMS_NEEDED = ["compute_stft", "spectral_magnitude", "resynth"]
    MODULES_NEEDED = ["model", "generator, discriminator"]

    def compute_feats(self, wavs):
        """Feature computation pipeline"""
        feats = self.hparams.compute_STFT(wavs)
        feats = self.spectral_magnitude(feats, power=0.5)
        feats = torch.log1p(feats)
        return feats

    def enhance_batch(self, noisy, lengths=None):
        noisy_features = self.compute_features(noisy)
        if lengths is not None:
            mask = self.modules.enhance_model(noisy_features, lengths=lengths)
            mask = self.modules.enhance_model(noisy_features)
        enhanced = torch.mul(mask, noisy_features)
        # Return resynthesized waveforms
        return self.hparams.resynth(torch.expm1(enhanced), noisy)

    def enhance_file(self, filename, output_filename=None):

        noisy = self.load_audio(filename)
        batch = noisy.unsqueeze(0)
        enhanced = self.enhance_batch(batch)

        if output_filename is not None:
            torchaudio.save(output_filename, enhanced, channels_first=False)
        return enhance.squeeze(0)

# STFT arguments
Sample_rate: 16000
N_fft: 512
Win_length: 32
Hop_length: 16
window_fn: !name:torch.hamming_window

# Enhancement Model Arguments

compute_STFT: !new:speechbrain.processing.features.STFT
  sample_rate: !ref <Sample_rate>
  win_length: !ref <Win_length>
  hop_length: !ref <Hop_length>
  n_fft: !ref <N_fft>
  window_fn: !ref <window_fn>
compute_ISTFT: !new:speechbrain.processing.features.ISTFT
  sample_rate: !ref <Sample_rate>
  win_length: !ref <Win_length>
  hop_length: !ref <Hop_length>
  window_fn: !ref <window_fn>
resynth: !name:speechbrain.processing.signal_processing.resynthesize
  stft: !ref <compute_STFT>
  istft: !ref <compute_ISTFT>
  normalize_wavs: False

# Neural parameters
kernel_size: (5,5)
base_channels: 15

generator: !new:speechbrain.lobes.models.MetricGAN.EnhancementGenerator

discriminator: !new:speechbrain.lobes.models.MetricGAN.MetricDiscriminator
  kernel_size: !ref <kernel_size>
  base_channels: !ref <base_channels>

  generator: !ref <generator>
  discriminator: !ref <discriminator>

pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
    generator: !ref <generator>
    discriminator: !ref <discriminator>

Could someone help me to load model parameters as expected ?


If you’re looking for a pretrained interface for MetricGAN+, there’s already one on HuggingFace:

If not, let me know and I can take a look at your code.