File size: 12,888 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
import json
import random
from typing import Any, Dict, List, Tuple, Union

import fsspec
import numpy as np
import torch

from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.utils.audio import AudioProcessor


def load_file(path: str):
    if path.endswith(".json"):
        with fsspec.open(path, "r") as f:
            return json.load(f)
    elif path.endswith(".pth"):
        with fsspec.open(path, "rb") as f:
            return torch.load(f, map_location="cpu")
    else:
        raise ValueError("Unsupported file type")


def save_file(obj: Any, path: str):
    if path.endswith(".json"):
        with fsspec.open(path, "w") as f:
            json.dump(obj, f, indent=4)
    elif path.endswith(".pth"):
        with fsspec.open(path, "wb") as f:
            torch.save(obj, f)
    else:
        raise ValueError("Unsupported file type")


class BaseIDManager:
    """Base `ID` Manager class. Every new `ID` manager must inherit this.
    It defines common `ID` manager specific functions.
    """

    def __init__(self, id_file_path: str = ""):
        self.name_to_id = {}

        if id_file_path:
            self.load_ids_from_file(id_file_path)

    @staticmethod
    def _load_json(json_file_path: str) -> Dict:
        with fsspec.open(json_file_path, "r") as f:
            return json.load(f)

    @staticmethod
    def _save_json(json_file_path: str, data: dict) -> None:
        with fsspec.open(json_file_path, "w") as f:
            json.dump(data, f, indent=4)

    def set_ids_from_data(self, items: List, parse_key: str) -> None:
        """Set IDs from data samples.

        Args:
            items (List): Data sampled returned by `load_tts_samples()`.
        """
        self.name_to_id = self.parse_ids_from_data(items, parse_key=parse_key)

    def load_ids_from_file(self, file_path: str) -> None:
        """Set IDs from a file.

        Args:
            file_path (str): Path to the file.
        """
        self.name_to_id = load_file(file_path)

    def save_ids_to_file(self, file_path: str) -> None:
        """Save IDs to a json file.

        Args:
            file_path (str): Path to the output file.
        """
        save_file(self.name_to_id, file_path)

    def get_random_id(self) -> Any:
        """Get a random embedding.

        Args:

        Returns:
            np.ndarray: embedding.
        """
        if self.name_to_id:
            return self.name_to_id[random.choices(list(self.name_to_id.keys()))[0]]

        return None

    @staticmethod
    def parse_ids_from_data(items: List, parse_key: str) -> Tuple[Dict]:
        """Parse IDs from data samples retured by `load_tts_samples()`.

        Args:
            items (list): Data sampled returned by `load_tts_samples()`.
            parse_key (str): The key to being used to parse the data.
        Returns:
            Tuple[Dict]: speaker IDs.
        """
        classes = sorted({item[parse_key] for item in items})
        ids = {name: i for i, name in enumerate(classes)}
        return ids


class EmbeddingManager(BaseIDManager):
    """Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
    It defines common `Embedding` manager specific functions.

    It expects embeddings files in the following format:

    ::

        {
            'audio_file_key':{
                'name': 'category_name',
                'embedding'[<embedding_values>]
            },
            ...
        }

    `audio_file_key` is a unique key to the audio file in the dataset. It can be the path to the file or any other unique key.
    `embedding` is the embedding vector of the audio file.
    `name` can be name of the speaker of the audio file.
    """

    def __init__(
        self,
        embedding_file_path: Union[str, List[str]] = "",
        id_file_path: str = "",
        encoder_model_path: str = "",
        encoder_config_path: str = "",
        use_cuda: bool = False,
    ):
        super().__init__(id_file_path=id_file_path)

        self.embeddings = {}
        self.embeddings_by_names = {}
        self.clip_ids = []
        self.encoder = None
        self.encoder_ap = None
        self.use_cuda = use_cuda

        if embedding_file_path:
            if isinstance(embedding_file_path, list):
                self.load_embeddings_from_list_of_files(embedding_file_path)
            else:
                self.load_embeddings_from_file(embedding_file_path)

        if encoder_model_path and encoder_config_path:
            self.init_encoder(encoder_model_path, encoder_config_path, use_cuda)

    @property
    def num_embeddings(self):
        """Get number of embeddings."""
        return len(self.embeddings)

    @property
    def num_names(self):
        """Get number of embeddings."""
        return len(self.embeddings_by_names)

    @property
    def embedding_dim(self):
        """Dimensionality of embeddings. If embeddings are not loaded, returns zero."""
        if self.embeddings:
            return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"])
        return 0

    @property
    def embedding_names(self):
        """Get embedding names."""
        return list(self.embeddings_by_names.keys())

    def save_embeddings_to_file(self, file_path: str) -> None:
        """Save embeddings to a json file.

        Args:
            file_path (str): Path to the output file.
        """
        save_file(self.embeddings, file_path)

    @staticmethod
    def read_embeddings_from_file(file_path: str):
        """Load embeddings from a json file.

        Args:
            file_path (str): Path to the file.
        """
        embeddings = load_file(file_path)
        speakers = sorted({x["name"] for x in embeddings.values()})
        name_to_id = {name: i for i, name in enumerate(speakers)}
        clip_ids = list(set(sorted(clip_name for clip_name in embeddings.keys())))
        # cache embeddings_by_names for fast inference using a bigger speakers.json
        embeddings_by_names = {}
        for x in embeddings.values():
            if x["name"] not in embeddings_by_names.keys():
                embeddings_by_names[x["name"]] = [x["embedding"]]
            else:
                embeddings_by_names[x["name"]].append(x["embedding"])
        return name_to_id, clip_ids, embeddings, embeddings_by_names

    def load_embeddings_from_file(self, file_path: str) -> None:
        """Load embeddings from a json file.

        Args:
            file_path (str): Path to the target json file.
        """
        self.name_to_id, self.clip_ids, self.embeddings, self.embeddings_by_names = self.read_embeddings_from_file(
            file_path
        )

    def load_embeddings_from_list_of_files(self, file_paths: List[str]) -> None:
        """Load embeddings from a list of json files and don't allow duplicate keys.

        Args:
            file_paths (List[str]): List of paths to the target json files.
        """
        self.name_to_id = {}
        self.clip_ids = []
        self.embeddings_by_names = {}
        self.embeddings = {}
        for file_path in file_paths:
            ids, clip_ids, embeddings, embeddings_by_names = self.read_embeddings_from_file(file_path)
            # check colliding keys
            duplicates = set(self.embeddings.keys()) & set(embeddings.keys())
            if duplicates:
                raise ValueError(f" [!] Duplicate embedding names <{duplicates}> in {file_path}")
            # store values
            self.name_to_id.update(ids)
            self.clip_ids.extend(clip_ids)
            self.embeddings_by_names.update(embeddings_by_names)
            self.embeddings.update(embeddings)

        # reset name_to_id to get the right speaker ids
        self.name_to_id = {name: i for i, name in enumerate(self.name_to_id)}

    def get_embedding_by_clip(self, clip_idx: str) -> List:
        """Get embedding by clip ID.

        Args:
            clip_idx (str): Target clip ID.

        Returns:
            List: embedding as a list.
        """
        return self.embeddings[clip_idx]["embedding"]

    def get_embeddings_by_name(self, idx: str) -> List[List]:
        """Get all embeddings of a speaker.

        Args:
            idx (str): Target name.

        Returns:
            List[List]: all the embeddings of the given speaker.
        """
        return self.embeddings_by_names[idx]

    def get_embeddings_by_names(self) -> Dict:
        """Get all embeddings by names.

        Returns:
            Dict: all the embeddings of each speaker.
        """
        embeddings_by_names = {}
        for x in self.embeddings.values():
            if x["name"] not in embeddings_by_names.keys():
                embeddings_by_names[x["name"]] = [x["embedding"]]
            else:
                embeddings_by_names[x["name"]].append(x["embedding"])
        return embeddings_by_names

    def get_mean_embedding(self, idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray:
        """Get mean embedding of a idx.

        Args:
            idx (str): Target name.
            num_samples (int, optional): Number of samples to be averaged. Defaults to None.
            randomize (bool, optional): Pick random `num_samples` of embeddings. Defaults to False.

        Returns:
            np.ndarray: Mean embedding.
        """
        embeddings = self.get_embeddings_by_name(idx)
        if num_samples is None:
            embeddings = np.stack(embeddings).mean(0)
        else:
            assert len(embeddings) >= num_samples, f" [!] {idx} has number of samples < {num_samples}"
            if randomize:
                embeddings = np.stack(random.choices(embeddings, k=num_samples)).mean(0)
            else:
                embeddings = np.stack(embeddings[:num_samples]).mean(0)
        return embeddings

    def get_random_embedding(self) -> Any:
        """Get a random embedding.

        Args:

        Returns:
            np.ndarray: embedding.
        """
        if self.embeddings:
            return self.embeddings[random.choices(list(self.embeddings.keys()))[0]]["embedding"]

        return None

    def get_clips(self) -> List:
        return sorted(self.embeddings.keys())

    def init_encoder(self, model_path: str, config_path: str, use_cuda=False) -> None:
        """Initialize a speaker encoder model.

        Args:
            model_path (str): Model file path.
            config_path (str): Model config file path.
            use_cuda (bool, optional): Use CUDA. Defaults to False.
        """
        self.use_cuda = use_cuda
        self.encoder_config = load_config(config_path)
        self.encoder = setup_encoder_model(self.encoder_config)
        self.encoder_criterion = self.encoder.load_checkpoint(
            self.encoder_config, model_path, eval=True, use_cuda=use_cuda, cache=True
        )
        self.encoder_ap = AudioProcessor(**self.encoder_config.audio)

    def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list:
        """Compute a embedding from a given audio file.

        Args:
            wav_file (Union[str, List[str]]): Target file path.

        Returns:
            list: Computed embedding.
        """

        def _compute(wav_file: str):
            waveform = self.encoder_ap.load_wav(wav_file, sr=self.encoder_ap.sample_rate)
            if not self.encoder_config.model_params.get("use_torch_spec", False):
                m_input = self.encoder_ap.melspectrogram(waveform)
                m_input = torch.from_numpy(m_input)
            else:
                m_input = torch.from_numpy(waveform)

            if self.use_cuda:
                m_input = m_input.cuda()
            m_input = m_input.unsqueeze(0)
            embedding = self.encoder.compute_embedding(m_input)
            return embedding

        if isinstance(wav_file, list):
            # compute the mean embedding
            embeddings = None
            for wf in wav_file:
                embedding = _compute(wf)
                if embeddings is None:
                    embeddings = embedding
                else:
                    embeddings += embedding
            return (embeddings / len(wav_file))[0].tolist()
        embedding = _compute(wav_file)
        return embedding[0].tolist()

    def compute_embeddings(self, feats: Union[torch.Tensor, np.ndarray]) -> List:
        """Compute embedding from features.

        Args:
            feats (Union[torch.Tensor, np.ndarray]): Input features.

        Returns:
            List: computed embedding.
        """
        if isinstance(feats, np.ndarray):
            feats = torch.from_numpy(feats)
        if feats.ndim == 2:
            feats = feats.unsqueeze(0)
        if self.use_cuda:
            feats = feats.cuda()
        return self.encoder.compute_embedding(feats)