doduo / modeling_doduo.py
stevetod's picture
Upload model (#1)
5189ac9
raw
history blame contribute delete
No virus
3.55 kB
from typing import Dict, Tuple
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from transformers import PreTrainedModel
from .dino import vit_small
from .unimatch import UniMatch
from .configuration_doduo import DoduoConfig
class DoduoModel(PreTrainedModel):
config_class = DoduoConfig
def __init__(self, config):
super().__init__(config)
self.model = CorrSegFlowNet(
dino_corr_mask_ratio=config.dino_corr_mask_ratio
)
def forward(self, frame_src, frame_dst):
if isinstance(frame_src, Image.Image):
frame_src = self.model.process_frame(frame_src)
frame_dst = self.model.process_frame(frame_dst)
assert frame_src.shape == frame_dst.shape
return self.model(frame_src, frame_dst)
class CorrSegFlowNet(nn.Module):
def __init__(
self,
dino_corr_mask_ratio: float = 0.1,
):
super().__init__()
self.dino_corr_mask_ratio = dino_corr_mask_ratio
self.unimatch = UniMatch(bilinear_upsample=True)
self.dino = vit_small(patch_size=8, num_classes=0)
for k in self.dino.parameters():
k.requires_grad = False
self.transform = transforms.Compose(
[
lambda x: transforms.ToTensor()(x)[:3],
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def process_frame(self, frame):
device = next(self.parameters()).device
frame = self.transform(frame)
frame = frame.unsqueeze(0).to(device)
return frame
def forward(
self,
frame_src,
frame_dst,
):
corr_mask = get_dino_corr_mask(
self.dino,
frame_src,
frame_dst,
mask_ratio=self.dino_corr_mask_ratio
)
flow, flow_low, correlation, feature0, feature1 = self.unimatch(
frame_src,
frame_dst,
return_feature=True,
bidirectional=False,
cycle_consistency=False,
corr_mask=corr_mask,
)
return flow
@torch.no_grad()
def extract_dino_feature(model, frame, return_h_w=False):
"""frame: B, C, H, W"""
B = frame.shape[0]
out = model.get_intermediate_layers(frame, n=1)[0]
out = out[:, 1:, :] # we discard the [CLS] token
h, w = int(frame.shape[2] / model.patch_embed.patch_size), int(
frame.shape[3] / model.patch_embed.patch_size
)
dim = out.shape[-1]
out = out.reshape(B, -1, dim)
if return_h_w:
return out, h, w
return out
@torch.no_grad()
def get_dino_corr_mask(
model, frame_src, frame_dst, mask_ratio
):
# frame_src: B x C x H x W
# frame_dst: B x C x H x W
# mask_ratio: ratio of pixels to be masked
# return: B x h*w x h*w
feat_1, h, w = extract_dino_feature(model, frame_src, return_h_w=True)
feat_2 = extract_dino_feature(model, frame_dst)
feat_1_norm = F.normalize(feat_1, dim=2, p=2)
feat_2_norm = F.normalize(feat_2, dim=2, p=2)
aff_raw = torch.einsum("bnc,bmc->bnm", [feat_1_norm, feat_2_norm])
if mask_ratio <= 0:
# no corr mask
corr_mask = None
else:
if aff_raw.dtype == torch.float16:
aff_raw = aff_raw.float()
aff_percentile = torch.quantile(aff_raw, mask_ratio, 2, keepdim=True)
# True for masked
corr_mask = aff_raw < aff_percentile
return corr_mask