|
from typing import Dict, Tuple |
|
from PIL import Image |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from torchvision import transforms |
|
from transformers import PreTrainedModel |
|
|
|
from .dino import vit_small |
|
from .unimatch import UniMatch |
|
from .configuration_doduo import DoduoConfig |
|
|
|
class DoduoModel(PreTrainedModel): |
|
config_class = DoduoConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = CorrSegFlowNet( |
|
dino_corr_mask_ratio=config.dino_corr_mask_ratio |
|
) |
|
|
|
def forward(self, frame_src, frame_dst): |
|
if isinstance(frame_src, Image.Image): |
|
frame_src = self.model.process_frame(frame_src) |
|
frame_dst = self.model.process_frame(frame_dst) |
|
assert frame_src.shape == frame_dst.shape |
|
return self.model(frame_src, frame_dst) |
|
|
|
class CorrSegFlowNet(nn.Module): |
|
def __init__( |
|
self, |
|
dino_corr_mask_ratio: float = 0.1, |
|
): |
|
super().__init__() |
|
|
|
self.dino_corr_mask_ratio = dino_corr_mask_ratio |
|
self.unimatch = UniMatch(bilinear_upsample=True) |
|
self.dino = vit_small(patch_size=8, num_classes=0) |
|
for k in self.dino.parameters(): |
|
k.requires_grad = False |
|
|
|
self.transform = transforms.Compose( |
|
[ |
|
lambda x: transforms.ToTensor()(x)[:3], |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
def process_frame(self, frame): |
|
device = next(self.parameters()).device |
|
frame = self.transform(frame) |
|
frame = frame.unsqueeze(0).to(device) |
|
return frame |
|
|
|
def forward( |
|
self, |
|
frame_src, |
|
frame_dst, |
|
): |
|
corr_mask = get_dino_corr_mask( |
|
self.dino, |
|
frame_src, |
|
frame_dst, |
|
mask_ratio=self.dino_corr_mask_ratio |
|
) |
|
|
|
flow, flow_low, correlation, feature0, feature1 = self.unimatch( |
|
frame_src, |
|
frame_dst, |
|
return_feature=True, |
|
bidirectional=False, |
|
cycle_consistency=False, |
|
corr_mask=corr_mask, |
|
) |
|
return flow |
|
|
|
@torch.no_grad() |
|
def extract_dino_feature(model, frame, return_h_w=False): |
|
"""frame: B, C, H, W""" |
|
B = frame.shape[0] |
|
out = model.get_intermediate_layers(frame, n=1)[0] |
|
out = out[:, 1:, :] |
|
h, w = int(frame.shape[2] / model.patch_embed.patch_size), int( |
|
frame.shape[3] / model.patch_embed.patch_size |
|
) |
|
dim = out.shape[-1] |
|
out = out.reshape(B, -1, dim) |
|
if return_h_w: |
|
return out, h, w |
|
return out |
|
|
|
@torch.no_grad() |
|
def get_dino_corr_mask( |
|
model, frame_src, frame_dst, mask_ratio |
|
): |
|
|
|
|
|
|
|
|
|
feat_1, h, w = extract_dino_feature(model, frame_src, return_h_w=True) |
|
feat_2 = extract_dino_feature(model, frame_dst) |
|
|
|
feat_1_norm = F.normalize(feat_1, dim=2, p=2) |
|
feat_2_norm = F.normalize(feat_2, dim=2, p=2) |
|
aff_raw = torch.einsum("bnc,bmc->bnm", [feat_1_norm, feat_2_norm]) |
|
|
|
if mask_ratio <= 0: |
|
|
|
corr_mask = None |
|
else: |
|
if aff_raw.dtype == torch.float16: |
|
aff_raw = aff_raw.float() |
|
aff_percentile = torch.quantile(aff_raw, mask_ratio, 2, keepdim=True) |
|
|
|
corr_mask = aff_raw < aff_percentile |
|
return corr_mask |