File size: 7,564 Bytes
5189ac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import torch
import torch.nn as nn
import torch.nn.functional as F

from .backbone import CNNEncoder
from .geometry import coords_grid
from .matching import (
    global_correlation_softmax_prototype,
    local_correlation_softmax_prototype,
)
from .transformer import FeatureTransformer
from .utils import feature_add_position


class UniMatch(nn.Module):
    def __init__(
        self,
        num_scales=1,
        feature_channels=128,
        upsample_factor=8,
        num_head=1,
        ffn_dim_expansion=4,
        num_transformer_layers=6,
        bilinear_upsample=False,
        corr_fn="global",
    ):
        super().__init__()

        self.feature_channels = feature_channels
        self.num_scales = num_scales
        self.upsample_factor = upsample_factor
        self.bilinear_upsample = bilinear_upsample
        if corr_fn == "global":
            self.corr_fn = global_correlation_softmax_prototype
        elif corr_fn == "local":
            self.corr_fn = local_correlation_softmax_prototype
        else:
            raise NotImplementedError(f"Correlation function {corr_fn} not implemented")

        # CNN
        self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)

        # Transformer
        self.transformer = FeatureTransformer(
            num_layers=num_transformer_layers,
            d_model=feature_channels,
            nhead=num_head,
            ffn_dim_expansion=ffn_dim_expansion,
        )

        # convex upsampling similar to RAFT
        # concat feature0 and low res flow as input
        if not bilinear_upsample:
            self.upsampler = nn.Sequential(
                nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
                nn.ReLU(inplace=True),
                nn.Conv2d(256, upsample_factor**2 * 9, 1, 1, 0),
            )

    def extract_feature(self, img0, img1):
        concat = torch.cat((img0, img1), dim=0)  # [2B, C, H, W]
        features = self.backbone(concat)  # list of [2B, C, H, W], resolution from high to low

        # reverse: resolution from low to high
        features = features[::-1]

        feature0, feature1 = [], []

        for i in range(len(features)):
            feature = features[i]
            chunks = torch.chunk(feature, 2, 0)  # tuple
            feature0.append(chunks[0])
            feature1.append(chunks[1])

        return feature0, feature1

    def correlate_feature(self, feature0, feature1, attn_splits=2, attn_type="swin"):
        feature0, feature1 = feature_add_position(
            feature0, feature1, attn_splits, self.feature_channels
        )
        feature0, feature1 = self.transformer(
            feature0,
            feature1,
            attn_type=attn_type,
            attn_num_splits=attn_splits,
        )
        b, c, h, w = feature0.shape
        feature0 = feature0.view(b, c, -1).permute(0, 2, 1)  # [B, H*W, C]
        feature1 = feature1.view(b, c, -1)  # [B, C, H*W]
        correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
            c**0.5
        )  # [B, H, W, H, W]
        correlation = correlation.view(b, h * w, h * w)  # [B, H*W, H*W]
        return correlation

    def forward(
        self,
        img0,
        img1,
        attn_type="swin",
        attn_splits=2,
        return_feature=False,
        bidirectional=False,
        cycle_consistency=False,
        corr_mask=None,
    ):
        # list of features, resolution low to high
        feature0_list, feature1_list = self.extract_feature(img0, img1)  # list of features
        assert self.num_scales == 1  # multi-scale depth model is not supported yet
        scale_idx = 0
        feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]

        if cycle_consistency:
            # get both directions of features
            feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat(
                (feature1, feature0), dim=0
            )

        # add position to features
        feature0, feature1 = feature_add_position(
            feature0, feature1, attn_splits, self.feature_channels
        )

        # Transformer
        feature0, feature1 = self.transformer(
            feature0,
            feature1,
            attn_type=attn_type,
            attn_num_splits=attn_splits,
        )
        b, c, h, w = feature0.shape
        # downsampled_img0 = F.interpolate(img0, size=(h, w), mode="bilinear", align_corners=False)
        flow_coords = coords_grid(b, h, w).to(feature0.device)  # [B, 2, H, W]
        # values = torch.cat((downsampled_img0, flow_coords), dim=1)  # [B, 5, H, W]
        # correlation and softmax
        query_results, correlation = self.corr_fn(
            feature0, feature1, flow_coords, pred_bidir_flow=bidirectional, corr_mask=corr_mask
        )
        if bidirectional:
            flow_coords = torch.cat((flow_coords, flow_coords), dim=0)
            up_feature = torch.cat((feature0, feature1), dim=0)
        else:
            up_feature = feature0
        flow = query_results - flow_coords
        flow_up = self.upsample_flow(flow, up_feature, bilinear=self.bilinear_upsample)
        if return_feature:
            return flow_up, flow, correlation, feature0, feature1
        else:
            return flow_up, flow, correlation

    def forward_features(
        self,
        img0,
        img1,
        attn_type="swin",
        attn_splits=2,
    ):

        feature0_list, feature1_list = self.extract_feature(img0, img1)  # list of features
        assert self.num_scales == 1  # multi-scale depth model is not supported yet
        scale_idx = 0
        feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
        # add position to features
        feature0, feature1 = feature_add_position(
            feature0, feature1, attn_splits, self.feature_channels
        )

        # Transformer
        feature0, feature1 = self.transformer(
            feature0,
            feature1,
            attn_type=attn_type,
            attn_num_splits=attn_splits,
        )
        return feature0, feature1

    def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, is_depth=False):
        if bilinear:
            multiplier = 1 if is_depth else upsample_factor
            up_flow = (
                F.interpolate(
                    flow, scale_factor=upsample_factor, mode="bilinear", align_corners=False
                )
                * multiplier
            )
        else:
            concat = torch.cat((flow, feature), dim=1)
            mask = self.upsampler(concat)
            up_flow = upsample_flow_with_mask(
                flow, mask, upsample_factor=self.upsample_factor, is_depth=is_depth
            )
        return up_flow


def upsample_flow_with_mask(flow, up_mask, upsample_factor, is_depth=False):
    # convex upsampling following raft

    mask = up_mask
    b, flow_channel, h, w = flow.shape
    mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w)  # [B, 1, 9, K, K, H, W]
    mask = torch.softmax(mask, dim=2)

    multiplier = 1 if is_depth else upsample_factor
    up_flow = F.unfold(multiplier * flow, [3, 3], padding=1)
    up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w)  # [B, 2, 9, 1, 1, H, W]

    up_flow = torch.sum(mask * up_flow, dim=2)  # [B, 2, K, K, H, W]
    up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)  # [B, 2, K, H, K, W]
    up_flow = up_flow.reshape(
        b, flow_channel, upsample_factor * h, upsample_factor * w
    )  # [B, 2, K*H, K*W]

    return up_flow