Extending Pretrained Module for MetricGAN+

Hello,

I have a trained metricgan model and want it to use it as a pretrained model which I can access directly from the package itself. In order to do so, I cloned the repo and extended the pretrained module. The problem is that I could not figure out how to set the parameters of the hyperparams.yaml file. The output of the model has 3 model outputs - generator, discriminator and model. I tried to use the other enhancement module (SpectralMaskEnhancement) as a reference. You can find the class definition inside the interface.py and the hyperyaml.yaml file for model description below.


"""Defines interfaces for simple inference with pretrained models

Authors:
 * Aku Rouhe 2021
 * Peter Plantinga 2021
 * Loren Lugosch 2020
 * Mirco Ravanelli 2020
 * Titouan Parcollet 2021
"""
import torch
import torchaudio
from types import SimpleNamespace
from torch.nn import SyncBatchNorm
from torch.nn import DataParallel as DP
from hyperpyyaml import load_hyperpyyaml
from speechbrain.pretrained.fetching import fetch
from speechbrain.dataio.preprocess import AudioNormalizer
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from speechbrain.utils.data_utils import split_path
from speechbrain.utils.distributed import run_on_main


class Pretrained:
    """Takes a trained model and makes predictions on new data.

    This is a base class which handles some common boilerplate.
    It intentionally has an interface similar to ``Brain`` - these base
    classes handle similar things.

    Subclasses of Pretrained should implement the actual logic of how
    the pretrained system runs, and add methods with descriptive names
    (e.g. transcribe_file() for ASR).
   

     .....



class MetricGANPlus(Pretrained):
    """ MetricGANPlus model.
    The class can be used to run on single .wav files.

    Example
    -------
    >>> from speechbrain.pretrained import MetricGANPlus
    >>> tmpdir =  getfixture("tmpdir")
    >>> metricgan = MetricGANPlus.from_hparams(
    source = ""
    savedir=tmpdir)
    >>> noisy, fs = torchaudio.load(samples/noisy.wav")
    >>> metricgan = metricgan.enhance_batch(noisy)
    """
    
    HPARAMS_NEEDED = ["compute_stft", "spectral_magnitude", "resynth"]
    MODULES_NEEDED = ["model", "generator, discriminator"]

    def compute_feats(self, wavs):
        """Feature computation pipeline"""
        feats = self.hparams.compute_STFT(wavs)
        feats = self.spectral_magnitude(feats, power=0.5)
        feats = torch.log1p(feats)
        return feats

    def enhance_batch(self, noisy, lengths=None):
        noisy_features = self.compute_features(noisy)
        if lengths is not None:
            mask = self.modules.enhance_model(noisy_features, lengths=lengths)
        else:
            mask = self.modules.enhance_model(noisy_features)
        enhanced = torch.mul(mask, noisy_features)
        # Return resynthesized waveforms
        return self.hparams.resynth(torch.expm1(enhanced), noisy)

    def enhance_file(self, filename, output_filename=None):

        noisy = self.load_audio(filename)
        batch = noisy.unsqueeze(0)
        enhanced = self.enhance_batch(batch)

        if output_filename is not None:
            torchaudio.save(output_filename, enhanced, channels_first=False)
            
        return enhance.squeeze(0)



# STFT arguments
Sample_rate: 16000
N_fft: 512
Win_length: 32
Hop_length: 16
window_fn: !name:torch.hamming_window

# Enhancement Model Arguments

compute_STFT: !new:speechbrain.processing.features.STFT
  sample_rate: !ref <Sample_rate>
  win_length: !ref <Win_length>
  hop_length: !ref <Hop_length>
  n_fft: !ref <N_fft>
  window_fn: !ref <window_fn>
  
compute_ISTFT: !new:speechbrain.processing.features.ISTFT
  sample_rate: !ref <Sample_rate>
  win_length: !ref <Win_length>
  hop_length: !ref <Hop_length>
  window_fn: !ref <window_fn>
  
resynth: !name:speechbrain.processing.signal_processing.resynthesize
  stft: !ref <compute_STFT>
  istft: !ref <compute_ISTFT>
  normalize_wavs: False



# Neural parameters
kernel_size: (5,5)
base_channels: 15

generator: !new:speechbrain.lobes.models.MetricGAN.EnhancementGenerator

discriminator: !new:speechbrain.lobes.models.MetricGAN.MetricDiscriminator
  kernel_size: !ref <kernel_size>
  base_channels: !ref <base_channels>

modules:
  
  generator: !ref <generator>
  discriminator: !ref <discriminator>

pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
  loadables:
    generator: !ref <generator>
    discriminator: !ref <discriminator>

Could someone help me to load model parameters as expected ?

B.R.

If you’re looking for a pretrained interface for MetricGAN+, there’s already one on HuggingFace:

If not, let me know and I can take a look at your code.