from torch import nn from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock from TTS.tts.layers.generic.transformer import FFTransformerBlock from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer class RelativePositionTransformerEncoder(nn.Module): """Speedy speech encoder built on Transformer with Relative Position encoding. TODO: Integrate speaker conditioning vector. Args: in_channels (int): number of input channels. out_channels (int): number of output channels. hidden_channels (int): number of hidden channels params (dict): dictionary for residual convolutional blocks. """ def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() self.prenet = ResidualConv1dBNBlock( in_channels, hidden_channels, hidden_channels, kernel_size=5, num_res_blocks=3, num_conv_blocks=1, dilations=[1, 1, 1], ) self.rel_pos_transformer = RelativePositionTransformer(hidden_channels, out_channels, hidden_channels, **params) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument if x_mask is None: x_mask = 1 o = self.prenet(x) * x_mask o = self.rel_pos_transformer(o, x_mask) return o class ResidualConv1dBNEncoder(nn.Module): """Residual Convolutional Encoder as in the original Speedy Speech paper TODO: Integrate speaker conditioning vector. Args: in_channels (int): number of input channels. out_channels (int): number of output channels. hidden_channels (int): number of hidden channels params (dict): dictionary for residual convolutional blocks. """ def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1), nn.ReLU()) self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **params) self.postnet = nn.Sequential( *[ nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU(), nn.BatchNorm1d(hidden_channels), nn.Conv1d(hidden_channels, out_channels, 1), ] ) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument if x_mask is None: x_mask = 1 o = self.prenet(x) * x_mask o = self.res_conv_block(o, x_mask) o = self.postnet(o + x) * x_mask return o * x_mask class Encoder(nn.Module): # pylint: disable=dangerous-default-value """Factory class for Speedy Speech encoder enables different encoder types internally. Args: num_chars (int): number of characters. out_channels (int): number of output channels. in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers. encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. encoder_params (dict): model parameters for specified encoder type. c_in_channels (int): number of channels for conditional input. Note: Default encoder_params to be set in config.json... ```python # for 'relative_position_transformer' encoder_params={ 'hidden_channels_ffn': 128, 'num_heads': 2, "kernel_size": 3, "dropout_p": 0.1, "num_layers": 6, "rel_attn_window_size": 4, "input_length": None }, # for 'residual_conv_bn' encoder_params = { "kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13 } # for 'fftransformer' encoder_params = { "hidden_channels_ffn": 1024 , "num_heads": 2, "num_layers": 6, "dropout_p": 0.1 } ``` """ def __init__( self, in_hidden_channels, out_channels, encoder_type="residual_conv_bn", encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13}, c_in_channels=0, ): super().__init__() self.out_channels = out_channels self.in_channels = in_hidden_channels self.hidden_channels = in_hidden_channels self.encoder_type = encoder_type self.c_in_channels = c_in_channels # init encoder if encoder_type.lower() == "relative_position_transformer": # text encoder # pylint: disable=unexpected-keyword-arg self.encoder = RelativePositionTransformerEncoder( in_hidden_channels, out_channels, in_hidden_channels, encoder_params ) elif encoder_type.lower() == "residual_conv_bn": self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, out_channels, in_hidden_channels, encoder_params) elif encoder_type.lower() == "fftransformer": assert ( in_hidden_channels == out_channels ), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'" # pylint: disable=unexpected-keyword-arg self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params) else: raise NotImplementedError(" [!] unknown encoder type.") def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument """ Shapes: x: [B, C, T] x_mask: [B, 1, T] g: [B, C, 1] """ o = self.encoder(x, x_mask) return o * x_mask