L0SG commited on
Commit
6d8c66f
1 Parent(s): 4a4bb01
.gitattributes CHANGED
@@ -25,7 +25,6 @@
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.wav filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.pyc
2
+ __pycache__/
3
+ */__pycache__/
4
+ alias_free_cuda/build/
5
+ exp/
6
+ tmp/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 NVIDIA CORPORATION.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
  title: BigVGAN
3
- emoji: 🦀
4
- colorFrom: yellow
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.38.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: BigVGAN
3
+ emoji: 🔊
4
+ colorFrom: red
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 4.38.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
README_model.md ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## BigVGAN: A Universal Neural Vocoder with Large-Scale Training
2
+ #### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon
3
+
4
+ <center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
5
+
6
+
7
+ ### [Paper](https://arxiv.org/abs/2206.04658) &emsp; [Project page](https://research.nvidia.com/labs/adlr/projects/bigvgan/) &emsp; [Audio demo](https://bigvgan-demo.github.io/)
8
+
9
+ ## News
10
+ [Jul 2024] We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
11
+ * Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.
12
+ * Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546).
13
+ * Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
14
+ * We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio.
15
+
16
+ ## Installation
17
+ The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment:
18
+ ```shell
19
+ conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
20
+ conda activate bigvgan
21
+ ```
22
+
23
+ Clone the repository and install dependencies:
24
+ ```shell
25
+ git clone https://github.com/NVIDIA/BigVGAN
26
+ cd BigVGAN
27
+ pip install -r requirements.txt
28
+ ```
29
+
30
+
31
+
32
+ Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
33
+ ``` shell
34
+ cd LibriTTS && \
35
+ ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
36
+ ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \
37
+ ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \
38
+ ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \
39
+ ln -s /path/to/your/LibriTTS/dev-other dev-other && \
40
+ ln -s /path/to/your/LibriTTS/test-clean test-clean && \
41
+ ln -s /path/to/your/LibriTTS/test-other test-other && \
42
+ cd ..
43
+ ```
44
+
45
+ ## Training
46
+ Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
47
+ ```shell
48
+ python train.py \
49
+ --config configs/bigvgan_v2_24khz_100band_256x.json \
50
+ --input_wavs_dir LibriTTS \
51
+ --input_training_file LibriTTS/train-full.txt \
52
+ --input_validation_file LibriTTS/val-full.txt \
53
+ --list_input_unseen_wavs_dir LibriTTS LibriTTS \
54
+ --list_input_unseen_validation_file LibriTTS/dev-clean.txt LibriTTS/dev-other.txt \
55
+ --checkpoint_path exp/bigvgan_v2_24khz_100band_256x
56
+ ```
57
+
58
+
59
+ ## Synthesis
60
+ Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
61
+ It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
62
+ ```shell
63
+ python inference.py \
64
+ --checkpoint_file exp/bigvgan_v2_24khz_100band_256x/g_03000000 \
65
+ --input_wavs_dir /path/to/your/input_wav \
66
+ --output_dir /path/to/your/output_wav
67
+ ```
68
+
69
+ `inference_e2e.py` supports synthesis directly from the mel spectrogram saved in `.npy` format, with shapes `[1, channel, frame]` or `[channel, frame]`.
70
+ It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`.
71
+
72
+ Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
73
+ ```shell
74
+ python inference_e2e.py \
75
+ --checkpoint_file exp/bigvgan_v2_24khz_100band_256x/g_03000000 \
76
+ --input_mels_dir /path/to/your/input_mel \
77
+ --output_dir /path/to/your/output_wav
78
+ ```
79
+
80
+ ## Using Custom CUDA Kernel for Synthesis
81
+ You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN:
82
+
83
+ ```python
84
+ generator = BigVGAN(h, use_cuda_kernel=True)
85
+ ```
86
+
87
+ You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.
88
+
89
+ When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.
90
+
91
+ Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.
92
+
93
+ We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`:
94
+
95
+ ```python
96
+ python test_cuda_vs_torch_model.py \
97
+ --checkpoint_file /path/to/your/bigvgan/g_03000000
98
+ ```
99
+
100
+ ```shell
101
+ loading plain Pytorch BigVGAN
102
+ ...
103
+ loading CUDA kernel BigVGAN with auto-build
104
+ Detected CUDA files, patching ldflags
105
+ Emitting ninja build file /path/to/your/BigVGAN/alias_free_cuda/build/build.ninja...
106
+ Building extension module anti_alias_activation_cuda...
107
+ ...
108
+ Loading extension module anti_alias_activation_cuda...
109
+ ...
110
+ Loading '/path/to/your/bigvgan/g_03000000'
111
+ ...
112
+ [Success] test CUDA fused vs. plain torch BigVGAN inference
113
+ > mean_difference=0.0007238413265440613
114
+ ...
115
+ ```
116
+
117
+ If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version.
118
+
119
+
120
+ ## Pretrained Models
121
+ We provide the [pretrained models](https://drive.google.com/drive/folders/1L2RDeJMBE7QAI8qV51n0QAf4mkSgUUeE?usp=sharing).
122
+ One can download the checkpoints of the generator weight (e.g., `g_(training_steps)`) and its discriminator/optimizer states (e.g., `do_(training_steps)`) within the listed folders.
123
+
124
+ |Folder Name|Sampling Rate|Mel band|fmax|Upsampling Ratio|Params.|Dataset|Fine-Tuned|
125
+ |------|---|---|---|---|---|------|---|
126
+ |bigvgan_v2_44khz_128band_512x|44 kHz|128|22050|512|122M|Large-scale Compilation|No|
127
+ |bigvgan_v2_44khz_128band_256x|44 kHz|128|22050|256|112M|Large-scale Compilation|No|
128
+ |bigvgan_v2_24khz_100band_256x|24 kHz|100|12000|256|112M|Large-scale Compilation|No|
129
+ |bigvgan_v2_22khz_80band_256x|22 kHz|80|11025|256|112M|Large-scale Compilation|No|
130
+ |bigvgan_v2_22khz_80band_fmax8k_256x|22 kHz|80|8000|256|112M|Large-scale Compilation|No|
131
+ |bigvgan_24khz_100band|24 kHz|100|12000|256|112M|LibriTTS|No|
132
+ |bigvgan_base_24khz_100band|24 kHz|100|12000|256|14M|LibriTTS|No|
133
+ |bigvgan_22khz_80band|22 kHz|80|8000|256|112M|LibriTTS + VCTK + LJSpeech|No|
134
+ |bigvgan_base_22khz_80band|22 kHz|80|8000|256|14M|LibriTTS + VCTK + LJSpeech|No|
135
+
136
+ The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
137
+ We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
138
+ Note that the checkpoints use ``snakebeta`` activation with log scale parameterization, which have the best overall quality.
139
+
140
+ You can fine-tune the models by downloading the checkpoints (both the generator weight and its discrimiantor/optimizer states) and resuming training using your audio dataset.
141
+
142
+ ## Training Details of BigVGAN-v2
143
+ Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
144
+
145
+ Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs.
146
+
147
+ When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`.
148
+
149
+ ## Evaluation Results of BigVGAN-v2
150
+ Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio.
151
+
152
+ |Model|Dataset|Steps|PESQ(↑)|M-STFT(↓)|MCD(↓)|Periodicity(↓)|V/UV F1(↑)|
153
+ |-------|-----|-----|-----|-----|-----|-----|-----|
154
+ |BigVGAN|LibriTTS|1M|4.027|0.7997|0.3745|0.1018|0.9598|
155
+ |BigVGAN|LibriTTS|5M|4.256|0.7409|0.2988|0.0809|0.9698|
156
+ |BigVGAN-v2|Large-scale Compilation|3M|**4.359**|**0.7134**|0.3060|**0.0621**|**0.9777**|
157
+
158
+ ## Acknowledgements
159
+ We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference.
160
+
161
+ ## References
162
+ * [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
163
+ * [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
164
+ * [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
165
+ * [Julius](https://github.com/adefossez/julius) (for low-pass filter)
166
+ * [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
167
+ * [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss)
168
+ * [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator)
169
+
activations.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ '''
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ '''
25
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
26
+ '''
27
+ Initialization.
28
+ INPUT:
29
+ - in_features: shape of the input
30
+ - alpha: trainable parameter
31
+ alpha is initialized to 1 by default, higher values = higher-frequency.
32
+ alpha will be trained along with the rest of your model.
33
+ '''
34
+ super(Snake, self).__init__()
35
+ self.in_features = in_features
36
+
37
+ # initialize alpha
38
+ self.alpha_logscale = alpha_logscale
39
+ if self.alpha_logscale: # log scale alphas initialized to zeros
40
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
41
+ else: # linear scale alphas initialized to ones
42
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+
46
+ self.no_div_by_zero = 0.000000001
47
+
48
+ def forward(self, x):
49
+ '''
50
+ Forward pass of the function.
51
+ Applies the function to the input elementwise.
52
+ Snake ∶= x + 1/a * sin^2 (xa)
53
+ '''
54
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
55
+ if self.alpha_logscale:
56
+ alpha = torch.exp(alpha)
57
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
58
+
59
+ return x
60
+
61
+
62
+ class SnakeBeta(nn.Module):
63
+ '''
64
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
65
+ Shape:
66
+ - Input: (B, C, T)
67
+ - Output: (B, C, T), same shape as the input
68
+ Parameters:
69
+ - alpha - trainable parameter that controls frequency
70
+ - beta - trainable parameter that controls magnitude
71
+ References:
72
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73
+ https://arxiv.org/abs/2006.08195
74
+ Examples:
75
+ >>> a1 = snakebeta(256)
76
+ >>> x = torch.randn(256)
77
+ >>> x = a1(x)
78
+ '''
79
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
80
+ '''
81
+ Initialization.
82
+ INPUT:
83
+ - in_features: shape of the input
84
+ - alpha - trainable parameter that controls frequency
85
+ - beta - trainable parameter that controls magnitude
86
+ alpha is initialized to 1 by default, higher values = higher-frequency.
87
+ beta is initialized to 1 by default, higher values = higher-magnitude.
88
+ alpha will be trained along with the rest of your model.
89
+ '''
90
+ super(SnakeBeta, self).__init__()
91
+ self.in_features = in_features
92
+
93
+ # initialize alpha
94
+ self.alpha_logscale = alpha_logscale
95
+ if self.alpha_logscale: # log scale alphas initialized to zeros
96
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
97
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
98
+ else: # linear scale alphas initialized to ones
99
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
100
+ self.beta = Parameter(torch.ones(in_features) * alpha)
101
+
102
+ self.alpha.requires_grad = alpha_trainable
103
+ self.beta.requires_grad = alpha_trainable
104
+
105
+ self.no_div_by_zero = 0.000000001
106
+
107
+ def forward(self, x):
108
+ '''
109
+ Forward pass of the function.
110
+ Applies the function to the input elementwise.
111
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
112
+ '''
113
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
114
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
115
+ if self.alpha_logscale:
116
+ alpha = torch.exp(alpha)
117
+ beta = torch.exp(beta)
118
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
119
+
120
+ return x
alias_free_cuda/__init__.py ADDED
File without changes
alias_free_cuda/activation1d.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from alias_free_torch.resample import UpSample1d, DownSample1d
7
+ # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
8
+ from alias_free_cuda import load
9
+ load.load()
10
+
11
+ class FusedAntiAliasActivation(torch.autograd.Function):
12
+ """
13
+ Assumes filter size 12, replication padding on upsampling, and logscale alpha/beta parameters as inputs
14
+ """
15
+ @staticmethod
16
+ def forward(ctx, inputs, ftr, alpha, beta):
17
+ import anti_alias_activation_cuda
18
+ activation_results = anti_alias_activation_cuda.forward(inputs, ftr, alpha, beta)
19
+ return activation_results
20
+
21
+ @staticmethod
22
+ def backward(ctx, output_grads):
23
+ # TODO: implement bwd pass
24
+ raise NotImplementedError
25
+ return output_grads, None, None
26
+
27
+ class Activation1d(nn.Module):
28
+ def __init__(self,
29
+ activation,
30
+ up_ratio: int = 2,
31
+ down_ratio: int = 2,
32
+ up_kernel_size: int = 12,
33
+ down_kernel_size: int = 12,
34
+ fused: bool = True
35
+ ):
36
+ super().__init__()
37
+ self.up_ratio = up_ratio
38
+ self.down_ratio = down_ratio
39
+ self.act = activation
40
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
41
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
42
+
43
+ self.fused = fused # whether to use fused CUDA kernel or not
44
+
45
+
46
+ def forward(self, x):
47
+ if not self.fused:
48
+ x = self.upsample(x)
49
+ x = self.act(x)
50
+ x = self.downsample(x)
51
+ return x
52
+ else:
53
+ if self.act.__class__.__name__ == "Snake":
54
+ beta = self.act.alpha.data # snake uses same params for alpha and beta
55
+ else:
56
+ beta = self.act.beta.data # snakebeta uses different params for alpha and beta
57
+ alpha = self.act.alpha.data
58
+ if not self.act.alpha_logscale: # exp baked into cuda kernel, cancel it out with a log
59
+ alpha = torch.log(alpha)
60
+ beta = torch.log(beta)
61
+ x = FusedAntiAliasActivation.apply(x, self.upsample.filter, alpha, beta)
62
+ x = self.downsample(x)
63
+ return x
alias_free_cuda/anti_alias_activation.cpp ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <cuda_fp16.h>
18
+ #include <torch/extension.h>
19
+ #include <vector>
20
+
21
+ namespace anti_alias_activation {
22
+
23
+ torch::Tensor fwd_cuda(torch::Tensor const& input,
24
+ torch::Tensor const& filter,
25
+ torch::Tensor const& alpha,
26
+ torch::Tensor const& beta
27
+ );
28
+
29
+ torch::Tensor fwd(torch::Tensor const& input,
30
+ torch::Tensor const& filter,
31
+ torch::Tensor const& alpha,
32
+ torch::Tensor const& beta
33
+ ) {
34
+ AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
35
+ //AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
36
+ // (input.scalar_type() == at::ScalarType::BFloat16),
37
+ // "Only fp16 and bf16 are supported");
38
+
39
+ return fwd_cuda(input, filter, alpha, beta);
40
+ }
41
+
42
+ } // end namespace anti_alias_activation
43
+
44
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
45
+ m.def("forward",
46
+ &anti_alias_activation::fwd,
47
+ "Anti Alias Activation -- Forward.");
48
+ }
alias_free_cuda/anti_alias_activation_cuda.cu ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include <cuda.h>
19
+ #include <cuda_runtime.h>
20
+ #include <cuda_fp16.h>
21
+ #include <cuda_profiler_api.h>
22
+ #include <ATen/cuda/CUDAContext.h>
23
+ #include <torch/extension.h>
24
+ #include "type_shim.h"
25
+ #include <assert.h>
26
+ #include <cfloat>
27
+ #include <limits>
28
+ #include <stdint.h>
29
+ #include <c10/macros/Macros.h>
30
+
31
+ namespace {
32
+
33
+ /*
34
+ template <typename Datatype, int ELEMENTS_PER_LDG>
35
+ __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
36
+
37
+ template <>
38
+ __device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
39
+
40
+ template <>
41
+ __device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
42
+
43
+ template <>
44
+ __device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
45
+
46
+ template <>
47
+ __device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
48
+
49
+ template <>
50
+ __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
51
+
52
+ template <>
53
+ __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
54
+
55
+ int log2_ceil(int value) {
56
+ int log2_value = 0;
57
+ while ((1 << log2_value) < value) ++log2_value;
58
+ return log2_value;
59
+ }
60
+
61
+ template<typename T>
62
+ struct Add {
63
+ __device__ __forceinline__ T operator()(T a, T b) const {
64
+ return a + b;
65
+ }
66
+ };
67
+
68
+ template<typename T>
69
+ struct Max {
70
+ __device__ __forceinline__ T operator()(T a, T b) const {
71
+ return a < b ? b : a;
72
+ }
73
+ };
74
+
75
+ template <typename T>
76
+ __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
77
+ {
78
+ #if CUDA_VERSION >= 9000
79
+ return __shfl_xor_sync(mask, value, laneMask, width);
80
+ #else
81
+ return __shfl_xor(value, laneMask, width);
82
+ #endif
83
+ }
84
+
85
+ template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
86
+ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
87
+ ReduceOp<acc_t> r;
88
+ #pragma unroll
89
+ for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
90
+ #pragma unroll
91
+ for (int i = 0; i < WARP_BATCH; ++i) {
92
+ acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
93
+ sum[i] = r(sum[i], b);
94
+ }
95
+ }
96
+ }
97
+ */
98
+
99
+ template <typename input_t, typename output_t, typename acc_t>
100
+ __global__ void anti_alias_activation_forward(
101
+ output_t *dst,
102
+ const input_t *src,
103
+ const input_t *ftr,
104
+ const input_t *alpha,
105
+ const input_t *beta,
106
+ int batch_size,
107
+ int channels,
108
+ int seq_len)
109
+ {
110
+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
111
+ constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
112
+ constexpr int BUFFER_SIZE = 32;
113
+ constexpr int FILTER_SIZE = 12;
114
+ constexpr int HALF_FILTER_SIZE = 6;
115
+ constexpr int REPLICATION_PAD = 5; // 5 on each side
116
+
117
+ // blockDim/threadIdx = (128, 1, 1)
118
+ // gridDim/blockIdx = (seq_blocks, channels, batches)
119
+ int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
120
+ int local_offset = threadIdx.x * BUFFER_SIZE;
121
+ int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
122
+
123
+
124
+ //int intermediate_seq_len = seq_len * 2 - 1 + 4 * REPLICATION_PAD;
125
+ //int intermediate_block_offset = (blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
126
+ //int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
127
+
128
+ int output_seq_len = seq_len * 2 ; //
129
+ int output_block_offset = (blockIdx.x * 128 * BUFFER_SIZE * 2 + output_seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
130
+ int output_local_offset = threadIdx.x * BUFFER_SIZE * 2;
131
+ int output_seq_offset = blockIdx.x * 128 * BUFFER_SIZE *2 + output_local_offset;
132
+ // get values needed for replication padding before moving pointer
133
+ const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
134
+ input_t seq_left_most_value = right_most_pntr[0];
135
+ input_t seq_right_most_value = right_most_pntr[seq_len - 1];
136
+
137
+ src += block_offset + local_offset;
138
+ dst += output_block_offset + output_local_offset ;
139
+ alpha = alpha + blockIdx.y;
140
+ input_t alpha_val = expf(alpha[0]);
141
+ beta = beta + blockIdx.y;
142
+ input_t beta_val = expf(beta[0]);
143
+ // load data from global memory
144
+ input_t elements[2*FILTER_SIZE+2*BUFFER_SIZE] = {0};
145
+ input_t intermediates[2*FILTER_SIZE+2*BUFFER_SIZE] = {0};
146
+ //output_t output[2*BUFFER_SIZE];
147
+ input_t filter[FILTER_SIZE];
148
+ //input_t temp_data[ELEMENTS_PER_LDG_STG];
149
+ //uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
150
+
151
+ #pragma unroll
152
+ for (int it = 0; it < FILTER_SIZE; it+=1) {
153
+ filter[it] = ftr[it];
154
+ }
155
+
156
+
157
+ #pragma unroll
158
+ for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE ; it+=1) {
159
+ int element_index = seq_offset + it;
160
+ if ((element_index < 0) && (element_index >= -REPLICATION_PAD)) {
161
+ elements[2*(HALF_FILTER_SIZE+it)] = 2*seq_left_most_value;
162
+ }
163
+ if ((element_index >= seq_len) && (element_index < seq_len + REPLICATION_PAD)) {
164
+ elements[2*(HALF_FILTER_SIZE+it)] = 2*seq_right_most_value;
165
+ }
166
+ if ((element_index >= 0) && (element_index < seq_len)) {
167
+ elements[2*(HALF_FILTER_SIZE+it)] = 2*src[it];
168
+ }
169
+ }
170
+
171
+
172
+
173
+ // apply filter
174
+ #pragma unroll
175
+ for (int it = 0; it < (2 * BUFFER_SIZE + 2*FILTER_SIZE); it+=1) {
176
+ input_t acc = 0.0;
177
+
178
+ int element_index = output_seq_offset + it; // index for output
179
+ #pragma unroll
180
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx+=1){
181
+ if ((element_index + f_idx) >= 0){
182
+ acc += filter[f_idx] * elements[it+f_idx];
183
+ }
184
+ }
185
+ intermediates[it] = acc;
186
+ }
187
+
188
+ double no_div_by_zero = 0.000000001;
189
+ #pragma unroll
190
+ for (int it = 0; it < 12 + 2 * BUFFER_SIZE; it++) {
191
+ intermediates[it] += (1.0/(beta_val + no_div_by_zero)) * sinf(intermediates[it] * alpha_val) * sinf(intermediates[it] * alpha_val);
192
+ }
193
+
194
+
195
+ // now copy to output
196
+ #pragma unroll
197
+ for (int it = 0; it < 2*BUFFER_SIZE; it+=1){
198
+ int element_index = output_seq_offset + it;
199
+ if (element_index < output_seq_len) {
200
+ dst[it] = intermediates[it+6];
201
+ }
202
+ }
203
+
204
+
205
+
206
+ // for (int it = 0; it < BUFFER_SIZE; it+=ELEMENTS_PER_LDG_STG) {
207
+ // int element_index = seq_offset + it;
208
+ // if (element_index < seq_len) {
209
+ // dst[it] = output[it];
210
+ // }
211
+ // }
212
+
213
+
214
+ // // Upsample convolution
215
+ // for (int it = 0; it < 2 * BUFFER_SIZE + 12; it+=1) {
216
+ // input_t acc = 0.0;
217
+
218
+ // for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx+=1){
219
+ // acc += filter[f_idx] * elements[it+f_idx];
220
+ // }
221
+ // intermediates[it] = acc;
222
+ // }
223
+
224
+ // // correct the corners of intermediates
225
+ // if (seq_offset == 0) {
226
+ // for (int it = 0; it < 6; it+=1)
227
+ // intermediates[it] = 0;
228
+ // }
229
+
230
+ // if (seq_offset + 32 >= seq_len) {
231
+ // int offset = seq_len % 32 == 0 ? 32 : seq_len % 32;
232
+
233
+ // for (int it = 0; it < 6; it++) {
234
+ // intermediates[6+2*offset+it] = 0;
235
+ // }
236
+ // }
237
+
238
+
239
+
240
+
241
+ // for (int it = 0; it < BUFFER_SIZE; it+=ELEMENTS_PER_LDG_STG) {
242
+ // int element_index = seq_offset + it;
243
+ // if (element_index < seq_len) {
244
+ // dst[it] = output[it];
245
+ // }
246
+ // }
247
+ }
248
+
249
+ template<typename input_t, typename output_t, typename acc_t>
250
+ void dispatch_anti_alias_activation_forward(
251
+ output_t *dst,
252
+ const input_t *src,
253
+ const input_t *ftr,
254
+ const input_t *alpha,
255
+ const input_t *beta,
256
+ int batch_size,
257
+ int channels,
258
+ int seq_len)
259
+ {
260
+ if (seq_len == 0) {
261
+ return;
262
+ } else {
263
+ // use 128 threads per block to maximimize gpu utilization
264
+ constexpr int threads_per_block = 128;
265
+ constexpr int seq_len_per_block = 4096;
266
+ int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
267
+ dim3 blocks(blocks_per_seq_len, channels, batch_size);
268
+ dim3 threads(threads_per_block, 1, 1);
269
+
270
+ anti_alias_activation_forward<input_t, output_t, acc_t>
271
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, ftr, alpha, beta, batch_size, channels, seq_len);
272
+ }
273
+ }
274
+ }
275
+
276
+ namespace anti_alias_activation {
277
+
278
+ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& filter, torch::Tensor const& alpha, torch::Tensor const& beta)
279
+ {
280
+ // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
281
+ const int batches = input.size(0);
282
+ const int channels = input.size(1);
283
+ const int seq_len = input.size(2);
284
+
285
+ // Output
286
+ auto act_options = input.options().requires_grad(false);
287
+ int output_seq_len = seq_len*2; // we'll be dilating between each element by interspersing with zeros
288
+
289
+ torch::Tensor anti_alias_activation_results =
290
+ torch::empty({batches, channels, output_seq_len}, act_options);
291
+
292
+ // Softmax Intermediate Result Ptr
293
+ void* input_ptr = static_cast<void*>(input.data_ptr());
294
+ void* filter_ptr = static_cast<void*>(filter.data_ptr());
295
+ void* alpha_ptr = static_cast<void*>(alpha.data_ptr());
296
+ void* beta_ptr = static_cast<void*>(beta.data_ptr());
297
+ void* anti_alias_activation_results_ptr = static_cast<void*>(anti_alias_activation_results.data_ptr());
298
+
299
+ DISPATCH_FLOAT_HALF_AND_BFLOAT(
300
+ input.scalar_type(),
301
+ "dispatch anti alias activation_forward",
302
+ dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
303
+ reinterpret_cast<scalar_t*>(anti_alias_activation_results_ptr),
304
+ reinterpret_cast<const scalar_t*>(input_ptr),
305
+ reinterpret_cast<const scalar_t*>(filter_ptr),
306
+ reinterpret_cast<const scalar_t*>(alpha_ptr),
307
+ reinterpret_cast<const scalar_t*>(beta_ptr),
308
+ batches,
309
+ channels,
310
+ seq_len);
311
+ );
312
+ return anti_alias_activation_results;
313
+ }
314
+ }
alias_free_cuda/compat.h ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ /*This code is copied fron NVIDIA apex:
18
+ * https://github.com/NVIDIA/apex
19
+ * with minor changes. */
20
+
21
+
22
+
23
+ #ifndef TORCH_CHECK
24
+ #define TORCH_CHECK AT_CHECK
25
+ #endif
26
+
27
+ #ifdef VERSION_GE_1_3
28
+ #define DATA_PTR data_ptr
29
+ #else
30
+ #define DATA_PTR data
31
+ #endif
alias_free_cuda/load.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import os
5
+ import pathlib
6
+ import subprocess
7
+
8
+ from torch.utils import cpp_extension
9
+
10
+ # Setting this param to a list has a problem of generating different
11
+ # compilation commands (with diferent order of architectures) and
12
+ # leading to recompilation of fused kernels. Set it to empty string
13
+ # to avoid recompilation and assign arch flags explicity in
14
+ # extra_cuda_cflags below
15
+ os.environ["TORCH_CUDA_ARCH_LIST"] = ""
16
+
17
+
18
+ def load():
19
+ # Check if cuda 11 is installed for compute capability 8.0
20
+ cc_flag = []
21
+ _, bare_metal_major, _ = _get_cuda_bare_metal_version(
22
+ cpp_extension.CUDA_HOME)
23
+ if int(bare_metal_major) >= 11:
24
+ cc_flag.append('-gencode')
25
+ cc_flag.append('arch=compute_80,code=sm_80')
26
+
27
+ # Build path
28
+ srcpath = pathlib.Path(__file__).parent.absolute()
29
+ buildpath = srcpath / 'build'
30
+ _create_build_dir(buildpath)
31
+
32
+ # Helper function to build the kernels.
33
+ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
34
+ return cpp_extension.load(
35
+ name=name,
36
+ sources=sources,
37
+ build_directory=buildpath,
38
+ extra_cflags=['-O3',],
39
+ extra_cuda_cflags=['-O3',
40
+ '-gencode', 'arch=compute_70,code=sm_70',
41
+ '--use_fast_math'] + extra_cuda_flags + cc_flag,
42
+ verbose=True
43
+ )
44
+
45
+ extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
46
+ '-U__CUDA_NO_HALF_CONVERSIONS__',
47
+ '--expt-relaxed-constexpr',
48
+ '--expt-extended-lambda']
49
+
50
+ sources=[srcpath / 'anti_alias_activation.cpp',
51
+ srcpath / 'anti_alias_activation_cuda.cu']
52
+ anti_alias_activation_cuda = _cpp_extention_load_helper(
53
+ "anti_alias_activation_cuda", sources, extra_cuda_flags)
54
+
55
+ def _get_cuda_bare_metal_version(cuda_dir):
56
+ raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
57
+ universal_newlines=True)
58
+ output = raw_output.split()
59
+ release_idx = output.index("release") + 1
60
+ release = output[release_idx].split(".")
61
+ bare_metal_major = release[0]
62
+ bare_metal_minor = release[1][0]
63
+
64
+ return raw_output, bare_metal_major, bare_metal_minor
65
+
66
+
67
+ def _create_build_dir(buildpath):
68
+ try:
69
+ os.mkdir(buildpath)
70
+ except OSError:
71
+ if not os.path.isdir(buildpath):
72
+ print(f"Creation of the build directory {buildpath} failed")
alias_free_cuda/test_activation.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import math
5
+ import torch
6
+ import alias_free_cuda
7
+ from alias_free_cuda import activation1d
8
+ from activations import Snake, SnakeBeta
9
+
10
+ def test_load_fused_kernels():
11
+ try:
12
+ import alias_free_cuda
13
+ import torch
14
+ print("[Success] load_fused_kernels")
15
+ except ImportError as e:
16
+ print("[Fail] load_fused_kernels")
17
+ raise e
18
+
19
+ def test_anti_alias_activation():
20
+ data = torch.rand((10, 10, 50000), device='cuda')
21
+
22
+ # check activations.Snake cuda vs. torch
23
+ fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
24
+ fused_activation_output = fused_anti_alias_activation(data)
25
+
26
+ torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
27
+ torch_activation_output = torch_anti_alias_activation(data)
28
+
29
+ test_result = (fused_activation_output - torch_activation_output).abs()
30
+
31
+ while test_result.dim() != 1:
32
+ test_result = test_result.mean(dim=-1)
33
+
34
+ diff = test_result.mean(dim=-1)
35
+
36
+ if diff <= 1e-3:
37
+ print(
38
+ f"\n[Success] test_fused_anti_alias_activation"
39
+ f"\n > mean_difference={diff}"
40
+ f"\n > fused_values={fused_activation_output[-1][-1][-100:].tolist()}"
41
+ f"\n > torch_values={torch_activation_output[-1][-1][-100:].tolist()}"
42
+ )
43
+ else:
44
+ print(
45
+ f"\n[Fail] test_fused_anti_alias_activation"
46
+ f"\n > mean_difference={diff}, "
47
+ f"\n > fused_values={fused_activation_output[-1][-1][-30:].tolist()}, "
48
+ f"\n > torch_values={torch_activation_output[-1][-1][-30:].tolist()}"
49
+ )
50
+
51
+ if __name__ == "__main__":
52
+ from alias_free_cuda import load
53
+ load.load()
54
+ test_load_fused_kernels()
55
+ test_anti_alias_activation()
alias_free_cuda/test_activation_snake_beta.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import math
5
+ import torch
6
+ import alias_free_cuda
7
+ from alias_free_cuda import activation1d
8
+ from activations import Snake, SnakeBeta
9
+
10
+ def test_load_fused_kernels():
11
+ try:
12
+ import alias_free_cuda
13
+ import torch
14
+ print("[Success] load_fused_kernels")
15
+ except ImportError as e:
16
+ print("[Fail] load_fused_kernels")
17
+ raise e
18
+
19
+ def test_anti_alias_activation():
20
+ data = torch.rand((10, 10, 50000), device='cuda')
21
+
22
+ # check activations.Snake cuda vs. torch
23
+ fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
24
+ fused_activation_output = fused_anti_alias_activation(data)
25
+
26
+ torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
27
+ torch_activation_output = torch_anti_alias_activation(data)
28
+
29
+ test_result = (fused_activation_output - torch_activation_output).abs()
30
+
31
+ while test_result.dim() != 1:
32
+ test_result = test_result.mean(dim=-1)
33
+
34
+ diff = test_result.mean(dim=-1)
35
+
36
+ if diff <= 1e-3:
37
+ print(
38
+ f"\n[Success] test_fused_anti_alias_activation"
39
+ f"\n > mean_difference={diff}"
40
+ f"\n > fused_values={fused_activation_output[-1][-1][-100:].tolist()}"
41
+ f"\n > torch_values={torch_activation_output[-1][-1][-100:].tolist()}"
42
+ )
43
+ else:
44
+ print(
45
+ f"\n[Fail] test_fused_anti_alias_activation"
46
+ f"\n > mean_difference={diff}, "
47
+ f"\n > fused_values={fused_activation_output[-1][-1][-30:].tolist()}, "
48
+ f"\n > torch_values={torch_activation_output[-1][-1][-30:].tolist()}"
49
+ )
50
+
51
+ if __name__ == "__main__":
52
+ from alias_free_cuda import load
53
+ load.load()
54
+ test_load_fused_kernels()
55
+ test_anti_alias_activation()
alias_free_cuda/type_shim.h ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+
18
+ #include <ATen/ATen.h>
19
+ #include "compat.h"
20
+
21
+
22
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
23
+ switch(TYPE) \
24
+ { \
25
+ case at::ScalarType::Float: \
26
+ { \
27
+ using scalar_t = float; \
28
+ __VA_ARGS__; \
29
+ break; \
30
+ } \
31
+ case at::ScalarType::Half: \
32
+ { \
33
+ using scalar_t = at::Half; \
34
+ __VA_ARGS__; \
35
+ break; \
36
+ } \
37
+ case at::ScalarType::BFloat16: \
38
+ { \
39
+ using scalar_t = at::BFloat16; \
40
+ __VA_ARGS__; \
41
+ break; \
42
+ } \
43
+ default: \
44
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
45
+ }
46
+
47
+
48
+
49
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
50
+ switch(TYPEIN) \
51
+ { \
52
+ case at::ScalarType::Float: \
53
+ { \
54
+ using scalar_t_in = float; \
55
+ switch(TYPEOUT) \
56
+ { \
57
+ case at::ScalarType::Float: \
58
+ { \
59
+ using scalar_t_out = float; \
60
+ __VA_ARGS__; \
61
+ break; \
62
+ } \
63
+ case at::ScalarType::Half: \
64
+ { \
65
+ using scalar_t_out = at::Half; \
66
+ __VA_ARGS__; \
67
+ break; \
68
+ } \
69
+ case at::ScalarType::BFloat16: \
70
+ { \
71
+ using scalar_t_out = at::BFloat16; \
72
+ __VA_ARGS__; \
73
+ break; \
74
+ } \
75
+ default: \
76
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
77
+ } \
78
+ break; \
79
+ } \
80
+ case at::ScalarType::Half: \
81
+ { \
82
+ using scalar_t_in = at::Half; \
83
+ using scalar_t_out = at::Half; \
84
+ __VA_ARGS__; \
85
+ break; \
86
+ } \
87
+ case at::ScalarType::BFloat16: \
88
+ { \
89
+ using scalar_t_in = at::BFloat16; \
90
+ using scalar_t_out = at::BFloat16; \
91
+ __VA_ARGS__; \
92
+ break; \
93
+ } \
94
+ default: \
95
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
96
+ }
97
+
alias_free_torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
alias_free_torch/act.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from .resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(self,
10
+ activation,
11
+ up_ratio: int = 2,
12
+ down_ratio: int = 2,
13
+ up_kernel_size: int = 12,
14
+ down_kernel_size: int = 12):
15
+ super().__init__()
16
+ self.up_ratio = up_ratio
17
+ self.down_ratio = down_ratio
18
+ self.act = activation
19
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
20
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
21
+
22
+ # x: [B,C,T]
23
+ def forward(self, x):
24
+ x = self.upsample(x)
25
+ x = self.act(x)
26
+ x = self.downsample(x)
27
+
28
+ return x
alias_free_torch/filter.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if 'sinc' in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(x == 0,
21
+ torch.tensor(1., device=x.device, dtype=x.dtype),
22
+ torch.sin(math.pi * x) / math.pi / x)
23
+
24
+
25
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
+ # https://adefossez.github.io/julius/julius/lowpass.html
27
+ # LICENSE is in incl_licenses directory.
28
+ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
29
+ even = (kernel_size % 2 == 0)
30
+ half_size = kernel_size // 2
31
+
32
+ #For kaiser window
33
+ delta_f = 4 * half_width
34
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
35
+ if A > 50.:
36
+ beta = 0.1102 * (A - 8.7)
37
+ elif A >= 21.:
38
+ beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
39
+ else:
40
+ beta = 0.
41
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
42
+
43
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
44
+ if even:
45
+ time = (torch.arange(-half_size, half_size) + 0.5)
46
+ else:
47
+ time = torch.arange(kernel_size) - half_size
48
+ if cutoff == 0:
49
+ filter_ = torch.zeros_like(time)
50
+ else:
51
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
52
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
53
+ # of the constant component in the input signal.
54
+ filter_ /= filter_.sum()
55
+ filter = filter_.view(1, 1, kernel_size)
56
+
57
+ return filter
58
+
59
+
60
+ class LowPassFilter1d(nn.Module):
61
+ def __init__(self,
62
+ cutoff=0.5,
63
+ half_width=0.6,
64
+ stride: int = 1,
65
+ padding: bool = True,
66
+ padding_mode: str = 'replicate',
67
+ kernel_size: int = 12):
68
+ # kernel_size should be even number for stylegan3 setup,
69
+ # in this implementation, odd number is also possible.
70
+ super().__init__()
71
+ if cutoff < -0.:
72
+ raise ValueError("Minimum cutoff must be larger than zero.")
73
+ if cutoff > 0.5:
74
+ raise ValueError("A cutoff above 0.5 does not make sense.")
75
+ self.kernel_size = kernel_size
76
+ self.even = (kernel_size % 2 == 0)
77
+ self.pad_left = kernel_size // 2 - int(self.even)
78
+ self.pad_right = kernel_size // 2
79
+ self.stride = stride
80
+ self.padding = padding
81
+ self.padding_mode = padding_mode
82
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
83
+ self.register_buffer("filter", filter)
84
+
85
+ #input [B, C, T]
86
+ def forward(self, x):
87
+ _, C, _ = x.shape
88
+
89
+ if self.padding:
90
+ x = F.pad(x, (self.pad_left, self.pad_right),
91
+ mode=self.padding_mode)
92
+ out = F.conv1d(x, self.filter.expand(C, -1, -1),
93
+ stride=self.stride, groups=C)
94
+
95
+ return out
alias_free_torch/resample.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ self.stride = ratio
16
+ self.pad = self.kernel_size // ratio - 1
17
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
20
+ half_width=0.6 / ratio,
21
+ kernel_size=self.kernel_size)
22
+ self.register_buffer("filter", filter)
23
+
24
+ # x: [B, C, T]
25
+ def forward(self, x):
26
+ _, C, _ = x.shape
27
+
28
+ x = F.pad(x, (self.pad, self.pad), mode='replicate')
29
+ x = self.ratio * F.conv_transpose1d(
30
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
31
+ x = x[..., self.pad_left:-self.pad_right]
32
+
33
+ return x
34
+
35
+
36
+ class DownSample1d(nn.Module):
37
+ def __init__(self, ratio=2, kernel_size=None):
38
+ super().__init__()
39
+ self.ratio = ratio
40
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
41
+ self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
42
+ half_width=0.6 / ratio,
43
+ stride=ratio,
44
+ kernel_size=self.kernel_size)
45
+
46
+ def forward(self, x):
47
+ xx = self.lowpass(x)
48
+
49
+ return xx
app.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ from huggingface_hub import hf_hub_download
4
+
5
+ import json
6
+ import torch
7
+ import os
8
+ from env import AttrDict
9
+ from meldataset import mel_spectrogram, MAX_WAV_VALUE
10
+ from models import BigVGAN as Generator
11
+ import librosa
12
+ import numpy as np
13
+ from utils import plot_spectrogram, plot_spectrogram_clipped
14
+ import PIL
15
+
16
+ if torch.cuda.is_available():
17
+ device = torch.device('cuda')
18
+ torch.backends.cudnn.benchmark = False
19
+ print(f"using GPU")
20
+ else:
21
+ device = torch.device('cpu')
22
+ print(f"using CPU")
23
+
24
+
25
+ def load_checkpoint(filepath):
26
+ assert os.path.isfile(filepath)
27
+ print("Loading '{}'".format(filepath))
28
+ checkpoint_dict = torch.load(filepath, map_location='cpu')
29
+ print("Complete.")
30
+ return checkpoint_dict
31
+
32
+
33
+ def inference_gradio(input, model_choice): # input is audio waveform in [T, channel]
34
+ sr, audio = input # unpack input to sampling rate and audio itself
35
+ audio = np.transpose(audio) # transpose to [channel, T] for librosa
36
+ audio = audio / MAX_WAV_VALUE # convert int16 to float range used by BigVGAN
37
+
38
+ h = list_config[model_choice]
39
+ model = list_model[model_choice]
40
+
41
+ if sr != h.sampling_rate: # convert audio to model's sampling rate
42
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=h.sampling_rate)
43
+ if len(audio.shape) == 2: # stereo
44
+ audio = librosa.to_mono(audio) # convert to mono if stereo
45
+ audio = librosa.util.normalize(audio) * 0.95
46
+ output, spec_gen = inference_model(audio, h, model) # output is generated audio in ndarray
47
+
48
+ spec_plot_gen = plot_spectrogram(spec_gen.numpy())
49
+
50
+ output_video = gr.make_waveform((h.sampling_rate, output))
51
+ output_image_gen = PIL.Image.frombytes('RGB',
52
+ spec_plot_gen.canvas.get_width_height(),
53
+ spec_plot_gen.canvas.tostring_rgb())
54
+
55
+ return output_video, output_image_gen
56
+
57
+
58
+ @spaces.GPU(duration=120)
59
+ def inference_model(audio_input, h, model):
60
+ model.to(device)
61
+
62
+ def get_mel(x):
63
+ return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
64
+
65
+ with torch.inference_mode():
66
+ wav = torch.FloatTensor(audio_input)
67
+ # compute mel spectrogram from the ground truth audio
68
+ spec_gt = get_mel(wav.unsqueeze(0)).to(device)
69
+
70
+ y_g_hat = model(spec_gt)
71
+
72
+ audio_gen = y_g_hat.squeeze()
73
+ spec_gen = get_mel(audio_gen.unsqueeze(0))
74
+ audio_gen = audio_gen * MAX_WAV_VALUE
75
+ audio_gen = audio_gen.cpu().numpy().astype('int16')
76
+
77
+ return audio_gen, spec_gen[0].cpu()
78
+
79
+
80
+ css = """
81
+ a {
82
+ color: inherit;
83
+ text-decoration: underline;
84
+ }
85
+ .gradio-container {
86
+ font-family: 'IBM Plex Sans', sans-serif;
87
+ }
88
+ .gr-button {
89
+ color: white;
90
+ border-color: #000000;
91
+ background: #000000;
92
+ }
93
+ input[type='range'] {
94
+ accent-color: #000000;
95
+ }
96
+ .dark input[type='range'] {
97
+ accent-color: #dfdfdf;
98
+ }
99
+ .container {
100
+ max-width: 730px;
101
+ margin: auto;
102
+ padding-top: 1.5rem;
103
+ }
104
+ #gallery {
105
+ min-height: 22rem;
106
+ margin-bottom: 15px;
107
+ margin-left: auto;
108
+ margin-right: auto;
109
+ border-bottom-right-radius: .5rem !important;
110
+ border-bottom-left-radius: .5rem !important;
111
+ }
112
+ #gallery>div>.h-full {
113
+ min-height: 20rem;
114
+ }
115
+ .details:hover {
116
+ text-decoration: underline;
117
+ }
118
+ .gr-button {
119
+ white-space: nowrap;
120
+ }
121
+ .gr-button:focus {
122
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
123
+ outline: none;
124
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
125
+ --tw-border-opacity: 1;
126
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
127
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
128
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
129
+ --tw-ring-opacity: .5;
130
+ }
131
+ #advanced-btn {
132
+ font-size: .7rem !important;
133
+ line-height: 19px;
134
+ margin-top: 12px;
135
+ margin-bottom: 12px;
136
+ padding: 2px 8px;
137
+ border-radius: 14px !important;
138
+ }
139
+ #advanced-options {
140
+ margin-bottom: 20px;
141
+ }
142
+ .footer {
143
+ margin-bottom: 45px;
144
+ margin-top: 35px;
145
+ text-align: center;
146
+ border-bottom: 1px solid #e5e5e5;
147
+ }
148
+ .footer>p {
149
+ font-size: .8rem;
150
+ display: inline-block;
151
+ padding: 0 10px;
152
+ transform: translateY(10px);
153
+ background: white;
154
+ }
155
+ .dark .footer {
156
+ border-color: #303030;
157
+ }
158
+ .dark .footer>p {
159
+ background: #0b0f19;
160
+ }
161
+ .acknowledgments h4{
162
+ margin: 1.25em 0 .25em 0;
163
+ font-weight: bold;
164
+ font-size: 115%;
165
+ }
166
+ #container-advanced-btns{
167
+ display: flex;
168
+ flex-wrap: wrap;
169
+ justify-content: space-between;
170
+ align-items: center;
171
+ }
172
+ .animate-spin {
173
+ animation: spin 1s linear infinite;
174
+ }
175
+ @keyframes spin {
176
+ from {
177
+ transform: rotate(0deg);
178
+ }
179
+ to {
180
+ transform: rotate(360deg);
181
+ }
182
+ }
183
+ #share-btn-container {
184
+ display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
185
+ margin-top: 10px;
186
+ margin-left: auto;
187
+ }
188
+ #share-btn {
189
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
190
+ }
191
+ #share-btn * {
192
+ all: unset;
193
+ }
194
+ #share-btn-container div:nth-child(-n+2){
195
+ width: auto !important;
196
+ min-height: 0px !important;
197
+ }
198
+ #share-btn-container .wrap {
199
+ display: none !important;
200
+ }
201
+ .gr-form{
202
+ flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
203
+ }
204
+ #prompt-container{
205
+ gap: 0;
206
+ }
207
+ #generated_id{
208
+ min-height: 700px
209
+ }
210
+ #setting_id{
211
+ margin-bottom: 12px;
212
+ text-align: center;
213
+ font-weight: 900;
214
+ }
215
+ """
216
+
217
+ ######################## script for loading the models ########################
218
+
219
+ model_path = "L0SG/BigVGAN"
220
+
221
+ list_model_name = [
222
+ "bigvgan_24khz_100band",
223
+ "bigvgan_base_24khz_100band",
224
+ "bigvgan_22khz_80band",
225
+ "bigvgan_base_22khz_80band",
226
+ "bigvgan_v2_22khz_80band_256x",
227
+ "bigvgan_v2_22khz_80band_fmax8k_256x",
228
+ "bigvgan_v2_24khz_100band_256x",
229
+ "bigvgan_v2_44khz_128band_256x",
230
+ "bigvgan_v2_44khz_128band_512x"
231
+ ]
232
+
233
+ model_files = {
234
+ "bigvgan_24khz_100band": "g_05000000",
235
+ "bigvgan_base_24khz_100band": "g_05000000",
236
+ "bigvgan_22khz_80band": "g_05000000",
237
+ "bigvgan_base_22khz_80band": "g_05000000",
238
+ "bigvgan_v2_22khz_80band_256x": "g_03000000",
239
+ "bigvgan_v2_22khz_80band_fmax8k_256x": "g_03000000",
240
+ "bigvgan_v2_24khz_100band_256x": "g_03000000",
241
+ "bigvgan_v2_44khz_128band_256x": "g_03000000",
242
+ "bigvgan_v2_44khz_128band_512x": "g_03000000"
243
+ }
244
+
245
+ list_model = []
246
+ list_config = []
247
+
248
+ for model_name in list_model_name:
249
+ model_file = hf_hub_download(model_path, f"{model_name}/{model_files[model_name]}")
250
+ config_file = hf_hub_download(model_path, f"{model_name}/config.json")
251
+
252
+ with open(config_file) as f:
253
+ data = f.read()
254
+
255
+ json_config = json.loads(data)
256
+ h = AttrDict(json_config)
257
+
258
+ torch.manual_seed(h.seed)
259
+
260
+ generator = Generator(h)
261
+ state_dict_g = load_checkpoint(model_file)
262
+ generator.load_state_dict(state_dict_g['generator'])
263
+ generator.eval()
264
+ generator.remove_weight_norm()
265
+
266
+ list_model.append(generator)
267
+ list_config.append(h)
268
+
269
+ ######################## script for gradio UI ########################
270
+
271
+ iface = gr.Blocks(css=css)
272
+
273
+ with iface:
274
+ gr.HTML(
275
+ """
276
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
277
+ <div
278
+ style="
279
+ display: inline-flex;
280
+ align-items: center;
281
+ gap: 0.8rem;
282
+ font-size: 1.75rem;
283
+ "
284
+ >
285
+ <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
286
+ BigVGAN: A Universal Neural Vocoder with Large-Scale Training
287
+ </h1>
288
+ </div>
289
+ <p style="margin-bottom: 10px; font-size: 94%">
290
+ <a href="https://arxiv.org/abs/2206.04658">[Paper]</a> <a href="https://github.com/NVIDIA/BigVGAN">[Code]</a> <a href="https://bigvgan-demo.github.io/">[Demo]</a> <a href="https://research.nvidia.com/labs/adlr/projects/bigvgan/">[Project page]</a>
291
+ </p>
292
+ </div>
293
+ """
294
+ )
295
+ gr.HTML(
296
+ """
297
+ <div>
298
+ <h2>News</h2>
299
+ <p>[Jul 2024] We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:</p>
300
+ <ul>
301
+ <li>Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.</li>
302
+ <li>Improved discriminator and loss: BigVGAN-v2 is trained using a <a href="https://arxiv.org/abs/2311.14957" target="_blank">multi-scale sub-band CQT discriminator</a> and a <a href="https://arxiv.org/abs/2306.06546" target="_blank">multi-scale mel spectrogram loss</a>.</li>
303
+ <li>Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.</li>
304
+ <li>We provide <a href="https://huggingface.co/L0SG/BigVGAN" target="_blank">pretrained checkpoints</a> of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio.</li>
305
+ </ul>
306
+ </div>
307
+ """
308
+ )
309
+
310
+ with gr.Group():
311
+ model_choice = gr.Radio(label="Select the model. Default: bigvgan_v2_24khz_100band_256x",
312
+ value="bigvgan_v2_24khz_100band_256x",
313
+ choices=[m for m in list_model_name],
314
+ type="index",
315
+ interactive=True)
316
+ audio_input = gr.Audio(label="Input Audio",
317
+ elem_id="input-audio",
318
+ interactive=True)
319
+ button = gr.Button("Submit")
320
+ output_video = gr.Video(label="Output Audio",
321
+ elem_id="output-video")
322
+ output_image_gen = gr.Image(label="Output Mel Spectrogram",
323
+ elem_id="output-image-gen")
324
+ button.click(inference_gradio,
325
+ inputs=[audio_input, model_choice],
326
+ outputs=[output_video, output_image_gen],
327
+ concurrency_limit=10
328
+ )
329
+
330
+ gr.Examples(
331
+ [
332
+ [os.path.join(os.path.dirname(__file__), "examples/jensen_24k.wav"), "bigvgan_v2_24khz_100band_256x"],
333
+ [os.path.join(os.path.dirname(__file__), "examples/libritts_24k.wav"), "bigvgan_v2_24khz_100band_256x"],
334
+ [os.path.join(os.path.dirname(__file__), "examples/queen_24k.wav"), "bigvgan_v2_24khz_100band_256x"],
335
+ [os.path.join(os.path.dirname(__file__), "examples/dance_24k.wav"), "bigvgan_v2_24khz_100band_256x"],
336
+ [os.path.join(os.path.dirname(__file__), "examples/megalovania_24k.wav"), "bigvgan_v2_24khz_100band_256x"],
337
+ [os.path.join(os.path.dirname(__file__), "examples/hifitts_44k.wav"), "bigvgan_v2_44khz_128band_256x"],
338
+ [os.path.join(os.path.dirname(__file__), "examples/musdbhq_44k.wav"), "bigvgan_v2_44khz_128band_256x"],
339
+ [os.path.join(os.path.dirname(__file__), "examples/musiccaps1_44k.wav"), "bigvgan_v2_44khz_128band_256x"],
340
+ [os.path.join(os.path.dirname(__file__), "examples/musiccaps2_44k.wav"), "bigvgan_v2_44khz_128band_256x"],
341
+ ],
342
+ fn=inference_gradio,
343
+ inputs=[audio_input, model_choice],
344
+ outputs=[output_video, output_image_gen]
345
+ )
346
+
347
+ gr.HTML(
348
+ """
349
+ <table border="1" cellspacing="0" cellpadding="5">
350
+ <thead>
351
+ <tr>
352
+ <th>Folder Name</th>
353
+ <th>Sampling Rate</th>
354
+ <th>Mel band</th>
355
+ <th>fmax</th>
356
+ <th>Upsampling Ratio</th>
357
+ <th>Params.</th>
358
+ <th>Dataset</th>
359
+ <th>Fine-Tuned</th>
360
+ </tr>
361
+ </thead>
362
+ <tbody>
363
+ <tr>
364
+ <td>bigvgan_v2_44khz_128band_512x</td>
365
+ <td>44 kHz</td>
366
+ <td>128</td>
367
+ <td>22050</td>
368
+ <td>512</td>
369
+ <td>122M</td>
370
+ <td>Large-scale Compilation</td>
371
+ <td>No</td>
372
+ </tr>
373
+ <tr>
374
+ <td>bigvgan_v2_44khz_128band_256x</td>
375
+ <td>44 kHz</td>
376
+ <td>128</td>
377
+ <td>22050</td>
378
+ <td>256</td>
379
+ <td>112M</td>
380
+ <td>Large-scale Compilation</td>
381
+ <td>No</td>
382
+ </tr>
383
+ <tr>
384
+ <td>bigvgan_v2_24khz_100band_256x</td>
385
+ <td>24 kHz</td>
386
+ <td>100</td>
387
+ <td>12000</td>
388
+ <td>256</td>
389
+ <td>112M</td>
390
+ <td>Large-scale Compilation</td>
391
+ <td>No</td>
392
+ </tr>
393
+ <tr>
394
+ <td>bigvgan_v2_22khz_80band_256x</td>
395
+ <td>22 kHz</td>
396
+ <td>80</td>
397
+ <td>11025</td>
398
+ <td>256</td>
399
+ <td>112M</td>
400
+ <td>Large-scale Compilation</td>
401
+ <td>No</td>
402
+ </tr>
403
+ <tr>
404
+ <td>bigvgan_v2_22khz_80band_fmax8k_256x</td>
405
+ <td>22 kHz</td>
406
+ <td>80</td>
407
+ <td>8000</td>
408
+ <td>256</td>
409
+ <td>112M</td>
410
+ <td>Large-scale Compilation</td>
411
+ <td>No</td>
412
+ </tr>
413
+ <tr>
414
+ <td>bigvgan_24khz_100band</td>
415
+ <td>24 kHz</td>
416
+ <td>100</td>
417
+ <td>12000</td>
418
+ <td>256</td>
419
+ <td>112M</td>
420
+ <td>LibriTTS</td>
421
+ <td>No</td>
422
+ </tr>
423
+ <tr>
424
+ <td>bigvgan_base_24khz_100band</td>
425
+ <td>24 kHz</td>
426
+ <td>100</td>
427
+ <td>12000</td>
428
+ <td>256</td>
429
+ <td>14M</td>
430
+ <td>LibriTTS</td>
431
+ <td>No</td>
432
+ </tr>
433
+ <tr>
434
+ <td>bigvgan_22khz_80band</td>
435
+ <td>22 kHz</td>
436
+ <td>80</td>
437
+ <td>8000</td>
438
+ <td>256</td>
439
+ <td>112M</td>
440
+ <td>LibriTTS + VCTK + LJSpeech</td>
441
+ <td>No</td>
442
+ </tr>
443
+ <tr>
444
+ <td>bigvgan_base_22khz_80band</td>
445
+ <td>22 kHz</td>
446
+ <td>80</td>
447
+ <td>8000</td>
448
+ <td>256</td>
449
+ <td>14M</td>
450
+ <td>LibriTTS + VCTK + LJSpeech</td>
451
+ <td>No</td>
452
+ </tr>
453
+ </tbody>
454
+ </table>
455
+ <p><b>NOTE: The v1 models are trained using speech audio datasets ONLY! (24kHz models: LibriTTS, 22kHz models: LibriTTS + VCTK + LJSpeech).</b></p>
456
+ </div>
457
+ """
458
+ )
459
+
460
+ iface.queue()
461
+ iface.launch()
env.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import os
5
+ import shutil
6
+
7
+
8
+ class AttrDict(dict):
9
+ def __init__(self, *args, **kwargs):
10
+ super(AttrDict, self).__init__(*args, **kwargs)
11
+ self.__dict__ = self
12
+
13
+
14
+ def build_env(config, config_name, path):
15
+ t_path = os.path.join(path, config_name)
16
+ if config != t_path:
17
+ os.makedirs(path, exist_ok=True)
18
+ shutil.copyfile(config, os.path.join(path, config_name))
examples/dance_24k.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7068d78ce4d008a793f6bfbbe49d0f8962a752f07780833c5ab73652da9849fd
3
+ size 479788
examples/hifitts_44k.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01f7653b188bdb7349542bbc8af473208d463639682b684527cef651d8225483
3
+ size 570024
examples/jensen_24k.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ec26c78377e056ba8f08e0c337cc535c0fe08a9d0e7923ef3f5c52369173713
3
+ size 479788
examples/libritts_24k.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e9259975995438846da86fd69f0263a1ef859a6e5a4c4501b7c71bca52d5acc
3
+ size 281644
examples/megalovania_24k.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7970ac637e680876d48ad84e9185db1b21da01929fe46d855e8794bd83d14c20
3
+ size 1548328
examples/musdbhq_44k.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87dbdabc47550f493c2c0e2c9389b6dddffb93977408b54d9c4db3b5f071856c
3
+ size 917548
examples/musiccaps1_44k.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d433e0be92a742e9fd2c6a38d627e8cf8864c78ba76f334bd99ec9d931fb615f
3
+ size 887062
examples/musiccaps2_44k.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fafab98d1d31866e432c6b5cfd67e19278ce5a37547781c30c5638136cbab04
3
+ size 887062
examples/queen_24k.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee9fcaf8d21b098f94541b6f2dfc0803167b39f2aea0ca5c40d0b7430b3954d8
3
+ size 479788
incl_licenses/LICENSE_1 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Jungil Kong
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
incl_licenses/LICENSE_2 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Edward Dixon
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
incl_licenses/LICENSE_3 ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
incl_licenses/LICENSE_4 ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2019, Seungwon Park 박승원
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
incl_licenses/LICENSE_5 ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2020 Alexandre Défossez
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
4
+ associated documentation files (the "Software"), to deal in the Software without restriction,
5
+ including without limitation the rights to use, copy, modify, merge, publish, distribute,
6
+ sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
7
+ furnished to do so, subject to the following conditions:
8
+
9
+ The above copyright notice and this permission notice shall be included in all copies or
10
+ substantial portions of the Software.
11
+
12
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
13
+ NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
14
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
15
+ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
16
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
incl_licenses/LICENSE_6 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023-present, Descript
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
incl_licenses/LICENSE_7 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Charactr Inc.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
incl_licenses/LICENSE_8 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Amphion
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
inference.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from __future__ import absolute_import, division, print_function, unicode_literals
5
+
6
+ import glob
7
+ import os
8
+ import argparse
9
+ import json
10
+ import torch
11
+ from scipy.io.wavfile import write
12
+ from env import AttrDict
13
+ from meldataset import mel_spectrogram, MAX_WAV_VALUE
14
+ from models import BigVGAN as Generator
15
+ import librosa
16
+
17
+ h = None
18
+ device = None
19
+ torch.backends.cudnn.benchmark = False
20
+
21
+
22
+ def load_checkpoint(filepath, device):
23
+ assert os.path.isfile(filepath)
24
+ print("Loading '{}'".format(filepath))
25
+ checkpoint_dict = torch.load(filepath, map_location=device)
26
+ print("Complete.")
27
+ return checkpoint_dict
28
+
29
+
30
+ def get_mel(x):
31
+ return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
32
+
33
+
34
+ def scan_checkpoint(cp_dir, prefix):
35
+ pattern = os.path.join(cp_dir, prefix + '*')
36
+ cp_list = glob.glob(pattern)
37
+ if len(cp_list) == 0:
38
+ return ''
39
+ return sorted(cp_list)[-1]
40
+
41
+
42
+ def inference(a, h):
43
+ generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device)
44
+
45
+ state_dict_g = load_checkpoint(a.checkpoint_file, device)
46
+ generator.load_state_dict(state_dict_g['generator'])
47
+
48
+ filelist = os.listdir(a.input_wavs_dir)
49
+
50
+ os.makedirs(a.output_dir, exist_ok=True)
51
+
52
+ generator.eval()
53
+ generator.remove_weight_norm()
54
+ with torch.no_grad():
55
+ for i, filname in enumerate(filelist):
56
+ # load the ground truth audio and resample if necessary
57
+ wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
58
+ wav = torch.FloatTensor(wav).to(device)
59
+ # compute mel spectrogram from the ground truth audio
60
+ x = get_mel(wav.unsqueeze(0))
61
+
62
+ y_g_hat = generator(x)
63
+
64
+ audio = y_g_hat.squeeze()
65
+ audio = audio * MAX_WAV_VALUE
66
+ audio = audio.cpu().numpy().astype('int16')
67
+
68
+ output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + '_generated.wav')
69
+ write(output_file, h.sampling_rate, audio)
70
+ print(output_file)
71
+
72
+
73
+ def main():
74
+ print('Initializing Inference Process..')
75
+
76
+ parser = argparse.ArgumentParser()
77
+ parser.add_argument('--input_wavs_dir', default='test_files')
78
+ parser.add_argument('--output_dir', default='generated_files')
79
+ parser.add_argument('--checkpoint_file', required=True)
80
+ parser.add_argument('--use_cuda_kernel', action='store_true', default=False)
81
+
82
+ a = parser.parse_args()
83
+
84
+ config_file = os.path.join(os.path.split(a.checkpoint_file)[0], 'config.json')
85
+ with open(config_file) as f:
86
+ data = f.read()
87
+
88
+ global h
89
+ json_config = json.loads(data)
90
+ h = AttrDict(json_config)
91
+
92
+ torch.manual_seed(h.seed)
93
+ global device
94
+ if torch.cuda.is_available():
95
+ torch.cuda.manual_seed(h.seed)
96
+ device = torch.device('cuda')
97
+ else:
98
+ device = torch.device('cpu')
99
+
100
+ inference(a, h)
101
+
102
+
103
+ if __name__ == '__main__':
104
+ main()
105
+
meldataset.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import math
8
+ import os
9
+ import random
10
+ import torch
11
+ import torch.utils.data
12
+ import numpy as np
13
+ from librosa.util import normalize
14
+ from scipy.io.wavfile import read
15
+ from librosa.filters import mel as librosa_mel_fn
16
+ import pathlib
17
+ from tqdm import tqdm
18
+
19
+ MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
20
+
21
+
22
+ def load_wav(full_path, sr_target):
23
+ sampling_rate, data = read(full_path)
24
+ if sampling_rate != sr_target:
25
+ raise RuntimeError("Sampling rate of the file {} is {} Hz, but the model requires {} Hz".
26
+ format(full_path, sampling_rate, sr_target))
27
+ return data, sampling_rate
28
+
29
+
30
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
31
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
32
+
33
+
34
+ def dynamic_range_decompression(x, C=1):
35
+ return np.exp(x) / C
36
+
37
+
38
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
39
+ return torch.log(torch.clamp(x, min=clip_val) * C)
40
+
41
+
42
+ def dynamic_range_decompression_torch(x, C=1):
43
+ return torch.exp(x) / C
44
+
45
+
46
+ def spectral_normalize_torch(magnitudes):
47
+ output = dynamic_range_compression_torch(magnitudes)
48
+ return output
49
+
50
+
51
+ def spectral_de_normalize_torch(magnitudes):
52
+ output = dynamic_range_decompression_torch(magnitudes)
53
+ return output
54
+
55
+
56
+ mel_basis = {}
57
+ hann_window = {}
58
+
59
+
60
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
61
+ if torch.min(y) < -1.:
62
+ print('min value is ', torch.min(y))
63
+ if torch.max(y) > 1.:
64
+ print('max value is ', torch.max(y))
65
+
66
+ global mel_basis, hann_window
67
+ if fmax not in mel_basis:
68
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
69
+ str_key_mel_basis = str(fmax)+'_'+str(y.device)
70
+ mel_basis[str_key_mel_basis] = torch.from_numpy(mel).float().to(y.device)
71
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
72
+
73
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
74
+ y = y.squeeze(1)
75
+
76
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
77
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
78
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
79
+ spec = torch.view_as_real(spec)
80
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
81
+
82
+ spec = torch.matmul(mel_basis[str_key_mel_basis], spec)
83
+ spec = spectral_normalize_torch(spec)
84
+
85
+ return spec
86
+
87
+
88
+ def get_dataset_filelist(a):
89
+ with open(a.input_training_file, 'r', encoding='utf-8') as fi:
90
+ training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
91
+ for x in fi.read().split('\n') if len(x) > 0]
92
+ print("first training file: {}".format(training_files[0]))
93
+
94
+ with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
95
+ validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
96
+ for x in fi.read().split('\n') if len(x) > 0]
97
+ print("first validation file: {}".format(validation_files[0]))
98
+
99
+ list_unseen_validation_files = []
100
+ for i in range(len(a.list_input_unseen_validation_file)):
101
+ with open(a.list_input_unseen_validation_file[i], 'r', encoding='utf-8') as fi:
102
+ unseen_validation_files = [os.path.join(a.list_input_unseen_wavs_dir[i], x.split('|')[0] + '.wav')
103
+ for x in fi.read().split('\n') if len(x) > 0]
104
+ print("first unseen {}th validation fileset: {}".format(i, unseen_validation_files[0]))
105
+ list_unseen_validation_files.append(unseen_validation_files)
106
+
107
+ return training_files, validation_files, list_unseen_validation_files
108
+
109
+
110
+ class MelDataset(torch.utils.data.Dataset):
111
+ def __init__(self, training_files, hparams, segment_size, n_fft, num_mels,
112
+ hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
113
+ device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None, is_seen=True):
114
+ self.audio_files = training_files
115
+ random.seed(1234)
116
+ if shuffle:
117
+ random.shuffle(self.audio_files)
118
+ self.hparams = hparams
119
+ self.is_seen = is_seen
120
+ if self.is_seen:
121
+ self.name = pathlib.Path(self.audio_files[0]).parts[0]
122
+ else:
123
+ self.name = '-'.join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/")
124
+
125
+ self.segment_size = segment_size
126
+ self.sampling_rate = sampling_rate
127
+ self.split = split
128
+ self.n_fft = n_fft
129
+ self.num_mels = num_mels
130
+ self.hop_size = hop_size
131
+ self.win_size = win_size
132
+ self.fmin = fmin
133
+ self.fmax = fmax
134
+ self.fmax_loss = fmax_loss
135
+ self.cached_wav = None
136
+ self.n_cache_reuse = n_cache_reuse
137
+ self._cache_ref_count = 0
138
+ self.device = device
139
+ self.fine_tuning = fine_tuning
140
+ self.base_mels_path = base_mels_path
141
+
142
+ print("INFO: checking dataset integrity...")
143
+ for i in tqdm(range(len(self.audio_files))):
144
+ assert os.path.exists(self.audio_files[i]), "{} not found".format(self.audio_files[i])
145
+
146
+ def __getitem__(self, index):
147
+
148
+ filename = self.audio_files[index]
149
+ if self._cache_ref_count == 0:
150
+ audio, sampling_rate = load_wav(filename, self.sampling_rate)
151
+ audio = audio / MAX_WAV_VALUE
152
+ if not self.fine_tuning:
153
+ audio = normalize(audio) * 0.95
154
+ self.cached_wav = audio
155
+ if sampling_rate != self.sampling_rate:
156
+ raise ValueError("{} SR doesn't match target {} SR".format(
157
+ sampling_rate, self.sampling_rate))
158
+ self._cache_ref_count = self.n_cache_reuse
159
+ else:
160
+ audio = self.cached_wav
161
+ self._cache_ref_count -= 1
162
+
163
+ audio = torch.FloatTensor(audio)
164
+ audio = audio.unsqueeze(0)
165
+
166
+ if not self.fine_tuning:
167
+ if self.split:
168
+ if audio.size(1) >= self.segment_size:
169
+ max_audio_start = audio.size(1) - self.segment_size
170
+ audio_start = random.randint(0, max_audio_start)
171
+ audio = audio[:, audio_start:audio_start+self.segment_size]
172
+ else:
173
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
174
+
175
+ mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
176
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
177
+ center=False)
178
+ else: # validation step
179
+ # match audio length to self.hop_size * n for evaluation
180
+ if (audio.size(1) % self.hop_size) != 0:
181
+ audio = audio[:, :-(audio.size(1) % self.hop_size)]
182
+ mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
183
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
184
+ center=False)
185
+ assert audio.shape[1] == mel.shape[2] * self.hop_size, "audio shape {} mel shape {}".format(audio.shape, mel.shape)
186
+
187
+ else:
188
+ mel = np.load(
189
+ os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
190
+ mel = torch.from_numpy(mel)
191
+
192
+ if len(mel.shape) < 3:
193
+ mel = mel.unsqueeze(0)
194
+
195
+ if self.split:
196
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
197
+
198
+ if audio.size(1) >= self.segment_size:
199
+ mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
200
+ mel = mel[:, :, mel_start:mel_start + frames_per_seg]
201
+ audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
202
+ else:
203
+ mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
204
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
205
+
206
+ mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
207
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
208
+ center=False)
209
+
210
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
211
+
212
+ def __len__(self):
213
+ return len(self.audio_files)
models.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.nn as nn
11
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+ from torchaudio.transforms import Spectrogram, Resample
14
+ from librosa.filters import mel as librosa_mel_fn
15
+ from scipy import signal
16
+
17
+ import activations
18
+ from utils import init_weights, get_padding
19
+ from alias_free_torch.act import Activation1d as TorchActivation1d
20
+ import typing
21
+ from typing import List, Optional, Tuple
22
+ from collections import namedtuple
23
+ import math
24
+ import functools
25
+
26
+
27
+ class AMPBlock1(torch.nn.Module):
28
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
29
+ super(AMPBlock1, self).__init__()
30
+ self.h = h
31
+
32
+ self.convs1 = nn.ModuleList([
33
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
34
+ padding=get_padding(kernel_size, dilation[0]))),
35
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
36
+ padding=get_padding(kernel_size, dilation[1]))),
37
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
38
+ padding=get_padding(kernel_size, dilation[2])))
39
+ ])
40
+ self.convs1.apply(init_weights)
41
+
42
+ self.convs2 = nn.ModuleList([
43
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
44
+ padding=get_padding(kernel_size, 1))),
45
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
46
+ padding=get_padding(kernel_size, 1))),
47
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
48
+ padding=get_padding(kernel_size, 1)))
49
+ ])
50
+ self.convs2.apply(init_weights)
51
+
52
+ self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
53
+
54
+ # select which Activation1d, lazy-load cuda version to ensure backward compatibility
55
+ if self.h.get("use_cuda_kernel", False):
56
+ # faster CUDA kernel implementation of Activation1d
57
+ from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
58
+ Activation1d = CudaActivation1d
59
+ else:
60
+ Activation1d = TorchActivation1d
61
+
62
+ if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
63
+ self.activations = nn.ModuleList([
64
+ Activation1d(
65
+ activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
66
+ for _ in range(self.num_layers)
67
+ ])
68
+ elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
69
+ self.activations = nn.ModuleList([
70
+ Activation1d(
71
+ activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
72
+ for _ in range(self.num_layers)
73
+ ])
74
+ else:
75
+ raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
76
+
77
+ def forward(self, x):
78
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
79
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
80
+ xt = a1(x)
81
+ xt = c1(xt)
82
+ xt = a2(xt)
83
+ xt = c2(xt)
84
+ x = xt + x
85
+
86
+ return x
87
+
88
+ def remove_weight_norm(self):
89
+ for l in self.convs1:
90
+ remove_weight_norm(l)
91
+ for l in self.convs2:
92
+ remove_weight_norm(l)
93
+
94
+
95
+ class AMPBlock2(torch.nn.Module):
96
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
97
+ super(AMPBlock2, self).__init__()
98
+ self.h = h
99
+
100
+ self.convs = nn.ModuleList([
101
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
102
+ padding=get_padding(kernel_size, dilation[0]))),
103
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
104
+ padding=get_padding(kernel_size, dilation[1])))
105
+ ])
106
+ self.convs.apply(init_weights)
107
+
108
+ self.num_layers = len(self.convs) # total number of conv layers
109
+
110
+ # select which Activation1d, lazy-load cuda version to ensure backward compatibility
111
+ if self.h.get("use_cuda_kernel", False):
112
+ # faster CUDA kernel implementation of Activation1d
113
+ from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
114
+ Activation1d = CudaActivation1d
115
+ else:
116
+ Activation1d = TorchActivation1d
117
+
118
+ if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
119
+ self.activations = nn.ModuleList([
120
+ Activation1d(
121
+ activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
122
+ for _ in range(self.num_layers)
123
+ ])
124
+ elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
125
+ self.activations = nn.ModuleList([
126
+ Activation1d(
127
+ activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
128
+ for _ in range(self.num_layers)
129
+ ])
130
+ else:
131
+ raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
132
+
133
+ def forward(self, x):
134
+ for c, a in zip (self.convs, self.activations):
135
+ xt = a(x)
136
+ xt = c(xt)
137
+ x = xt + x
138
+
139
+ return x
140
+
141
+ def remove_weight_norm(self):
142
+ for l in self.convs:
143
+ remove_weight_norm(l)
144
+
145
+
146
+ class BigVGAN(torch.nn.Module):
147
+ # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
148
+ # New in v2: if use_cuda_kernel is set to True, it loads optimized CUDA kernels for AMP.
149
+ # NOTE: use_cuda_kernel=True should be used for inference only (training is not supported).
150
+ def __init__(
151
+ self,
152
+ h,
153
+ use_cuda_kernel: bool=False
154
+ ):
155
+ super(BigVGAN, self).__init__()
156
+ self.h = h
157
+ self.h["use_cuda_kernel"] = use_cuda_kernel # add it to global hyperparameters (h)
158
+
159
+ self.num_kernels = len(h.resblock_kernel_sizes)
160
+ self.num_upsamples = len(h.upsample_rates)
161
+
162
+ # pre conv
163
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
164
+
165
+ # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
166
+ resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
167
+
168
+ # transposed conv-based upsamplers. does not apply anti-aliasing
169
+ self.ups = nn.ModuleList()
170
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
171
+ self.ups.append(nn.ModuleList([
172
+ weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
173
+ h.upsample_initial_channel // (2 ** (i + 1)),
174
+ k, u, padding=(k - u) // 2))
175
+ ]))
176
+
177
+ # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
178
+ self.resblocks = nn.ModuleList()
179
+ for i in range(len(self.ups)):
180
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
181
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
182
+ self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
183
+
184
+ # select which Activation1d, lazy-load cuda version to ensure backward compatibility
185
+ if self.h.get("use_cuda_kernel", False):
186
+ # faster CUDA kernel implementation of Activation1d
187
+ from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
188
+ Activation1d = CudaActivation1d
189
+ else:
190
+ Activation1d = TorchActivation1d
191
+
192
+ # post conv
193
+ if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
194
+ activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
195
+ self.activation_post = Activation1d(activation=activation_post)
196
+ elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
197
+ activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
198
+ self.activation_post = Activation1d(activation=activation_post)
199
+ else:
200
+ raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
201
+
202
+ # whether to use bias for the final conv_post. Defaults to True for backward compatibility
203
+ self.use_bias_at_final = h.get("use_bias_at_final", True)
204
+ self.conv_post = weight_norm(Conv1d(
205
+ ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final
206
+ ))
207
+
208
+ # weight initialization
209
+ for i in range(len(self.ups)):
210
+ self.ups[i].apply(init_weights)
211
+ self.conv_post.apply(init_weights)
212
+
213
+ # final tanh activation. Defaults to True for backward compatibility
214
+ self.use_tanh_at_final = h.get("use_tanh_at_final", True)
215
+
216
+ def forward(self, x):
217
+ # pre conv
218
+ x = self.conv_pre(x)
219
+
220
+ for i in range(self.num_upsamples):
221
+ # upsampling
222
+ for i_up in range(len(self.ups[i])):
223
+ x = self.ups[i][i_up](x)
224
+ # AMP blocks
225
+ xs = None
226
+ for j in range(self.num_kernels):
227
+ if xs is None:
228
+ xs = self.resblocks[i * self.num_kernels + j](x)
229
+ else:
230
+ xs += self.resblocks[i * self.num_kernels + j](x)
231
+ x = xs / self.num_kernels
232
+
233
+ # post conv
234
+ x = self.activation_post(x)
235
+ x = self.conv_post(x)
236
+ # final tanh activation
237
+ if self.use_tanh_at_final:
238
+ x = torch.tanh(x)
239
+ else:
240
+ x = torch.clamp(x, min=-1., max=1.) # bound the output to [-1, 1]
241
+
242
+ return x
243
+
244
+ def remove_weight_norm(self):
245
+ print('Removing weight norm...')
246
+ for l in self.ups:
247
+ for l_i in l:
248
+ remove_weight_norm(l_i)
249
+ for l in self.resblocks:
250
+ l.remove_weight_norm()
251
+ remove_weight_norm(self.conv_pre)
252
+ remove_weight_norm(self.conv_post)
253
+
254
+
255
+ class DiscriminatorP(torch.nn.Module):
256
+ def __init__(self, h, period, kernel_size=5, stride=3, use_spectral_norm=False):
257
+ super(DiscriminatorP, self).__init__()
258
+ self.period = period
259
+ self.d_mult = h.discriminator_channel_mult
260
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
261
+ self.convs = nn.ModuleList([
262
+ norm_f(Conv2d(1, int(32*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
263
+ norm_f(Conv2d(int(32*self.d_mult), int(128*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
264
+ norm_f(Conv2d(int(128*self.d_mult), int(512*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
265
+ norm_f(Conv2d(int(512*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
266
+ norm_f(Conv2d(int(1024*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
267
+ ])
268
+ self.conv_post = norm_f(Conv2d(int(1024*self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
269
+
270
+ def forward(self, x):
271
+ fmap = []
272
+
273
+ # 1d to 2d
274
+ b, c, t = x.shape
275
+ if t % self.period != 0: # pad first
276
+ n_pad = self.period - (t % self.period)
277
+ x = F.pad(x, (0, n_pad), "reflect")
278
+ t = t + n_pad
279
+ x = x.view(b, c, t // self.period, self.period)
280
+
281
+ for l in self.convs:
282
+ x = l(x)
283
+ x = F.leaky_relu(x, 0.1)
284
+ fmap.append(x)
285
+ x = self.conv_post(x)
286
+ fmap.append(x)
287
+ x = torch.flatten(x, 1, -1)
288
+
289
+ return x, fmap
290
+
291
+
292
+ class MultiPeriodDiscriminator(torch.nn.Module):
293
+ def __init__(self, h):
294
+ super(MultiPeriodDiscriminator, self).__init__()
295
+ self.mpd_reshapes = h.mpd_reshapes
296
+ print("mpd_reshapes: {}".format(self.mpd_reshapes))
297
+ discriminators = [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
298
+ self.discriminators = nn.ModuleList(discriminators)
299
+
300
+ def forward(self, y, y_hat):
301
+ y_d_rs = []
302
+ y_d_gs = []
303
+ fmap_rs = []
304
+ fmap_gs = []
305
+ for i, d in enumerate(self.discriminators):
306
+ y_d_r, fmap_r = d(y)
307
+ y_d_g, fmap_g = d(y_hat)
308
+ y_d_rs.append(y_d_r)
309
+ fmap_rs.append(fmap_r)
310
+ y_d_gs.append(y_d_g)
311
+ fmap_gs.append(fmap_g)
312
+
313
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
314
+
315
+
316
+ class DiscriminatorR(nn.Module):
317
+ def __init__(self, cfg, resolution):
318
+ super().__init__()
319
+
320
+ self.resolution = resolution
321
+ assert len(self.resolution) == 3, \
322
+ "MRD layer requires list with len=3, got {}".format(self.resolution)
323
+ self.lrelu_slope = 0.1
324
+
325
+ norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
326
+ if hasattr(cfg, "mrd_use_spectral_norm"):
327
+ print("INFO: overriding MRD use_spectral_norm as {}".format(cfg.mrd_use_spectral_norm))
328
+ norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
329
+ self.d_mult = cfg.discriminator_channel_mult
330
+ if hasattr(cfg, "mrd_channel_mult"):
331
+ print("INFO: overriding mrd channel multiplier as {}".format(cfg.mrd_channel_mult))
332
+ self.d_mult = cfg.mrd_channel_mult
333
+
334
+ self.convs = nn.ModuleList([
335
+ norm_f(nn.Conv2d(1, int(32*self.d_mult), (3, 9), padding=(1, 4))),
336
+ norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
337
+ norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
338
+ norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
339
+ norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 3), padding=(1, 1))),
340
+ ])
341
+ self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
342
+
343
+ def forward(self, x):
344
+ fmap = []
345
+
346
+ x = self.spectrogram(x)
347
+ x = x.unsqueeze(1)
348
+ for l in self.convs:
349
+ x = l(x)
350
+ x = F.leaky_relu(x, self.lrelu_slope)
351
+ fmap.append(x)
352
+ x = self.conv_post(x)
353
+ fmap.append(x)
354
+ x = torch.flatten(x, 1, -1)
355
+
356
+ return x, fmap
357
+
358
+ def spectrogram(self, x):
359
+ n_fft, hop_length, win_length = self.resolution
360
+ x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
361
+ x = x.squeeze(1)
362
+ x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True)
363
+ x = torch.view_as_real(x) # [B, F, TT, 2]
364
+ mag = torch.norm(x, p=2, dim =-1) #[B, F, TT]
365
+
366
+ return mag
367
+
368
+
369
+ class MultiResolutionDiscriminator(nn.Module):
370
+ def __init__(self, cfg, debug=False):
371
+ super().__init__()
372
+ self.resolutions = cfg.resolutions
373
+ assert len(self.resolutions) == 3,\
374
+ "MRD requires list of list with len=3, each element having a list with len=3. got {}".\
375
+ format(self.resolutions)
376
+ self.discriminators = nn.ModuleList(
377
+ [DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
378
+ )
379
+
380
+ def forward(self, y, y_hat):
381
+ y_d_rs = []
382
+ y_d_gs = []
383
+ fmap_rs = []
384
+ fmap_gs = []
385
+
386
+ for i, d in enumerate(self.discriminators):
387
+ y_d_r, fmap_r = d(x=y)
388
+ y_d_g, fmap_g = d(x=y_hat)
389
+ y_d_rs.append(y_d_r)
390
+ fmap_rs.append(fmap_r)
391
+ y_d_gs.append(y_d_g)
392
+ fmap_gs.append(fmap_g)
393
+
394
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
395
+
396
+ # Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
397
+ # Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
398
+ # LICENSE is in incl_licenses directory.
399
+ class DiscriminatorB(nn.Module):
400
+ def __init__(
401
+ self,
402
+ window_length: int,
403
+ channels: int = 32,
404
+ hop_factor: float = 0.25,
405
+ bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
406
+ ):
407
+ super().__init__()
408
+ self.window_length = window_length
409
+ self.hop_factor = hop_factor
410
+ self.spec_fn = Spectrogram(
411
+ n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
412
+ )
413
+ n_fft = window_length // 2 + 1
414
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
415
+ self.bands = bands
416
+ convs = lambda: nn.ModuleList(
417
+ [
418
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
419
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
420
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
421
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
422
+ weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
423
+ ]
424
+ )
425
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
426
+
427
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
428
+
429
+ def spectrogram(self, x):
430
+ # Remove DC offset
431
+ x = x - x.mean(dim=-1, keepdims=True)
432
+ # Peak normalize the volume of input audio
433
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
434
+ x = self.spec_fn(x)
435
+ x = torch.view_as_real(x)
436
+ x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
437
+ # Split into bands
438
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
439
+ return x_bands
440
+
441
+ def forward(self, x: torch.Tensor):
442
+ x_bands = self.spectrogram(x.squeeze(1))
443
+ fmap = []
444
+ x = []
445
+
446
+ for band, stack in zip(x_bands, self.band_convs):
447
+ for i, layer in enumerate(stack):
448
+ band = layer(band)
449
+ band = torch.nn.functional.leaky_relu(band, 0.1)
450
+ if i > 0:
451
+ fmap.append(band)
452
+ x.append(band)
453
+
454
+ x = torch.cat(x, dim=-1)
455
+ x = self.conv_post(x)
456
+ fmap.append(x)
457
+
458
+ return x, fmap
459
+
460
+ # Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
461
+ # Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
462
+ # LICENSE is in incl_licenses directory.
463
+ class MultiBandDiscriminator(nn.Module):
464
+ def __init__(
465
+ self,
466
+ h,
467
+ ):
468
+ """
469
+ Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
470
+ and the modified code adapted from https://github.com/gemelo-ai/vocos.
471
+ """
472
+ super().__init__()
473
+ # fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
474
+ self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
475
+ self.discriminators = nn.ModuleList(
476
+ [DiscriminatorB(window_length=w) for w in self.fft_sizes]
477
+ )
478
+
479
+ def forward(
480
+ self,
481
+ y: torch.Tensor,
482
+ y_hat: torch.Tensor
483
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
484
+
485
+ y_d_rs = []
486
+ y_d_gs = []
487
+ fmap_rs = []
488
+ fmap_gs = []
489
+
490
+ for d in self.discriminators:
491
+ y_d_r, fmap_r = d(x=y)
492
+ y_d_g, fmap_g = d(x=y_hat)
493
+ y_d_rs.append(y_d_r)
494
+ fmap_rs.append(fmap_r)
495
+ y_d_gs.append(y_d_g)
496
+ fmap_gs.append(fmap_g)
497
+
498
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
499
+
500
+
501
+ # Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
502
+ # LICENSE is in incl_licenses directory.
503
+ class DiscriminatorCQT(nn.Module):
504
+ def __init__(self, cfg, hop_length, n_octaves, bins_per_octave):
505
+ super().__init__()
506
+ self.cfg = cfg
507
+
508
+ self.filters = cfg["cqtd_filters"]
509
+ self.max_filters = cfg["cqtd_max_filters"]
510
+ self.filters_scale = cfg["cqtd_filters_scale"]
511
+ self.kernel_size = (3, 9)
512
+ self.dilations = cfg["cqtd_dilations"]
513
+ self.stride = (1, 2)
514
+
515
+ self.in_channels = cfg["cqtd_in_channels"]
516
+ self.out_channels = cfg["cqtd_out_channels"]
517
+ self.fs = cfg["sampling_rate"]
518
+ self.hop_length = hop_length
519
+ self.n_octaves = n_octaves
520
+ self.bins_per_octave = bins_per_octave
521
+
522
+ # lazy-load
523
+ from nnAudio import features
524
+ self.cqt_transform = features.cqt.CQT2010v2(
525
+ sr=self.fs * 2,
526
+ hop_length=self.hop_length,
527
+ n_bins=self.bins_per_octave * self.n_octaves,
528
+ bins_per_octave=self.bins_per_octave,
529
+ output_format="Complex",
530
+ pad_mode="constant",
531
+ )
532
+
533
+ self.conv_pres = nn.ModuleList()
534
+ for i in range(self.n_octaves):
535
+ self.conv_pres.append(
536
+ nn.Conv2d(
537
+ self.in_channels * 2,
538
+ self.in_channels * 2,
539
+ kernel_size=self.kernel_size,
540
+ padding=self.get_2d_padding(self.kernel_size),
541
+ )
542
+ )
543
+
544
+ self.convs = nn.ModuleList()
545
+
546
+ self.convs.append(
547
+ nn.Conv2d(
548
+ self.in_channels * 2,
549
+ self.filters,
550
+ kernel_size=self.kernel_size,
551
+ padding=self.get_2d_padding(self.kernel_size),
552
+ )
553
+ )
554
+
555
+ in_chs = min(self.filters_scale * self.filters, self.max_filters)
556
+ for i, dilation in enumerate(self.dilations):
557
+ out_chs = min(
558
+ (self.filters_scale ** (i + 1)) * self.filters, self.max_filters
559
+ )
560
+ self.convs.append(
561
+ weight_norm(nn.Conv2d(
562
+ in_chs,
563
+ out_chs,
564
+ kernel_size=self.kernel_size,
565
+ stride=self.stride,
566
+ dilation=(dilation, 1),
567
+ padding=self.get_2d_padding(self.kernel_size, (dilation, 1)),
568
+ ))
569
+ )
570
+ in_chs = out_chs
571
+ out_chs = min(
572
+ (self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
573
+ self.max_filters,
574
+ )
575
+ self.convs.append(
576
+ weight_norm(nn.Conv2d(
577
+ in_chs,
578
+ out_chs,
579
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
580
+ padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
581
+ ))
582
+ )
583
+
584
+ self.conv_post = weight_norm(nn.Conv2d(
585
+ out_chs,
586
+ self.out_channels,
587
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
588
+ padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
589
+ ))
590
+
591
+ self.activation = torch.nn.LeakyReLU(negative_slope=0.1)
592
+ self.resample = Resample(orig_freq=self.fs, new_freq=self.fs * 2)
593
+
594
+ self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
595
+ if self.cqtd_normalize_volume:
596
+ print(f"INFO: cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!")
597
+
598
+ def get_2d_padding(
599
+ self, kernel_size: typing.Tuple[int, int], dilation: typing.Tuple[int, int] = (1, 1)
600
+ ):
601
+ return (
602
+ ((kernel_size[0] - 1) * dilation[0]) // 2,
603
+ ((kernel_size[1] - 1) * dilation[1]) // 2,
604
+ )
605
+
606
+ def forward(self, x):
607
+ fmap = []
608
+
609
+ if self.cqtd_normalize_volume:
610
+ # Remove DC offset
611
+ x = x - x.mean(dim=-1, keepdims=True)
612
+ # Peak normalize the volume of input audio
613
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
614
+
615
+ x = self.resample(x)
616
+
617
+ z = self.cqt_transform(x)
618
+
619
+ z_amplitude = z[:, :, :, 0].unsqueeze(1)
620
+ z_phase = z[:, :, :, 1].unsqueeze(1)
621
+
622
+ z = torch.cat([z_amplitude, z_phase], dim=1)
623
+ z = torch.permute(z, (0, 1, 3, 2)) # [B, C, W, T] -> [B, C, T, W]
624
+
625
+ latent_z = []
626
+ for i in range(self.n_octaves):
627
+ latent_z.append(
628
+ self.conv_pres[i](
629
+ z[
630
+ :,
631
+ :,
632
+ :,
633
+ i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
634
+ ]
635
+ )
636
+ )
637
+ latent_z = torch.cat(latent_z, dim=-1)
638
+
639
+ for i, l in enumerate(self.convs):
640
+ latent_z = l(latent_z)
641
+
642
+ latent_z = self.activation(latent_z)
643
+ fmap.append(latent_z)
644
+
645
+ latent_z = self.conv_post(latent_z)
646
+
647
+ return latent_z, fmap
648
+
649
+
650
+ class MultiScaleSubbandCQTDiscriminator(nn.Module):
651
+ def __init__(self, cfg):
652
+ super().__init__()
653
+
654
+ self.cfg = cfg
655
+ # Using get with defaults
656
+ self.cfg["cqtd_filters"] = self.cfg.get("cqtd_filters", 32)
657
+ self.cfg["cqtd_max_filters"] = self.cfg.get("cqtd_max_filters", 1024)
658
+ self.cfg["cqtd_filters_scale"] = self.cfg.get("cqtd_filters_scale", 1)
659
+ self.cfg["cqtd_dilations"] = self.cfg.get("cqtd_dilations", [1, 2, 4])
660
+ self.cfg["cqtd_in_channels"] = self.cfg.get("cqtd_in_channels", 1)
661
+ self.cfg["cqtd_out_channels"] = self.cfg.get("cqtd_out_channels", 1)
662
+ # multi-scale params to loop over
663
+ self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
664
+ self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
665
+ self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
666
+
667
+ self.discriminators = nn.ModuleList(
668
+ [
669
+ DiscriminatorCQT(
670
+ self.cfg,
671
+ hop_length=self.cfg["cqtd_hop_lengths"][i],
672
+ n_octaves=self.cfg["cqtd_n_octaves"][i],
673
+ bins_per_octave=self.cfg["cqtd_bins_per_octaves"][i],
674
+ )
675
+ for i in range(len(self.cfg["cqtd_hop_lengths"]))
676
+ ]
677
+ )
678
+
679
+ def forward(
680
+ self,
681
+ y: torch.Tensor,
682
+ y_hat: torch.Tensor
683
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
684
+
685
+ y_d_rs = []
686
+ y_d_gs = []
687
+ fmap_rs = []
688
+ fmap_gs = []
689
+
690
+ for disc in self.discriminators:
691
+ y_d_r, fmap_r = disc(y)
692
+ y_d_g, fmap_g = disc(y_hat)
693
+ y_d_rs.append(y_d_r)
694
+ fmap_rs.append(fmap_r)
695
+ y_d_gs.append(y_d_g)
696
+ fmap_gs.append(fmap_g)
697
+
698
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
699
+
700
+
701
+ class CombinedDiscriminator(nn.Module):
702
+ # wrapper of chaining multiple discrimiantor architectures
703
+ # ex: combine mbd and cqtd as a single class
704
+ def __init__(
705
+ self,
706
+ list_discriminator: List[nn.Module]
707
+ ):
708
+ super().__init__()
709
+ self.discrimiantor = nn.ModuleList(list_discriminator)
710
+
711
+ def forward(
712
+ self,
713
+ y: torch.Tensor,
714
+ y_hat: torch.Tensor
715
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
716
+
717
+ y_d_rs = []
718
+ y_d_gs = []
719
+ fmap_rs = []
720
+ fmap_gs = []
721
+
722
+ for disc in self.discrimiantor:
723
+ y_d_r, y_d_g, fmap_r, fmap_g = disc(y, y_hat)
724
+ y_d_rs.extend(y_d_r)
725
+ fmap_rs.extend(fmap_r)
726
+ y_d_gs.extend(y_d_g)
727
+ fmap_gs.extend(fmap_g)
728
+
729
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
730
+
731
+
732
+ # Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license.
733
+ # LICENSE is in incl_licenses directory.
734
+ class MultiScaleMelSpectrogramLoss(nn.Module):
735
+ """Compute distance between mel spectrograms. Can be used
736
+ in a multi-scale way.
737
+
738
+ Parameters
739
+ ----------
740
+ n_mels : List[int]
741
+ Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320],
742
+ window_lengths : List[int], optional
743
+ Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048]
744
+ loss_fn : typing.Callable, optional
745
+ How to compare each loss, by default nn.L1Loss()
746
+ clamp_eps : float, optional
747
+ Clamp on the log magnitude, below, by default 1e-5
748
+ mag_weight : float, optional
749
+ Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part)
750
+ log_weight : float, optional
751
+ Weight of log magnitude portion of loss, by default 1.0
752
+ pow : float, optional
753
+ Power to raise magnitude to before taking log, by default 1.0
754
+ weight : float, optional
755
+ Weight of this loss, by default 1.0
756
+ match_stride : bool, optional
757
+ Whether to match the stride of convolutional layers, by default False
758
+
759
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
760
+ Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
761
+ """
762
+
763
+ def __init__(
764
+ self,
765
+ sampling_rate: int,
766
+ n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
767
+ window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
768
+ loss_fn: typing.Callable = nn.L1Loss(),
769
+ clamp_eps: float = 1e-5,
770
+ mag_weight: float = 0.0,
771
+ log_weight: float = 1.0,
772
+ pow: float = 1.0,
773
+ weight: float = 1.0,
774
+ match_stride: bool = False,
775
+ mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0],
776
+ mel_fmax: List[float] = [None, None, None, None, None, None, None],
777
+ window_type: str = 'hann',
778
+ ):
779
+ super().__init__()
780
+ self.sampling_rate = sampling_rate
781
+
782
+ STFTParams = namedtuple(
783
+ "STFTParams",
784
+ ["window_length", "hop_length", "window_type", "match_stride"],
785
+ )
786
+
787
+ self.stft_params = [
788
+ STFTParams(
789
+ window_length=w,
790
+ hop_length=w // 4,
791
+ match_stride=match_stride,
792
+ window_type=window_type,
793
+ )
794
+ for w in window_lengths
795
+ ]
796
+ self.n_mels = n_mels
797
+ self.loss_fn = loss_fn
798
+ self.clamp_eps = clamp_eps
799
+ self.log_weight = log_weight
800
+ self.mag_weight = mag_weight
801
+ self.weight = weight
802
+ self.mel_fmin = mel_fmin
803
+ self.mel_fmax = mel_fmax
804
+ self.pow = pow
805
+
806
+ @staticmethod
807
+ @functools.lru_cache(None)
808
+ def get_window(
809
+ window_type,window_length,
810
+ ):
811
+ return signal.get_window(window_type, window_length)
812
+
813
+ @staticmethod
814
+ @functools.lru_cache(None)
815
+ def get_mel_filters(
816
+ sr, n_fft, n_mels, fmin, fmax
817
+ ):
818
+ return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
819
+
820
+ def mel_spectrogram(
821
+ self, wav, n_mels, fmin, fmax, window_length, hop_length, match_stride, window_type
822
+ ):
823
+ # mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
824
+ # https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
825
+ B, C, T = wav.shape
826
+
827
+ if match_stride:
828
+ assert (
829
+ hop_length == window_length // 4
830
+ ), "For match_stride, hop must equal n_fft // 4"
831
+ right_pad = math.ceil(T / hop_length) * hop_length - T
832
+ pad = (window_length - hop_length) // 2
833
+ else:
834
+ right_pad = 0
835
+ pad = 0
836
+
837
+ wav = torch.nn.functional.pad(
838
+ wav, (pad, pad + right_pad), mode='reflect'
839
+ )
840
+
841
+ window = self.get_window(window_type, window_length)
842
+ window = torch.from_numpy(window).to(wav.device).float()
843
+
844
+ stft = torch.stft(
845
+ wav.reshape(-1, T),
846
+ n_fft=window_length,
847
+ hop_length=hop_length,
848
+ window=window,
849
+ return_complex=True,
850
+ center=True,
851
+ )
852
+ _, nf, nt = stft.shape
853
+ stft = stft.reshape(B, C, nf, nt)
854
+ if match_stride:
855
+ # Drop first two and last two frames, which are added
856
+ # because of padding. Now num_frames * hop_length = num_samples.
857
+ stft = stft[..., 2:-2]
858
+ magnitude = torch.abs(stft)
859
+
860
+ nf = magnitude.shape[2]
861
+ mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
862
+ mel_basis = torch.from_numpy(mel_basis).to(wav.device)
863
+ mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
864
+ mel_spectrogram = mel_spectrogram.transpose(-1, 2)
865
+
866
+ return mel_spectrogram
867
+
868
+ def forward(
869
+ self,
870
+ x: torch.Tensor,
871
+ y: torch.Tensor
872
+ ) -> torch.Tensor:
873
+ """Computes mel loss between an estimate and a reference
874
+ signal.
875
+
876
+ Parameters
877
+ ----------
878
+ x : torch.Tensor
879
+ Estimate signal
880
+ y : torch.Tensor
881
+ Reference signal
882
+
883
+ Returns
884
+ -------
885
+ torch.Tensor
886
+ Mel loss.
887
+ """
888
+
889
+ loss = 0.0
890
+ for n_mels, fmin, fmax, s in zip(
891
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
892
+ ):
893
+ kwargs = {
894
+ "n_mels": n_mels,
895
+ "fmin": fmin,
896
+ "fmax": fmax,
897
+ "window_length": s.window_length,
898
+ "hop_length": s.hop_length,
899
+ "match_stride": s.match_stride,
900
+ "window_type": s.window_type,
901
+ }
902
+
903
+ x_mels = self.mel_spectrogram(x, **kwargs)
904
+ y_mels = self.mel_spectrogram(y, **kwargs)
905
+ x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
906
+ y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
907
+
908
+ loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
909
+ loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
910
+
911
+ return loss
912
+
913
+
914
+ # loss functions
915
+ def feature_loss(
916
+ fmap_r: List[List[torch.Tensor]],
917
+ fmap_g: List[List[torch.Tensor]]
918
+ ) -> torch.Tensor:
919
+
920
+ loss = 0
921
+ for dr, dg in zip(fmap_r, fmap_g):
922
+ for rl, gl in zip(dr, dg):
923
+ loss += torch.mean(torch.abs(rl - gl))
924
+
925
+ return loss*2 # this equates to lambda=2.0 for the feature matching loss
926
+
927
+ def discriminator_loss(
928
+ disc_real_outputs: List[torch.Tensor],
929
+ disc_generated_outputs: List[torch.Tensor]
930
+ ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
931
+
932
+ loss = 0
933
+ r_losses = []
934
+ g_losses = []
935
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
936
+ r_loss = torch.mean((1-dr)**2)
937
+ g_loss = torch.mean(dg**2)
938
+ loss += (r_loss + g_loss)
939
+ r_losses.append(r_loss.item())
940
+ g_losses.append(g_loss.item())
941
+
942
+ return loss, r_losses, g_losses
943
+
944
+ def generator_loss(
945
+ disc_outputs: List[torch.Tensor]
946
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
947
+
948
+ loss = 0
949
+ gen_losses = []
950
+ for dg in disc_outputs:
951
+ l = torch.mean((1-dg)**2)
952
+ gen_losses.append(l)
953
+ loss += l
954
+
955
+ return loss, gen_losses
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ numpy<2
4
+ librosa>=0.8.1
5
+ scipy
6
+ tensorboard
7
+ soundfile
8
+ matplotlib
9
+ pesq
10
+ auraloss
11
+ tqdm
12
+ nnAudio
13
+ ninja
utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import glob
5
+ import os
6
+ import matplotlib
7
+ import torch
8
+ from torch.nn.utils import weight_norm
9
+ matplotlib.use("Agg")
10
+ import matplotlib.pylab as plt
11
+ from meldataset import MAX_WAV_VALUE
12
+ from scipy.io.wavfile import write
13
+
14
+
15
+ def plot_spectrogram(spectrogram):
16
+ fig, ax = plt.subplots(figsize=(10, 2))
17
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
18
+ interpolation='none')
19
+ plt.colorbar(im, ax=ax)
20
+
21
+ fig.canvas.draw()
22
+ plt.close()
23
+
24
+ return fig
25
+
26
+
27
+ def plot_spectrogram_clipped(spectrogram, clip_max=2.):
28
+ fig, ax = plt.subplots(figsize=(10, 2))
29
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
30
+ interpolation='none', vmin=1e-6, vmax=clip_max)
31
+ plt.colorbar(im, ax=ax)
32
+
33
+ fig.canvas.draw()
34
+ plt.close()
35
+
36
+ return fig
37
+
38
+
39
+ def init_weights(m, mean=0.0, std=0.01):
40
+ classname = m.__class__.__name__
41
+ if classname.find("Conv") != -1:
42
+ m.weight.data.normal_(mean, std)
43
+
44
+
45
+ def apply_weight_norm(m):
46
+ classname = m.__class__.__name__
47
+ if classname.find("Conv") != -1:
48
+ weight_norm(m)
49
+
50
+
51
+ def get_padding(kernel_size, dilation=1):
52
+ return int((kernel_size*dilation - dilation)/2)
53
+
54
+
55
+ def load_checkpoint(filepath, device):
56
+ assert os.path.isfile(filepath)
57
+ print("Loading '{}'".format(filepath))
58
+ checkpoint_dict = torch.load(filepath, map_location=device)
59
+ print("Complete.")
60
+ return checkpoint_dict
61
+
62
+
63
+ def save_checkpoint(filepath, obj):
64
+ print("Saving checkpoint to {}".format(filepath))
65
+ torch.save(obj, filepath)
66
+ print("Complete.")
67
+
68
+
69
+ def scan_checkpoint(cp_dir, prefix):
70
+ pattern = os.path.join(cp_dir, prefix + '????????')
71
+ cp_list = glob.glob(pattern)
72
+ if len(cp_list) == 0:
73
+ return None
74
+ return sorted(cp_list)[-1]
75
+
76
+ def save_audio(audio, path, sr):
77
+ # wav: torch with 1d shape
78
+ audio = audio * MAX_WAV_VALUE
79
+ audio = audio.cpu().numpy().astype('int16')
80
+ write(path, sr, audio)