import math from typing import Tuple import torch import torch.nn.functional as F from jaxtyping import Float, Integer from torch import Tensor from sf3d.models.utils import dot, triangle_intersection_2d def _box_assign_vertex_to_cube_face( vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"], triangle_idxs: Integer[Tensor, "Nf 3"], bbox: Float[Tensor, "2 3"], ) -> Tuple[Float[Tensor, "Nf 3 2"], Integer[Tensor, "Nf 3"]]: # Test to not have a scaled model to fit the space better # bbox_min = bbox[:1].mean(-1, keepdim=True) # bbox_max = bbox[1:].mean(-1, keepdim=True) # v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min) # Create a [0, 1] normalized vertex position v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1]) # And to [-1, 1] v_pos_normalized = 2.0 * v_pos_normalized - 1.0 # Get all vertex positions for each triangle # Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos? v0 = v_pos_normalized[triangle_idxs[:, 0]] v1 = v_pos_normalized[triangle_idxs[:, 1]] v2 = v_pos_normalized[triangle_idxs[:, 2]] tri_stack = torch.stack([v0, v1, v2], dim=1) vn0 = vertex_normals[triangle_idxs[:, 0]] vn1 = vertex_normals[triangle_idxs[:, 1]] vn2 = vertex_normals[triangle_idxs[:, 2]] tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1) # Just average the normals per face face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1) # Now decide based on the face normal in which box map we project # abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1) abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1) axis = torch.tensor( [ [1, 0, 0], # 0 [-1, 0, 0], # 1 [0, 1, 0], # 2 [0, -1, 0], # 3 [0, 0, 1], # 4 [0, 0, -1], # 5 ], device=face_normal.device, dtype=face_normal.dtype, ) face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1) index = face_normal_axis.argmax(-1) max_axis, uc, vc = ( torch.ones_like(abs_x), torch.zeros_like(tri_stack[..., :1]), torch.zeros_like(tri_stack[..., :1]), ) mask_pos_x = index == 0 max_axis[mask_pos_x] = abs_x[mask_pos_x] uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2] vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:] mask_neg_x = index == 1 max_axis[mask_neg_x] = abs_x[mask_neg_x] uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2] vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:] mask_pos_y = index == 2 max_axis[mask_pos_y] = abs_y[mask_pos_y] uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1] vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:] mask_neg_y = index == 3 max_axis[mask_neg_y] = abs_y[mask_neg_y] uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1] vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:] mask_pos_z = index == 4 max_axis[mask_pos_z] = abs_z[mask_pos_z] uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1] vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2] mask_neg_z = index == 5 max_axis[mask_neg_z] = abs_z[mask_neg_z] uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1] vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2] # UC from [-1, 1] to [0, 1] max_dim_div = max_axis.max(dim=0, keepdims=True).values uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1) vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1) uv = torch.stack([uc, vc], dim=-1) return uv, index def _assign_faces_uv_to_atlas_index( vertex_positions: Float[Tensor, "Nv 3"], triangle_idxs: Integer[Tensor, "Nf 3"], face_uv: Float[Tensor, "Nf 3 2"], face_index: Integer[Tensor, "Nf 3"], ) -> Integer[Tensor, "Nf"]: # noqa: F821 triangle_pos = vertex_positions[triangle_idxs] # We need to do perform 3 overlap checks. # The first set is placed in the upper two thirds of the UV atlas. # Conceptually, this is the direct visible surfaces from the each cube side # The second set is placed in the lower thirds and the left half of the UV atlas. # This is the first set of occluded surfaces. They will also be saved in the projected fashion # The third pass finds all non assigned faces. They will be placed in the bottom right half of # the UV atlas in scattered fashion. assign_idx = face_index.clone() for overlap_step in range(3): overlapping_indicator = torch.zeros_like(assign_idx, dtype=torch.bool) for i in range(overlap_step * 6, (overlap_step + 1) * 6): mask = assign_idx == i if not mask.any(): continue # Get all elements belonging to the projection face uv_triangle = face_uv[mask] cur_triangle_pos = triangle_pos[mask] # Find the center of the uv coordinates center_uv = uv_triangle.mean(dim=1, keepdim=True) # And also the radius of the triangle uv_triangle_radius = (uv_triangle - center_uv).norm(dim=-1).max(-1).values potentially_overlapping_mask = ( # Find all close triangles (center_uv[None, ...] - center_uv[:, None]).norm(dim=-1) # Do not select the same element by offseting with an large valued identity matrix + torch.eye( uv_triangle.shape[0], device=uv_triangle.device, dtype=uv_triangle.dtype, ).unsqueeze(-1) * 1000 ) # Mark all potentially overlapping triangles to reduce the number of triangle intersection tests potentially_overlapping_mask = ( potentially_overlapping_mask <= (uv_triangle_radius.view(-1, 1, 1) * 3.0) ).squeeze(-1) overlap_coords = torch.stack(torch.where(potentially_overlapping_mask), -1) # Only unique triangles (A|B and B|A should be the same) f = torch.min(overlap_coords, dim=-1).values s = torch.max(overlap_coords, dim=-1).values overlap_coords = torch.unique(torch.stack([f, s], dim=1), dim=0) first, second = overlap_coords.unbind(-1) # Get the triangles tri_1 = uv_triangle[first] tri_2 = uv_triangle[second] # Perform the actual set with the reduced number of potentially overlapping triangles its = triangle_intersection_2d(tri_1, tri_2, eps=1e-6) # So we now need to detect which triangles are the occluded ones. # We always assume the first to be the visible one (the others should move) # In the previous step we use a lexigraphical sort to get the unique pairs # In this we use a sort based on the orthographic projection ax = 0 if i < 2 else 1 if i < 4 else 2 use_max = i % 2 == 1 tri1_c = cur_triangle_pos[first].mean(dim=1) tri2_c = cur_triangle_pos[second].mean(dim=1) mark_first = ( (tri1_c[..., ax] > tri2_c[..., ax]) if use_max else (tri1_c[..., ax] < tri2_c[..., ax]) ) first[mark_first] = second[mark_first] # Lastly the same index can be tested multiple times. # If one marks it as overlapping we keep it marked as such. # We do this by testing if it has been marked at least once. unique_idx, rev_idx = torch.unique(first, return_inverse=True) add = torch.zeros_like(unique_idx, dtype=torch.float32) add.index_add_(0, rev_idx, its.float()) its_mask = add > 0 # And fill it in the overlapping indicator idx = torch.where(mask)[0][unique_idx] overlapping_indicator[idx] = its_mask # Move the index to the overlap regions (shift by 6) assign_idx[overlapping_indicator] += 6 # We do not care about the correct face placement after the first 2 slices max_idx = 6 * 2 return assign_idx.clamp(0, max_idx) def _find_slice_offset_and_scale( index: Integer[Tensor, "Nf"], # noqa: F821 ) -> Tuple[ Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"] # noqa: F821 ]: # noqa: F821 # 6 due to the 6 cube faces off = 1 / 3 dupl_off = 1 / 6 # Here, we need to decide how to pack the textures in the case of overlap def x_offset_calc(x, i): offset_calc = i // 6 # Initial coordinates - just 3x2 grid if offset_calc == 0: return off * x else: # Smaller 3x2 grid plus eventual shift to right for # second overlap return dupl_off * x + min(offset_calc - 1, 1) * 0.5 def y_offset_calc(x, i): offset_calc = i // 6 # Initial coordinates - just a 3x2 grid if offset_calc == 0: return off * x else: # Smaller coordinates in the lowest row return dupl_off * x + off * 2 offset_x = torch.zeros_like(index, dtype=torch.float32) offset_y = torch.zeros_like(index, dtype=torch.float32) offset_x_vals = [0, 1, 2, 0, 1, 2] offset_y_vals = [0, 0, 0, 1, 1, 1] for i in range(index.max().item() + 1): mask = index == i if not mask.any(): continue offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i) offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i) div_x = torch.full_like(index, 6 // 2, dtype=torch.float32) # All overlap elements are saved in half scale div_x[index >= 6] = 6 div_y = div_x.clone() # Same for y # Except for the random overlaps div_x[index >= 12] = 2 # But the random overlaps are saved in a large block in the lower thirds div_y[index >= 12] = 3 return offset_x, offset_y, div_x, div_y def rotation_flip_matrix_2d( rad: float, flip_x: bool = False, flip_y: bool = False ) -> Float[Tensor, "2 2"]: cos = math.cos(rad) sin = math.sin(rad) rot_mat = torch.tensor([[cos, -sin], [sin, cos]], dtype=torch.float32) flip_mat = torch.tensor( [ [-1 if flip_x else 1, 0], [0, -1 if flip_y else 1], ], dtype=torch.float32, ) return flip_mat @ rot_mat def calculate_tangents( vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"], triangle_idxs: Integer[Tensor, "Nf 3"], face_uv: Float[Tensor, "Nf 3 2"], ) -> Float[Tensor, "Nf 3 4"]: # noqa: F821 vn_idx = [None] * 3 pos = [None] * 3 tex = face_uv.unbind(1) for i in range(0, 3): pos[i] = vertex_positions[triangle_idxs[:, i]] # t_nrm_idx is always the same as t_pos_idx vn_idx[i] = triangle_idxs[:, i] tangents = torch.zeros_like(vertex_normals) tansum = torch.zeros_like(vertex_normals) # Compute tangent space for each triangle duv1 = tex[1] - tex[0] duv2 = tex[2] - tex[0] dpos1 = pos[1] - pos[0] dpos2 = pos[2] - pos[0] tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2] denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1] # Avoid division by zero for degenerated texture coordinates denom_safe = denom.clip(1e-6) tang = tng_nom / denom_safe # Update all 3 vertices for i in range(0, 3): idx = vn_idx[i][:, None].repeat(1, 3) tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang tansum.scatter_add_( 0, idx, torch.ones_like(tang) ) # tansum[n_i] = tansum[n_i] + 1 # Also normalize it. Here we do not normalize the individual triangles first so larger area # triangles influence the tangent space more tangents = tangents / tansum # Normalize and make sure tangent is perpendicular to normal tangents = F.normalize(tangents, dim=1) tangents = F.normalize(tangents - dot(tangents, vertex_normals) * vertex_normals) return tangents def _rotate_uv_slices_consistent_space( vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"], triangle_idxs: Integer[Tensor, "Nf 3"], uv: Float[Tensor, "Nf 3 2"], index: Integer[Tensor, "Nf"], # noqa: F821 ): tangents = calculate_tangents(vertex_positions, vertex_normals, triangle_idxs, uv) pos_stack = torch.stack( [ -vertex_positions[..., 1], vertex_positions[..., 0], torch.zeros_like(vertex_positions[..., 0]), ], dim=-1, ) expected_tangents = F.normalize( torch.linalg.cross( vertex_normals, torch.linalg.cross(pos_stack, vertex_normals) ), -1, ) actual_tangents = tangents[triangle_idxs] expected_tangents = expected_tangents[triangle_idxs] def rotation_matrix_2d(theta): c, s = torch.cos(theta), torch.sin(theta) return torch.tensor([[c, -s], [s, c]]) # Now find the rotation index_mod = index % 6 # Shouldn't happen. Just for safety for i in range(6): mask = index_mod == i if not mask.any(): continue actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1)) expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1)) dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent) cross_product = ( actual_mean_tangent[0] * expected_mean_tangent[1] - actual_mean_tangent[1] * expected_mean_tangent[0] ) angle = torch.atan2(cross_product, dot_product) rot_matrix = rotation_matrix_2d(angle).to(mask.device) # Center the uv coordinate to be in the range of -1 to 1 and 0 centered uv_cur = uv[mask] * 2 - 1 # Center it first # Rotate it uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur) # Rescale uv[mask] to be within the 0-1 range uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min()) return uv def _handle_slice_uvs( uv: Float[Tensor, "Nf 3 2"], index: Integer[Tensor, "Nf"], # noqa: F821 island_padding: float, max_index: int = 6 * 2, ) -> Float[Tensor, "Nf 3 2"]: # noqa: F821 uc, vc = uv.unbind(-1) # Get the second slice (The first overlap) index_filter = [index == i for i in range(6, max_index)] # Normalize them to always fully fill the atlas patch for i, fi in enumerate(index_filter): if fi.sum() > 0: # Scale the slice but only up to a factor of 2 # This keeps the texture resolution with the first slice in line (Half space in UV) uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(0.5) vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(0.5) uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1) vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1) return torch.stack([uc_padded, vc_padded], dim=-1) def _handle_remaining_uvs( uv: Float[Tensor, "Nf 3 2"], index: Integer[Tensor, "Nf"], # noqa: F821 island_padding: float, ) -> Float[Tensor, "Nf 3 2"]: uc, vc = uv.unbind(-1) # Get all remaining elements remaining_filter = index >= 6 * 2 squares_left = remaining_filter.sum() if squares_left == 0: return uv uc = uc[remaining_filter] vc = vc[remaining_filter] # Or remaining triangles are distributed in a rectangle # The rectangle takes 0.5 of the entire uv space in width and 1/3 in height ratio = 0.5 * (1 / 3) # 1.5 # sqrt(744/(0.5*(1/3))) mult = math.sqrt(squares_left / ratio) num_square_width = int(math.ceil(0.5 * mult)) num_square_height = int(math.ceil(squares_left / num_square_width)) width = 1 / num_square_width height = 1 / num_square_height # The idea is again to keep the texture resolution consistent with the first slice # This only occupys half the region in the texture chart but the scaling on the squares # assumes full coverage. clip_val = min(width, height) * 1.5 # Now normalize the UVs with taking into account the maximum scaling uc = (uc - uc.min(dim=1, keepdim=True).values) / ( uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True) ).clip(clip_val) vc = (vc - vc.min(dim=1, keepdim=True).values) / ( vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True) ).clip(clip_val) # Add a small padding uc = ( uc * (1 - island_padding * num_square_width * 0.5) + island_padding * num_square_width * 0.25 ).clip(0, 1) vc = ( vc * (1 - island_padding * num_square_height * 0.5) + island_padding * num_square_height * 0.25 ).clip(0, 1) uc = uc * width vc = vc * height # And calculate offsets for each element idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32) x_idx = idx % num_square_width y_idx = idx // num_square_width # And move each triangle to its own spot uc = uc + x_idx[:, None] * width vc = vc + y_idx[:, None] * height uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1) vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1) uv[remaining_filter] = torch.stack([uc, vc], dim=-1) return uv def _distribute_individual_uvs_in_atlas( face_uv: Float[Tensor, "Nf 3 2"], assigned_faces: Integer[Tensor, "Nf"], # noqa: F821 offset_x: Float[Tensor, "Nf"], # noqa: F821 offset_y: Float[Tensor, "Nf"], # noqa: F821 div_x: Float[Tensor, "Nf"], # noqa: F821 div_y: Float[Tensor, "Nf"], # noqa: F821 island_padding: float, ): # Place the slice first placed_uv = _handle_slice_uvs(face_uv, assigned_faces, island_padding) # Then handle the remaining overlap elements placed_uv = _handle_remaining_uvs(placed_uv, assigned_faces, island_padding) uc, vc = placed_uv.unbind(-1) uc = uc / div_x[:, None] + offset_x[:, None] vc = vc / div_y[:, None] + offset_y[:, None] uv = torch.stack([uc, vc], dim=-1).view(-1, 2) return uv def _get_unique_face_uv( uv: Float[Tensor, "Nf 3 2"], ) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821 unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0) # And add the face to uv index mapping vtex_idx = unique_idx.view(-1, 3) return unique_uv, vtex_idx def _align_mesh_with_main_axis( vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"] ) -> Tuple[Float[Tensor, "Nv 3"], Float[Tensor, "Nv 3"]]: # Use pca to find the 2 main axis (third is derived by cross product) # Set the random seed so it's repeatable torch.manual_seed(0) _, _, v = torch.pca_lowrank(vertex_positions, q=2) main_axis, seconday_axis = v[:, 0], v[:, 1] main_axis: Float[Tensor, "3"] = F.normalize(main_axis, eps=1e-6, dim=-1) # Orthogonalize the second axis seconday_axis: Float[Tensor, "3"] = F.normalize( seconday_axis - dot(seconday_axis, main_axis) * main_axis, eps=1e-6, dim=-1 ) # Create perpendicular third axis third_axis: Float[Tensor, "3"] = F.normalize( torch.cross(main_axis, seconday_axis), dim=-1, eps=1e-6 ) # Check to which canonical axis each aligns main_axis_max_idx = main_axis.abs().argmax().item() seconday_axis_max_idx = seconday_axis.abs().argmax().item() third_axis_max_idx = third_axis.abs().argmax().item() # Now sort the axes based on the argmax so they align with thecanonoical axes # If two axes have the same argmax move one of them all_possible_axis = {0, 1, 2} cur_index = 1 while len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx])) != 3: # Find missing axis missing_axis = all_possible_axis - set( [main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx] ) missing_axis = missing_axis.pop() # Just assign it to third axis as it had the smallest contribution to the # overall shape if cur_index == 1: third_axis_max_idx = missing_axis elif cur_index == 2: seconday_axis_max_idx = missing_axis else: raise ValueError("Could not find 3 unique axis") cur_index += 1 if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3: raise ValueError("Could not find 3 unique axis") axes = [None] * 3 axes[main_axis_max_idx] = main_axis axes[seconday_axis_max_idx] = seconday_axis axes[third_axis_max_idx] = third_axis # Create rotation matrix from the individual axes rot_mat = torch.stack(axes, dim=1).T # Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions) vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals) return vertex_positions, vertex_normals def box_projection_uv_unwrap( vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"], triangle_idxs: Integer[Tensor, "Nf 3"], island_padding: float, ) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821 # Align the mesh with main axis directions first vertex_positions, vertex_normals = _align_mesh_with_main_axis( vertex_positions, vertex_normals ) bbox: Float[Tensor, "2 3"] = torch.stack( [vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values], dim=0 ) # First decide in which cube face the triangle is placed face_uv, face_index = _box_assign_vertex_to_cube_face( vertex_positions, vertex_normals, triangle_idxs, bbox ) # Rotate the UV islands in a way that they align with the radial z tangent space face_uv = _rotate_uv_slices_consistent_space( vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index ) # Then find where where the face is placed in the atlas. # This has to detect potential overlaps assigned_atlas_index = _assign_faces_uv_to_atlas_index( vertex_positions, triangle_idxs, face_uv, face_index ) # Then figure out the final place in the atlas based on the assignment offset_x, offset_y, div_x, div_y = _find_slice_offset_and_scale( assigned_atlas_index ) # Next distribute the faces in the uv atlas placed_uv = _distribute_individual_uvs_in_atlas( face_uv, assigned_atlas_index, offset_x, offset_y, div_x, div_y, island_padding ) # And get the unique per-triangle UV coordinates return _get_unique_face_uv(placed_uv)