doduo / unimatch.py
stevetod's picture
Upload model (#1)
5189ac9
raw
history blame contribute delete
No virus
7.56 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from .backbone import CNNEncoder
from .geometry import coords_grid
from .matching import (
global_correlation_softmax_prototype,
local_correlation_softmax_prototype,
)
from .transformer import FeatureTransformer
from .utils import feature_add_position
class UniMatch(nn.Module):
def __init__(
self,
num_scales=1,
feature_channels=128,
upsample_factor=8,
num_head=1,
ffn_dim_expansion=4,
num_transformer_layers=6,
bilinear_upsample=False,
corr_fn="global",
):
super().__init__()
self.feature_channels = feature_channels
self.num_scales = num_scales
self.upsample_factor = upsample_factor
self.bilinear_upsample = bilinear_upsample
if corr_fn == "global":
self.corr_fn = global_correlation_softmax_prototype
elif corr_fn == "local":
self.corr_fn = local_correlation_softmax_prototype
else:
raise NotImplementedError(f"Correlation function {corr_fn} not implemented")
# CNN
self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
# Transformer
self.transformer = FeatureTransformer(
num_layers=num_transformer_layers,
d_model=feature_channels,
nhead=num_head,
ffn_dim_expansion=ffn_dim_expansion,
)
# convex upsampling similar to RAFT
# concat feature0 and low res flow as input
if not bilinear_upsample:
self.upsampler = nn.Sequential(
nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, upsample_factor**2 * 9, 1, 1, 0),
)
def extract_feature(self, img0, img1):
concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low
# reverse: resolution from low to high
features = features[::-1]
feature0, feature1 = [], []
for i in range(len(features)):
feature = features[i]
chunks = torch.chunk(feature, 2, 0) # tuple
feature0.append(chunks[0])
feature1.append(chunks[1])
return feature0, feature1
def correlate_feature(self, feature0, feature1, attn_splits=2, attn_type="swin"):
feature0, feature1 = feature_add_position(
feature0, feature1, attn_splits, self.feature_channels
)
feature0, feature1 = self.transformer(
feature0,
feature1,
attn_type=attn_type,
attn_num_splits=attn_splits,
)
b, c, h, w = feature0.shape
feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
feature1 = feature1.view(b, c, -1) # [B, C, H*W]
correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
c**0.5
) # [B, H, W, H, W]
correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
return correlation
def forward(
self,
img0,
img1,
attn_type="swin",
attn_splits=2,
return_feature=False,
bidirectional=False,
cycle_consistency=False,
corr_mask=None,
):
# list of features, resolution low to high
feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
assert self.num_scales == 1 # multi-scale depth model is not supported yet
scale_idx = 0
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
if cycle_consistency:
# get both directions of features
feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat(
(feature1, feature0), dim=0
)
# add position to features
feature0, feature1 = feature_add_position(
feature0, feature1, attn_splits, self.feature_channels
)
# Transformer
feature0, feature1 = self.transformer(
feature0,
feature1,
attn_type=attn_type,
attn_num_splits=attn_splits,
)
b, c, h, w = feature0.shape
# downsampled_img0 = F.interpolate(img0, size=(h, w), mode="bilinear", align_corners=False)
flow_coords = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
# values = torch.cat((downsampled_img0, flow_coords), dim=1) # [B, 5, H, W]
# correlation and softmax
query_results, correlation = self.corr_fn(
feature0, feature1, flow_coords, pred_bidir_flow=bidirectional, corr_mask=corr_mask
)
if bidirectional:
flow_coords = torch.cat((flow_coords, flow_coords), dim=0)
up_feature = torch.cat((feature0, feature1), dim=0)
else:
up_feature = feature0
flow = query_results - flow_coords
flow_up = self.upsample_flow(flow, up_feature, bilinear=self.bilinear_upsample)
if return_feature:
return flow_up, flow, correlation, feature0, feature1
else:
return flow_up, flow, correlation
def forward_features(
self,
img0,
img1,
attn_type="swin",
attn_splits=2,
):
feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
assert self.num_scales == 1 # multi-scale depth model is not supported yet
scale_idx = 0
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
# add position to features
feature0, feature1 = feature_add_position(
feature0, feature1, attn_splits, self.feature_channels
)
# Transformer
feature0, feature1 = self.transformer(
feature0,
feature1,
attn_type=attn_type,
attn_num_splits=attn_splits,
)
return feature0, feature1
def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, is_depth=False):
if bilinear:
multiplier = 1 if is_depth else upsample_factor
up_flow = (
F.interpolate(
flow, scale_factor=upsample_factor, mode="bilinear", align_corners=False
)
* multiplier
)
else:
concat = torch.cat((flow, feature), dim=1)
mask = self.upsampler(concat)
up_flow = upsample_flow_with_mask(
flow, mask, upsample_factor=self.upsample_factor, is_depth=is_depth
)
return up_flow
def upsample_flow_with_mask(flow, up_mask, upsample_factor, is_depth=False):
# convex upsampling following raft
mask = up_mask
b, flow_channel, h, w = flow.shape
mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W]
mask = torch.softmax(mask, dim=2)
multiplier = 1 if is_depth else upsample_factor
up_flow = F.unfold(multiplier * flow, [3, 3], padding=1)
up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W]
up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
up_flow = up_flow.reshape(
b, flow_channel, upsample_factor * h, upsample_factor * w
) # [B, 2, K*H, K*W]
return up_flow