from typing import Dict, Tuple from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from transformers import PreTrainedModel from .dino import vit_small from .unimatch import UniMatch from .configuration_doduo import DoduoConfig class DoduoModel(PreTrainedModel): config_class = DoduoConfig def __init__(self, config): super().__init__(config) self.model = CorrSegFlowNet( dino_corr_mask_ratio=config.dino_corr_mask_ratio ) def forward(self, frame_src, frame_dst): if isinstance(frame_src, Image.Image): frame_src = self.model.process_frame(frame_src) frame_dst = self.model.process_frame(frame_dst) assert frame_src.shape == frame_dst.shape return self.model(frame_src, frame_dst) class CorrSegFlowNet(nn.Module): def __init__( self, dino_corr_mask_ratio: float = 0.1, ): super().__init__() self.dino_corr_mask_ratio = dino_corr_mask_ratio self.unimatch = UniMatch(bilinear_upsample=True) self.dino = vit_small(patch_size=8, num_classes=0) for k in self.dino.parameters(): k.requires_grad = False self.transform = transforms.Compose( [ lambda x: transforms.ToTensor()(x)[:3], transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) def process_frame(self, frame): device = next(self.parameters()).device frame = self.transform(frame) frame = frame.unsqueeze(0).to(device) return frame def forward( self, frame_src, frame_dst, ): corr_mask = get_dino_corr_mask( self.dino, frame_src, frame_dst, mask_ratio=self.dino_corr_mask_ratio ) flow, flow_low, correlation, feature0, feature1 = self.unimatch( frame_src, frame_dst, return_feature=True, bidirectional=False, cycle_consistency=False, corr_mask=corr_mask, ) return flow @torch.no_grad() def extract_dino_feature(model, frame, return_h_w=False): """frame: B, C, H, W""" B = frame.shape[0] out = model.get_intermediate_layers(frame, n=1)[0] out = out[:, 1:, :] # we discard the [CLS] token h, w = int(frame.shape[2] / model.patch_embed.patch_size), int( frame.shape[3] / model.patch_embed.patch_size ) dim = out.shape[-1] out = out.reshape(B, -1, dim) if return_h_w: return out, h, w return out @torch.no_grad() def get_dino_corr_mask( model, frame_src, frame_dst, mask_ratio ): # frame_src: B x C x H x W # frame_dst: B x C x H x W # mask_ratio: ratio of pixels to be masked # return: B x h*w x h*w feat_1, h, w = extract_dino_feature(model, frame_src, return_h_w=True) feat_2 = extract_dino_feature(model, frame_dst) feat_1_norm = F.normalize(feat_1, dim=2, p=2) feat_2_norm = F.normalize(feat_2, dim=2, p=2) aff_raw = torch.einsum("bnc,bmc->bnm", [feat_1_norm, feat_2_norm]) if mask_ratio <= 0: # no corr mask corr_mask = None else: if aff_raw.dtype == torch.float16: aff_raw = aff_raw.float() aff_percentile = torch.quantile(aff_raw, mask_ratio, 2, keepdim=True) # True for masked corr_mask = aff_raw < aff_percentile return corr_mask