Source code for aquamvs.fusion

"""Depth map fusion into unified point clouds."""

from pathlib import Path

import numpy as np
import open3d as o3d
import torch
import torch.nn.functional as F

from .config import ReconstructionConfig
from .projection.protocol import ProjectionModel


def _sample_depth_map(
    depth_map: torch.Tensor,
    pixels: torch.Tensor,
    valid: torch.Tensor,
) -> torch.Tensor:
    """Sample a depth map at sub-pixel locations via bilinear interpolation.

    Args:
        depth_map: Depth map, shape (H, W), float32. NaN for invalid.
        pixels: Pixel coordinates (u, v), shape (K, 2), float32.
        valid: Validity mask, shape (K,), bool. Invalid pixels are skipped.

    Returns:
        Sampled depths, shape (K,), float32. NaN for invalid pixels or
        out-of-bounds locations.
    """
    H, W = depth_map.shape
    result = torch.full((pixels.shape[0],), float("nan"), device=pixels.device)

    if not valid.any():
        return result

    # Replace NaN in depth map with 0 for grid_sample (NaN poisons bilinear)
    depth_clean = torch.where(
        torch.isnan(depth_map), torch.zeros_like(depth_map), depth_map
    )
    nan_mask = torch.isnan(depth_map).float()  # 1 where NaN

    # Normalize pixel coords to [-1, 1] for grid_sample
    valid_pixels = pixels[valid]  # (V, 2)
    grid_x = 2.0 * valid_pixels[:, 0] / (W - 1) - 1.0
    grid_y = 2.0 * valid_pixels[:, 1] / (H - 1) - 1.0
    grid = torch.stack([grid_x, grid_y], dim=-1).reshape(1, 1, -1, 2)

    # Sample depth
    depth_4d = depth_clean.unsqueeze(0).unsqueeze(0)  # (1, 1, H, W)
    sampled = F.grid_sample(
        depth_4d, grid, mode="bilinear", padding_mode="zeros", align_corners=True
    )
    sampled = sampled.reshape(-1)  # (V,)

    # Also sample the NaN mask to detect if any neighbor was NaN
    nan_sampled = F.grid_sample(
        nan_mask.unsqueeze(0).unsqueeze(0),
        grid,
        mode="bilinear",
        padding_mode="border",
        align_corners=True,
    )
    nan_sampled = nan_sampled.reshape(-1)

    # Mark as NaN where any neighbor was NaN (nan_sampled > 0)
    sampled = torch.where(
        nan_sampled > 0.01,
        torch.tensor(float("nan"), device=sampled.device),
        sampled,
    )

    result[valid] = sampled
    return result


[docs] def filter_depth_map( ref_name: str, ref_model: ProjectionModel, ref_depth: torch.Tensor, ref_confidence: torch.Tensor, target_names: list[str], target_models: dict[str, ProjectionModel], target_depths: dict[str, torch.Tensor], config: ReconstructionConfig, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Filter a depth map by cross-view geometric consistency. For each valid pixel in the reference depth map, reprojects into each target camera and checks depth agreement. Pixels consistent with fewer than min_consistent_views other cameras are invalidated. Args: ref_name: Reference camera name. ref_model: Projection model for the reference camera. ref_depth: Reference depth map, shape (H, W), float32. NaN = invalid. ref_confidence: Reference confidence map, shape (H, W), float32. target_names: List of other camera names to check against. target_models: Camera name to ProjectionModel mapping. target_depths: Camera name to depth map (H, W) mapping. config: Fusion configuration. Returns: filtered_depth: Depth map with inconsistent pixels set to NaN, shape (H, W), float32. filtered_confidence: Confidence map with inconsistent pixels set to 0, shape (H, W), float32. consistency_count: Number of agreeing views per pixel, shape (H, W), int32. 0 for invalid/filtered pixels. """ H, W = ref_depth.shape device = ref_depth.device # Step 1: Build valid pixel mask (not NaN, above min_confidence) valid_mask = ~torch.isnan(ref_depth) & (ref_confidence >= config.min_confidence) valid_indices = torch.where(valid_mask.reshape(-1))[0] # (K,) flat indices N_valid = valid_indices.shape[0] if N_valid == 0: # No valid pixels -- return all-invalid return ( torch.full((H, W), float("nan"), device=device), torch.zeros(H, W, device=device), torch.zeros(H, W, device=device, dtype=torch.int32), ) # Step 2: Get pixel coordinates for valid pixels v_coords = valid_indices // W # row u_coords = valid_indices % W # col pixels = torch.stack([u_coords.float(), v_coords.float()], dim=-1) # (K, 2) # Step 3: Back-project to 3D depths_flat = ref_depth.reshape(-1)[valid_indices] # (K,) origins, directions = ref_model.cast_ray(pixels) # (K, 3), (K, 3) points_3d = origins + depths_flat.unsqueeze(-1) * directions # (K, 3) # Step 4: Check consistency against each target camera consistency_count = torch.zeros(N_valid, device=device, dtype=torch.int32) for target_name in target_names: model_t = target_models[target_name] depth_t = target_depths[target_name] # (H, W) # Project 3D points into target camera pixels_t, valid_t = model_t.project(points_3d) # (K, 2), (K,) # Look up target depth at projected pixel locations (bilinear) depth_t_lookup = _sample_depth_map(depth_t, pixels_t, valid_t) # (K,) # Cast rays in target camera to get expected ray depth # Only for pixels that are valid in both projection and depth lookup lookup_valid = valid_t & ~torch.isnan(depth_t_lookup) if lookup_valid.any(): origins_t, directions_t = model_t.cast_ray(pixels_t[lookup_valid]) d_expected = ((points_3d[lookup_valid] - origins_t) * directions_t).sum( dim=-1 ) # Compare depths depth_diff = (depth_t_lookup[lookup_valid] - d_expected).abs() consistent = depth_diff < config.depth_tolerance # Accumulate count consistent_full = torch.zeros(N_valid, device=device, dtype=torch.bool) consistent_full[lookup_valid] = consistent consistency_count += consistent_full.int() # Step 5: Apply min_consistent_views threshold passes = consistency_count >= config.min_consistent_views # Step 6: Build output maps filtered_depth = torch.full((H, W), float("nan"), device=device) filtered_confidence = torch.zeros(H, W, device=device) consistency_map = torch.zeros(H, W, device=device, dtype=torch.int32) passing_indices = valid_indices[passes] filtered_depth.reshape(-1)[passing_indices] = depths_flat[passes] filtered_confidence.reshape(-1)[passing_indices] = ref_confidence.reshape(-1)[ passing_indices ] consistency_map.reshape(-1)[valid_indices] = consistency_count return filtered_depth, filtered_confidence, consistency_map
def filter_all_depth_maps( ring_cameras: list[str], projection_models: dict[str, ProjectionModel], depth_maps: dict[str, torch.Tensor], confidence_maps: dict[str, torch.Tensor], config: ReconstructionConfig, ) -> dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Run geometric consistency filtering for all ring cameras. Args: ring_cameras: List of ring camera names (reference views). projection_models: Camera name to ProjectionModel mapping. depth_maps: Camera name to depth map (H, W) mapping. confidence_maps: Camera name to confidence map (H, W) mapping. config: Fusion configuration. Returns: Dict mapping camera name to (filtered_depth, filtered_confidence, consistency_count) tuple. """ results = {} for ref_name in ring_cameras: if ref_name not in depth_maps: continue # Target cameras = all other ring cameras that have depth maps target_names = [ name for name in ring_cameras if name != ref_name and name in depth_maps ] filtered_depth, filtered_conf, count = filter_depth_map( ref_name=ref_name, ref_model=projection_models[ref_name], ref_depth=depth_maps[ref_name], ref_confidence=confidence_maps[ref_name], target_names=target_names, target_models=projection_models, target_depths=depth_maps, config=config, ) results[ref_name] = (filtered_depth, filtered_conf, count) return results
[docs] def backproject_depth_map( model: ProjectionModel, depth_map: torch.Tensor, image: torch.Tensor, confidence: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: """Back-project a filtered depth map to a colored 3D point cloud. For each valid (non-NaN) pixel in the depth map, casts a ray through the projection model, computes the 3D point at the stored ray depth, and samples the color from the reference image. Args: model: Projection model for this camera. depth_map: Filtered depth map, shape (H, W), float32. NaN = invalid. image: Reference BGR image, shape (H, W, 3), uint8. confidence: Optional confidence map, shape (H, W), float32. If provided, included in the output for downstream use. Returns: Dict with keys: - "points": shape (K, 3), float32 -- 3D world coordinates - "colors": shape (K, 3), float32 -- RGB colors in [0, 1] - "confidence": shape (K,), float32 -- confidence values (if provided, else ones) Where K is the number of valid (non-NaN) pixels. Returns empty tensors if no valid pixels. """ H, W = depth_map.shape device = depth_map.device # Ensure image is on the same device as the depth map (torch.from_numpy # always returns CPU tensors, but depth maps may be on CUDA -- see B.10) image = image.to(device) # Find valid pixels valid_mask = ~torch.isnan(depth_map) valid_indices = torch.where(valid_mask.reshape(-1))[0] # (K,) K = valid_indices.shape[0] if K == 0: return { "points": torch.zeros(0, 3, device=device), "colors": torch.zeros(0, 3, device=device), "confidence": torch.zeros(0, device=device), } # Get pixel coordinates v_coords = valid_indices // W u_coords = valid_indices % W pixels = torch.stack([u_coords.float(), v_coords.float()], dim=-1) # (K, 2) # Get depths depths = depth_map.reshape(-1)[valid_indices] # (K,) # Cast rays and compute 3D points origins, directions = model.cast_ray(pixels) # (K, 3), (K, 3) points = origins + depths.unsqueeze(-1) * directions # (K, 3) # Sample colors from the image (BGR uint8 -> RGB float [0, 1]) # Use integer pixel coords (nearest neighbor for color -- sub-pixel not needed) colors_bgr = image[v_coords.long(), u_coords.long()] # (K, 3), uint8 colors_rgb = colors_bgr[:, [2, 1, 0]].float() / 255.0 # (K, 3), RGB [0, 1] # Confidence if confidence is not None: conf = confidence.reshape(-1)[valid_indices] # (K,) else: conf = torch.ones(K, device=device) return { "points": points, "colors": colors_rgb, "confidence": conf, }
[docs] def fuse_depth_maps( ring_cameras: list[str], projection_models: dict[str, ProjectionModel], filtered_depth_maps: dict[str, torch.Tensor], filtered_confidence_maps: dict[str, torch.Tensor], images: dict[str, torch.Tensor], config: ReconstructionConfig, ) -> o3d.geometry.PointCloud: """Fuse filtered depth maps from all ring cameras into a single point cloud. Back-projects each camera's filtered depth map to 3D with colors, concatenates all points, deduplicates via voxel grid downsampling, and estimates normals. Args: ring_cameras: List of ring camera names. projection_models: Camera name to ProjectionModel mapping. filtered_depth_maps: Camera name to filtered depth map (H, W) mapping (output of filter_all_depth_maps). filtered_confidence_maps: Camera name to filtered confidence map (H, W) mapping. images: Camera name to BGR image (H, W, 3) uint8 mapping. config: Fusion configuration. Returns: Open3D PointCloud with positions, colors, and estimated normals. """ # Step 1: Back-project each camera's depth map all_points = [] all_colors = [] for cam_name in ring_cameras: if cam_name not in filtered_depth_maps: continue result = backproject_depth_map( model=projection_models[cam_name], depth_map=filtered_depth_maps[cam_name], image=images[cam_name], confidence=filtered_confidence_maps.get(cam_name), ) if result["points"].shape[0] > 0: all_points.append(result["points"]) all_colors.append(result["colors"]) if len(all_points) == 0: # Return empty point cloud return o3d.geometry.PointCloud() # Step 2: Concatenate points_cat = torch.cat(all_points, dim=0) # (N_total, 3) colors_cat = torch.cat(all_colors, dim=0) # (N_total, 3) # Step 3: Convert to Open3D point cloud pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points_cat.cpu().numpy().astype(np.float64)) pcd.colors = o3d.utility.Vector3dVector(colors_cat.cpu().numpy().astype(np.float64)) # Step 4: Voxel grid downsampling for deduplication pcd = pcd.voxel_down_sample(voxel_size=config.voxel_size) # Step 5: Estimate normals # Use a search radius proportional to voxel size for local consistency pcd.estimate_normals( search_param=o3d.geometry.KDTreeSearchParamHybrid( radius=config.voxel_size * 5, max_nn=30, ) ) # Orient normals toward camera (upward, -Z in our coordinate system) # The cameras are above the surface, so normals should point toward -Z pcd.orient_normals_towards_camera_location( camera_location=np.array([0.0, 0.0, 0.0]) ) return pcd
def save_point_cloud( pcd: o3d.geometry.PointCloud, path: str | Path, ) -> None: """Save a point cloud to a PLY file (binary). Args: pcd: Open3D PointCloud with points, colors, and optionally normals. path: Output file path (should end with .ply). """ o3d.io.write_point_cloud(str(path), pcd, write_ascii=False) def load_point_cloud( path: str | Path, ) -> o3d.geometry.PointCloud: """Load a point cloud from a PLY file. Args: path: Path to .ply file. Returns: Open3D PointCloud. """ return o3d.io.read_point_cloud(str(path))