WaveGRU-Text-To-Speech / inference.py
NTT123
add fast cpp wavegru
d1a84ee
raw
history blame
No virus
2.4 kB
import jax
import jax.numpy as jnp
import librosa
import numpy as np
import pax
from text import english_cleaners
from utils import (
create_tacotron_model,
load_tacotron_ckpt,
load_tacotron_config,
load_wavegru_ckpt,
load_wavegru_config,
)
from wavegru import WaveGRU
def load_tacotron_model(alphabet_file, config_file, model_file):
"""load tacotron model to memory"""
with open(alphabet_file, "r", encoding="utf-8") as f:
alphabet = f.read().split("\n")
config = load_tacotron_config(config_file)
net = create_tacotron_model(config)
_, net, _ = load_tacotron_ckpt(net, None, model_file)
net = net.eval()
net = jax.device_put(net)
return alphabet, net, config
tacotron_inference_fn = pax.pure(lambda net, text: net.inference(text, max_len=10000))
def text_to_mel(net, text, alphabet, config):
"""convert text to mel spectrogram"""
text = english_cleaners(text)
text = text + config["PAD"] * (100 - (len(text) % 100))
tokens = [alphabet.index(c) for c in text]
tokens = jnp.array(tokens, dtype=jnp.int32)
mel = tacotron_inference_fn(net, tokens[None])
return mel
def load_wavegru_net(config_file, model_file):
"""load wavegru to memory"""
config = load_wavegru_config(config_file)
net = WaveGRU(
mel_dim=config["mel_dim"],
embed_dim=config["embed_dim"],
rnn_dim=config["rnn_dim"],
upsample_factors=config["upsample_factors"],
)
_, net, _ = load_wavegru_ckpt(net, None, model_file)
net = net.eval()
net = jax.device_put(net)
return config, net
wavegru_inference = pax.pure(lambda net, mel: net.inference(mel, no_gru=True))
def mel_to_wav(net, netcpp, mel, config):
"""convert mel to wav"""
if len(mel.shape) == 2:
mel = mel[None]
pad = config["num_pad_frames"] // 2 + 4
mel = np.pad(
mel,
[(0, 0), (pad, pad), (0, 0)],
constant_values=np.log(config["mel_min"]),
)
ft = wavegru_inference(net, mel)
ft = jax.device_get(ft[0])
wav = netcpp.inference(ft, 1.0)
wav = np.array(wav)
wav = librosa.mu_expand(wav - 127, mu=255)
wav = librosa.effects.deemphasis(wav, coef=0.86)
wav = wav * 2.0
wav = wav / max(1.0, np.max(np.abs(wav)))
wav = wav * 2**15
wav = np.clip(wav, a_min=-(2**15), a_max=(2**15) - 1)
wav = wav.astype(np.int16)
return wav