File size: 11,034 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
from typing import Dict

import numpy as np
import torch
from torch import nn


def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"):
    if cuda:
        device = "cuda"
    if np_array is None:
        return None
    tensor = torch.as_tensor(np_array, dtype=dtype, device=device)
    return tensor


def compute_style_mel(style_wav, ap, cuda=False, device="cpu"):
    if cuda:
        device = "cuda"
    style_mel = torch.FloatTensor(
        ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)),
        device=device,
    ).unsqueeze(0)
    return style_mel


def run_model_torch(
    model: nn.Module,
    inputs: torch.Tensor,
    speaker_id: int = None,
    style_mel: torch.Tensor = None,
    style_text: str = None,
    d_vector: torch.Tensor = None,
    language_id: torch.Tensor = None,
) -> Dict:
    """Run a torch model for inference. It does not support batch inference.

    Args:
        model (nn.Module): The model to run inference.
        inputs (torch.Tensor): Input tensor with character ids.
        speaker_id (int, optional): Input speaker ids for multi-speaker models. Defaults to None.
        style_mel (torch.Tensor, optional): Spectrograms used for voice styling . Defaults to None.
        d_vector (torch.Tensor, optional): d-vector for multi-speaker models    . Defaults to None.

    Returns:
        Dict: model outputs.
    """
    input_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device)
    if hasattr(model, "module"):
        _func = model.module.inference
    else:
        _func = model.inference
    outputs = _func(
        inputs,
        aux_input={
            "x_lengths": input_lengths,
            "speaker_ids": speaker_id,
            "d_vectors": d_vector,
            "style_mel": style_mel,
            "style_text": style_text,
            "language_ids": language_id,
        },
    )
    return outputs


def trim_silence(wav, ap):
    return wav[: ap.find_endpoint(wav)]


def inv_spectrogram(postnet_output, ap, CONFIG):
    if CONFIG.model.lower() in ["tacotron"]:
        wav = ap.inv_spectrogram(postnet_output.T)
    else:
        wav = ap.inv_melspectrogram(postnet_output.T)
    return wav


def id_to_torch(aux_id, cuda=False, device="cpu"):
    if cuda:
        device = "cuda"
    if aux_id is not None:
        aux_id = np.asarray(aux_id)
        aux_id = torch.from_numpy(aux_id).to(device)
    return aux_id


def embedding_to_torch(d_vector, cuda=False, device="cpu"):
    if cuda:
        device = "cuda"
    if d_vector is not None:
        d_vector = np.asarray(d_vector)
        d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
        d_vector = d_vector.squeeze().unsqueeze(0).to(device)
    return d_vector


# TODO: perform GL with pytorch for batching
def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
    """Apply griffin-lim to each sample iterating throught the first dimension.
    Args:
        inputs (Tensor or np.Array): Features to be converted by GL. First dimension is the batch size.
        input_lens (Tensor or np.Array): 1D array of sample lengths.
        CONFIG (Dict): TTS config.
        ap (AudioProcessor): TTS audio processor.
    """
    wavs = []
    for idx, spec in enumerate(inputs):
        wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length  # inverse librosa padding
        wav = inv_spectrogram(spec, ap, CONFIG)
        # assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}"
        wavs.append(wav[:wav_len])
    return wavs


def synthesis(
    model,
    text,
    CONFIG,
    use_cuda,
    speaker_id=None,
    style_wav=None,
    style_text=None,
    use_griffin_lim=False,
    do_trim_silence=False,
    d_vector=None,
    language_id=None,
):
    """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
    the vocoder model.

    Args:
        model (TTS.tts.models):
            The TTS model to synthesize audio with.

        text (str):
            The input text to convert to speech.

        CONFIG (Coqpit):
            Model configuration.

        use_cuda (bool):
            Enable/disable CUDA.

        speaker_id (int):
            Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None.

        style_wav (str | Dict[str, float]):
            Path or tensor to/of a waveform used for computing the style embedding based on GST or Capacitron.
            Defaults to None, meaning that Capacitron models will sample from the prior distribution to
            generate random but realistic prosody.

        style_text (str):
            Transcription of style_wav for Capacitron models. Defaults to None.

        enable_eos_bos_chars (bool):
            enable special chars for end of sentence and start of sentence. Defaults to False.

        do_trim_silence (bool):
            trim silence after synthesis. Defaults to False.

        d_vector (torch.Tensor):
            d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None.

        language_id (int):
            Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
    """
    # device
    device = next(model.parameters()).device
    if use_cuda:
        device = "cuda"

    # GST or Capacitron processing
    # TODO: need to handle the case of setting both gst and capacitron to true somewhere
    style_mel = None
    if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
        if isinstance(style_wav, dict):
            style_mel = style_wav
        else:
            style_mel = compute_style_mel(style_wav, model.ap, device=device)

    if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None:
        style_mel = compute_style_mel(style_wav, model.ap, device=device)
        style_mel = style_mel.transpose(1, 2)  # [1, time, depth]

    language_name = None
    if language_id is not None:
        language = [k for k, v in model.language_manager.name_to_id.items() if v == language_id]
        assert len(language) == 1, "language_id must be a valid language"
        language_name = language[0]

    # convert text to sequence of token IDs
    text_inputs = np.asarray(
        model.tokenizer.text_to_ids(text, language=language_name),
        dtype=np.int32,
    )
    # pass tensors to backend
    if speaker_id is not None:
        speaker_id = id_to_torch(speaker_id, device=device)

    if d_vector is not None:
        d_vector = embedding_to_torch(d_vector, device=device)

    if language_id is not None:
        language_id = id_to_torch(language_id, device=device)

    if not isinstance(style_mel, dict):
        # GST or Capacitron style mel
        style_mel = numpy_to_torch(style_mel, torch.float, device=device)
        if style_text is not None:
            style_text = np.asarray(
                model.tokenizer.text_to_ids(style_text, language=language_id),
                dtype=np.int32,
            )
            style_text = numpy_to_torch(style_text, torch.long, device=device)
            style_text = style_text.unsqueeze(0)

    text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
    text_inputs = text_inputs.unsqueeze(0)
    # synthesize voice
    outputs = run_model_torch(
        model,
        text_inputs,
        speaker_id,
        style_mel,
        style_text,
        d_vector=d_vector,
        language_id=language_id,
    )
    model_outputs = outputs["model_outputs"]
    model_outputs = model_outputs[0].data.cpu().numpy()
    alignments = outputs["alignments"]

    # convert outputs to numpy
    # plot results
    wav = None
    model_outputs = model_outputs.squeeze()
    if model_outputs.ndim == 2:  # [T, C_spec]
        if use_griffin_lim:
            wav = inv_spectrogram(model_outputs, model.ap, CONFIG)
            # trim silence
            if do_trim_silence:
                wav = trim_silence(wav, model.ap)
    else:  # [T,]
        wav = model_outputs
    return_dict = {
        "wav": wav,
        "alignments": alignments,
        "text_inputs": text_inputs,
        "outputs": outputs,
    }
    return return_dict


def transfer_voice(
    model,
    CONFIG,
    use_cuda,
    reference_wav,
    speaker_id=None,
    d_vector=None,
    reference_speaker_id=None,
    reference_d_vector=None,
    do_trim_silence=False,
    use_griffin_lim=False,
):
    """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
    the vocoder model.

    Args:
        model (TTS.tts.models):
            The TTS model to synthesize audio with.

        CONFIG (Coqpit):
            Model configuration.

        use_cuda (bool):
            Enable/disable CUDA.

        reference_wav (str):
            Path of reference_wav to be used to voice conversion.

        speaker_id (int):
            Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None.

        d_vector (torch.Tensor):
            d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None.

        reference_speaker_id (int):
            Reference Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None.

        reference_d_vector (torch.Tensor):
            Reference d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None.

        enable_eos_bos_chars (bool):
            enable special chars for end of sentence and start of sentence. Defaults to False.

        do_trim_silence (bool):
            trim silence after synthesis. Defaults to False.
    """
    # device
    device = next(model.parameters()).device
    if use_cuda:
        device = "cuda"

    # pass tensors to backend
    if speaker_id is not None:
        speaker_id = id_to_torch(speaker_id, device=device)

    if d_vector is not None:
        d_vector = embedding_to_torch(d_vector, device=device)

    if reference_d_vector is not None:
        reference_d_vector = embedding_to_torch(reference_d_vector, device=device)

    # load reference_wav audio
    reference_wav = embedding_to_torch(
        model.ap.load_wav(
            reference_wav, sr=model.args.encoder_sample_rate if model.args.encoder_sample_rate else model.ap.sample_rate
        ),
        device=device,
    )

    if hasattr(model, "module"):
        _func = model.module.inference_voice_conversion
    else:
        _func = model.inference_voice_conversion
    model_outputs = _func(reference_wav, speaker_id, d_vector, reference_speaker_id, reference_d_vector)

    # convert outputs to numpy
    # plot results
    wav = None
    model_outputs = model_outputs.squeeze()
    if model_outputs.ndim == 2:  # [T, C_spec]
        if use_griffin_lim:
            wav = inv_spectrogram(model_outputs, model.ap, CONFIG)
            # trim silence
            if do_trim_silence:
                wav = trim_silence(wav, model.ap)
    else:  # [T,]
        wav = model_outputs

    return wav