File size: 975 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from torch import nn


class MDNBlock(nn.Module):
    """Mixture of Density Network implementation
    https://arxiv.org/pdf/2003.01950.pdf
    """

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.out_channels = out_channels
        self.conv1 = nn.Conv1d(in_channels, in_channels, 1)
        self.norm = nn.LayerNorm(in_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        self.conv2 = nn.Conv1d(in_channels, out_channels, 1)

    def forward(self, x):
        o = self.conv1(x)
        o = o.transpose(1, 2)
        o = self.norm(o)
        o = o.transpose(1, 2)
        o = self.relu(o)
        o = self.dropout(o)
        mu_sigma = self.conv2(o)
        # TODO: check this sigmoid
        # mu = torch.sigmoid(mu_sigma[:, :self.out_channels//2, :])
        mu = mu_sigma[:, : self.out_channels // 2, :]
        log_sigma = mu_sigma[:, self.out_channels // 2 :, :]
        return mu, log_sigma