doduo / geometry.py
stevetod's picture
Upload model (#1)
5189ac9
raw
history blame
No virus
6.64 kB
import torch
import torch.nn.functional as F
def coords_grid(b, h, w, homogeneous=False, device=None):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
stacks = [x, y]
if homogeneous:
ones = torch.ones_like(x) # [H, W]
stacks.append(ones)
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
if device is not None:
grid = grid.to(device)
return grid
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
assert device is not None
x, y = torch.meshgrid(
[
torch.linspace(w_min, w_max, len_w, device=device),
torch.linspace(h_min, h_max, len_h, device=device),
],
)
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
return grid
def normalize_coords(coords, h, w):
# coords: [B, H, W, 2]
c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device)
return (coords - c) / c # [-1, 1]
def bilinear_sample(img, sample_coords, mode="bilinear", padding_mode="zeros", return_mask=False):
# img: [B, C, H, W]
# sample_coords: [B, 2, H, W] in image scale
if sample_coords.size(1) != 2: # [B, H, W, 2]
sample_coords = sample_coords.permute(0, 3, 1, 2)
b, _, h, w = sample_coords.shape
# Normalize to [-1, 1]
x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=False)
if return_mask:
mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
return img, mask
return img
def flow_warp(feature, flow, mask=False, padding_mode="zeros"):
b, c, h, w = feature.size()
assert flow.size(1) == 2
grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
return bilinear_sample(feature, grid, padding_mode=padding_mode, return_mask=mask)
def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5):
# fwd_flow, bwd_flow: [B, 2, H, W]
# alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
threshold = alpha * flow_mag + beta
fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
bwd_occ = (diff_bwd > threshold).float()
return fwd_occ, bwd_occ
def back_project(depth, intrinsics):
# Back project 2D pixel coords to 3D points
# depth: [B, H, W]
# intrinsics: [B, 3, 3]
b, h, w = depth.shape
grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3]
points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(
1
) # [B, 3, H, W]
return points
def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None):
# Transform 3D points from reference camera to target camera
# points_ref: [B, 3, H, W]
# extrinsics_ref: [B, 4, 4]
# extrinsics_tgt: [B, 4, 4]
# extrinsics_rel: [B, 4, 4], relative pose transform
b, _, h, w = points_ref.shape
if extrinsics_rel is None:
extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4]
points_tgt = (
torch.bmm(extrinsics_rel[:, :3, :3], points_ref.view(b, 3, -1))
+ extrinsics_rel[:, :3, -1:]
) # [B, 3, H*W]
points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W]
return points_tgt
def reproject(points_tgt, intrinsics, return_mask=False):
# reproject to target view
# points_tgt: [B, 3, H, W]
# intrinsics: [B, 3, 3]
b, _, h, w = points_tgt.shape
proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W]
X = proj_points[:, 0]
Y = proj_points[:, 1]
Z = proj_points[:, 2].clamp(min=1e-3)
pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(
b, 2, h, w
) # [B, 2, H, W] in image scale
if return_mask:
# valid mask in pixel space
mask = (
(pixel_coords[:, 0] >= 0)
& (pixel_coords[:, 0] <= (w - 1))
& (pixel_coords[:, 1] >= 0)
& (pixel_coords[:, 1] <= (h - 1))
) # [B, H, W]
return pixel_coords, mask
return pixel_coords
def reproject_coords(
depth_ref,
intrinsics,
extrinsics_ref=None,
extrinsics_tgt=None,
extrinsics_rel=None,
return_mask=False,
):
# Compute reprojection sample coords
points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W]
points_tgt = camera_transform(
points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel
)
if return_mask:
reproj_coords, mask = reproject(
points_tgt, intrinsics, return_mask=return_mask
) # [B, 2, H, W] in image scale
return reproj_coords, mask
reproj_coords = reproject(
points_tgt, intrinsics, return_mask=return_mask
) # [B, 2, H, W] in image scale
return reproj_coords
def compute_flow_with_depth_pose(
depth_ref,
intrinsics,
extrinsics_ref=None,
extrinsics_tgt=None,
extrinsics_rel=None,
return_mask=False,
):
b, h, w = depth_ref.shape
coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W]
if return_mask:
reproj_coords, mask = reproject_coords(
depth_ref,
intrinsics,
extrinsics_ref,
extrinsics_tgt,
extrinsics_rel=extrinsics_rel,
return_mask=return_mask,
) # [B, 2, H, W]
rigid_flow = reproj_coords - coords_init
return rigid_flow, mask
reproj_coords = reproject_coords(
depth_ref,
intrinsics,
extrinsics_ref,
extrinsics_tgt,
extrinsics_rel=extrinsics_rel,
return_mask=return_mask,
) # [B, 2, H, W]
rigid_flow = reproj_coords - coords_init
return rigid_flow