File size: 3,546 Bytes
5189ac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from typing import Dict, Tuple
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import transforms
from transformers import PreTrainedModel

from .dino import vit_small
from .unimatch import UniMatch
from .configuration_doduo import DoduoConfig

class DoduoModel(PreTrainedModel):
    config_class = DoduoConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = CorrSegFlowNet(
            dino_corr_mask_ratio=config.dino_corr_mask_ratio
        )

    def forward(self, frame_src, frame_dst):
        if isinstance(frame_src, Image.Image):
            frame_src = self.model.process_frame(frame_src)
            frame_dst = self.model.process_frame(frame_dst)
        assert frame_src.shape == frame_dst.shape
        return self.model(frame_src, frame_dst)
    
class CorrSegFlowNet(nn.Module):
    def __init__(
        self,
        dino_corr_mask_ratio: float = 0.1,
    ):
        super().__init__()

        self.dino_corr_mask_ratio = dino_corr_mask_ratio
        self.unimatch = UniMatch(bilinear_upsample=True)
        self.dino = vit_small(patch_size=8, num_classes=0)
        for k in self.dino.parameters():
            k.requires_grad = False

        self.transform = transforms.Compose(
            [
                lambda x: transforms.ToTensor()(x)[:3],
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

    def process_frame(self, frame):
        device = next(self.parameters()).device
        frame = self.transform(frame)
        frame = frame.unsqueeze(0).to(device)
        return frame

    def forward(
        self,
        frame_src,
        frame_dst,
    ):
        corr_mask = get_dino_corr_mask(
            self.dino,
            frame_src,
            frame_dst,
            mask_ratio=self.dino_corr_mask_ratio
        )

        flow, flow_low, correlation, feature0, feature1 = self.unimatch(
            frame_src,
            frame_dst,
            return_feature=True,
            bidirectional=False,
            cycle_consistency=False,
            corr_mask=corr_mask,
        )
        return flow

@torch.no_grad()
def extract_dino_feature(model, frame, return_h_w=False):
    """frame: B, C, H, W"""
    B = frame.shape[0]
    out = model.get_intermediate_layers(frame, n=1)[0]
    out = out[:, 1:, :]  # we discard the [CLS] token
    h, w = int(frame.shape[2] / model.patch_embed.patch_size), int(
        frame.shape[3] / model.patch_embed.patch_size
    )
    dim = out.shape[-1]
    out = out.reshape(B, -1, dim)
    if return_h_w:
        return out, h, w
    return out

@torch.no_grad()
def get_dino_corr_mask(
    model, frame_src, frame_dst, mask_ratio
):
    # frame_src: B x C x H x W
    # frame_dst: B x C x H x W
    # mask_ratio: ratio of pixels to be masked
    # return: B x h*w x h*w
    feat_1, h, w = extract_dino_feature(model, frame_src, return_h_w=True)
    feat_2 = extract_dino_feature(model, frame_dst)

    feat_1_norm = F.normalize(feat_1, dim=2, p=2)
    feat_2_norm = F.normalize(feat_2, dim=2, p=2)
    aff_raw = torch.einsum("bnc,bmc->bnm", [feat_1_norm, feat_2_norm])

    if mask_ratio <= 0:
        # no corr mask
        corr_mask = None
    else:
        if aff_raw.dtype == torch.float16:
            aff_raw = aff_raw.float()
        aff_percentile = torch.quantile(aff_raw, mask_ratio, 2, keepdim=True)
        # True for masked
        corr_mask = aff_raw < aff_percentile
    return corr_mask