LVM / torch_vqvae_model.py
Emma02's picture
Add application file
a858bb2
raw
history blame
No virus
9.23 kB
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from einops.layers.torch import Rearrange
def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
def swish(x):
return x*torch.sigmoid(x)
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, activation_fn="relu"):
super(ResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
self.activation_fn = activation_fn
if activation_fn=="relu":
self.actn = nn.ReLU()
def forward(self, x_in):
x = x_in
x = self.norm1(x)
if self.activation_fn=="relu":
x = self.actn(x)
elif self.activation_fn=="swish":
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
if self.activation_fn=="relu":
x = self.actn(x)
elif self.activation_fn=="swish":
x = swish(x)
x = self.conv2(x)
if self.in_channels != self.out_channels:
x_in = self.conv_out(x_in)
return x + x_in
class Encoder(nn.Module):
def __init__(self, ):
super().__init__()
self.filters = 128
self.num_res_blocks = 2
self.ch_mult = [1,1,2,2,4]
self.in_ch_mult = (1,)+tuple(self.ch_mult)
self.embedding_dim = 32
self.conv_downsample = False
self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias=False)
blocks = []
for i in range(len(self.ch_mult)):
block_in_ch = self.filters * self.in_ch_mult[i]
block_out_ch = self.filters * self.ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
block_in_ch = block_out_ch
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
self.norm1 = normalize(block_in_ch)
self.conv2 = nn.Conv2d(block_in_ch, self.embedding_dim, kernel_size=1, stride=1, padding=0)
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
x = self.conv1(x)
for i in range(len(self.ch_mult)):
for j in range(self.num_res_blocks):
x = self.blocks[i*2+j](x)
if i < len(self.ch_mult) -1:
x = torch.nn.functional.avg_pool2d(x, (2,2),(2,2))
x = self.blocks[-2](x)
x = self.blocks[-1](x)
x = self.norm1(x)
x = swish(x)
x = self.conv2(x)
return x
class VectorQuantizer(nn.Module):
def __init__(self, codebook_size=8192, emb_dim=32, beta=None):
super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
self.beta=0.0
self.z_dim = emb_dim
def forward(self, z):
# preprocess
b, c, h, w = z.size()
flatten = z.permute(0, 2, 3, 1).reshape(-1, c)
codebook = self.embedding.weight
with torch.no_grad():
tokens = torch.cdist(flatten, codebook).argmin(dim=1)
quantized = F.embedding(tokens,
codebook).view(b, h, w, c).permute(0, 3, 1, 2)
# compute loss
codebook_loss = F.mse_loss(quantized, z.detach())
commitment_loss = F.mse_loss(quantized.detach(), z)
loss = codebook_loss + self.beta * commitment_loss
# perplexity
counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype)
# dist.all_reduce(counts)
p = counts / counts.sum()
perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10)))
# postprocess
tokens = tokens.view(b, h, w)
quantized = z + (quantized - z).detach()
# quantized_2 = self.get_codebook_feat(tokens, (b, h, w, c))
return quantized, tokens, loss, perplexity
def get_codebook_feat(self, indices, shape=None):
# input indices: batch*token_num -> (batch*token_num)*1
# shape: batch, height, width, channel
indices = indices.view(-1,1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None: # reshape back to match original input shape
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
return z_q
class Decoder(nn.Module):
def __init__(self,):
super().__init__()
self.filters = 128
self.num_res_blocks = 2
self.ch_mult = [1,1,2,2,4]
self.in_ch_mult = (1,)+tuple(self.ch_mult)
self.embedding_dim =32
self.out_channels = 3
self.in_channels = self.embedding_dim
self.conv_downsample = False
self.conv1 = nn.Conv2d(32, 512, kernel_size=3, stride=1, padding=1)
blocks = []
block_in_ch = self.filters * self.ch_mult[-1]
block_out_ch = self.filters * self.ch_mult[-1]
#blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
upsample_conv_layers = []
for i in reversed(range(len(self.ch_mult))):
block_out_ch = self.filters * self.ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch, activation_fn="swish"))
block_in_ch = block_out_ch
if i > 0:
upsample_conv_layers.append(nn.Conv2d(block_in_ch, block_out_ch*4, kernel_size=3, stride=1, padding=1))
self.upsample = Rearrange("b h w (h2 w2 c) -> b (h h2) (w w2) c", h2=2, w2=2)
self.norm1 = normalize(block_in_ch)
# self.act_fn
self.conv6 = nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)
self.blocks = nn.ModuleList(blocks)
self.up_convs = nn.ModuleList(upsample_conv_layers)
def forward(self, x):
x = self.conv1(x)
x = self.blocks[0](x)
x = self.blocks[1](x)
for i in range(len(self.ch_mult)):
for j in range(self.num_res_blocks):
x = self.blocks[2+i*2+j](x)
if i < len(self.ch_mult)-1:
x = self.up_convs[i](x)
#print("pre: x.size()",x.size())
x = x.permute(0,2,3,1)
x = self.upsample(x)
x = x.permute(0,3,1,2)
#print("post: x.size()", x.size())
x = self.norm1(x)
x = swish(x)
x = self.conv6(x)
return x
class VQVAE(nn.Module):
def __init__(self, ):
super().__init__()
self.encoder = Encoder()
self.quantizer = VectorQuantizer()
self.decoder = Decoder()
def forward(self, x):
x = self.encoder(x)
quant,tokens, loss, perplexity = self.quantizer(x)
x = self.decoder(quant)
return x
def tokenize(self, x):
batch_shape = x.shape[:-3]
x = x.reshape(-1, *x.shape[-3:])
x = self.encoder(x)
quant,tokens, loss, perplexity = self.quantizer(x)
return tokens.reshape(*batch_shape, *tokens.shape[1:])
def decode(self, tokens):
tokens = einops.rearrange(tokens, 'b ... -> b (...)')
b = tokens.shape[0]
if tokens.shape[-1] == 256:
hw = 16
elif tokens.shape[-1] == 224:
hw = 14
else:
raise ValueError("Invalid tokens shape")
quant = self.quantizer.get_codebook_feat(tokens, (b, hw, hw, 32))
x = self.decoder(quant)
return x
class VAEDecoder(nn.Module):
def __init__(self, ):
super().__init__()
self.quantizer = VectorQuantizer()
self.decoder = Decoder()
def forward(self, x):
quant = self.quantizer.get_codebook_feat(x,(1,14,14,32))
x = self.decoder(quant)
return x
def get_tokenizer():
checkpoint_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "xh_ckpt.pth"
)
torch_state_dict = torch.load(checkpoint_path)
net = VQVAE()
net.load_state_dict(torch_state_dict)
return net