import numpy as np import torch import torchaudio from coqpit import Coqpit from torch import nn from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.utils.generic_utils import set_init_dict from TTS.utils.io import load_fsspec class PreEmphasis(nn.Module): def __init__(self, coefficient=0.97): super().__init__() self.coefficient = coefficient self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) def forward(self, x): assert len(x.size()) == 2 x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") return torch.nn.functional.conv1d(x, self.filter).squeeze(1) class BaseEncoder(nn.Module): """Base `encoder` class. Every new `encoder` model must inherit this. It defines common `encoder` specific functions. """ # pylint: disable=W0102 def __init__(self): super(BaseEncoder, self).__init__() def get_torch_mel_spectrogram_class(self, audio_config): return torch.nn.Sequential( PreEmphasis(audio_config["preemphasis"]), # TorchSTFT( # n_fft=audio_config["fft_size"], # hop_length=audio_config["hop_length"], # win_length=audio_config["win_length"], # sample_rate=audio_config["sample_rate"], # window="hamming_window", # mel_fmin=0.0, # mel_fmax=None, # use_htk=True, # do_amp_to_db=False, # n_mels=audio_config["num_mels"], # power=2.0, # use_mel=True, # mel_norm=None, # ) torchaudio.transforms.MelSpectrogram( sample_rate=audio_config["sample_rate"], n_fft=audio_config["fft_size"], win_length=audio_config["win_length"], hop_length=audio_config["hop_length"], window_fn=torch.hamming_window, n_mels=audio_config["num_mels"], ), ) @torch.no_grad() def inference(self, x, l2_norm=True): return self.forward(x, l2_norm) @torch.no_grad() def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True): """ Generate embeddings for a batch of utterances x: 1xTxD """ # map to the waveform size if self.use_torch_spec: num_frames = num_frames * self.audio_config["hop_length"] max_len = x.shape[1] if max_len < num_frames: num_frames = max_len offsets = np.linspace(0, max_len - num_frames, num=num_eval) frames_batch = [] for offset in offsets: offset = int(offset) end_offset = int(offset + num_frames) frames = x[:, offset:end_offset] frames_batch.append(frames) frames_batch = torch.cat(frames_batch, dim=0) embeddings = self.inference(frames_batch, l2_norm=l2_norm) if return_mean: embeddings = torch.mean(embeddings, dim=0, keepdim=True) return embeddings def get_criterion(self, c: Coqpit, num_classes=None): if c.loss == "ge2e": criterion = GE2ELoss(loss_method="softmax") elif c.loss == "angleproto": criterion = AngleProtoLoss() elif c.loss == "softmaxproto": criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes) else: raise Exception("The %s not is a loss supported" % c.loss) return criterion def load_checkpoint( self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None, cache=False, ): state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) try: self.load_state_dict(state["model"]) print(" > Model fully restored. ") except (KeyError, RuntimeError) as error: # If eval raise the error if eval: raise error print(" > Partial model initialization.") model_dict = self.state_dict() model_dict = set_init_dict(model_dict, state["model"], c) self.load_state_dict(model_dict) del model_dict # load the criterion for restore_path if criterion is not None and "criterion" in state: try: criterion.load_state_dict(state["criterion"]) except (KeyError, RuntimeError) as error: print(" > Criterion load ignored because of:", error) # instance and load the criterion for the encoder classifier in inference time if ( eval and criterion is None and "criterion" in state and getattr(config, "map_classid_to_classname", None) is not None ): criterion = self.get_criterion(config, len(config.map_classid_to_classname)) criterion.load_state_dict(state["criterion"]) if use_cuda: self.cuda() if criterion is not None: criterion = criterion.cuda() if eval: self.eval() assert not self.training if not eval: return criterion, state["step"] return criterion