From d3601e3fa34b10b6d10aa1c579f7a276d79d9ea3 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 12 Feb 2025 00:50:38 +0200 Subject: [PATCH] init das --- __init__.py | 5 +- custom_cogvideox_transformer_3d.py | 74 +- das/das_nodes.py | 146 ++ das/spatracker/__init__.py | 5 + das/spatracker/models/__init__.py | 5 + das/spatracker/models/build_spatracker.py | 55 + das/spatracker/models/core/__init__.py | 5 + das/spatracker/models/core/embeddings.py | 250 +++ das/spatracker/models/core/model_utils.py | 477 ++++++ .../models/core/spatracker/__init__.py | 5 + .../models/core/spatracker/blocks.py | 999 ++++++++++++ .../models/core/spatracker/dpt/__init__.py | 0 .../models/core/spatracker/dpt/base_model.py | 16 + .../models/core/spatracker/dpt/blocks.py | 394 +++++ .../models/core/spatracker/dpt/midas_net.py | 77 + .../models/core/spatracker/dpt/models.py | 231 +++ .../models/core/spatracker/dpt/transforms.py | 231 +++ .../models/core/spatracker/dpt/vit.py | 596 +++++++ .../models/core/spatracker/feature_net.py | 916 +++++++++++ .../models/core/spatracker/loftr/__init__.py | 1 + .../core/spatracker/loftr/linear_attention.py | 81 + .../core/spatracker/loftr/transformer.py | 142 ++ .../models/core/spatracker/losses.py | 90 + .../models/core/spatracker/softsplat.py | 539 ++++++ .../models/core/spatracker/spatracker.py | 736 +++++++++ das/spatracker/models/core/spatracker/unet.py | 258 +++ .../models/core/spatracker/vit/__init__.py | 0 .../models/core/spatracker/vit/common.py | 43 + .../models/core/spatracker/vit/encoder.py | 397 +++++ das/spatracker/predictor.py | 288 ++++ das/spatracker/utils/__init__.py | 5 + das/spatracker/utils/basic.py | 397 +++++ das/spatracker/utils/geom.py | 547 +++++++ das/spatracker/utils/improc.py | 1447 +++++++++++++++++ das/spatracker/utils/misc.py | 166 ++ das/spatracker/utils/samp.py | 152 ++ das/spatracker/utils/visualizer.py | 409 +++++ das/spatracker/utils/vox.py | 500 ++++++ model_loading.py | 10 +- nodes.py | 18 +- pipeline_cogvideox.py | 29 + 41 files changed, 10724 insertions(+), 18 deletions(-) create mode 100644 das/das_nodes.py create mode 100644 das/spatracker/__init__.py create mode 100644 das/spatracker/models/__init__.py create mode 100644 das/spatracker/models/build_spatracker.py create mode 100644 das/spatracker/models/core/__init__.py create mode 100644 das/spatracker/models/core/embeddings.py create mode 100644 das/spatracker/models/core/model_utils.py create mode 100644 das/spatracker/models/core/spatracker/__init__.py create mode 100644 das/spatracker/models/core/spatracker/blocks.py create mode 100644 das/spatracker/models/core/spatracker/dpt/__init__.py create mode 100644 das/spatracker/models/core/spatracker/dpt/base_model.py create mode 100644 das/spatracker/models/core/spatracker/dpt/blocks.py create mode 100644 das/spatracker/models/core/spatracker/dpt/midas_net.py create mode 100644 das/spatracker/models/core/spatracker/dpt/models.py create mode 100644 das/spatracker/models/core/spatracker/dpt/transforms.py create mode 100644 das/spatracker/models/core/spatracker/dpt/vit.py create mode 100644 das/spatracker/models/core/spatracker/feature_net.py create mode 100644 das/spatracker/models/core/spatracker/loftr/__init__.py create mode 100644 das/spatracker/models/core/spatracker/loftr/linear_attention.py create mode 100644 das/spatracker/models/core/spatracker/loftr/transformer.py create mode 100644 das/spatracker/models/core/spatracker/losses.py create mode 100644 das/spatracker/models/core/spatracker/softsplat.py create mode 100644 das/spatracker/models/core/spatracker/spatracker.py create mode 100644 das/spatracker/models/core/spatracker/unet.py create mode 100644 das/spatracker/models/core/spatracker/vit/__init__.py create mode 100644 das/spatracker/models/core/spatracker/vit/common.py create mode 100644 das/spatracker/models/core/spatracker/vit/encoder.py create mode 100644 das/spatracker/predictor.py create mode 100644 das/spatracker/utils/__init__.py create mode 100644 das/spatracker/utils/basic.py create mode 100644 das/spatracker/utils/geom.py create mode 100644 das/spatracker/utils/improc.py create mode 100644 das/spatracker/utils/misc.py create mode 100644 das/spatracker/utils/samp.py create mode 100644 das/spatracker/utils/visualizer.py create mode 100644 das/spatracker/utils/vox.py diff --git a/__init__.py b/__init__.py index a608714..afed27d 100644 --- a/__init__.py +++ b/__init__.py @@ -1,7 +1,8 @@ from .nodes import NODE_CLASS_MAPPINGS as NODES_CLASS, NODE_DISPLAY_NAME_MAPPINGS as NODES_DISPLAY from .model_loading import NODE_CLASS_MAPPINGS as MODEL_CLASS, NODE_DISPLAY_NAME_MAPPINGS as MODEL_DISPLAY +from .das.das_nodes import NODE_CLASS_MAPPINGS as DAS_CLASS, NODE_DISPLAY_NAME_MAPPINGS as DAS_DISPLAY -NODE_CLASS_MAPPINGS = {**NODES_CLASS, **MODEL_CLASS} -NODE_DISPLAY_NAME_MAPPINGS = {**NODES_DISPLAY, **MODEL_DISPLAY} +NODE_CLASS_MAPPINGS = {**NODES_CLASS, **MODEL_CLASS, **DAS_CLASS} +NODE_DISPLAY_NAME_MAPPINGS = {**NODES_DISPLAY, **MODEL_DISPLAY, **DAS_DISPLAY} __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] \ No newline at end of file diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 7a5d3b0..a459aba 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -453,6 +453,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): use_learned_positional_embeddings: bool = False, patch_bias: bool = True, attention_mode: Optional[str] = "sdpa", + das_transformer: bool = False, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -557,6 +558,41 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): else: #CogVideoX-5B self.teacache_coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02] + + #das + # Create linear layers for combining hidden states and tracking maps + if das_transformer: + num_tracking_blocks = 18 + self.combine_linears = nn.ModuleList( + [nn.Linear(inner_dim, inner_dim) for _ in range(num_tracking_blocks)] + ) + # Initialize weights of combine_linears to zero + for linear in self.combine_linears: + linear.weight.data.zero_() + linear.bias.data.zero_() + + # Create transformer blocks for processing tracking maps + self.transformer_blocks_copy = nn.ModuleList( + [ + CogVideoXBlock( + dim=inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + time_embed_dim=self.config.time_embed_dim, + dropout=self.config.dropout, + activation_fn=self.config.activation_fn, + attention_bias=self.config.attention_bias, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + ) + for _ in range(num_tracking_blocks) + ] + ) + + # For initial combination of hidden states and tracking maps + self.initial_combine_linear = nn.Linear(inner_dim, inner_dim) + self.initial_combine_linear.weight.data.zero_() + self.initial_combine_linear.bias.data.zero_() def _set_gradient_checkpointing(self, module, value=False): @@ -573,6 +609,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): controlnet_states: torch.Tensor = None, controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0, video_flow_features: Optional[torch.Tensor] = None, + tracking_maps: Optional[torch.Tensor] = None, return_dict: bool = True, ): batch_size, num_frames, channels, height, width = hidden_states.shape @@ -602,13 +639,27 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): #print("hidden_states before patch_embedding", hidden_states.shape) #torch.Size([2, 4, 16, 60, 90]) hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) - #print("hidden_states after patch_embedding", hidden_states.shape) #1.5: torch.Size([2, 2926, 3072]) #1.0: torch.Size([2, 5626, 3072]) hidden_states = self.embedding_dropout(hidden_states) - text_seq_length = encoder_hidden_states.shape[1] - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] - #print("hidden_states after split", hidden_states.shape) #1.5: torch.Size([2, 2700, 3072]) #1.0: torch.Size([2, 5400, 3072]) + if tracking_maps is not None: + # Process tracking maps + prompt_embed = encoder_hidden_states.clone() + tracking_maps_hidden_states = self.patch_embed(prompt_embed, tracking_maps) + tracking_maps_hidden_states = self.embedding_dropout(tracking_maps_hidden_states) + del prompt_embed + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + tracking_maps = tracking_maps_hidden_states[:, text_seq_length:] + + # Combine hidden states and tracking maps initially + combined = hidden_states + tracking_maps + tracking_maps = self.initial_combine_linear(combined) + else: + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] if self.use_fastercache: self.fastercache_counter+=1 @@ -706,6 +757,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): if self.use_teacache: ori_hidden_states = hidden_states.clone() ori_encoder_hidden_states = encoder_hidden_states.clone() + for i, block in enumerate(self.transformer_blocks): hidden_states, encoder_hidden_states = block( hidden_states=hidden_states, @@ -731,6 +783,18 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): controlnet_block_weight = controlnet_weights hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight + #das + if i < len(self.transformer_blocks_copy) and tracking_maps is not None: + tracking_maps, _ = self.transformer_blocks_copy[i]( + hidden_states=tracking_maps, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + # Combine hidden states and tracking maps + tracking_maps = self.combine_linears[i](tracking_maps) + hidden_states = hidden_states + tracking_maps + if self.use_teacache: self.previous_residual = hidden_states - ori_hidden_states self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states diff --git a/das/das_nodes.py b/das/das_nodes.py new file mode 100644 index 0000000..17b9b88 --- /dev/null +++ b/das/das_nodes.py @@ -0,0 +1,146 @@ +import torch +import comfy.model_management as mm +from ..utils import log +import os +import numpy as np +import folder_paths + +class CogVideoDASTrackingEncode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "vae": ("VAE",), + "images": ("IMAGE", ), + }, + "optional": { + "enable_tiling": ("BOOLEAN", {"default": True, "tooltip": "Enable tiling for the VAE to reduce memory usage"}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }, + } + + RETURN_TYPES = ("DASTRACKING",) + RETURN_NAMES = ("das_tracking",) + FUNCTION = "encode" + CATEGORY = "CogVideoWrapper" + + def encode(self, vae, images, enable_tiling=False, strength=1.0, start_percent=0.0, end_percent=1.0): + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + generator = torch.Generator(device=device).manual_seed(0) + + try: + vae.enable_slicing() + except: + pass + + vae_scaling_factor = vae.config.scaling_factor + + if enable_tiling: + from ..mz_enable_vae_encode_tiling import enable_vae_encode_tiling + enable_vae_encode_tiling(vae) + + vae.to(device) + + try: + vae._clear_fake_context_parallel_cache() + except: + pass + + tracking_maps = images.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W + tracking_first_frame = tracking_maps[:, :, 0:1, :, :] + tracking_first_frame *= 2.0 - 1.0 + print("tracking_first_frame shape: ", tracking_first_frame.shape) + + tracking_first_frame_latent = vae.encode(tracking_first_frame).latent_dist.sample(generator).permute(0, 2, 1, 3, 4) + tracking_first_frame_latent = tracking_first_frame_latent * vae_scaling_factor * strength + log.info(f"Encoded tracking first frame latents shape: {tracking_first_frame_latent.shape}") + + tracking_latents = vae.encode(tracking_maps).latent_dist.sample(generator).permute(0, 2, 1, 3, 4) # B, T, C, H, W + + tracking_latents = tracking_latents * vae_scaling_factor * strength + + log.info(f"Encoded tracking latents shape: {tracking_latents.shape}") + vae.to(offload_device) + + return ({ + "tracking_maps": tracking_latents, + "tracking_image_latents": tracking_first_frame_latent, + "start_percent": start_percent, + "end_percent": end_percent + }, ) + +class DAS_SpaTracker: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "images": ("IMAGE", ), + "depth_images": ("IMAGE", ), + "density": ("INT", {"default": 70, "min": 1, "max": 100, "step": 1}), + }, + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("tracking_video",) + FUNCTION = "encode" + CATEGORY = "CogVideoWrapper" + + def encode(self, images, depth_images, density): + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + generator = torch.Generator(device=device).manual_seed(0) + + model_path = os.path.join(folder_paths.models_dir, 'spatracker', 'spaT_final.pth') + from .spatracker.predictor import SpaTrackerPredictor + from .spatracker.utils.visualizer import Visualizer + + if not hasattr(self, "tracker"): + self.tracker = SpaTrackerPredictor( + checkpoint=model_path, + interp_shape=(384, 576), + seq_length=12 + ).to(device) + + segm_mask = np.ones((480, 720), dtype=np.uint8) + + video = images.permute(0, 3, 1, 2).to(device).unsqueeze(0) + video_depth = depth_images.permute(0, 3, 1, 2).to(device) + video_depth = video_depth[:, 0:1, :, :] + + pred_tracks, pred_visibility, T_Firsts = self.tracker( + video * 255, + video_depth=video_depth, + grid_size=density, + backward_tracking=False, + depth_predictor=None, + grid_query_frame=0, + segm_mask=torch.from_numpy(segm_mask)[None, None].to(device), + wind_length=12, + progressive_tracking=False + ) + + vis = Visualizer(grayscale=False, fps=24, pad_value=0) + msk_query = (T_Firsts == 0) + pred_tracks = pred_tracks[:,:,msk_query.squeeze()] + pred_visibility = pred_visibility[:,:,msk_query.squeeze()] + + tracking_video = vis.visualize(video=video, tracks=pred_tracks, + visibility=pred_visibility, save_video=False, + filename="temp") + + tracking_video = tracking_video.squeeze(0).permute(0, 2, 3, 1) # [T, H, W, C] + tracking_video = (tracking_video / 255.0).float() + + return (tracking_video,) + + +NODE_CLASS_MAPPINGS = { + "CogVideoDASTrackingEncode": CogVideoDASTrackingEncode, + "DAS_SpaTracker": DAS_SpaTracker, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "CogVideoDASTrackingEncode": "CogVideo DAS Tracking Encode", + "DAS_SpaTracker": "DAS SpaTracker", + } diff --git a/das/spatracker/__init__.py b/das/spatracker/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/das/spatracker/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/das/spatracker/models/__init__.py b/das/spatracker/models/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/das/spatracker/models/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/das/spatracker/models/build_spatracker.py b/das/spatracker/models/build_spatracker.py new file mode 100644 index 0000000..783b1b0 --- /dev/null +++ b/das/spatracker/models/build_spatracker.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .core.spatracker.spatracker import SpaTracker + + +def build_spatracker( + checkpoint: str, + seq_length: int = 8, +): + model_name = checkpoint.split("/")[-1].split(".")[0] + return build_spatracker_from_cfg(checkpoint=checkpoint, seq_length=seq_length) + + + +# model used to produce the results in the paper +def build_spatracker_from_cfg(checkpoint=None, seq_length=8): + return _build_spatracker( + stride=4, + sequence_len=seq_length, + checkpoint=checkpoint, + ) + + +def _build_spatracker( + stride, + sequence_len, + checkpoint=None, +): + spatracker = SpaTracker( + stride=stride, + S=sequence_len, + add_space_attn=True, + space_depth=6, + time_depth=6, + ) + if checkpoint is not None: + with open(checkpoint, "rb") as f: + if "safetensors" in checkpoint: + from safetensors.torch import load_file + state_dict = load_file(checkpoint) + else: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + if "model" in state_dict: + model_paras = spatracker.state_dict() + paras_dict = {k: v for k,v in state_dict["model"].items() if k in spatracker.state_dict()} + model_paras.update(paras_dict) + state_dict = model_paras + spatracker.load_state_dict(state_dict) + return spatracker diff --git a/das/spatracker/models/core/__init__.py b/das/spatracker/models/core/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/das/spatracker/models/core/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/das/spatracker/models/core/embeddings.py b/das/spatracker/models/core/embeddings.py new file mode 100644 index 0000000..1b84c0d --- /dev/null +++ b/das/spatracker/models/core/embeddings.py @@ -0,0 +1,250 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np + +def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = np.arange(grid_size_h, dtype=np.float32) + grid_w = np.arange(grid_size_w, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate( + [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +def get_3d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 3 == 0 + + # use half of dimensions to encode grid_h + B, S, N, _ = grid.shape + gridx = grid[..., 0].view(B*S*N).detach().cpu().numpy() + gridy = grid[..., 1].view(B*S*N).detach().cpu().numpy() + gridz = grid[..., 2].view(B*S*N).detach().cpu().numpy() + + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridx) # (N, D/3) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridy) # (N, D/3) + emb_z = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridz) # (N, D/3) + + + emb = np.concatenate([emb_h, emb_w, emb_z], axis=1) # (N, D) + emb = torch.from_numpy(emb).to(grid.device) + return emb.view(B, S, N, embed_dim) + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = np.arange(grid_size_h, dtype=np.float32) + grid_w = np.arange(grid_size_w, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate( + [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_2d_embedding(xy, C, cat_coords=True): + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = ( + torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) + ).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3 + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3 + return pe + + +def get_3d_embedding(xyz, C, cat_coords=True): + B, N, D = xyz.shape + assert D == 3 + + x = xyz[:, :, 0:1] + y = xyz[:, :, 1:2] + z = xyz[:, :, 2:3] + div_term = ( + torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C) + ).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) + pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe_z[:, :, 0::2] = torch.sin(z * div_term) + pe_z[:, :, 1::2] = torch.cos(z * div_term) + + pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3 + if cat_coords: + pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3 + return pe + + +def get_4d_embedding(xyzw, C, cat_coords=True): + B, N, D = xyzw.shape + assert D == 4 + + x = xyzw[:, :, 0:1] + y = xyzw[:, :, 1:2] + z = xyzw[:, :, 2:3] + w = xyzw[:, :, 3:4] + div_term = ( + torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C) + ).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) + pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) + pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe_z[:, :, 0::2] = torch.sin(z * div_term) + pe_z[:, :, 1::2] = torch.cos(z * div_term) + + pe_w[:, :, 0::2] = torch.sin(w * div_term) + pe_w[:, :, 1::2] = torch.cos(w * div_term) + + pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3 + if cat_coords: + pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3 + return pe + +import torch.nn as nn +class Embedder_Fourier(nn.Module): + def __init__(self, input_dim, max_freq_log2, N_freqs, + log_sampling=True, include_input=True, + periodic_fns=(torch.sin, torch.cos)): + ''' + :param input_dim: dimension of input to be embedded + :param max_freq_log2: log2 of max freq; min freq is 1 by default + :param N_freqs: number of frequency bands + :param log_sampling: if True, frequency bands are linerly sampled in log-space + :param include_input: if True, raw input is included in the embedding + :param periodic_fns: periodic functions used to embed input + ''' + super(Embedder_Fourier, self).__init__() + + self.input_dim = input_dim + self.include_input = include_input + self.periodic_fns = periodic_fns + + self.out_dim = 0 + if self.include_input: + self.out_dim += self.input_dim + + self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns) + + if log_sampling: + self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs) + else: + self.freq_bands = torch.linspace( + 2. ** 0., 2. ** max_freq_log2, N_freqs) + + self.freq_bands = self.freq_bands.numpy().tolist() + + def forward(self, + input: torch.Tensor, + rescale: float = 1.0): + ''' + :param input: tensor of shape [..., self.input_dim] + :return: tensor of shape [..., self.out_dim] + ''' + assert (input.shape[-1] == self.input_dim) + out = [] + if self.include_input: + out.append(input/rescale) + + for i in range(len(self.freq_bands)): + freq = self.freq_bands[i] + for p_fn in self.periodic_fns: + out.append(p_fn(input * freq)) + out = torch.cat(out, dim=-1) + + assert (out.shape[-1] == self.out_dim) + return out \ No newline at end of file diff --git a/das/spatracker/models/core/model_utils.py b/das/spatracker/models/core/model_utils.py new file mode 100644 index 0000000..3eda98a --- /dev/null +++ b/das/spatracker/models/core/model_utils.py @@ -0,0 +1,477 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from easydict import EasyDict as edict +from sklearn.decomposition import PCA +import matplotlib.pyplot as plt + +EPS = 1e-6 + +def nearest_sample2d(im, x, y, return_inbounds=False): + # x and y are each B, N + # output is B, C, N + if len(im.shape) == 5: + B, N, C, H, W = list(im.shape) + else: + B, C, H, W = list(im.shape) + N = list(x.shape)[1] + + x = x.float() + y = y.float() + H_f = torch.tensor(H, dtype=torch.float32) + W_f = torch.tensor(W, dtype=torch.float32) + + # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte() + y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() + inbounds = (x_valid & y_valid).float() + inbounds = inbounds.reshape( + B, N + ) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) + return output, inbounds + + return output # B, C, N + +def smart_cat(tensor1, tensor2, dim): + if tensor1 is None: + return tensor2 + return torch.cat([tensor1, tensor2], dim=dim) + + +def normalize_single(d): + # d is a whatever shape torch tensor + dmin = torch.min(d) + dmax = torch.max(d) + d = (d - dmin) / (EPS + (dmax - dmin)) + return d + + +def normalize(d): + # d is B x whatever. normalize within each element of the batch + out = torch.zeros(d.size()) + if d.is_cuda: + out = out.cuda() + B = list(d.size())[0] + for b in list(range(B)): + out[b] = normalize_single(d[b]) + return out + + +def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"): + # returns a meshgrid sized B x Y x X + + grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device)) + grid_y = torch.reshape(grid_y, [1, Y, 1]) + grid_y = grid_y.repeat(B, 1, X) + + grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device)) + grid_x = torch.reshape(grid_x, [1, 1, X]) + grid_x = grid_x.repeat(B, Y, 1) + + if stack: + # note we stack in xy order + # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) + grid = torch.stack([grid_x, grid_y], dim=-1) + return grid + else: + return grid_y, grid_x + + +def reduce_masked_mean(x, mask, dim=None, keepdim=False): + # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting + # returns shape-1 + # axis can be a list of axes + for (a, b) in zip(x.size(), mask.size()): + assert a == b # some shape mismatch! + prod = x * mask + if dim is None: + numer = torch.sum(prod) + denom = EPS + torch.sum(mask) + else: + numer = torch.sum(prod, dim=dim, keepdim=keepdim) + denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim) + + mean = numer / denom + return mean + + +def bilinear_sample2d(im, x, y, return_inbounds=False): + # x and y are each B, N + # output is B, C, N + if len(im.shape) == 5: + B, N, C, H, W = list(im.shape) + else: + B, C, H, W = list(im.shape) + N = list(x.shape)[1] + + x = x.float() + y = y.float() + H_f = torch.tensor(H, dtype=torch.float32) + W_f = torch.tensor(W, dtype=torch.float32) + + # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte() + y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() + inbounds = (x_valid & y_valid).float() + inbounds = inbounds.reshape( + B, N + ) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) + return output, inbounds + + return output # B, C, N + + +def procrustes_analysis(X0,X1,Weight): # [B,N,3] + # translation + t0 = X0.mean(dim=1,keepdim=True) + t1 = X1.mean(dim=1,keepdim=True) + X0c = X0-t0 + X1c = X1-t1 + # scale + # s0 = (X0c**2).sum(dim=-1).mean().sqrt() + # s1 = (X1c**2).sum(dim=-1).mean().sqrt() + # X0cs = X0c/s0 + # X1cs = X1c/s1 + # rotation (use double for SVD, float loses precision) + U,_,V = (X0c.t()@X1c).double().svd(some=True) + R = (U@V.t()).float() + if R.det()<0: R[2] *= -1 + # align X1 to X0: X1to0 = (X1-t1)/@R.t()+t0 + se3 = edict(t0=t0[0],t1=t1[0],R=R) + + return se3 + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + coords = coords * torch.tensor( + [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device + ) + else: + coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device) + + coords -= 1 + + return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view( + B, -1, feats.shape[1] * feats.shape[3] + ) # B C R 1 -> B R C + + +def sample_features5d(input, coords): + r"""Sample spatio-temporal features + + `sample_features5d(input, coords)` works in the same way as + :func:`sample_features4d` but for spatio-temporal features and points: + :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is + a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i, + x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`. + + Args: + input (Tensor): spatio-temporal features. + coords (Tensor): spatio-temporal points. + + Returns: + Tensor: sampled features. + """ + + B, T, _, _, _ = input.shape + + # B T C H W -> B C T H W + input = input.permute(0, 2, 1, 3, 4) + + # B R1 R2 3 -> B R1 R2 1 3 + coords = coords.unsqueeze(3) + + # B C R1 R2 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 3, 1, 4).view( + B, feats.shape[2], feats.shape[3], feats.shape[1] + ) # B C R1 R2 1 -> B R1 R2 C + +def vis_PCA(fmaps, save_dir): + """ + visualize the PCA of the feature maps + args: + fmaps: feature maps 1 C H W + save_dir: the directory to save the PCA visualization + """ + + pca = PCA(n_components=3) + fmap_vis = fmaps[0,...] + fmap_vnorm = ( + (fmap_vis-fmap_vis.min())/ + (fmap_vis.max()-fmap_vis.min())) + H_vis, W_vis = fmap_vis.shape[1:] + fmap_vnorm = fmap_vnorm.reshape(fmap_vnorm.shape[0], + -1).permute(1,0) + fmap_pca = pca.fit_transform(fmap_vnorm.detach().cpu().numpy()) + pca = fmap_pca.reshape(H_vis,W_vis,3) + plt.imsave(save_dir, + ( + (pca-pca.min())/ + (pca.max()-pca.min()) + )) + + + # debug=False + # if debug==True: + # pcd_idx = 60 + # vis_PCA(fmapYZ[0,:1], "./yz.png") + # vis_PCA(fmapXZ[0,:1], "./xz.png") + # vis_PCA(fmaps[0,:1], "./xy.png") + # vis_PCA(fmaps[0,-1:], "./xy_.png") + # fxy_q = fxy[0,0,pcd_idx:pcd_idx+1, :, None, None] + # fyz_q = fyz[0,0,pcd_idx:pcd_idx+1, :, None, None] + # fxz_q = fxz[0,0,pcd_idx:pcd_idx+1, :, None, None] + # corr_map = (fxy_q*fmaps[0,-1:]).sum(dim=1) + # corr_map_yz = (fyz_q*fmapYZ[0,-1:]).sum(dim=1) + # corr_map_xz = (fxz_q*fmapXZ[0,-1:]).sum(dim=1) + # coord_last = coords[0,-1,pcd_idx:pcd_idx+1] + # coord_last_neigh = coords[0,-1, self.neigh_indx[pcd_idx]] + # depth_last = depths_dnG[-1,0] + # abs_res = (depth_last-coord_last[-1,-1]).abs() + # abs_res = (abs_res - abs_res.min())/(abs_res.max()-abs_res.min()) + # res_dp = torch.exp(-abs_res) + # enhance_corr = res_dp*corr_map + # plt.imsave("./res.png", res_dp.detach().cpu().numpy()) + # plt.imsave("./enhance_corr.png", enhance_corr[0].detach().cpu().numpy()) + # plt.imsave("./corr_map.png", corr_map[0].detach().cpu().numpy()) + # plt.imsave("./corr_map_yz.png", corr_map_yz[0].detach().cpu().numpy()) + # plt.imsave("./corr_map_xz.png", corr_map_xz[0].detach().cpu().numpy()) + # img_feat = cv2.imread("./xy.png") + # cv2.circle(img_feat, (int(coord_last[0,0]), int(coord_last[0,1])), 2, (0, 0, 255), -1) + # for p_i in coord_last_neigh: + # cv2.circle(img_feat, (int(p_i[0]), int(p_i[1])), 1, (0, 255, 0), -1) + # cv2.imwrite("./xy_coord.png", img_feat) + # import ipdb; ipdb.set_trace() \ No newline at end of file diff --git a/das/spatracker/models/core/spatracker/__init__.py b/das/spatracker/models/core/spatracker/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/das/spatracker/models/core/spatracker/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/das/spatracker/models/core/spatracker/blocks.py b/das/spatracker/models/core/spatracker/blocks.py new file mode 100644 index 0000000..f8450ee --- /dev/null +++ b/das/spatracker/models/core/spatracker/blocks.py @@ -0,0 +1,999 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast +from einops import rearrange +import collections +from functools import partial +from itertools import repeat +import torchvision.models as tvm + +from .vit.encoder import ImageEncoderViT as vitEnc +from .dpt.models import DPTEncoder +from .loftr import LocalFeatureTransformer +# from models.monoD.depth_anything.dpt import DPTHeadEnc, DPTHead + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + +class Attention(nn.Module): + def __init__(self, query_dim, context_dim=None, + num_heads=8, dim_head=48, qkv_bias=False, flash=False): + super().__init__() + inner_dim = self.inner_dim = dim_head * num_heads + context_dim = default(context_dim, query_dim) + self.scale = dim_head**-0.5 + self.heads = num_heads + self.flash = flash + + self.qkv = nn.Linear(query_dim, inner_dim*3, bias=qkv_bias) + self.proj = nn.Linear(inner_dim, query_dim) + + def forward(self, x, context=None, attn_bias=None): + B, N1, _ = x.shape + C = self.inner_dim + h = self.heads + # q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3) + # k, v = self.to_kv(context).chunk(2, dim=-1) + # context = default(context, x) + + qkv = self.qkv(x).reshape(B, N1, 3, h, C // h) + q, k, v = qkv[:,:, 0], qkv[:,:, 1], qkv[:,:, 2] + N2 = x.shape[1] + + k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3) + v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3) + q = q.reshape(B, N1, h, C // h).permute(0, 2, 1, 3) + if self.flash==False: + sim = (q @ k.transpose(-2, -1)) * self.scale + if attn_bias is not None: + sim = sim + attn_bias + attn = sim.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N1, C) + else: + input_args = [x.half().contiguous() for x in [q, k, v]] + x = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).reshape(B,N1,-1) # type: ignore + + # return self.to_out(x.float()) + return self.proj(x.float()) + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, + planes, + kernel_size=3, + padding=1, + stride=stride, + padding_mode="zeros", + ) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, padding=1, padding_mode="zeros" + ) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BasicEncoder(nn.Module): + def __init__( + self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0, + Embed3D=False + ): + super(BasicEncoder, self).__init__() + self.stride = stride + self.norm_fn = norm_fn + self.in_planes = 64 + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes) + self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(self.in_planes) + self.norm2 = nn.BatchNorm2d(output_dim * 2) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(self.in_planes) + self.norm2 = nn.InstanceNorm2d(output_dim * 2) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d( + input_dim, + self.in_planes, + kernel_size=7, + stride=2, + padding=3, + padding_mode="zeros", + ) + self.relu1 = nn.ReLU(inplace=True) + + self.shallow = False + if self.shallow: + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1) + else: + if Embed3D: + self.conv_fuse = nn.Conv2d(64+63, + self.in_planes, kernel_size=3, padding=1) + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + self.layer4 = self._make_layer(128, stride=2) + self.conv2 = nn.Conv2d( + 128 + 128 + 96 + 64, + output_dim * 2, + kernel_size=3, + padding=1, + padding_mode="zeros", + ) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", + nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x, feat_PE=None): + _, _, H, W = x.shape + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + if self.shallow: + a = self.layer1(x) + b = self.layer2(a) + c = self.layer3(b) + a = F.interpolate( + a, + (H // self.stride, W // self.stride), + mode="bilinear", + align_corners=True, + ) + b = F.interpolate( + b, + (H // self.stride, W // self.stride), + mode="bilinear", + align_corners=True, + ) + c = F.interpolate( + c, + (H // self.stride, W // self.stride), + mode="bilinear", + align_corners=True, + ) + x = self.conv2(torch.cat([a, b, c], dim=1)) + else: + if feat_PE is not None: + x = self.conv_fuse(torch.cat([x, feat_PE], dim=1)) + a = self.layer1(x) + else: + a = self.layer1(x) + b = self.layer2(a) + c = self.layer3(b) + d = self.layer4(c) + a = F.interpolate( + a, + (H // self.stride, W // self.stride), + mode="bilinear", + align_corners=True, + ) + b = F.interpolate( + b, + (H // self.stride, W // self.stride), + mode="bilinear", + align_corners=True, + ) + c = F.interpolate( + c, + (H // self.stride, W // self.stride), + mode="bilinear", + align_corners=True, + ) + d = F.interpolate( + d, + (H // self.stride, W // self.stride), + mode="bilinear", + align_corners=True, + ) + x = self.conv2(torch.cat([a, b, c, d], dim=1)) + x = self.norm2(x) + x = self.relu2(x) + x = self.conv3(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + return x + +class VitEncoder(nn.Module): + def __init__(self, input_dim=4, output_dim=128, stride=4): + super(VitEncoder, self).__init__() + self.vit = vitEnc(img_size=512, + depth=6, num_heads=8, in_chans=input_dim, + out_chans=output_dim,embed_dim=384).cuda() + self.stride = stride + def forward(self, x): + T, C, H, W = x.shape + x_resize = F.interpolate(x.view(-1, C, H, W), size=(512, 512), + mode='bilinear', align_corners=False) + x_resize = self.vit(x_resize) + x = F.interpolate(x_resize, size=(H//self.stride, W//self.stride), + mode='bilinear', align_corners=False) + return x + +class DPTEnc(nn.Module): + def __init__(self, input_dim=3, output_dim=128, stride=2): + super(DPTEnc, self).__init__() + self.dpt = DPTEncoder() + self.stride = stride + def forward(self, x): + T, C, H, W = x.shape + x = (x-0.5)/0.5 + x_resize = F.interpolate(x.view(-1, C, H, W), size=(384, 384), + mode='bilinear', align_corners=False) + x_resize = self.dpt(x_resize) + x = F.interpolate(x_resize, size=(H//self.stride, W//self.stride), + mode='bilinear', align_corners=False) + return x + +# class DPT_DINOv2(nn.Module): +# def __init__(self, encoder='vits', features=64, out_channels=[48, 96, 192, 384], +# use_bn=True, use_clstoken=False, localhub=True, stride=2, enc_only=True): +# super(DPT_DINOv2, self).__init__() +# self.stride = stride +# self.enc_only = enc_only +# assert encoder in ['vits', 'vitb', 'vitl'] + +# if localhub: +# self.pretrained = torch.hub.load('models/torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False) +# else: +# self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder)) + +# state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vits14_pretrain.pth") +# self.pretrained.load_state_dict(state_dict, strict=True) +# self.pretrained.requires_grad_(False) +# dim = self.pretrained.blocks[0].attn.qkv.in_features +# if enc_only == True: +# out_channels=[128, 128, 128, 128] + +# self.DPThead = DPTHeadEnc(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + + +# def forward(self, x): +# mean_ = torch.tensor([0.485, 0.456, 0.406], +# device=x.device).view(1, 3, 1, 1) +# std_ = torch.tensor([0.229, 0.224, 0.225], +# device=x.device).view(1, 3, 1, 1) +# x = (x+1)/2 +# x = (x - mean_)/std_ +# h, w = x.shape[-2:] +# h_re, w_re = 560, 560 +# x_resize = F.interpolate(x, size=(h_re, w_re), +# mode='bilinear', align_corners=False) +# with torch.no_grad(): +# features = self.pretrained.get_intermediate_layers(x_resize, 4, return_class_token=True) +# patch_h, patch_w = h_re // 14, w_re // 14 +# feat = self.DPThead(features, patch_h, patch_w, self.enc_only) +# feat = F.interpolate(feat, size=(h//self.stride, w//self.stride), mode="bilinear", align_corners=True) + +# return feat + + +class VGG19(nn.Module): + def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None: + super().__init__() + self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40]) + self.amp = amp + self.amp_dtype = amp_dtype + + def forward(self, x, **kwargs): + with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype): + feats = {} + scale = 1 + for layer in self.layers: + if isinstance(layer, nn.MaxPool2d): + feats[scale] = x + scale = scale*2 + x = layer(x) + return feats + +class CNNandDinov2(nn.Module): + def __init__(self, cnn_kwargs = None, amp = True, amp_dtype = torch.float16): + super().__init__() + # in case the Internet connection is not stable, please load the DINOv2 locally + self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main', + 'dinov2_{:}14'.format("vitl"), source='local', pretrained=False) + + state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth") + self.dinov2_vitl14.load_state_dict(state_dict, strict=True) + + + cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {} + self.cnn = VGG19(**cnn_kwargs) + self.amp = amp + self.amp_dtype = amp_dtype + if self.amp: + dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype) + self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP + + + def train(self, mode: bool = True): + return self.cnn.train(mode) + + def forward(self, x, upsample = False): + B,C,H,W = x.shape + feature_pyramid = self.cnn(x) + + if not upsample: + with torch.no_grad(): + if self.dinov2_vitl14[0].device != x.device: + self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype) + dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype)) + features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14) + del dinov2_features_16 + feature_pyramid[16] = features_16 + return feature_pyramid + +class Dinov2(nn.Module): + def __init__(self, amp = True, amp_dtype = torch.float16): + super().__init__() + # in case the Internet connection is not stable, please load the DINOv2 locally + self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main', + 'dinov2_{:}14'.format("vitl"), source='local', pretrained=False) + + state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth") + self.dinov2_vitl14.load_state_dict(state_dict, strict=True) + + self.amp = amp + self.amp_dtype = amp_dtype + if self.amp: + self.dinov2_vitl14 = self.dinov2_vitl14.to(self.amp_dtype) + + def forward(self, x, upsample = False): + B,C,H,W = x.shape + mean_ = torch.tensor([0.485, 0.456, 0.406], + device=x.device).view(1, 3, 1, 1) + std_ = torch.tensor([0.229, 0.224, 0.225], + device=x.device).view(1, 3, 1, 1) + x = (x+1)/2 + x = (x - mean_)/std_ + h_re, w_re = 560, 560 + x_resize = F.interpolate(x, size=(h_re, w_re), + mode='bilinear', align_corners=True) + if not upsample: + with torch.no_grad(): + dinov2_features_16 = self.dinov2_vitl14.forward_features(x_resize.to(self.amp_dtype)) + features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,h_re//14, w_re//14) + del dinov2_features_16 + features_16 = F.interpolate(features_16, size=(H//8, W//8), mode="bilinear", align_corners=True) + return features_16 + +class AttnBlock(nn.Module): + """ + A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, + flash=False, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.flash=flash + + self.attn = Attention( + hidden_size, num_heads=num_heads, qkv_bias=True, flash=flash, + **block_kwargs + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=approx_gelu, + drop=0, + ) + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + +class CrossAttnBlock(nn.Module): + def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, + flash=True, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm_context = nn.LayerNorm(hidden_size) + + self.cross_attn = Attention( + hidden_size, context_dim=context_dim, + num_heads=num_heads, qkv_bias=True, **block_kwargs, flash=flash + + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=approx_gelu, + drop=0, + ) + + def forward(self, x, context): + with autocast(): + x = x + self.cross_attn( + self.norm1(x), self.norm_context(context) + ) + x = x + self.mlp(self.norm2(x)) + return x + + +def bilinear_sampler(img, coords, mode="bilinear", mask=False): + """Wrapper for grid_sample, uses pixel coordinates""" + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=-1) + # go to 0,1 then 0,2 then -1,1 + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +class CorrBlock: + def __init__(self, fmaps, num_levels=4, radius=4, depths_dnG=None): + B, S, C, H_prev, W_prev = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H_prev, W_prev + + self.num_levels = num_levels + self.radius = radius + self.fmaps_pyramid = [] + self.depth_pyramid = [] + self.fmaps_pyramid.append(fmaps) + if depths_dnG is not None: + self.depth_pyramid.append(depths_dnG) + for i in range(self.num_levels - 1): + if depths_dnG is not None: + depths_dnG_ = depths_dnG.reshape(B * S, 1, H_prev, W_prev) + depths_dnG_ = F.avg_pool2d(depths_dnG_, 2, stride=2) + _, _, H, W = depths_dnG_.shape + depths_dnG = depths_dnG_.reshape(B, S, 1, H, W) + self.depth_pyramid.append(depths_dnG) + fmaps_ = fmaps.reshape(B * S, C, H_prev, W_prev) + fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) + _, _, H, W = fmaps_.shape + fmaps = fmaps_.reshape(B, S, C, H, W) + H_prev = H + W_prev = W + self.fmaps_pyramid.append(fmaps) + + def sample(self, coords): + r = self.radius + B, S, N, D = coords.shape + assert D == 2 + + H, W = self.H, self.W + out_pyramid = [] + for i in range(self.num_levels): + corrs = self.corrs_pyramid[i] # B, S, N, H, W + _, _, _, H, W = corrs.shape + + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to( + coords.device + ) + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl) + corrs = corrs.view(B, S, N, -1) + out_pyramid.append(corrs) + + out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2 + return out.contiguous().float() + + def corr(self, targets): + B, S, N, C = targets.shape + assert C == self.C + assert S == self.S + + fmap1 = targets + + self.corrs_pyramid = [] + for fmaps in self.fmaps_pyramid: + _, _, _, H, W = fmaps.shape + fmap2s = fmaps.view(B, S, C, H * W) + corrs = torch.matmul(fmap1, fmap2s) + corrs = corrs.view(B, S, N, H, W) + corrs = corrs / torch.sqrt(torch.tensor(C).float()) + self.corrs_pyramid.append(corrs) + + def corr_sample(self, targets, coords, coords_dp=None): + B, S, N, C = targets.shape + r = self.radius + Dim_c = (2*r+1)**2 + assert C == self.C + assert S == self.S + + out_pyramid = [] + out_pyramid_dp = [] + for i in range(self.num_levels): + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to( + coords.device + ) + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + fmaps = self.fmaps_pyramid[i] + _, _, _, H, W = fmaps.shape + fmap2s = fmaps.view(B*S, C, H, W) + if len(self.depth_pyramid)>0: + depths_dnG_i = self.depth_pyramid[i] + depths_dnG_i = depths_dnG_i.view(B*S, 1, H, W) + dnG_sample = bilinear_sampler(depths_dnG_i, coords_lvl.view(B*S,1,N*Dim_c,2)) + dp_corrs = (dnG_sample.view(B*S,N,-1) - coords_dp[0]).abs()/coords_dp[0] + out_pyramid_dp.append(dp_corrs) + fmap2s_sample = bilinear_sampler(fmap2s, coords_lvl.view(B*S,1,N*Dim_c,2)) + fmap2s_sample = fmap2s_sample.permute(0, 3, 1, 2) # B*S, N*Dim_c, C, -1 + corrs = torch.matmul(targets.reshape(B*S*N, 1, -1), fmap2s_sample.reshape(B*S*N, Dim_c, -1).permute(0, 2, 1)) + corrs = corrs / torch.sqrt(torch.tensor(C).float()) + corrs = corrs.view(B, S, N, -1) + out_pyramid.append(corrs) + + out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2 + if len(self.depth_pyramid)>0: + out_dp = torch.cat(out_pyramid_dp, dim=-1) + self.fcorrD = out_dp.contiguous().float() + else: + self.fcorrD = torch.zeros_like(out).contiguous().float() + return out.contiguous().float() + + +class EUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=12, + time_depth=12, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + vq_depth=3, + add_space_attn=True, + add_time_attn=True, + flash=True + ): + super().__init__() + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + self.flash = flash + self.flow_head = nn.Sequential( + nn.Linear(hidden_size, output_dim, bias=True), + nn.ReLU(inplace=True), + nn.Linear(output_dim, output_dim, bias=True), + nn.ReLU(inplace=True), + nn.Linear(output_dim, output_dim, bias=True) + ) + + cross_attn_kwargs = { + "d_model": 384, + "nhead": 4, + "layer_names": ['self', 'cross'] * 3, + } + self.gnn = LocalFeatureTransformer(cross_attn_kwargs) + + # Attention Modules in the temporal dimension + self.time_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, flash=flash) if add_time_attn else nn.Identity() + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, flash=flash) + for _ in range(space_depth) + ] + ) + assert len(self.time_blocks) >= len(self.space_blocks) + + # Placeholder for the rigid transformation + self.RigidProj = nn.Linear(self.hidden_size, 128, bias=True) + self.Proj = nn.Linear(self.hidden_size, 128, bias=True) + + self.se3_dec = nn.Linear(384, 3, bias=True) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + def forward(self, input_tensor, se3_feature): + """ Updating with Transformer + + Args: + input_tensor: B, N, T, C + arap_embed: B, N, T, C + """ + B, N, T, C = input_tensor.shape + x = self.input_transform(input_tensor) + tokens = x + K = 0 + j = 0 + for i in range(len(self.time_blocks)): + tokens_time = rearrange(tokens, "b n t c -> (b n) t c", b=B, t=T, n=N+K) + tokens_time = self.time_blocks[i](tokens_time) + tokens = rearrange(tokens_time, "(b n) t c -> b n t c ", b=B, t=T, n=N+K) + if self.add_space_attn and ( + i % (len(self.time_blocks) // len(self.space_blocks)) == 0 + ): + tokens_space = rearrange(tokens, "b n t c -> (b t) n c ", b=B, t=T, n=N) + tokens_space = self.space_blocks[j](tokens_space) + tokens = rearrange(tokens_space, "(b t) n c -> b n t c ", b=B, t=T, n=N) + j += 1 + + B, N, S, _ = tokens.shape + feat0, feat1 = self.gnn(tokens.view(B*N*S, -1)[None,...], se3_feature[None, ...]) + + so3 = F.tanh(self.se3_dec(feat0.view(B*N*S, -1)[None,...].view(B, N, S, -1))/100) + flow = self.flow_head(feat0.view(B,N,S,-1)) + + return flow, _, _, feat1, so3 + + +class FusionFormer(nn.Module): + """ + Fuse the feature tracks info with the low rank motion tokens + """ + def __init__( + self, + d_model=64, + nhead=8, + attn_iters=4, + mlp_ratio=4.0, + flash=False, + input_dim=35, + output_dim=384+3, + ): + super().__init__() + self.flash = flash + self.in_proj = nn.ModuleList( + [ + nn.Linear(input_dim, d_model) + for _ in range(2) + ] + ) + self.out_proj = nn.Linear(d_model, output_dim, bias=True) + self.time_blocks = nn.ModuleList( + [ + CrossAttnBlock(d_model, d_model, nhead, mlp_ratio=mlp_ratio) + for _ in range(attn_iters) + ] + ) + self.space_blocks = nn.ModuleList( + [ + AttnBlock(d_model, nhead, mlp_ratio=mlp_ratio, flash=self.flash) + for _ in range(attn_iters) + ] + ) + + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + self.out_proj.weight.data.fill_(0) + self.out_proj.bias.data.fill_(0) + + def forward(self, x, token_cls): + """ Fuse the feature tracks info with the low rank motion tokens + + Args: + x: B, S, N, C + Traj_whole: B T N C + + """ + B, S, N, C = x.shape + _, T, _, _ = token_cls.shape + x = self.in_proj[0](x) + token_cls = self.in_proj[1](token_cls) + token_cls = rearrange(token_cls, 'b t n c -> (b n) t c') + + for i in range(len(self.space_blocks)): + x = rearrange(x, 'b s n c -> (b n) s c') + x = self.time_blocks[i](x, token_cls) + x = self.space_blocks[i](x.permute(1,0,2)) + x = rearrange(x, '(b s) n c -> b s n c', b=B, s=S, n=N) + + x = self.out_proj(x) + delta_xyz = x[..., :3] + feat_traj = x[..., 3:] + return delta_xyz, feat_traj + +class Lie(): + """ + Lie algebra for SO(3) and SE(3) operations in PyTorch + """ + + def so3_to_SO3(self,w): # [...,3] + wx = self.skew_symmetric(w) + theta = w.norm(dim=-1)[...,None,None] + I = torch.eye(3,device=w.device,dtype=torch.float32) + A = self.taylor_A(theta) + B = self.taylor_B(theta) + R = I+A*wx+B*wx@wx + return R + + def SO3_to_so3(self,R,eps=1e-7): # [...,3,3] + trace = R[...,0,0]+R[...,1,1]+R[...,2,2] + theta = ((trace-1)/2).clamp(-1+eps,1-eps).acos_()[...,None,None]%np.pi # ln(R) will explode if theta==pi + lnR = 1/(2*self.taylor_A(theta)+1e-8)*(R-R.transpose(-2,-1)) # FIXME: wei-chiu finds it weird + w0,w1,w2 = lnR[...,2,1],lnR[...,0,2],lnR[...,1,0] + w = torch.stack([w0,w1,w2],dim=-1) + return w + + def se3_to_SE3(self,wu): # [...,3] + w,u = wu.split([3,3],dim=-1) + wx = self.skew_symmetric(w) + theta = w.norm(dim=-1)[...,None,None] + I = torch.eye(3,device=w.device,dtype=torch.float32) + A = self.taylor_A(theta) + B = self.taylor_B(theta) + C = self.taylor_C(theta) + R = I+A*wx+B*wx@wx + V = I+B*wx+C*wx@wx + Rt = torch.cat([R,(V@u[...,None])],dim=-1) + return Rt + + def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4] + R,t = Rt.split([3,1],dim=-1) + w = self.SO3_to_so3(R) + wx = self.skew_symmetric(w) + theta = w.norm(dim=-1)[...,None,None] + I = torch.eye(3,device=w.device,dtype=torch.float32) + A = self.taylor_A(theta) + B = self.taylor_B(theta) + invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx + u = (invV@t)[...,0] + wu = torch.cat([w,u],dim=-1) + return wu + + def skew_symmetric(self,w): + w0,w1,w2 = w.unbind(dim=-1) + O = torch.zeros_like(w0) + wx = torch.stack([torch.stack([O,-w2,w1],dim=-1), + torch.stack([w2,O,-w0],dim=-1), + torch.stack([-w1,w0,O],dim=-1)],dim=-2) + return wx + + def taylor_A(self,x,nth=10): + # Taylor expansion of sin(x)/x + ans = torch.zeros_like(x) + denom = 1. + for i in range(nth+1): + if i>0: denom *= (2*i)*(2*i+1) + ans = ans+(-1)**i*x**(2*i)/denom + return ans + def taylor_B(self,x,nth=10): + # Taylor expansion of (1-cos(x))/x**2 + ans = torch.zeros_like(x) + denom = 1. + for i in range(nth+1): + denom *= (2*i+1)*(2*i+2) + ans = ans+(-1)**i*x**(2*i)/denom + return ans + def taylor_C(self,x,nth=10): + # Taylor expansion of (x-sin(x))/x**3 + ans = torch.zeros_like(x) + denom = 1. + for i in range(nth+1): + denom *= (2*i+2)*(2*i+3) + ans = ans+(-1)**i*x**(2*i)/denom + return ans + + + +def pix2cam(coords, + intr): + """ + Args: + coords: [B, T, N, 3] + intr: [B, T, 3, 3] + """ + coords=coords.detach() + B, S, N, _, = coords.shape + xy_src = coords.reshape(B*S*N, 3) + intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B*S*N, 3, 3) + xy_src = torch.cat([xy_src[..., :2], torch.ones_like(xy_src[..., :1])], dim=-1) + xyz_src = (torch.inverse(intr)@xy_src[...,None])[...,0] + dp_pred = coords[..., 2] + xyz_src_ = (xyz_src*(dp_pred.reshape(S*N, 1))) + xyz_src_ = xyz_src_.reshape(B, S, N, 3) + return xyz_src_ + +def cam2pix(coords, + intr): + """ + Args: + coords: [B, T, N, 3] + intr: [B, T, 3, 3] + """ + coords=coords.detach() + B, S, N, _, = coords.shape + xy_src = coords.reshape(B*S*N, 3).clone() + intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B*S*N, 3, 3) + xy_src = xy_src / (xy_src[..., 2:]+1e-5) + xyz_src = (intr@xy_src[...,None])[...,0] + dp_pred = coords[..., 2] + xyz_src[...,2] *= dp_pred.reshape(S*N) + xyz_src = xyz_src.reshape(B, S, N, 3) + return xyz_src + +def edgeMat(traj3d): + """ + Args: + traj3d: [B, T, N, 3] + """ + B, T, N, _ = traj3d.shape + traj3d = traj3d + traj3d = traj3d.view(B, T, N, 3) + traj3d = traj3d[..., None, :] - traj3d[..., None, :, :] # B, T, N, N, 3 + edgeMat = traj3d.norm(dim=-1) # B, T, N, N + return edgeMat \ No newline at end of file diff --git a/das/spatracker/models/core/spatracker/dpt/__init__.py b/das/spatracker/models/core/spatracker/dpt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/das/spatracker/models/core/spatracker/dpt/base_model.py b/das/spatracker/models/core/spatracker/dpt/base_model.py new file mode 100644 index 0000000..5c2e0e9 --- /dev/null +++ b/das/spatracker/models/core/spatracker/dpt/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device("cpu")) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/das/spatracker/models/core/spatracker/dpt/blocks.py b/das/spatracker/models/core/spatracker/dpt/blocks.py new file mode 100644 index 0000000..1e40674 --- /dev/null +++ b/das/spatracker/models/core/spatracker/dpt/blocks.py @@ -0,0 +1,394 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, + _make_pretrained_vit_tiny +) + + +def _make_encoder( + backbone, + features, + use_pretrained, + groups=1, + expand=False, + exportable=True, + hooks=None, + use_vit_only=False, + use_readout="ignore", + enable_attention_hooks=False, +): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch( + [256, 512, 1024, 2048], features, groups=groups, expand=expand + ) # efficientnet_lite3 + elif backbone == "vit_tiny_r_s16_p8_384": + pretrained = _make_pretrained_vit_tiny( + use_pretrained, + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand == True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + return scratch + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output diff --git a/das/spatracker/models/core/spatracker/dpt/midas_net.py b/das/spatracker/models/core/spatracker/dpt/midas_net.py new file mode 100644 index 0000000..34d6d7e --- /dev/null +++ b/das/spatracker/models/core/spatracker/dpt/midas_net.py @@ -0,0 +1,77 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet_large(BaseModel): + """Network for monocular depth estimation.""" + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_large, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder( + backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained + ) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/das/spatracker/models/core/spatracker/dpt/models.py b/das/spatracker/models/core/spatracker/dpt/models.py new file mode 100644 index 0000000..2784859 --- /dev/null +++ b/das/spatracker/models/core/spatracker/dpt/models.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=True, + enable_attention_hooks=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + "vit_tiny_r_s16_p8_384": [0, 1, 2, 3], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + enable_attention_hooks=enable_attention_hooks, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + self.proj_out = nn.Sequential( + nn.Conv2d( + 256+512+384+384, + 256, + kernel_size=3, + padding=1, + padding_mode="zeros", + ), + nn.BatchNorm2d(128 * 2), + nn.ReLU(True), + nn.Conv2d( + 128 * 2, + 128, + kernel_size=3, + padding=1, + padding_mode="zeros", + ) + ) + + + def forward(self, x, only_enc=False): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + if only_enc: + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + a = (layer_1) + b = ( + F.interpolate( + layer_2, + scale_factor=2, + mode="bilinear", + align_corners=True, + ) + ) + c = ( + F.interpolate( + layer_3, + scale_factor=8, + mode="bilinear", + align_corners=True, + ) + ) + d = ( + F.interpolate( + layer_4, + scale_factor=16, + mode="bilinear", + align_corners=True, + ) + ) + x = self.proj_out(torch.cat([a, b, c, d], dim=1)) + return x + else: + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + _,_,H_out,W_out = path_1.size() + path_2_up = F.interpolate(path_2, size=(H_out,W_out), mode="bilinear", align_corners=True) + path_3_up = F.interpolate(path_3, size=(H_out,W_out), mode="bilinear", align_corners=True) + path_4_up = F.interpolate(path_4, size=(H_out,W_out), mode="bilinear", align_corners=True) + + out = self.scratch.output_conv(path_1+path_2_up+path_3_up+path_4_up) + + return out + + +class DPTDepthModel(DPT): + def __init__( + self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs + ): + features = kwargs["features"] if "features" in kwargs else 256 + + self.scale = scale + self.shift = shift + self.invert = invert + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + inv_depth = super().forward(x).squeeze(dim=1) + + if self.invert: + depth = self.scale * inv_depth + self.shift + depth[depth < 1e-8] = 1e-8 + depth = 1.0 / depth + return depth + else: + return inv_depth + +class DPTEncoder(DPT): + def __init__( + self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs + ): + features = kwargs["features"] if "features" in kwargs else 256 + + self.scale = scale + self.shift = shift + + head = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + features = super().forward(x, only_enc=True).squeeze(dim=1) + + return features + + +class DPTSegmentationModel(DPT): + def __init__(self, num_classes, path=None, **kwargs): + + features = kwargs["features"] if "features" in kwargs else 256 + + kwargs["use_bn"] = True + + head = nn.Sequential( + nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(features), + nn.ReLU(True), + nn.Dropout(0.1, False), + nn.Conv2d(features, num_classes, kernel_size=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + ) + + super().__init__(head, **kwargs) + + self.auxlayer = nn.Sequential( + nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(features), + nn.ReLU(True), + nn.Dropout(0.1, False), + nn.Conv2d(features, num_classes, kernel_size=1), + ) + + if path is not None: + self.load(path) diff --git a/das/spatracker/models/core/spatracker/dpt/transforms.py b/das/spatracker/models/core/spatracker/dpt/transforms.py new file mode 100644 index 0000000..399adbc --- /dev/null +++ b/das/spatracker/models/core/spatracker/dpt/transforms.py @@ -0,0 +1,231 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height).""" + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std.""" + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input.""" + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/das/spatracker/models/core/spatracker/dpt/vit.py b/das/spatracker/models/core/spatracker/dpt/vit.py new file mode 100644 index 0000000..0a0b108 --- /dev/null +++ b/das/spatracker/models/core/spatracker/dpt/vit.py @@ -0,0 +1,596 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +attention = {} + + +def get_attention(name): + def hook(module, input, output): + x = input[0] + B, N, C = x.shape + qkv = ( + module.qkv(x) + .reshape(B, N, 3, module.num_heads, C // module.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * module.scale + + attn = attn.softmax(dim=-1) # [:,:,1,1:] + attention[name] = attn + + return hook + + +def get_mean_attention_map(attn, token, shape): + attn = attn[:, :, token, 1:] + attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float() + attn = torch.nn.functional.interpolate( + attn, size=shape[2:], mode="bicubic", align_corners=False + ).squeeze(0) + + all_attn = torch.mean(attn, 0) + + return all_attn + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + enable_attention_hooks=False, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + if enable_attention_hooks: + pretrained.model.blocks[hooks[0]].attn.register_forward_hook( + get_attention("attn_1") + ) + pretrained.model.blocks[hooks[1]].attn.register_forward_hook( + get_attention("attn_2") + ) + pretrained.model.blocks[hooks[2]].attn.register_forward_hook( + get_attention("attn_3") + ) + pretrained.model.blocks[hooks[3]].attn.register_forward_hook( + get_attention("attn_4") + ) + pretrained.attention = attention + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=384, + use_vit_only=False, + use_readout="ignore", + start_index=1, + enable_attention_hooks=False, +): + pretrained = nn.Module() + pretrained.model = model + pretrained.model.patch_size = [32, 32] + ps = pretrained.model.patch_size[0] + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + if enable_attention_hooks: + pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1")) + pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2")) + pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3")) + pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4")) + pretrained.attention = attention + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [32, 32] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, + use_readout="ignore", + hooks=None, + use_vit_only=False, + enable_attention_hooks=False, +): + # model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + # model = timm.create_model("vit_tiny_r_s16_p8_384", pretrained=pretrained) + model = timm.create_model("vit_small_r26_s32_384", pretrained=pretrained) + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[128, 256, 384, 384], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + +def _make_pretrained_vit_tiny( + pretrained, + use_readout="ignore", + hooks=None, + use_vit_only=False, + enable_attention_hooks=False, +): + # model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + model = timm.create_model("vit_tiny_r_s16_p8_384", pretrained=pretrained) + import ipdb; ipdb.set_trace() + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_tiny_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + +def _make_pretrained_vitl16_384( + pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False +): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + + +def _make_pretrained_vitb16_384( + pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False +): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + + +def _make_pretrained_deitb16_384( + pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False +): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + enable_attention_hooks=enable_attention_hooks, + ) + + +def _make_pretrained_deitb16_distil_384( + pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False +): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + enable_attention_hooks=enable_attention_hooks, + ) diff --git a/das/spatracker/models/core/spatracker/feature_net.py b/das/spatracker/models/core/spatracker/feature_net.py new file mode 100644 index 0000000..cb0cb12 --- /dev/null +++ b/das/spatracker/models/core/spatracker/feature_net.py @@ -0,0 +1,916 @@ +""" + Adapted from ConvONet + https://github.com/autonomousvision/convolutional_occupancy_networks/blob/838bea5b2f1314f2edbb68d05ebb0db49f1f3bd2/src/encoder/pointnet.py#L1 +""" + + +import torch +import torch.nn as nn +import torch.nn.functional as F +# from torch_scatter import scatter_mean, scatter_max +from .unet import UNet +from ..model_utils import ( + vis_PCA +) +from einops import rearrange +import numpy as np + +def compute_iou(occ1, occ2): + ''' Computes the Intersection over Union (IoU) value for two sets of + occupancy values. + + Args: + occ1 (tensor): first set of occupancy values + occ2 (tensor): second set of occupancy values + ''' + occ1 = np.asarray(occ1) + occ2 = np.asarray(occ2) + + # Put all data in second dimension + # Also works for 1-dimensional data + if occ1.ndim >= 2: + occ1 = occ1.reshape(occ1.shape[0], -1) + if occ2.ndim >= 2: + occ2 = occ2.reshape(occ2.shape[0], -1) + + # Convert to boolean values + occ1 = (occ1 >= 0.5) + occ2 = (occ2 >= 0.5) + + # Compute IOU + area_union = (occ1 | occ2).astype(np.float32).sum(axis=-1) + area_intersect = (occ1 & occ2).astype(np.float32).sum(axis=-1) + + iou = (area_intersect / area_union) + + return iou + + +def chamfer_distance(points1, points2, use_kdtree=True, give_id=False): + ''' Returns the chamfer distance for the sets of points. + + Args: + points1 (numpy array): first point set + points2 (numpy array): second point set + use_kdtree (bool): whether to use a kdtree + give_id (bool): whether to return the IDs of nearest points + ''' + if use_kdtree: + return chamfer_distance_kdtree(points1, points2, give_id=give_id) + else: + return chamfer_distance_naive(points1, points2) + + +def chamfer_distance_naive(points1, points2): + ''' Naive implementation of the Chamfer distance. + + Args: + points1 (numpy array): first point set + points2 (numpy array): second point set + ''' + assert(points1.size() == points2.size()) + batch_size, T, _ = points1.size() + + points1 = points1.view(batch_size, T, 1, 3) + points2 = points2.view(batch_size, 1, T, 3) + + distances = (points1 - points2).pow(2).sum(-1) + + chamfer1 = distances.min(dim=1)[0].mean(dim=1) + chamfer2 = distances.min(dim=2)[0].mean(dim=1) + + chamfer = chamfer1 + chamfer2 + return chamfer + + +def chamfer_distance_kdtree(points1, points2, give_id=False): + ''' KD-tree based implementation of the Chamfer distance. + + Args: + points1 (numpy array): first point set + points2 (numpy array): second point set + give_id (bool): whether to return the IDs of the nearest points + ''' + # Points have size batch_size x T x 3 + batch_size = points1.size(0) + + # First convert points to numpy + points1_np = points1.detach().cpu().numpy() + points2_np = points2.detach().cpu().numpy() + + # Get list of nearest neighbors indieces + idx_nn_12, _ = get_nearest_neighbors_indices_batch(points1_np, points2_np) + idx_nn_12 = torch.LongTensor(idx_nn_12).to(points1.device) + # Expands it as batch_size x 1 x 3 + idx_nn_12_expand = idx_nn_12.view(batch_size, -1, 1).expand_as(points1) + + # Get list of nearest neighbors indieces + idx_nn_21, _ = get_nearest_neighbors_indices_batch(points2_np, points1_np) + idx_nn_21 = torch.LongTensor(idx_nn_21).to(points1.device) + # Expands it as batch_size x T x 3 + idx_nn_21_expand = idx_nn_21.view(batch_size, -1, 1).expand_as(points2) + + # Compute nearest neighbors in points2 to points in points1 + # points_12[i, j, k] = points2[i, idx_nn_12_expand[i, j, k], k] + points_12 = torch.gather(points2, dim=1, index=idx_nn_12_expand) + + # Compute nearest neighbors in points1 to points in points2 + # points_21[i, j, k] = points2[i, idx_nn_21_expand[i, j, k], k] + points_21 = torch.gather(points1, dim=1, index=idx_nn_21_expand) + + # Compute chamfer distance + chamfer1 = (points1 - points_12).pow(2).sum(2).mean(1) + chamfer2 = (points2 - points_21).pow(2).sum(2).mean(1) + + # Take sum + chamfer = chamfer1 + chamfer2 + + # If required, also return nearest neighbors + if give_id: + return chamfer1, chamfer2, idx_nn_12, idx_nn_21 + + return chamfer + + +def get_nearest_neighbors_indices_batch(points_src, points_tgt, k=1): + ''' Returns the nearest neighbors for point sets batchwise. + + Args: + points_src (numpy array): source points + points_tgt (numpy array): target points + k (int): number of nearest neighbors to return + ''' + indices = [] + distances = [] + + for (p1, p2) in zip(points_src, points_tgt): + raise NotImplementedError() + # kdtree = KDTree(p2) + dist, idx = kdtree.query(p1, k=k) + indices.append(idx) + distances.append(dist) + + return indices, distances + + +def make_3d_grid(bb_min, bb_max, shape): + ''' Makes a 3D grid. + + Args: + bb_min (tuple): bounding box minimum + bb_max (tuple): bounding box maximum + shape (tuple): output shape + ''' + size = shape[0] * shape[1] * shape[2] + + pxs = torch.linspace(bb_min[0], bb_max[0], shape[0]) + pys = torch.linspace(bb_min[1], bb_max[1], shape[1]) + pzs = torch.linspace(bb_min[2], bb_max[2], shape[2]) + + pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size) + pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size) + pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size) + p = torch.stack([pxs, pys, pzs], dim=1) + + return p + + +def transform_points(points, transform): + ''' Transforms points with regard to passed camera information. + + Args: + points (tensor): points tensor + transform (tensor): transformation matrices + ''' + assert(points.size(2) == 3) + assert(transform.size(1) == 3) + assert(points.size(0) == transform.size(0)) + + if transform.size(2) == 4: + R = transform[:, :, :3] + t = transform[:, :, 3:] + points_out = points @ R.transpose(1, 2) + t.transpose(1, 2) + elif transform.size(2) == 3: + K = transform + points_out = points @ K.transpose(1, 2) + + return points_out + + +def b_inv(b_mat): + ''' Performs batch matrix inversion. + + Arguments: + b_mat: the batch of matrices that should be inverted + ''' + + eye = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat) + b_inv, _ = torch.gesv(eye, b_mat) + return b_inv + +def project_to_camera(points, transform): + ''' Projects points to the camera plane. + + Args: + points (tensor): points tensor + transform (tensor): transformation matrices + ''' + p_camera = transform_points(points, transform) + p_camera = p_camera[..., :2] / p_camera[..., 2:] + return p_camera + + +def fix_Rt_camera(Rt, loc, scale): + ''' Fixes Rt camera matrix. + + Args: + Rt (tensor): Rt camera matrix + loc (tensor): location + scale (float): scale + ''' + # Rt is B x 3 x 4 + # loc is B x 3 and scale is B + batch_size = Rt.size(0) + R = Rt[:, :, :3] + t = Rt[:, :, 3:] + + scale = scale.view(batch_size, 1, 1) + R_new = R * scale + t_new = t + R @ loc.unsqueeze(2) + + Rt_new = torch.cat([R_new, t_new], dim=2) + + assert(Rt_new.size() == (batch_size, 3, 4)) + return Rt_new + +def normalize_coordinate(p, padding=0.1, plane='xz'): + ''' Normalize coordinate to [0, 1] for unit cube experiments + + Args: + p (tensor): point + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + plane (str): plane feature type, ['xz', 'xy', 'yz'] + ''' + # breakpoint() + if plane == 'xz': + xy = p[:, :, [0, 2]] + elif plane =='xy': + xy = p[:, :, [0, 1]] + else: + xy = p[:, :, [1, 2]] + + xy = torch.clamp(xy, min=1e-6, max=1. - 1e-6) + + # xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5) + # xy_new = xy_new + 0.5 # range (0, 1) + + # # f there are outliers out of the range + # if xy_new.max() >= 1: + # xy_new[xy_new >= 1] = 1 - 10e-6 + # if xy_new.min() < 0: + # xy_new[xy_new < 0] = 0.0 + # xy_new = (xy + 1.) / 2. + return xy + +def normalize_3d_coordinate(p, padding=0.1): + ''' Normalize coordinate to [0, 1] for unit cube experiments. + Corresponds to our 3D model + + Args: + p (tensor): point + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + ''' + + p_nor = p / (1 + padding + 10e-4) # (-0.5, 0.5) + p_nor = p_nor + 0.5 # range (0, 1) + # f there are outliers out of the range + if p_nor.max() >= 1: + p_nor[p_nor >= 1] = 1 - 10e-4 + if p_nor.min() < 0: + p_nor[p_nor < 0] = 0.0 + return p_nor + +def normalize_coord(p, vol_range, plane='xz'): + ''' Normalize coordinate to [0, 1] for sliding-window experiments + + Args: + p (tensor): point + vol_range (numpy array): volume boundary + plane (str): feature type, ['xz', 'xy', 'yz'] - canonical planes; ['grid'] - grid volume + ''' + p[:, 0] = (p[:, 0] - vol_range[0][0]) / (vol_range[1][0] - vol_range[0][0]) + p[:, 1] = (p[:, 1] - vol_range[0][1]) / (vol_range[1][1] - vol_range[0][1]) + p[:, 2] = (p[:, 2] - vol_range[0][2]) / (vol_range[1][2] - vol_range[0][2]) + + if plane == 'xz': + x = p[:, [0, 2]] + elif plane =='xy': + x = p[:, [0, 1]] + elif plane =='yz': + x = p[:, [1, 2]] + else: + x = p + return x + +def coordinate2index(x, reso, coord_type='2d'): + ''' Normalize coordinate to [0, 1] for unit cube experiments. + Corresponds to our 3D model + + Args: + x (tensor): coordinate + reso (int): defined resolution + coord_type (str): coordinate type + ''' + x = (x * reso).long() + if coord_type == '2d': # plane + index = x[:, :, 0] + reso * x[:, :, 1] + elif coord_type == '3d': # grid + index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2]) + index = index[:, None, :] + return index + +def coord2index(p, vol_range, reso=None, plane='xz'): + ''' Normalize coordinate to [0, 1] for sliding-window experiments. + Corresponds to our 3D model + + Args: + p (tensor): points + vol_range (numpy array): volume boundary + reso (int): defined resolution + plane (str): feature type, ['xz', 'xy', 'yz'] - canonical planes; ['grid'] - grid volume + ''' + # normalize to [0, 1] + x = normalize_coord(p, vol_range, plane=plane) + + if isinstance(x, np.ndarray): + x = np.floor(x * reso).astype(int) + else: #* pytorch tensor + x = (x * reso).long() + + if x.shape[1] == 2: + index = x[:, 0] + reso * x[:, 1] + index[index > reso**2] = reso**2 + elif x.shape[1] == 3: + index = x[:, 0] + reso * (x[:, 1] + reso * x[:, 2]) + index[index > reso**3] = reso**3 + + return index[None] + +def update_reso(reso, depth): + ''' Update the defined resolution so that UNet can process. + + Args: + reso (int): defined resolution + depth (int): U-Net number of layers + ''' + base = 2**(int(depth) - 1) + if ~(reso / base).is_integer(): # when this is not integer, U-Net dimension error + for i in range(base): + if ((reso + i) / base).is_integer(): + reso = reso + i + break + return reso + +def decide_total_volume_range(query_vol_metric, recep_field, unit_size, unet_depth): + ''' Update the defined resolution so that UNet can process. + + Args: + query_vol_metric (numpy array): query volume size + recep_field (int): defined the receptive field for U-Net + unit_size (float): the defined voxel size + unet_depth (int): U-Net number of layers + ''' + reso = query_vol_metric / unit_size + recep_field - 1 + reso = update_reso(int(reso), unet_depth) # make sure input reso can be processed by UNet + input_vol_metric = reso * unit_size + p_c = np.array([0.0, 0.0, 0.0]).astype(np.float32) + lb_input_vol, ub_input_vol = p_c - input_vol_metric/2, p_c + input_vol_metric/2 + lb_query_vol, ub_query_vol = p_c - query_vol_metric/2, p_c + query_vol_metric/2 + input_vol = [lb_input_vol, ub_input_vol] + query_vol = [lb_query_vol, ub_query_vol] + + # handle the case when resolution is too large + if reso > 10000: + reso = 1 + + return input_vol, query_vol, reso + +def add_key(base, new, base_name, new_name, device=None): + ''' Add new keys to the given input + + Args: + base (tensor): inputs + new (tensor): new info for the inputs + base_name (str): name for the input + new_name (str): name for the new info + device (device): pytorch device + ''' + if (new is not None) and (isinstance(new, dict)): + if device is not None: + for key in new.keys(): + new[key] = new[key].to(device) + base = {base_name: base, + new_name: new} + return base + +class map2local(object): + ''' Add new keys to the given input + + Args: + s (float): the defined voxel size + pos_encoding (str): method for the positional encoding, linear|sin_cos + ''' + def __init__(self, s, pos_encoding='linear'): + super().__init__() + self.s = s + self.pe = positional_encoding(basis_function=pos_encoding) + + def __call__(self, p): + p = torch.remainder(p, self.s) / self.s # always possitive + # p = torch.fmod(p, self.s) / self.s # same sign as input p! + p = self.pe(p) + return p + +class positional_encoding(object): + ''' Positional Encoding (presented in NeRF) + + Args: + basis_function (str): basis function + ''' + def __init__(self, basis_function='sin_cos'): + super().__init__() + self.func = basis_function + + L = 10 + freq_bands = 2.**(np.linspace(0, L-1, L)) + self.freq_bands = freq_bands * math.pi + + def __call__(self, p): + if self.func == 'sin_cos': + out = [] + p = 2.0 * p - 1.0 # chagne to the range [-1, 1] + for freq in self.freq_bands: + out.append(torch.sin(freq * p)) + out.append(torch.cos(freq * p)) + p = torch.cat(out, dim=2) + return p + +# Resnet Blocks +class ResnetBlockFC(nn.Module): + ''' Fully connected ResNet Block class. + + Args: + size_in (int): input dimension + size_out (int): output dimension + size_h (int): hidden dimension + ''' + + def __init__(self, size_in, size_out=None, size_h=None): + super().__init__() + # Attributes + if size_out is None: + size_out = size_in + + if size_h is None: + size_h = min(size_in, size_out) + + self.size_in = size_in + self.size_h = size_h + self.size_out = size_out + # Submodules + self.fc_0 = nn.Linear(size_in, size_h) + self.fc_1 = nn.Linear(size_h, size_out) + self.actvn = nn.ReLU() + + if size_in == size_out: + self.shortcut = None + else: + self.shortcut = nn.Linear(size_in, size_out, bias=False) + # Initialization + nn.init.zeros_(self.fc_1.weight) + + def forward(self, x): + net = self.fc_0(self.actvn(x)) + dx = self.fc_1(self.actvn(net)) + + if self.shortcut is not None: + x_s = self.shortcut(x) + else: + x_s = x + + return x_s + dx + + + +''' +------------------ the key model for Pointnet ---------------------------- +''' + + +class LocalSoftSplat(nn.Module): + + def __init__(self, ch=128, dim=3, hidden_dim=128, scatter_type='max', + unet=True, unet_kwargs=None, unet3d=False, unet3d_kwargs=None, + hw=None, grid_resolution=None, plane_type='xz', padding=0.1, + n_blocks=4, splat_func=None): + super().__init__() + c_dim = ch + + self.c_dim = c_dim + + self.fc_pos = nn.Linear(dim, 2*hidden_dim) + self.blocks = nn.ModuleList([ + ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks) + ]) + self.fc_c = nn.Linear(hidden_dim, c_dim) + + self.actvn = nn.ReLU() + self.hidden_dim = hidden_dim + + if unet: + self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs) + else: + self.unet = None + + # get splat func + self.splat_func = splat_func + def forward(self, img_feat, + Fxy2xz, Fxy2yz, Dz, gridxy=None): + """ + Args: + img_feat (tensor): image features + Fxy2xz (tensor): transformation matrix from xy to xz + Fxy2yz (tensor): transformation matrix from xy to yz + """ + B, T, _, H, W = img_feat.shape + fea_reshp = rearrange(img_feat, 'b t c h w -> (b h w) t c', + c=img_feat.shape[2], h=H, w=W) + + gridyz = gridxy + Fxy2yz + gridxz = gridxy + Fxy2xz + # normalize + gridyz[:, 0, ...] = (gridyz[:, 0, ...] / (H - 1) - 0.5) * 2 + gridyz[:, 1, ...] = (gridyz[:, 1, ...] / (Dz - 1) - 0.5) * 2 + gridxz[:, 0, ...] = (gridxz[:, 0, ...] / (W - 1) - 0.5) * 2 + gridxz[:, 1, ...] = (gridxz[:, 1, ...] / (Dz - 1) - 0.5) * 2 + if len(self.blocks) > 0: + net = self.fc_pos(fea_reshp) + net = self.blocks[0](net) + for block in self.blocks[1:]: + # splat and fusion + net_plane = rearrange(net, '(b h w) t c -> (b t) c h w', b=B, h=H, w=W) + + net_planeYZ = self.splat_func(net_plane, Fxy2yz, None, + strMode="avg", tenoutH=Dz, tenoutW=H) + + net_planeXZ = self.splat_func(net_plane, Fxy2xz, None, + strMode="avg", tenoutH=Dz, tenoutW=W) + + net_plane = net_plane + ( + F.grid_sample( + net_planeYZ, gridyz.permute(0,2,3,1), mode='bilinear', padding_mode='border') + + F.grid_sample( + net_planeXZ, gridxz.permute(0,2,3,1), mode='bilinear', padding_mode='border') + ) + + pooled = rearrange(net_plane, 't c h w -> (h w) t c', + c=net_plane.shape[1], h=H, w=W) + + net = torch.cat([net, pooled], dim=2) + net = block(net) + + c = self.fc_c(net) + net_plane = rearrange(c, '(b h w) t c -> (b t) c h w', b=B, h=H, w=W) + else: + net_plane = rearrange(img_feat, 'b t c h w -> (b t) c h w', + c=img_feat.shape[2], h=H, w=W) + net_planeYZ = self.splat_func(net_plane, Fxy2yz, None, + strMode="avg", tenoutH=Dz, tenoutW=H) + net_planeXZ = self.splat_func(net_plane, Fxy2xz, None, + strMode="avg", tenoutH=Dz, tenoutW=W) + + return net_plane[None], net_planeYZ[None], net_planeXZ[None] + + + +class LocalPoolPointnet(nn.Module): + ''' PointNet-based encoder network with ResNet blocks for each point. + Number of input points are fixed. + + Args: + c_dim (int): dimension of latent code c + dim (int): input points dimension + hidden_dim (int): hidden dimension of the network + scatter_type (str): feature aggregation when doing local pooling + unet (bool): weather to use U-Net + unet_kwargs (str): U-Net parameters + unet3d (bool): weather to use 3D U-Net + unet3d_kwargs (str): 3D U-Net parameters + plane_resolution (int): defined resolution for plane feature + grid_resolution (int): defined resolution for grid feature + plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + n_blocks (int): number of blocks ResNetBlockFC layers + ''' + + def __init__(self, ch=128, dim=3, hidden_dim=128, scatter_type='max', + unet=True, unet_kwargs=None, unet3d=False, unet3d_kwargs=None, + hw=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5): + super().__init__() + c_dim = ch + unet3d = False + plane_type = ['xy', 'xz', 'yz'] + plane_resolution = hw + + self.c_dim = c_dim + + self.fc_pos = nn.Linear(dim, 2*hidden_dim) + self.blocks = nn.ModuleList([ + ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks) + ]) + self.fc_c = nn.Linear(hidden_dim, c_dim) + + self.actvn = nn.ReLU() + self.hidden_dim = hidden_dim + + if unet: + self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs) + else: + self.unet = None + + if unet3d: + # self.unet3d = UNet3D(**unet3d_kwargs) + raise NotImplementedError() + else: + self.unet3d = None + + self.reso_plane = plane_resolution + self.reso_grid = grid_resolution + self.plane_type = plane_type + self.padding = padding + + if scatter_type == 'max': + self.scatter = scatter_max + elif scatter_type == 'mean': + self.scatter = scatter_mean + else: + raise ValueError('incorrect scatter type') + + def generate_plane_features(self, p, c, plane='xz'): + # acquire indices of features in plane + xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1) + index = coordinate2index(xy, self.reso_plane) + + # scatter plane features from points + fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2) + c = c.permute(0, 2, 1) # B x 512 x T + fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 + fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso) + + # process the plane features with UNet + if self.unet is not None: + fea_plane = self.unet(fea_plane) + + return fea_plane + + def generate_grid_features(self, p, c): + p_nor = normalize_3d_coordinate(p.clone(), padding=self.padding) + index = coordinate2index(p_nor, self.reso_grid, coord_type='3d') + # scatter grid features from points + fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3) + c = c.permute(0, 2, 1) + fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3 + fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) # sparce matrix (B x 512 x reso x reso) + + if self.unet3d is not None: + fea_grid = self.unet3d(fea_grid) + + return fea_grid + + def pool_local(self, xy, index, c): + bs, fea_dim = c.size(0), c.size(2) + keys = xy.keys() + + c_out = 0 + for key in keys: + # scatter plane features from points + if key == 'grid': + fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3) + else: + c_permute = c.permute(0, 2, 1) + fea = self.scatter(c_permute, index[key], dim_size=self.reso_plane**2) + if self.scatter == scatter_max: + fea = fea[0] + # gather feature back to points + fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1)) + c_out = c_out + fea + return c_out.permute(0, 2, 1) + + + def forward(self, p_input, img_feats=None): + """ + Args: + p_input (tensor): input points T 3 H W + img_feats (tensor): image features T C H W + """ + T, _, H, W = img_feats.size() + p = rearrange(p_input, 't c h w -> (h w) t c', c=3, h=H, w=W) + fea_reshp = rearrange(img_feats, 't c h w -> (h w) t c', + c=img_feats.shape[1], h=H, w=W) + + # acquire the index for each point + coord = {} + index = {} + if 'xz' in self.plane_type: + coord['xz'] = normalize_coordinate(p.clone(), plane='xz', padding=self.padding) + index['xz'] = coordinate2index(coord['xz'], self.reso_plane) + if 'xy' in self.plane_type: + coord['xy'] = normalize_coordinate(p.clone(), plane='xy', padding=self.padding) + index['xy'] = coordinate2index(coord['xy'], self.reso_plane) + if 'yz' in self.plane_type: + coord['yz'] = normalize_coordinate(p.clone(), plane='yz', padding=self.padding) + index['yz'] = coordinate2index(coord['yz'], self.reso_plane) + if 'grid' in self.plane_type: + coord['grid'] = normalize_3d_coordinate(p.clone(), padding=self.padding) + index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d') + + net = self.fc_pos(p) + fea_reshp + net = self.blocks[0](net) + for block in self.blocks[1:]: + pooled = self.pool_local(coord, index, net) + net = torch.cat([net, pooled], dim=2) + net = block(net) + + c = self.fc_c(net) + + fea = {} + + if 'grid' in self.plane_type: + fea['grid'] = self.generate_grid_features(p, c) + if 'xz' in self.plane_type: + fea['xz'] = self.generate_plane_features(p, c, plane='xz') + if 'xy' in self.plane_type: + fea['xy'] = self.generate_plane_features(p, c, plane='xy') + if 'yz' in self.plane_type: + fea['yz'] = self.generate_plane_features(p, c, plane='yz') + + ret = torch.stack([fea['xy'], fea['xz'], fea['yz']]).permute((1, 0, 2, 3, 4)) + return ret + +class PatchLocalPoolPointnet(nn.Module): + ''' PointNet-based encoder network with ResNet blocks. + First transform input points to local system based on the given voxel size. + Support non-fixed number of point cloud, but need to precompute the index + + Args: + c_dim (int): dimension of latent code c + dim (int): input points dimension + hidden_dim (int): hidden dimension of the network + scatter_type (str): feature aggregation when doing local pooling + unet (bool): weather to use U-Net + unet_kwargs (str): U-Net parameters + unet3d (bool): weather to use 3D U-Net + unet3d_kwargs (str): 3D U-Net parameters + plane_resolution (int): defined resolution for plane feature + grid_resolution (int): defined resolution for grid feature + plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + n_blocks (int): number of blocks ResNetBlockFC layers + local_coord (bool): whether to use local coordinate + pos_encoding (str): method for the positional encoding, linear|sin_cos + unit_size (float): defined voxel unit size for local system + ''' + + def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max', + unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None, + plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5, + local_coord=False, pos_encoding='linear', unit_size=0.1): + super().__init__() + self.c_dim = c_dim + + self.blocks = nn.ModuleList([ + ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks) + ]) + self.fc_c = nn.Linear(hidden_dim, c_dim) + + self.actvn = nn.ReLU() + self.hidden_dim = hidden_dim + self.reso_plane = plane_resolution + self.reso_grid = grid_resolution + self.plane_type = plane_type + self.padding = padding + + if unet: + self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs) + else: + self.unet = None + + if unet3d: + # self.unet3d = UNet3D(**unet3d_kwargs) + raise NotImplementedError() + else: + self.unet3d = None + + if scatter_type == 'max': + self.scatter = scatter_max + elif scatter_type == 'mean': + self.scatter = scatter_mean + else: + raise ValueError('incorrect scatter type') + + if local_coord: + self.map2local = map2local(unit_size, pos_encoding=pos_encoding) + else: + self.map2local = None + + if pos_encoding == 'sin_cos': + self.fc_pos = nn.Linear(60, 2*hidden_dim) + else: + self.fc_pos = nn.Linear(dim, 2*hidden_dim) + + def generate_plane_features(self, index, c): + c = c.permute(0, 2, 1) + # scatter plane features from points + if index.max() < self.reso_plane**2: + fea_plane = c.new_zeros(c.size(0), self.c_dim, self.reso_plane**2) + fea_plane = scatter_mean(c, index, out=fea_plane) # B x c_dim x reso^2 + else: + fea_plane = scatter_mean(c, index) # B x c_dim x reso^2 + if fea_plane.shape[-1] > self.reso_plane**2: # deal with outliers + fea_plane = fea_plane[:, :, :-1] + + fea_plane = fea_plane.reshape(c.size(0), self.c_dim, self.reso_plane, self.reso_plane) + + # process the plane features with UNet + if self.unet is not None: + fea_plane = self.unet(fea_plane) + + return fea_plane + + def generate_grid_features(self, index, c): + # scatter grid features from points + c = c.permute(0, 2, 1) + if index.max() < self.reso_grid**3: + fea_grid = c.new_zeros(c.size(0), self.c_dim, self.reso_grid**3) + fea_grid = scatter_mean(c, index, out=fea_grid) # B x c_dim x reso^3 + else: + fea_grid = scatter_mean(c, index) # B x c_dim x reso^3 + if fea_grid.shape[-1] > self.reso_grid**3: # deal with outliers + fea_grid = fea_grid[:, :, :-1] + fea_grid = fea_grid.reshape(c.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) + + if self.unet3d is not None: + fea_grid = self.unet3d(fea_grid) + + return fea_grid + + def pool_local(self, index, c): + bs, fea_dim = c.size(0), c.size(2) + keys = index.keys() + + c_out = 0 + for key in keys: + # scatter plane features from points + if key == 'grid': + fea = self.scatter(c.permute(0, 2, 1), index[key]) + else: + fea = self.scatter(c.permute(0, 2, 1), index[key]) + if self.scatter == scatter_max: + fea = fea[0] + # gather feature back to points + fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1)) + c_out += fea + return c_out.permute(0, 2, 1) + + + def forward(self, inputs): + p = inputs['points'] + index = inputs['index'] + + batch_size, T, D = p.size() + + if self.map2local: + pp = self.map2local(p) + net = self.fc_pos(pp) + else: + net = self.fc_pos(p) + + net = self.blocks[0](net) + for block in self.blocks[1:]: + pooled = self.pool_local(index, net) + net = torch.cat([net, pooled], dim=2) + net = block(net) + + c = self.fc_c(net) + + fea = {} + if 'grid' in self.plane_type: + fea['grid'] = self.generate_grid_features(index['grid'], c) + if 'xz' in self.plane_type: + fea['xz'] = self.generate_plane_features(index['xz'], c) + if 'xy' in self.plane_type: + fea['xy'] = self.generate_plane_features(index['xy'], c) + if 'yz' in self.plane_type: + fea['yz'] = self.generate_plane_features(index['yz'], c) + + return fea \ No newline at end of file diff --git a/das/spatracker/models/core/spatracker/loftr/__init__.py b/das/spatracker/models/core/spatracker/loftr/__init__.py new file mode 100644 index 0000000..a343f89 --- /dev/null +++ b/das/spatracker/models/core/spatracker/loftr/__init__.py @@ -0,0 +1 @@ +from .transformer import LocalFeatureTransformer diff --git a/das/spatracker/models/core/spatracker/loftr/linear_attention.py b/das/spatracker/models/core/spatracker/loftr/linear_attention.py new file mode 100644 index 0000000..61b1b85 --- /dev/null +++ b/das/spatracker/models/core/spatracker/loftr/linear_attention.py @@ -0,0 +1,81 @@ +""" +Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" +Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py +""" + +import torch +from torch.nn import Module, Dropout + + +def elu_feature_map(x): + return torch.nn.functional.elu(x) + 1 + + +class LinearAttention(Module): + def __init__(self, eps=1e-6): + super().__init__() + self.feature_map = elu_feature_map + self.eps = eps + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-Head linear attention proposed in "Transformers are RNNs" + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + Q = self.feature_map(queries) + K = self.feature_map(keys) + + # set padded position to zero + if q_mask is not None: + Q = Q * q_mask[:, :, None, None] + if kv_mask is not None: + K = K * kv_mask[:, :, None, None] + values = values * kv_mask[:, :, None, None] + + v_length = values.size(1) + values = values / v_length # prevent fp16 overflow + KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V + Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) + queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + + return queried_values.contiguous() + + +class FullAttention(Module): + def __init__(self, use_dropout=False, attention_dropout=0.1): + super().__init__() + self.use_dropout = use_dropout + self.dropout = Dropout(attention_dropout) + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + + # Compute the unnormalized attention and apply the masks + QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) + if kv_mask is not None: + QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) + + # Compute the attention and the weighted average + softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=2) + if self.use_dropout: + A = self.dropout(A) + + queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) + + return queried_values.contiguous() \ No newline at end of file diff --git a/das/spatracker/models/core/spatracker/loftr/transformer.py b/das/spatracker/models/core/spatracker/loftr/transformer.py new file mode 100644 index 0000000..2f6abe7 --- /dev/null +++ b/das/spatracker/models/core/spatracker/loftr/transformer.py @@ -0,0 +1,142 @@ +''' +modified from +https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py +''' +import torch +from torch.nn import Module, Dropout +import copy +import torch.nn as nn +import torch.nn.functional as F + + +def elu_feature_map(x): + return torch.nn.functional.elu(x) + 1 + +class FullAttention(Module): + def __init__(self, use_dropout=False, attention_dropout=0.1): + super().__init__() + self.use_dropout = use_dropout + self.dropout = Dropout(attention_dropout) + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + + # Compute the unnormalized attention and apply the masks + # QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) + # if kv_mask is not None: + # QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float(-1e12)) + # softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + # A = torch.softmax(softmax_temp * QK, dim=2) + # if self.use_dropout: + # A = self.dropout(A) + # queried_values_ = torch.einsum("nlsh,nshd->nlhd", A, values) + + # Compute the attention and the weighted average + input_args = [x.half().contiguous() for x in [queries.permute(0,2,1,3), keys.permute(0,2,1,3), values.permute(0,2,1,3)]] + queried_values = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).float() # type: ignore + + + return queried_values.contiguous() + +class TransformerEncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead,): + super(TransformerEncoderLayer, self).__init__() + + self.dim = d_model // nhead + self.nhead = nhead + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.attention = FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, L, C] + source (torch.Tensor): [N, S, C] + x_mask (torch.Tensor): [N, L] (optional) + source_mask (torch.Tensor): [N, S] (optional) + """ + bs = x.size(0) + query, key, value = x, source, source + + # multi-head attention + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=2)) + message = self.norm2(message) + + return x + message + +class LocalFeatureTransformer(nn.Module): + """A Local Feature Transformer module.""" + + def __init__(self, config): + super(LocalFeatureTransformer, self).__init__() + + self.config = config + self.d_model = config['d_model'] + self.nhead = config['nhead'] + self.layer_names = config['layer_names'] + encoder_layer = TransformerEncoderLayer(config['d_model'], config['nhead']) + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feat0, feat1, mask0=None, mask1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + + assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" + + for layer, name in zip(self.layers, self.layer_names): + if name == 'self': + feat0 = layer(feat0, feat0, mask0, mask0) + feat1 = layer(feat1, feat1, mask1, mask1) + elif name == 'cross': + feat0 = layer(feat0, feat1, mask0, mask1) + feat1 = layer(feat1, feat0, mask1, mask0) + else: + raise KeyError + + return feat0, feat1 \ No newline at end of file diff --git a/das/spatracker/models/core/spatracker/losses.py b/das/spatracker/models/core/spatracker/losses.py new file mode 100644 index 0000000..f8e0cb8 --- /dev/null +++ b/das/spatracker/models/core/spatracker/losses.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from ..model_utils import reduce_masked_mean +from .blocks import ( + pix2cam +) +from ..model_utils import ( + bilinear_sample2d +) + +EPS = 1e-6 +import torchvision.transforms.functional as TF + +sigma = 3 +x_grid = torch.arange(-7,8,1) +y_grid = torch.arange(-7,8,1) +x_grid, y_grid = torch.meshgrid(x_grid, y_grid) +gridxy = torch.stack([x_grid, y_grid], dim=-1).float() +gs_kernel = torch.exp(-torch.sum(gridxy**2, dim=-1)/(2*sigma**2)) + + +def balanced_ce_loss(pred, gt, valid=None): + total_balanced_loss = 0.0 + for j in range(len(gt)): + B, S, N = gt[j].shape + # pred and gt are the same shape + for (a, b) in zip(pred[j].size(), gt[j].size()): + assert a == b # some shape mismatch! + # if valid is not None: + for (a, b) in zip(pred[j].size(), valid[j].size()): + assert a == b # some shape mismatch! + + pos = (gt[j] > 0.95).float() + neg = (gt[j] < 0.05).float() + + label = pos * 2.0 - 1.0 + a = -label * pred[j] + b = F.relu(a) + loss = b + torch.log(torch.exp(-b) + torch.exp(a - b)) + + pos_loss = reduce_masked_mean(loss, pos * valid[j]) + neg_loss = reduce_masked_mean(loss, neg * valid[j]) + balanced_loss = pos_loss + neg_loss + total_balanced_loss += balanced_loss / float(N) + import ipdb; ipdb.set_trace() + return total_balanced_loss + + +def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8, + intr=None, trajs_g_all=None): + """Loss function defined over sequence of flow predictions""" + total_flow_loss = 0.0 + + for j in range(len(flow_gt)): + B, S, N, D = flow_gt[j].shape + # assert D == 3 + B, S1, N = vis[j].shape + B, S2, N = valids[j].shape + assert S == S1 + assert S == S2 + n_predictions = len(flow_preds[j]) + if intr is not None: + intr_i = intr[j] + flow_loss = 0.0 + for i in range(n_predictions): + i_weight = gamma ** (n_predictions - i - 1) + flow_pred = flow_preds[j][i][..., -N:, :D] + flow_gt_j = flow_gt[j].clone() + if intr is not None: + xyz_j_gt = pix2cam(flow_gt_j, intr_i) + try: + i_loss = (flow_pred - flow_gt_j).abs() # B, S, N, 3 + except: + import ipdb; ipdb.set_trace() + if D==3: + i_loss[...,2]*=30 + i_loss = torch.mean(i_loss, dim=3) # B, S, N + flow_loss += i_weight * (reduce_masked_mean(i_loss, valids[j])) + + flow_loss = flow_loss / n_predictions + total_flow_loss += flow_loss / float(N) + + + return total_flow_loss diff --git a/das/spatracker/models/core/spatracker/softsplat.py b/das/spatracker/models/core/spatracker/softsplat.py new file mode 100644 index 0000000..fc4aad5 --- /dev/null +++ b/das/spatracker/models/core/spatracker/softsplat.py @@ -0,0 +1,539 @@ +#!/usr/bin/env python + +"""The code of softsplat function is modified from: +https://github.com/sniklaus/softmax-splatting/blob/master/softsplat.py + +""" + + +import collections +import cupy +import os +import re +import torch +import typing + + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn:int): + return cupy.int32(intIn) +# end + + +def cuda_float32(fltIn:float): + return cupy.float32(fltIn) +# end + + +def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert(False) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey:str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path() + # end + + return cupy.RawKernel(objCudacache[strKey]['strKernel'], objCudacache[strKey]['strFunction']) +# end + + +########################################################## + + +def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, + tenMetric:torch.Tensor, strMode:str, tenoutH=None, tenoutW=None): + assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft']) + + if strMode == 'sum': assert(tenMetric is None) + if strMode == 'avg': assert(tenMetric is None) + if strMode.split('-')[0] == 'linear': assert(tenMetric is not None) + if strMode.split('-')[0] == 'soft': assert(tenMetric is not None) + + if strMode == 'avg': + tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1) + + elif strMode.split('-')[0] == 'linear': + tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) + + elif strMode.split('-')[0] == 'soft': + tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) + + # end + + tenOut = softsplat_func.apply(tenIn, tenFlow, tenoutH, tenoutW) + + if strMode.split('-')[0] in ['avg', 'linear', 'soft']: + tenNormalize = tenOut[:, -1:, :, :] + + if len(strMode.split('-')) == 1: + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'addeps': + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'zeroeps': + tenNormalize[tenNormalize == 0.0] = 1.0 + + elif strMode.split('-')[1] == 'clipeps': + tenNormalize = tenNormalize.clip(0.0000001, None) + + # end + tenOut = tenOut[:, :-1, :, :] / tenNormalize + # end + + return tenOut +# end + + +class softsplat_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenIn, tenFlow, H=None, W=None): + if H is None: + tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) + else: + tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], H, W]) + + if tenIn.is_cuda == True: + cuda_launch(cuda_kernel('softsplat_out', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_out( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) / SIZE_1(tenIn) ) % SIZE_0(tenIn); + const int intC = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) ) % SIZE_1(tenIn); + const int intY = ( intIndex / SIZE_3(tenIn) ) % SIZE_2(tenIn); + const int intX = ( intIndex ) % SIZE_3(tenIn); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); + } + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOut': tenOut + }))( + grid=tuple([int((tenIn.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + + elif tenIn.is_cuda != True: + assert(False) + + # end + + self.save_for_backward(tenIn, tenFlow) + + return tenOut + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenIn, tenFlow = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) + + tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None + tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None + Hgrad = None + Wgrad = None + + if tenIngrad is not None: + cuda_launch(cuda_kernel('softsplat_ingrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); + const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); + const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); + const int intX = ( intIndex ) % SIZE_3(tenIngrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltIngrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + tenIngrad[intIndex] = fltIngrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + if tenFlowgrad is not None: + cuda_launch(cuda_kernel('softsplat_flowgrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); + const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); + const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); + const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltFlowgrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = 0.0f; + {{type}} fltNortheast = 0.0f; + {{type}} fltSouthwest = 0.0f; + {{type}} fltSoutheast = 0.0f; + + if (intC == 0) { + fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); + fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); + fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); + fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); + fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); + fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); + fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { + {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; + } + } + + tenFlowgrad[intIndex] = fltFlowgrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + return tenIngrad, tenFlowgrad, Hgrad, Wgrad + # end +# end \ No newline at end of file diff --git a/das/spatracker/models/core/spatracker/spatracker.py b/das/spatracker/models/core/spatracker/spatracker.py new file mode 100644 index 0000000..c21c468 --- /dev/null +++ b/das/spatracker/models/core/spatracker/spatracker.py @@ -0,0 +1,736 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from easydict import EasyDict as edict +from einops import rearrange +from sklearn.cluster import SpectralClustering +from .blocks import Lie +import matplotlib.pyplot as plt +import cv2 + +import torch.nn.functional as F +from .blocks import ( + BasicEncoder, + CorrBlock, + EUpdateFormer, + FusionFormer, + pix2cam, + cam2pix, + edgeMat, + VitEncoder, + DPTEnc, + Dinov2 +) + +from .feature_net import ( + LocalSoftSplat +) + +from ..model_utils import ( + meshgrid2d, bilinear_sample2d, smart_cat, sample_features5d, vis_PCA +) +from ..embeddings import ( + get_2d_embedding, + get_3d_embedding, + get_1d_sincos_pos_embed_from_grid, + get_2d_sincos_pos_embed, + get_3d_sincos_pos_embed_from_grid, + Embedder_Fourier, +) +import numpy as np +from .softsplat import softsplat + +torch.manual_seed(0) +from comfy.utils import ProgressBar +from tqdm import tqdm + +def get_points_on_a_grid(grid_size, interp_shape, + grid_center=(0, 0), device="cuda"): + if grid_size == 1: + return torch.tensor([interp_shape[1] / 2, + interp_shape[0] / 2], device=device)[ + None, None + ] + + grid_y, grid_x = meshgrid2d( + 1, grid_size, grid_size, stack=False, norm=False, device=device + ) + step = interp_shape[1] // 64 + if grid_center[0] != 0 or grid_center[1] != 0: + grid_y = grid_y - grid_size / 2.0 + grid_x = grid_x - grid_size / 2.0 + grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * ( + interp_shape[0] - step * 2 + ) + grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * ( + interp_shape[1] - step * 2 + ) + + grid_y = grid_y + grid_center[0] + grid_x = grid_x + grid_center[1] + xy = torch.stack([grid_x, grid_y], dim=-1).to(device) + return xy + + +def sample_pos_embed(grid_size, embed_dim, coords): + if coords.shape[-1] == 2: + pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, + grid_size=grid_size) + pos_embed = ( + torch.from_numpy(pos_embed) + .reshape(grid_size[0], grid_size[1], embed_dim) + .float() + .unsqueeze(0) + .to(coords.device) + ) + sampled_pos_embed = bilinear_sample2d( + pos_embed.permute(0, 3, 1, 2), + coords[:, 0, :, 0], coords[:, 0, :, 1] + ) + elif coords.shape[-1] == 3: + sampled_pos_embed = get_3d_sincos_pos_embed_from_grid( + embed_dim, coords[:, :1, ...] + ).float()[:,0,...].permute(0, 2, 1) + + return sampled_pos_embed + + +class SpaTracker(nn.Module): + def __init__( + self, + S=8, + stride=8, + add_space_attn=True, + num_heads=8, + hidden_size=384, + space_depth=12, + time_depth=12, + args=edict({}) + ): + super(SpaTracker, self).__init__() + + # step1: config the arch of the model + self.args=args + # step1.1: config the default value of the model + if getattr(args, "depth_color", None) == None: + self.args.depth_color = False + if getattr(args, "if_ARAP", None) == None: + self.args.if_ARAP = True + if getattr(args, "flash_attn", None) == None: + self.args.flash_attn = True + if getattr(args, "backbone", None) == None: + self.args.backbone = "CNN" + if getattr(args, "Nblock", None) == None: + self.args.Nblock = 0 + if getattr(args, "Embed3D", None) == None: + self.args.Embed3D = True + + # step1.2: config the model parameters + self.S = S + self.stride = stride + self.hidden_dim = 256 + self.latent_dim = latent_dim = 128 + self.b_latent_dim = self.latent_dim//3 + self.corr_levels = 4 + self.corr_radius = 3 + self.add_space_attn = add_space_attn + self.lie = Lie() + + # step2: config the model components + # @Encoder + self.fnet = BasicEncoder(input_dim=3, + output_dim=self.latent_dim, norm_fn="instance", dropout=0, + stride=stride, Embed3D=False + ) + + # conv head for the tri-plane features + self.headyz = nn.Sequential( + nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1)) + + self.headxz = nn.Sequential( + nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1)) + + # @UpdateFormer + self.updateformer = EUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=456, + hidden_size=hidden_size, + num_heads=num_heads, + output_dim=latent_dim + 3, + mlp_ratio=4.0, + add_space_attn=add_space_attn, + flash=getattr(self.args, "flash_attn", True) + ) + self.support_features = torch.zeros(100, 384).to("cuda") + 0.1 + + self.norm = nn.GroupNorm(1, self.latent_dim) + + self.ffeat_updater = nn.Sequential( + nn.Linear(self.latent_dim, self.latent_dim), + nn.GELU(), + ) + self.ffeatyz_updater = nn.Sequential( + nn.Linear(self.latent_dim, self.latent_dim), + nn.GELU(), + ) + self.ffeatxz_updater = nn.Sequential( + nn.Linear(self.latent_dim, self.latent_dim), + nn.GELU(), + ) + + #TODO @NeuralArap: optimize the arap + self.embed_traj = Embedder_Fourier( + input_dim=5, max_freq_log2=5.0, N_freqs=3, include_input=True + ) + self.embed3d = Embedder_Fourier( + input_dim=3, max_freq_log2=10.0, N_freqs=10, include_input=True + ) + self.embedConv = nn.Conv2d(self.latent_dim+63, + self.latent_dim, 3, padding=1) + + # @Vis_predictor + self.vis_predictor = nn.Sequential( + nn.Linear(128, 1), + ) + + self.embedProj = nn.Linear(63, 456) + self.zeroMLPflow = nn.Linear(195, 130) + + def prepare_track(self, rgbds, queries): + """ + NOTE: + Normalized the rgbs and sorted the queries via their first appeared time + Args: + rgbds: the input rgbd images (B T 4 H W) + queries: the input queries (B N 4) + Return: + rgbds: the normalized rgbds (B T 4 H W) + queries: the sorted queries (B N 4) + track_mask: + """ + assert (rgbds.shape[2]==4) and (queries.shape[2]==4) + #Step1: normalize the rgbs input + device = rgbds.device + rgbds[:, :, :3, ...] = 2 * (rgbds[:, :, :3, ...] / 255.0) - 1.0 + B, T, C, H, W = rgbds.shape + B, N, __ = queries.shape + self.traj_e = torch.zeros((B, T, N, 3), device=device) + self.vis_e = torch.zeros((B, T, N), device=device) + + #Step2: sort the points via their first appeared time + first_positive_inds = queries[0, :, 0].long() + __, sort_inds = torch.sort(first_positive_inds, dim=0, descending=False) + inv_sort_inds = torch.argsort(sort_inds, dim=0) + first_positive_sorted_inds = first_positive_inds[sort_inds] + # check if can be inverse + assert torch.allclose( + first_positive_inds, first_positive_inds[sort_inds][inv_sort_inds] + ) + + # filter those points never appear points during 1 - T + ind_array = torch.arange(T, device=device) + ind_array = ind_array[None, :, None].repeat(B, 1, N) + track_mask = (ind_array >= + first_positive_inds[None, None, :]).unsqueeze(-1) + + # scale the coords_init + coords_init = queries[:, :, 1:].reshape(B, 1, N, 3).repeat( + 1, self.S, 1, 1 + ) + coords_init[..., :2] /= float(self.stride) + + #Step3: initial the regular grid + gridx = torch.linspace(0, W//self.stride - 1, W//self.stride) + gridy = torch.linspace(0, H//self.stride - 1, H//self.stride) + gridx, gridy = torch.meshgrid(gridx, gridy) + gridxy = torch.stack([gridx, gridy], dim=-1).to(rgbds.device).permute( + 2, 1, 0 + ) + vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10 + + # Step4: initial traj for neural arap + T_series = torch.linspace(0, 5, T).reshape(1, T, 1 , 1).cuda() # 1 T 1 1 + T_series = T_series.repeat(B, 1, N, 1) + # get the 3d traj in the camera coordinates + intr_init = self.intrs[:,queries[0,:,0].long()] + Traj_series = pix2cam(queries[:,:,None,1:].double(), intr_init.double()) + #torch.inverse(intr_init.double())@queries[:,:,1:,None].double() # B N 3 1 + Traj_series = Traj_series.repeat(1, 1, T, 1).permute(0, 2, 1, 3).float() + Traj_series = torch.cat([T_series, Traj_series], dim=-1) + # get the indicator for the neural arap + Traj_mask = -1e2*torch.ones_like(T_series) + Traj_series = torch.cat([Traj_series, Traj_mask], dim=-1) + + return ( + rgbds, + first_positive_inds, + first_positive_sorted_inds, + sort_inds, inv_sort_inds, + track_mask, gridxy, coords_init[..., sort_inds, :].clone(), + vis_init, Traj_series[..., sort_inds, :].clone() + ) + + def sample_trifeat(self, t, + coords, + featMapxy, + featMapyz, + featMapxz): + """ + Sample the features from the 5D triplane feature map 3*(B S C H W) + Args: + t: the time index + coords: the coordinates of the points B S N 3 + featMapxy: the feature map B S C Hx Wy + featMapyz: the feature map B S C Hy Wz + featMapxz: the feature map B S C Hx Wz + """ + # get xy_t yz_t xz_t + queried_t = t.reshape(1, 1, -1, 1) + xy_t = torch.cat( + [queried_t, coords[..., [0,1]]], + dim=-1 + ) + yz_t = torch.cat( + [queried_t, coords[..., [1, 2]]], + dim=-1 + ) + xz_t = torch.cat( + [queried_t, coords[..., [0, 2]]], + dim=-1 + ) + featxy_init = sample_features5d(featMapxy, xy_t) + + featyz_init = sample_features5d(featMapyz, yz_t) + featxz_init = sample_features5d(featMapxz, xz_t) + + featxy_init = featxy_init.repeat(1, self.S, 1, 1) + featyz_init = featyz_init.repeat(1, self.S, 1, 1) + featxz_init = featxz_init.repeat(1, self.S, 1, 1) + + return featxy_init, featyz_init, featxz_init + + def neural_arap(self, coords, Traj_arap, intrs_S, T_mark): + """ calculate the ARAP embedding and offset + Args: + coords: the coordinates of the current points 1 S N' 3 + Traj_arap: the trajectory of the points 1 T N' 5 + intrs_S: the camera intrinsics B S 3 3 + + """ + coords_out = coords.clone() + coords_out[..., :2] *= float(self.stride) + coords_out[..., 2] = coords_out[..., 2]/self.Dz + coords_out[..., 2] = coords_out[..., 2]*(self.d_far-self.d_near) + self.d_near + intrs_S = intrs_S[:, :, None, ...].repeat(1, 1, coords_out.shape[2], 1, 1) + B, S, N, D = coords_out.shape + if S != intrs_S.shape[1]: + intrs_S = torch.cat( + [intrs_S, intrs_S[:, -1:].repeat(1, S - intrs_S.shape[1],1,1,1)], dim=1 + ) + T_mark = torch.cat( + [T_mark, T_mark[:, -1:].repeat(1, S - T_mark.shape[1],1)], dim=1 + ) + xyz_ = pix2cam(coords_out.double(), intrs_S.double()[:,:,0]) + xyz_ = xyz_.float() + xyz_embed = torch.cat([T_mark[...,None], xyz_, + torch.zeros_like(T_mark[...,None])], dim=-1) + + xyz_embed = self.embed_traj(xyz_embed) + Traj_arap_embed = self.embed_traj(Traj_arap) + d_xyz,traj_feat = self.arapFormer(xyz_embed, Traj_arap_embed) + # update in camera coordinate + xyz_ = xyz_ + d_xyz.clamp(-5, 5) + # project back to the image plane + coords_out = cam2pix(xyz_.double(), intrs_S[:,:,0].double()).float() + # resize back + coords_out[..., :2] /= float(self.stride) + coords_out[..., 2] = (coords_out[..., 2] - self.d_near)/(self.d_far-self.d_near) + coords_out[..., 2] *= self.Dz + + return xyz_, coords_out, traj_feat + + def gradient_arap(self, coords, aff_avg=None, aff_std=None, aff_f_sg=None, + iter=0, iter_num=4, neigh_idx=None, intr=None, msk_track=None): + with torch.enable_grad(): + coords.requires_grad_(True) + y = self.ARAP_ln(coords, aff_f_sg=aff_f_sg, neigh_idx=neigh_idx, + iter=iter, iter_num=iter_num, intr=intr,msk_track=msk_track) + d_output = torch.ones_like(y, requires_grad=False, device=y.device) + gradients = torch.autograd.grad( + outputs=y, + inputs=coords, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True, allow_unused=True)[0] + + return gradients.detach() + + def forward_iteration( + self, + fmapXY, + fmapYZ, + fmapXZ, + coords_init, + feat_init=None, + vis_init=None, + track_mask=None, + iters=4, + intrs_S=None, + ): + B, S_init, N, D = coords_init.shape + assert D == 3 + assert B == 1 + B, S, __, H8, W8 = fmapXY.shape + device = fmapXY.device + + if S_init < S: + coords = torch.cat( + [coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)], + dim=1 + ) + vis_init = torch.cat( + [vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1 + ) + intrs_S = torch.cat( + [intrs_S, intrs_S[:, -1].repeat(1, S - S_init, 1, 1)], dim=1 + ) + else: + coords = coords_init.clone() + + fcorr_fnXY = CorrBlock( + fmapXY, num_levels=self.corr_levels, radius=self.corr_radius + ) + fcorr_fnYZ = CorrBlock( + fmapYZ, num_levels=self.corr_levels, radius=self.corr_radius + ) + fcorr_fnXZ = CorrBlock( + fmapXZ, num_levels=self.corr_levels, radius=self.corr_radius + ) + + ffeats = torch.split(feat_init.clone(), dim=-1, split_size_or_sections=1) + ffeats = [f.squeeze(-1) for f in ffeats] + + times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1) + pos_embed = sample_pos_embed( + grid_size=(H8, W8), + embed_dim=456, + coords=coords[..., :2], + ) + pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1) + + times_embed = ( + torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None] + .repeat(B, 1, 1) + .float() + .to(device) + ) + coord_predictions = [] + attn_predictions = [] + Rot_ln = 0 + support_feat = self.support_features + + comfy_pbar = ProgressBar(iters) + + for __ in tqdm(range(iters), desc="Processing iterations", leave=True): + coords = coords.detach() + # if self.args.if_ARAP == True: + # # refine the track with arap + # xyz_pred, coords, flows_cat0 = self.neural_arap(coords.detach(), + # Traj_arap.detach(), + # intrs_S, T_mark) + with torch.no_grad(): + fcorrsXY = fcorr_fnXY.corr_sample(ffeats[0], coords[..., :2]) + fcorrsYZ = fcorr_fnYZ.corr_sample(ffeats[1], coords[..., [1,2]]) + fcorrsXZ = fcorr_fnXZ.corr_sample(ffeats[2], coords[..., [0,2]]) + # fcorrs = fcorrsXY + fcorrs = fcorrsXY + fcorrsYZ + fcorrsXZ + LRR = fcorrs.shape[3] + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR) + + flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 3) + flows_cat = get_3d_embedding(flows_, 64, cat_coords=True) + flows_cat = self.zeroMLPflow(flows_cat) + + + ffeats_xy = ffeats[0].permute(0, + 2, 1, 3).reshape(B * N, S, self.latent_dim) + ffeats_yz = ffeats[1].permute(0, + 2, 1, 3).reshape(B * N, S, self.latent_dim) + ffeats_xz = ffeats[2].permute(0, + 2, 1, 3).reshape(B * N, S, self.latent_dim) + ffeats_ = ffeats_xy + ffeats_yz + ffeats_xz + + if track_mask.shape[1] < vis_init.shape[1]: + track_mask = torch.cat( + [ + track_mask, + torch.zeros_like(track_mask[:, 0]).repeat( + 1, vis_init.shape[1] - track_mask.shape[1], 1, 1 + ), + ], + dim=1, + ) + concat = ( + torch.cat([track_mask, vis_init], dim=2) + .permute(0, 2, 1, 3) + .reshape(B * N, S, 2) + ) + + transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2) + + if transformer_input.shape[-1] < pos_embed.shape[-1]: + # padding the transformer_input to the same dimension as pos_embed + transformer_input = F.pad( + transformer_input, (0, pos_embed.shape[-1] - transformer_input.shape[-1]), + "constant", 0 + ) + + x = transformer_input + pos_embed + times_embed + x = rearrange(x, "(b n) t d -> b n t d", b=B) + + delta, AttnMap, so3_dist, delta_se3F, so3 = self.updateformer(x, support_feat) + support_feat = support_feat + delta_se3F[0]/100 + delta = rearrange(delta, " b n t d -> (b n) t d") + d_coord = delta[:, :, :3] + d_feats = delta[:, :, 3:] + + ffeats_xy = self.ffeat_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_xy.reshape(-1, self.latent_dim) + ffeats_yz = self.ffeatyz_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_yz.reshape(-1, self.latent_dim) + ffeats_xz = self.ffeatxz_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_xz.reshape(-1, self.latent_dim) + ffeats[0] = ffeats_xy.reshape(B, N, S, self.latent_dim).permute( + 0, 2, 1, 3 + ) # B,S,N,C + ffeats[1] = ffeats_yz.reshape(B, N, S, self.latent_dim).permute( + 0, 2, 1, 3 + ) # B,S,N,C + ffeats[2] = ffeats_xz.reshape(B, N, S, self.latent_dim).permute( + 0, 2, 1, 3 + ) # B,S,N,C + coords = coords + d_coord.reshape(B, N, S, 3).permute(0, 2, 1, 3) + if torch.isnan(coords).any(): + import ipdb; ipdb.set_trace() + + coords_out = coords.clone() + coords_out[..., :2] *= float(self.stride) + + coords_out[..., 2] = coords_out[..., 2]/self.Dz + coords_out[..., 2] = coords_out[..., 2]*(self.d_far-self.d_near) + self.d_near + + coord_predictions.append(coords_out) + attn_predictions.append(AttnMap) + comfy_pbar.update(1) + + ffeats_f = ffeats[0] + ffeats[1] + ffeats[2] + vis_e = self.vis_predictor(ffeats_f.reshape(B * S * N, self.latent_dim)).reshape( + B, S, N + ) + self.support_features = support_feat.detach() + return coord_predictions, attn_predictions, vis_e, feat_init, Rot_ln + + + def forward(self, rgbds, queries, iters=4, feat_init=None, + is_train=False, intrs=None, wind_S=None): + self.support_features = torch.zeros(100, 384).to("cuda") + 0.1 + self.is_train=is_train + B, T, C, H, W = rgbds.shape + # set the intrinsic or simply initialized + if intrs is None: + intrs = torch.from_numpy(np.array([[W, 0.0, W//2], + [0.0, W, H//2], + [0.0, 0.0, 1.0]])) + intrs = intrs[None, + None,...].repeat(B, T, 1, 1).float().to(rgbds.device) + self.intrs = intrs + + # prepare the input for tracking + ( + rgbds, + first_positive_inds, + first_positive_sorted_inds, sort_inds, + inv_sort_inds, track_mask, gridxy, + coords_init, vis_init, Traj_arap + ) = self.prepare_track(rgbds.clone(), queries) + coords_init_ = coords_init.clone() + vis_init_ = vis_init[:, :, sort_inds].clone() + + depth_all = rgbds[:, :, 3,...] + d_near = self.d_near = depth_all[depth_all>0.01].min().item() + d_far = self.d_far = depth_all[depth_all>0.01].max().item() + + if wind_S is not None: + self.S = wind_S + + B, N, __ = queries.shape + self.Dz = Dz = W//self.stride + w_idx_start = 0 + p_idx_end = 0 + p_idx_start = 0 + fmaps_ = None + vis_predictions = [] + coord_predictions = [] + attn_predictions = [] + p_idx_end_list = [] + Rigid_ln_total = 0 + while w_idx_start < T - self.S // 2: + curr_wind_points = torch.nonzero( + first_positive_sorted_inds < w_idx_start + self.S) + if curr_wind_points.shape[0] == 0: + w_idx_start = w_idx_start + self.S // 2 + continue + p_idx_end = curr_wind_points[-1] + 1 + p_idx_end_list.append(p_idx_end) + # the T may not be divided by self.S + rgbds_seq = rgbds[:, w_idx_start:w_idx_start + self.S].clone() + S = S_local = rgbds_seq.shape[1] + if S < self.S: + rgbds_seq = torch.cat( + [rgbds_seq, + rgbds_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)], + dim=1, + ) + S = rgbds_seq.shape[1] + + rgbs_ = rgbds_seq.reshape(B * S, C, H, W)[:, :3] + depths = rgbds_seq.reshape(B * S, C, H, W)[:, 3:].clone() + # open the mask + # Traj_arap[:, w_idx_start:w_idx_start + self.S, :p_idx_end, -1] = 0 + #step1: normalize the depth map + + depths = (depths - d_near)/(d_far-d_near) + depths_dn = nn.functional.interpolate( + depths, scale_factor=1.0 / self.stride, mode="nearest") + depths_dnG = depths_dn*Dz + + #step2: normalize the coordinate + coords_init_[:, :, p_idx_start:p_idx_end, 2] = ( + coords_init[:, :, p_idx_start:p_idx_end, 2] - d_near + )/(d_far-d_near) + coords_init_[:, :, p_idx_start:p_idx_end, 2] *= Dz + + # efficient triplane splatting + gridxyz = torch.cat([gridxy[None,...].repeat( + depths_dn.shape[0],1,1,1), depths_dnG], dim=1) + Fxy2yz = gridxyz[:,[1, 2], ...] - gridxyz[:,:2] + Fxy2xz = gridxyz[:,[0, 2], ...] - gridxyz[:,:2] + if getattr(self.args, "Embed3D", None) == True: + gridxyz_nm = gridxyz.clone() + gridxyz_nm[:,0,...] = (gridxyz_nm[:,0,...]-gridxyz_nm[:,0,...].min())/(gridxyz_nm[:,0,...].max()-gridxyz_nm[:,0,...].min()) + gridxyz_nm[:,1,...] = (gridxyz_nm[:,1,...]-gridxyz_nm[:,1,...].min())/(gridxyz_nm[:,1,...].max()-gridxyz_nm[:,1,...].min()) + gridxyz_nm[:,2,...] = (gridxyz_nm[:,2,...]-gridxyz_nm[:,2,...].min())/(gridxyz_nm[:,2,...].max()-gridxyz_nm[:,2,...].min()) + gridxyz_nm = 2*(gridxyz_nm-0.5) + _,_,h4,w4 = gridxyz_nm.shape + gridxyz_nm = gridxyz_nm.permute(0,2,3,1).reshape(S*h4*w4, 3) + featPE = self.embed3d(gridxyz_nm).view(S, h4, w4, -1).permute(0,3,1,2) + if fmaps_ is None: + fmaps_ = torch.cat([self.fnet(rgbs_),featPE], dim=1) + fmaps_ = self.embedConv(fmaps_) + else: + fmaps_new = torch.cat([self.fnet(rgbs_[self.S // 2 :]),featPE[self.S // 2 :]], dim=1) + fmaps_new = self.embedConv(fmaps_new) + fmaps_ = torch.cat( + [fmaps_[self.S // 2 :], fmaps_new], dim=0 + ) + else: + if fmaps_ is None: + fmaps_ = self.fnet(rgbs_) + else: + fmaps_ = torch.cat( + [fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0 + ) + + fmapXY = fmaps_[:, :self.latent_dim].reshape( + B, S, self.latent_dim, H // self.stride, W // self.stride + ) + + fmapYZ = softsplat(fmapXY[0], Fxy2yz, None, + strMode="avg", tenoutH=self.Dz, tenoutW=H//self.stride) + fmapXZ = softsplat(fmapXY[0], Fxy2xz, None, + strMode="avg", tenoutH=self.Dz, tenoutW=W//self.stride) + + fmapYZ = self.headyz(fmapYZ)[None, ...] + fmapXZ = self.headxz(fmapXZ)[None, ...] + + if p_idx_end - p_idx_start > 0: + queried_t = (first_positive_sorted_inds[p_idx_start:p_idx_end] + - w_idx_start) + (featxy_init, + featyz_init, + featxz_init) = self.sample_trifeat( + t=queried_t,featMapxy=fmapXY, + featMapyz=fmapYZ,featMapxz=fmapXZ, + coords=coords_init_[:, :1, p_idx_start:p_idx_end] + ) + # T, S, N, C, 3 + feat_init_curr = torch.stack([featxy_init, + featyz_init, featxz_init], dim=-1) + feat_init = smart_cat(feat_init, feat_init_curr, dim=2) + + if p_idx_start > 0: + # preprocess the coordinates of last windows + last_coords = coords[-1][:, self.S // 2 :].clone() + last_coords[..., :2] /= float(self.stride) + last_coords[..., 2:] = (last_coords[..., 2:]-d_near)/(d_far-d_near) + last_coords[..., 2:] = last_coords[..., 2:]*Dz + + coords_init_[:, : self.S // 2, :p_idx_start] = last_coords + coords_init_[:, self.S // 2 :, :p_idx_start] = last_coords[ + :, -1 + ].repeat(1, self.S // 2, 1, 1) + + last_vis = vis[:, self.S // 2 :].unsqueeze(-1) + vis_init_[:, : self.S // 2, :p_idx_start] = last_vis + vis_init_[:, self.S // 2 :, :p_idx_start] = last_vis[:, -1].repeat( + 1, self.S // 2, 1, 1 + ) + + coords, attns, vis, __, Rigid_ln = self.forward_iteration( + fmapXY=fmapXY, + fmapYZ=fmapYZ, + fmapXZ=fmapXZ, + coords_init=coords_init_[:, :, :p_idx_end], + feat_init=feat_init[:, :, :p_idx_end], + vis_init=vis_init_[:, :, :p_idx_end], + track_mask=track_mask[:, w_idx_start : w_idx_start + self.S, :p_idx_end], + iters=iters, + intrs_S=self.intrs[:, w_idx_start : w_idx_start + self.S], + ) + + Rigid_ln_total+=Rigid_ln + + if is_train: + vis_predictions.append(torch.sigmoid(vis[:, :S_local])) + coord_predictions.append([coord[:, :S_local] for coord in coords]) + attn_predictions.append(attns) + + self.traj_e[:, w_idx_start:w_idx_start+self.S, :p_idx_end] = coords[-1][:, :S_local] + self.vis_e[:, w_idx_start:w_idx_start+self.S, :p_idx_end] = vis[:, :S_local] + + track_mask[:, : w_idx_start + self.S, :p_idx_end] = 0.0 + w_idx_start = w_idx_start + self.S // 2 + + p_idx_start = p_idx_end + + self.traj_e = self.traj_e[:, :, inv_sort_inds] + self.vis_e = self.vis_e[:, :, inv_sort_inds] + + self.vis_e = torch.sigmoid(self.vis_e) + train_data = ( + (vis_predictions, coord_predictions, attn_predictions, + p_idx_end_list, sort_inds, Rigid_ln_total) + ) + if self.is_train: + return self.traj_e, feat_init, self.vis_e, train_data + else: + return self.traj_e, feat_init, self.vis_e + diff --git a/das/spatracker/models/core/spatracker/unet.py b/das/spatracker/models/core/spatracker/unet.py new file mode 100644 index 0000000..715ae32 --- /dev/null +++ b/das/spatracker/models/core/spatracker/unet.py @@ -0,0 +1,258 @@ +''' +Codes are from: +https://github.com/jaxony/unet-pytorch/blob/master/model.py +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from collections import OrderedDict +from torch.nn import init +import numpy as np + +def conv3x3(in_channels, out_channels, stride=1, + padding=1, bias=True, groups=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=padding, + bias=bias, + groups=groups) + +def upconv2x2(in_channels, out_channels, mode='transpose'): + if mode == 'transpose': + return nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=2, + stride=2) + else: + # out_channels is always going to be the same + # as in_channels + return nn.Sequential( + nn.Upsample(mode='bilinear', scale_factor=2), + conv1x1(in_channels, out_channels)) + +def conv1x1(in_channels, out_channels, groups=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + groups=groups, + stride=1) + + +class DownConv(nn.Module): + """ + A helper Module that performs 2 convolutions and 1 MaxPool. + A ReLU activation follows each convolution. + """ + def __init__(self, in_channels, out_channels, pooling=True): + super(DownConv, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.pooling = pooling + + self.conv1 = conv3x3(self.in_channels, self.out_channels) + self.conv2 = conv3x3(self.out_channels, self.out_channels) + + if self.pooling: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + before_pool = x + if self.pooling: + x = self.pool(x) + return x, before_pool + + +class UpConv(nn.Module): + """ + A helper Module that performs 2 convolutions and 1 UpConvolution. + A ReLU activation follows each convolution. + """ + def __init__(self, in_channels, out_channels, + merge_mode='concat', up_mode='transpose'): + super(UpConv, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.merge_mode = merge_mode + self.up_mode = up_mode + + self.upconv = upconv2x2(self.in_channels, self.out_channels, + mode=self.up_mode) + + if self.merge_mode == 'concat': + self.conv1 = conv3x3( + 2*self.out_channels, self.out_channels) + else: + # num of input channels to conv2 is same + self.conv1 = conv3x3(self.out_channels, self.out_channels) + self.conv2 = conv3x3(self.out_channels, self.out_channels) + + + def forward(self, from_down, from_up): + """ Forward pass + Arguments: + from_down: tensor from the encoder pathway + from_up: upconv'd tensor from the decoder pathway + """ + from_up = self.upconv(from_up) + if self.merge_mode == 'concat': + x = torch.cat((from_up, from_down), 1) + else: + x = from_up + from_down + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + return x + + +class UNet(nn.Module): + """ `UNet` class is based on https://arxiv.org/abs/1505.04597 + + The U-Net is a convolutional encoder-decoder neural network. + Contextual spatial information (from the decoding, + expansive pathway) about an input tensor is merged with + information representing the localization of details + (from the encoding, compressive pathway). + + Modifications to the original paper: + (1) padding is used in 3x3 convolutions to prevent loss + of border pixels + (2) merging outputs does not require cropping due to (1) + (3) residual connections can be used by specifying + UNet(merge_mode='add') + (4) if non-parametric upsampling is used in the decoder + pathway (specified by upmode='upsample'), then an + additional 1x1 2d convolution occurs after upsampling + to reduce channel dimensionality by a factor of 2. + This channel halving happens with the convolution in + the tranpose convolution (specified by upmode='transpose') + """ + + def __init__(self, num_classes, in_channels=3, depth=5, + start_filts=64, up_mode='transpose', + merge_mode='concat', **kwargs): + """ + Arguments: + in_channels: int, number of channels in the input tensor. + Default is 3 for RGB images. + depth: int, number of MaxPools in the U-Net. + start_filts: int, number of convolutional filters for the + first conv. + up_mode: string, type of upconvolution. Choices: 'transpose' + for transpose convolution or 'upsample' for nearest neighbour + upsampling. + """ + super(UNet, self).__init__() + + if up_mode in ('transpose', 'upsample'): + self.up_mode = up_mode + else: + raise ValueError("\"{}\" is not a valid mode for " + "upsampling. Only \"transpose\" and " + "\"upsample\" are allowed.".format(up_mode)) + + if merge_mode in ('concat', 'add'): + self.merge_mode = merge_mode + else: + raise ValueError("\"{}\" is not a valid mode for" + "merging up and down paths. " + "Only \"concat\" and " + "\"add\" are allowed.".format(up_mode)) + + # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add' + if self.up_mode == 'upsample' and self.merge_mode == 'add': + raise ValueError("up_mode \"upsample\" is incompatible " + "with merge_mode \"add\" at the moment " + "because it doesn't make sense to use " + "nearest neighbour to reduce " + "depth channels (by half).") + + self.num_classes = num_classes + self.in_channels = in_channels + self.start_filts = start_filts + self.depth = depth + + self.down_convs = [] + self.up_convs = [] + + # create the encoder pathway and add to a list + for i in range(depth): + ins = self.in_channels if i == 0 else outs + outs = self.start_filts*(2**i) + pooling = True if i < depth-1 else False + + down_conv = DownConv(ins, outs, pooling=pooling) + self.down_convs.append(down_conv) + + # create the decoder pathway and add to a list + # - careful! decoding only requires depth-1 blocks + for i in range(depth-1): + ins = outs + outs = ins // 2 + up_conv = UpConv(ins, outs, up_mode=up_mode, + merge_mode=merge_mode) + self.up_convs.append(up_conv) + + # add the list of modules to current module + self.down_convs = nn.ModuleList(self.down_convs) + self.up_convs = nn.ModuleList(self.up_convs) + + self.conv_final = conv1x1(outs, self.num_classes) + + self.reset_params() + + @staticmethod + def weight_init(m): + if isinstance(m, nn.Conv2d): + init.xavier_normal_(m.weight) + init.constant_(m.bias, 0) + + + def reset_params(self): + for i, m in enumerate(self.modules()): + self.weight_init(m) + + + def forward(self, x): + encoder_outs = [] + # encoder pathway, save outputs for merging + for i, module in enumerate(self.down_convs): + x, before_pool = module(x) + encoder_outs.append(before_pool) + for i, module in enumerate(self.up_convs): + before_pool = encoder_outs[-(i+2)] + x = module(before_pool, x) + + # No softmax is used. This means you need to use + # nn.CrossEntropyLoss is your training script, + # as this module includes a softmax already. + x = self.conv_final(x) + return x + +if __name__ == "__main__": + """ + testing + """ + model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32) + print(model) + print(sum(p.numel() for p in model.parameters())) + + reso = 176 + x = np.zeros((1, 1, reso, reso)) + x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan + x = torch.FloatTensor(x) + + out = model(x) + print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso))) + + # loss = torch.sum(out) + # loss.backward() \ No newline at end of file diff --git a/das/spatracker/models/core/spatracker/vit/__init__.py b/das/spatracker/models/core/spatracker/vit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/das/spatracker/models/core/spatracker/vit/common.py b/das/spatracker/models/core/spatracker/vit/common.py new file mode 100644 index 0000000..d67662c --- /dev/null +++ b/das/spatracker/models/core/spatracker/vit/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x \ No newline at end of file diff --git a/das/spatracker/models/core/spatracker/vit/encoder.py b/das/spatracker/models/core/spatracker/vit/encoder.py new file mode 100644 index 0000000..eeaf053 --- /dev/null +++ b/das/spatracker/models/core/spatracker/vit/encoder.py @@ -0,0 +1,397 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import ( + LayerNorm2d, MLPBlock +) + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x \ No newline at end of file diff --git a/das/spatracker/predictor.py b/das/spatracker/predictor.py new file mode 100644 index 0000000..6cb3e3d --- /dev/null +++ b/das/spatracker/predictor.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import time + +from tqdm import tqdm +from .models.core.spatracker.spatracker import get_points_on_a_grid +from .models.core.model_utils import smart_cat +from .models.build_spatracker import ( + build_spatracker, +) +from .models.core.model_utils import ( + meshgrid2d, bilinear_sample2d, smart_cat +) + +from comfy.utils import ProgressBar + +class SpaTrackerPredictor(torch.nn.Module): + def __init__( + self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", + interp_shape=(384, 512), + seq_length=16 + ): + super().__init__() + self.interp_shape = interp_shape + self.support_grid_size = 6 + model = build_spatracker(checkpoint, seq_length=seq_length) + + self.model = model + self.model.eval() + + @torch.no_grad() + def forward( + self, + video, # (1, T, 3, H, W) + video_depth = None, # (T, 1, H, W) + # input prompt types: + # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame. + # *backward_tracking=True* will compute tracks in both directions. + # - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates. + # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask. + # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks. + queries: torch.Tensor = None, + segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W) + grid_size: int = 0, + grid_query_frame: int = 0, # only for dense and regular grid tracks + backward_tracking: bool = False, + depth_predictor=None, + wind_length: int = 8, + progressive_tracking: bool = False, + ): + if queries is None and grid_size == 0: + tracks, visibilities, T_Firsts = self._compute_dense_tracks( + video, + grid_query_frame=grid_query_frame, + backward_tracking=backward_tracking, + video_depth=video_depth, + depth_predictor=depth_predictor, + wind_length=wind_length, + ) + else: + tracks, visibilities, T_Firsts = self._compute_sparse_tracks( + video, + queries, + segm_mask, + grid_size, + add_support_grid=False, #(grid_size == 0 or segm_mask is not None), + grid_query_frame=grid_query_frame, + backward_tracking=backward_tracking, + video_depth=video_depth, + depth_predictor=depth_predictor, + wind_length=wind_length, + ) + + return tracks, visibilities, T_Firsts + + def _compute_dense_tracks( + self, video, grid_query_frame, grid_size=30, backward_tracking=False, + depth_predictor=None, video_depth=None, wind_length=8 + ): + *_, H, W = video.shape + grid_step = W // grid_size + grid_width = W // grid_step + grid_height = H // grid_step + tracks = visibilities = T_Firsts = None + grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device) + grid_pts[0, :, 0] = grid_query_frame + for offset in tqdm(range(grid_step * grid_step)): + ox = offset % grid_step + oy = offset // grid_step + grid_pts[0, :, 1] = ( + torch.arange(grid_width).repeat(grid_height) * grid_step + ox + ) + grid_pts[0, :, 2] = ( + torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy + ) + tracks_step, visibilities_step, T_First_step = self._compute_sparse_tracks( + video=video, + queries=grid_pts, + backward_tracking=backward_tracking, + wind_length=wind_length, + video_depth=video_depth, + depth_predictor=depth_predictor, + ) + tracks = smart_cat(tracks, tracks_step, dim=2) + visibilities = smart_cat(visibilities, visibilities_step, dim=2) + T_Firsts = smart_cat(T_Firsts, T_First_step, dim=1) + + + return tracks, visibilities, T_Firsts + + def _compute_sparse_tracks( + self, + video, + queries, + segm_mask=None, + grid_size=0, + add_support_grid=False, + grid_query_frame=0, + backward_tracking=False, + depth_predictor=None, + video_depth=None, + wind_length=8, + ): + B, T, C, H, W = video.shape + assert B == 1 + + video = video.reshape(B * T, C, H, W) + video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear") + video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) + + if queries is not None: + queries = queries.clone() + B, N, D = queries.shape + assert D == 3 + queries[:, :, 1] *= self.interp_shape[1] / W + queries[:, :, 2] *= self.interp_shape[0] / H + elif grid_size > 0: + grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device) + if segm_mask is not None: + segm_mask = F.interpolate( + segm_mask, tuple(self.interp_shape), mode="nearest" + ) + point_mask = segm_mask[0, 0][ + (grid_pts[0, :, 1]).round().long().cpu(), + (grid_pts[0, :, 0]).round().long().cpu(), + ].bool() + grid_pts_extra = grid_pts[:, point_mask] + else: + grid_pts_extra = None + if grid_pts_extra is not None: + total_num = int(grid_pts_extra.shape[1]) + total_num = min(800, total_num) + pick_idx = torch.randperm(grid_pts_extra.shape[1])[:total_num] + grid_pts_extra = grid_pts_extra[:, pick_idx] + queries_extra = torch.cat( + [ + torch.ones_like(grid_pts_extra[:, :, :1]) * grid_query_frame, + grid_pts_extra, + ], + dim=2, + ) + + queries = torch.cat( + [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], + dim=2, + ) + + if add_support_grid: + grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=video.device) + grid_pts = torch.cat( + [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 + ) + queries = torch.cat([queries, grid_pts], dim=1) + + ## ----------- estimate the video depth -----------## + if video_depth is None: + with torch.no_grad(): + if video[0].shape[0]>30: + vidDepths = [] + for i in range(video[0].shape[0]//30+1): + if (i+1)*30 > video[0].shape[0]: + end_idx = video[0].shape[0] + else: + end_idx = (i+1)*30 + if end_idx == i*30: + break + video_ = video[0][i*30:end_idx] + vidDepths.append(depth_predictor.infer(video_/255)) + + video_depth = torch.cat(vidDepths, dim=0) + + else: + video_depth = depth_predictor.infer(video[0]/255) + video_depth = F.interpolate(video_depth, + tuple(self.interp_shape), mode="nearest") + + # from PIL import Image + # import numpy + # depth_frame = video_depth[0].detach().cpu() + # depth_frame = depth_frame.squeeze(0) + # print(depth_frame) + # print(depth_frame.min(), depth_frame.max()) + # depth_img = (depth_frame * 255).numpy().astype(numpy.uint8) + # depth_img = Image.fromarray(depth_img, mode='L') + # depth_img.save('outputs/depth_map.png') + + # frame = video[0, 0].detach().cpu() + # frame = frame.permute(1, 2, 0) + # frame = (frame * 255).numpy().astype(numpy.uint8) + # frame = Image.fromarray(frame, mode='RGB') + # frame.save('outputs/frame.png') + + depths = video_depth + rgbds = torch.cat([video, depths[None,...]], dim=2) + # get the 3D queries + comfy_pbar = ProgressBar(queries.shape[1]) + depth_interp=[] + for i in tqdm(range(queries.shape[1]), desc="Processing queries"): + depth_interp_i = bilinear_sample2d(video_depth[queries[:, i:i+1, 0].long()], + queries[:, i:i+1, 1], queries[:, i:i+1, 2]) + depth_interp.append(depth_interp_i) + comfy_pbar.update(1) + + depth_interp = torch.cat(depth_interp, dim=1) + queries = smart_cat(queries, depth_interp,dim=-1) + + #NOTE: free the memory of depth_predictor + del depth_predictor + torch.cuda.empty_cache() + t0 = time.time() + tracks, __, visibilities = self.model(rgbds=rgbds, queries=queries, iters=6, wind_S=wind_length) + print("Time taken for inference: ", time.time()-t0) + + if backward_tracking: + tracks, visibilities = self._compute_backward_tracks( + rgbds, queries, tracks, visibilities + ) + if add_support_grid: + queries[:, -self.support_grid_size ** 2 :, 0] = T - 1 + if add_support_grid: + tracks = tracks[:, :, : -self.support_grid_size ** 2] + visibilities = visibilities[:, :, : -self.support_grid_size ** 2] + thr = 0.9 + visibilities = visibilities > thr + + # correct query-point predictions + # see https://github.com/facebookresearch/co-tracker/issues/28 + + # TODO: batchify + + for i in tqdm(range(len(queries)), desc="Processing queries", leave=False): + queries_t = queries[i, :tracks.size(2), 0].to(torch.int64) + arange = torch.arange(0, len(queries_t)) + + # overwrite the predictions with the query points + tracks[i, queries_t, arange] = queries[i, :tracks.size(2), 1:] + + # correct visibilities, the query points should be visible + visibilities[i, queries_t, arange] = True + + T_First = queries[..., :tracks.size(2), 0].to(torch.uint8) + tracks[:, :, :, 0] *= W / float(self.interp_shape[1]) + tracks[:, :, :, 1] *= H / float(self.interp_shape[0]) + return tracks, visibilities, T_First + + def _compute_backward_tracks(self, video, queries, tracks, visibilities): + inv_video = video.flip(1).clone() + inv_queries = queries.clone() + inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 + + inv_tracks, __, inv_visibilities = self.model( + rgbds=inv_video, queries=queries, iters=6 + ) + + inv_tracks = inv_tracks.flip(1) + inv_visibilities = inv_visibilities.flip(1) + + mask = tracks == 0 + + tracks[mask] = inv_tracks[mask] + visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]] + return tracks, visibilities \ No newline at end of file diff --git a/das/spatracker/utils/__init__.py b/das/spatracker/utils/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/das/spatracker/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/das/spatracker/utils/basic.py b/das/spatracker/utils/basic.py new file mode 100644 index 0000000..4a4a15e --- /dev/null +++ b/das/spatracker/utils/basic.py @@ -0,0 +1,397 @@ +import os +import numpy as np +from os.path import isfile +import torch +import torch.nn.functional as F +EPS = 1e-6 +import copy + +def sub2ind(height, width, y, x): + return y*width + x + +def ind2sub(height, width, ind): + y = ind // width + x = ind % width + return y, x + +def get_lr_str(lr): + lrn = "%.1e" % lr # e.g., 5.0e-04 + lrn = lrn[0] + lrn[3:5] + lrn[-1] # e.g., 5e-4 + return lrn + +def strnum(x): + s = '%g' % x + if '.' in s: + if x < 1.0: + s = s[s.index('.'):] + s = s[:min(len(s),4)] + return s + +def assert_same_shape(t1, t2): + for (x, y) in zip(list(t1.shape), list(t2.shape)): + assert(x==y) + +def print_stats(name, tensor): + shape = tensor.shape + tensor = tensor.detach().cpu().numpy() + print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape) + +def print_stats_py(name, tensor): + shape = tensor.shape + print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape) + +def print_(name, tensor): + tensor = tensor.detach().cpu().numpy() + print(name, tensor, tensor.shape) + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + +def normalize_single(d): + # d is a whatever shape torch tensor + dmin = torch.min(d) + dmax = torch.max(d) + d = (d-dmin)/(EPS+(dmax-dmin)) + return d + +def normalize(d): + # d is B x whatever. normalize within each element of the batch + out = torch.zeros(d.size()) + if d.is_cuda: + out = out.cuda() + B = list(d.size())[0] + for b in list(range(B)): + out[b] = normalize_single(d[b]) + return out + +def hard_argmax2d(tensor): + B, C, Y, X = list(tensor.shape) + assert(C==1) + + # flatten the Tensor along the height and width axes + flat_tensor = tensor.reshape(B, -1) + # argmax of the flat tensor + argmax = torch.argmax(flat_tensor, dim=1) + + # convert the indices into 2d coordinates + argmax_y = torch.floor(argmax / X) # row + argmax_x = argmax % X # col + + argmax_y = argmax_y.reshape(B) + argmax_x = argmax_x.reshape(B) + return argmax_y, argmax_x + +def argmax2d(heat, hard=True): + B, C, Y, X = list(heat.shape) + assert(C==1) + + if hard: + # hard argmax + loc_y, loc_x = hard_argmax2d(heat) + loc_y = loc_y.float() + loc_x = loc_x.float() + else: + heat = heat.reshape(B, Y*X) + prob = torch.nn.functional.softmax(heat, dim=1) + + grid_y, grid_x = meshgrid2d(B, Y, X) + + grid_y = grid_y.reshape(B, -1) + grid_x = grid_x.reshape(B, -1) + + loc_y = torch.sum(grid_y*prob, dim=1) + loc_x = torch.sum(grid_x*prob, dim=1) + # these are B + + return loc_y, loc_x + +def reduce_masked_mean(x, mask, dim=None, keepdim=False): + # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting + # returns shape-1 + # axis can be a list of axes + for (a,b) in zip(x.size(), mask.size()): + # if not b==1: + assert(a==b) # some shape mismatch! + # assert(x.size() == mask.size()) + prod = x*mask + if dim is None: + numer = torch.sum(prod) + denom = EPS+torch.sum(mask) + else: + numer = torch.sum(prod, dim=dim, keepdim=keepdim) + denom = EPS+torch.sum(mask, dim=dim, keepdim=keepdim) + + mean = numer/denom + return mean + +def reduce_masked_median(x, mask, keep_batch=False): + # x and mask are the same shape + assert(x.size() == mask.size()) + device = x.device + + B = list(x.shape)[0] + x = x.detach().cpu().numpy() + mask = mask.detach().cpu().numpy() + + if keep_batch: + x = np.reshape(x, [B, -1]) + mask = np.reshape(mask, [B, -1]) + meds = np.zeros([B], np.float32) + for b in list(range(B)): + xb = x[b] + mb = mask[b] + if np.sum(mb) > 0: + xb = xb[mb > 0] + meds[b] = np.median(xb) + else: + meds[b] = np.nan + meds = torch.from_numpy(meds).to(device) + return meds.float() + else: + x = np.reshape(x, [-1]) + mask = np.reshape(mask, [-1]) + if np.sum(mask) > 0: + x = x[mask > 0] + med = np.median(x) + else: + med = np.nan + med = np.array([med], np.float32) + med = torch.from_numpy(med).to(device) + return med.float() + +def pack_seqdim(tensor, B): + shapelist = list(tensor.shape) + B_, S = shapelist[:2] + assert(B==B_) + otherdims = shapelist[2:] + tensor = torch.reshape(tensor, [B*S]+otherdims) + return tensor + +def unpack_seqdim(tensor, B): + shapelist = list(tensor.shape) + BS = shapelist[0] + assert(BS%B==0) + otherdims = shapelist[1:] + S = int(BS/B) + tensor = torch.reshape(tensor, [B,S]+otherdims) + return tensor + +def meshgrid2d(B, Y, X, stack=False, norm=False, device='cuda', on_chans=False): + # returns a meshgrid sized B x Y x X + + grid_y = torch.linspace(0.0, Y-1, Y, device=torch.device(device)) + grid_y = torch.reshape(grid_y, [1, Y, 1]) + grid_y = grid_y.repeat(B, 1, X) + + grid_x = torch.linspace(0.0, X-1, X, device=torch.device(device)) + grid_x = torch.reshape(grid_x, [1, 1, X]) + grid_x = grid_x.repeat(B, Y, 1) + + if norm: + grid_y, grid_x = normalize_grid2d( + grid_y, grid_x, Y, X) + + if stack: + # note we stack in xy order + # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) + if on_chans: + grid = torch.stack([grid_x, grid_y], dim=1) + else: + grid = torch.stack([grid_x, grid_y], dim=-1) + return grid + else: + return grid_y, grid_x + +def meshgrid3d(B, Z, Y, X, stack=False, norm=False, device='cuda'): + # returns a meshgrid sized B x Z x Y x X + + grid_z = torch.linspace(0.0, Z-1, Z, device=device) + grid_z = torch.reshape(grid_z, [1, Z, 1, 1]) + grid_z = grid_z.repeat(B, 1, Y, X) + + grid_y = torch.linspace(0.0, Y-1, Y, device=device) + grid_y = torch.reshape(grid_y, [1, 1, Y, 1]) + grid_y = grid_y.repeat(B, Z, 1, X) + + grid_x = torch.linspace(0.0, X-1, X, device=device) + grid_x = torch.reshape(grid_x, [1, 1, 1, X]) + grid_x = grid_x.repeat(B, Z, Y, 1) + + # if cuda: + # grid_z = grid_z.cuda() + # grid_y = grid_y.cuda() + # grid_x = grid_x.cuda() + + if norm: + grid_z, grid_y, grid_x = normalize_grid3d( + grid_z, grid_y, grid_x, Z, Y, X) + + if stack: + # note we stack in xyz order + # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) + grid = torch.stack([grid_x, grid_y, grid_z], dim=-1) + return grid + else: + return grid_z, grid_y, grid_x + +def normalize_grid2d(grid_y, grid_x, Y, X, clamp_extreme=True): + # make things in [-1,1] + grid_y = 2.0*(grid_y / float(Y-1)) - 1.0 + grid_x = 2.0*(grid_x / float(X-1)) - 1.0 + + if clamp_extreme: + grid_y = torch.clamp(grid_y, min=-2.0, max=2.0) + grid_x = torch.clamp(grid_x, min=-2.0, max=2.0) + + return grid_y, grid_x + +def normalize_grid3d(grid_z, grid_y, grid_x, Z, Y, X, clamp_extreme=True): + # make things in [-1,1] + grid_z = 2.0*(grid_z / float(Z-1)) - 1.0 + grid_y = 2.0*(grid_y / float(Y-1)) - 1.0 + grid_x = 2.0*(grid_x / float(X-1)) - 1.0 + + if clamp_extreme: + grid_z = torch.clamp(grid_z, min=-2.0, max=2.0) + grid_y = torch.clamp(grid_y, min=-2.0, max=2.0) + grid_x = torch.clamp(grid_x, min=-2.0, max=2.0) + + return grid_z, grid_y, grid_x + +def gridcloud2d(B, Y, X, norm=False, device='cuda'): + # we want to sample for each location in the grid + grid_y, grid_x = meshgrid2d(B, Y, X, norm=norm, device=device) + x = torch.reshape(grid_x, [B, -1]) + y = torch.reshape(grid_y, [B, -1]) + # these are B x N + xy = torch.stack([x, y], dim=2) + # this is B x N x 2 + return xy + +def gridcloud3d(B, Z, Y, X, norm=False, device='cuda'): + # we want to sample for each location in the grid + grid_z, grid_y, grid_x = meshgrid3d(B, Z, Y, X, norm=norm, device=device) + x = torch.reshape(grid_x, [B, -1]) + y = torch.reshape(grid_y, [B, -1]) + z = torch.reshape(grid_z, [B, -1]) + # these are B x N + xyz = torch.stack([x, y, z], dim=2) + # this is B x N x 3 + return xyz + +import re +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + +def normalize_boxlist2d(boxlist2d, H, W): + boxlist2d = boxlist2d.clone() + ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) + ymin = ymin / float(H) + ymax = ymax / float(H) + xmin = xmin / float(W) + xmax = xmax / float(W) + boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) + return boxlist2d + +def unnormalize_boxlist2d(boxlist2d, H, W): + boxlist2d = boxlist2d.clone() + ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) + ymin = ymin * float(H) + ymax = ymax * float(H) + xmin = xmin * float(W) + xmax = xmax * float(W) + boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) + return boxlist2d + +def unnormalize_box2d(box2d, H, W): + return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) + +def normalize_box2d(box2d, H, W): + return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) + +def get_gaussian_kernel_2d(channels, kernel_size=3, sigma=2.0, mid_one=False): + C = channels + xy_grid = gridcloud2d(C, kernel_size, kernel_size) # C x N x 2 + + mean = (kernel_size - 1)/2.0 + variance = sigma**2.0 + + gaussian_kernel = (1.0/(2.0*np.pi*variance)**1.5) * torch.exp(-torch.sum((xy_grid - mean)**2.0, dim=-1) / (2.0*variance)) # C X N + gaussian_kernel = gaussian_kernel.view(C, 1, kernel_size, kernel_size) # C x 1 x 3 x 3 + kernel_sum = torch.sum(gaussian_kernel, dim=(2,3), keepdim=True) + + gaussian_kernel = gaussian_kernel / kernel_sum # normalize + + if mid_one: + # normalize so that the middle element is 1 + maxval = gaussian_kernel[:,:,(kernel_size//2),(kernel_size//2)].reshape(C, 1, 1, 1) + gaussian_kernel = gaussian_kernel / maxval + + return gaussian_kernel + +def gaussian_blur_2d(input, kernel_size=3, sigma=2.0, reflect_pad=False, mid_one=False): + B, C, Z, X = input.shape + kernel = get_gaussian_kernel_2d(C, kernel_size, sigma, mid_one=mid_one) + if reflect_pad: + pad = (kernel_size - 1)//2 + out = F.pad(input, (pad, pad, pad, pad), mode='reflect') + out = F.conv2d(out, kernel, padding=0, groups=C) + else: + out = F.conv2d(input, kernel, padding=(kernel_size - 1)//2, groups=C) + return out + +def gradient2d(x, absolute=False, square=False, return_sum=False): + # x should be B x C x H x W + dh = x[:, :, 1:, :] - x[:, :, :-1, :] + dw = x[:, :, :, 1:] - x[:, :, :, :-1] + + zeros = torch.zeros_like(x) + zero_h = zeros[:, :, 0:1, :] + zero_w = zeros[:, :, :, 0:1] + dh = torch.cat([dh, zero_h], axis=2) + dw = torch.cat([dw, zero_w], axis=3) + if absolute: + dh = torch.abs(dh) + dw = torch.abs(dw) + if square: + dh = dh ** 2 + dw = dw ** 2 + if return_sum: + return dh+dw + else: + return dh, dw diff --git a/das/spatracker/utils/geom.py b/das/spatracker/utils/geom.py new file mode 100644 index 0000000..967adad --- /dev/null +++ b/das/spatracker/utils/geom.py @@ -0,0 +1,547 @@ +import torch +from . import basic as utils +import numpy as np +import torchvision.ops as ops +from .basic import print_ + +def matmul2(mat1, mat2): + return torch.matmul(mat1, mat2) + +def matmul3(mat1, mat2, mat3): + return torch.matmul(mat1, torch.matmul(mat2, mat3)) + +def eye_3x3(B, device='cuda'): + rt = torch.eye(3, device=torch.device(device)).view(1,3,3).repeat([B, 1, 1]) + return rt + +def eye_4x4(B, device='cuda'): + rt = torch.eye(4, device=torch.device(device)).view(1,4,4).repeat([B, 1, 1]) + return rt + +def safe_inverse(a): #parallel version + B, _, _ = list(a.shape) + inv = a.clone() + r_transpose = a[:, :3, :3].transpose(1,2) #inverse of rotation matrix + + inv[:, :3, :3] = r_transpose + inv[:, :3, 3:4] = -torch.matmul(r_transpose, a[:, :3, 3:4]) + + return inv + +def safe_inverse_single(a): + r, t = split_rt_single(a) + t = t.view(3,1) + r_transpose = r.t() + inv = torch.cat([r_transpose, -torch.matmul(r_transpose, t)], 1) + bottom_row = a[3:4, :] # this is [0, 0, 0, 1] + # bottom_row = torch.tensor([0.,0.,0.,1.]).view(1,4) + inv = torch.cat([inv, bottom_row], 0) + return inv + +def split_intrinsics(K): + # K is B x 3 x 3 or B x 4 x 4 + fx = K[:,0,0] + fy = K[:,1,1] + x0 = K[:,0,2] + y0 = K[:,1,2] + return fx, fy, x0, y0 + +def apply_pix_T_cam(pix_T_cam, xyz): + + fx, fy, x0, y0 = split_intrinsics(pix_T_cam) + + # xyz is shaped B x H*W x 3 + # returns xy, shaped B x H*W x 2 + + B, N, C = list(xyz.shape) + assert(C==3) + + x, y, z = torch.unbind(xyz, axis=-1) + + fx = torch.reshape(fx, [B, 1]) + fy = torch.reshape(fy, [B, 1]) + x0 = torch.reshape(x0, [B, 1]) + y0 = torch.reshape(y0, [B, 1]) + + EPS = 1e-4 + z = torch.clamp(z, min=EPS) + x = (x*fx)/(z)+x0 + y = (y*fy)/(z)+y0 + xy = torch.stack([x, y], axis=-1) + return xy + +def apply_pix_T_cam_py(pix_T_cam, xyz): + + fx, fy, x0, y0 = split_intrinsics(pix_T_cam) + + # xyz is shaped B x H*W x 3 + # returns xy, shaped B x H*W x 2 + + B, N, C = list(xyz.shape) + assert(C==3) + + x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2] + + fx = np.reshape(fx, [B, 1]) + fy = np.reshape(fy, [B, 1]) + x0 = np.reshape(x0, [B, 1]) + y0 = np.reshape(y0, [B, 1]) + + EPS = 1e-4 + z = np.clip(z, EPS, None) + x = (x*fx)/(z)+x0 + y = (y*fy)/(z)+y0 + xy = np.stack([x, y], axis=-1) + return xy + +def get_camM_T_camXs(origin_T_camXs, ind=0): + B, S = list(origin_T_camXs.shape)[0:2] + camM_T_camXs = torch.zeros_like(origin_T_camXs) + for b in list(range(B)): + camM_T_origin = safe_inverse_single(origin_T_camXs[b,ind]) + for s in list(range(S)): + camM_T_camXs[b,s] = torch.matmul(camM_T_origin, origin_T_camXs[b,s]) + return camM_T_camXs + +def apply_4x4(RT, xyz): + B, N, _ = list(xyz.shape) + ones = torch.ones_like(xyz[:,:,0:1]) + xyz1 = torch.cat([xyz, ones], 2) + xyz1_t = torch.transpose(xyz1, 1, 2) + # this is B x 4 x N + xyz2_t = torch.matmul(RT, xyz1_t) + xyz2 = torch.transpose(xyz2_t, 1, 2) + xyz2 = xyz2[:,:,:3] + return xyz2 + +def apply_4x4_py(RT, xyz): + # print('RT', RT.shape) + B, N, _ = list(xyz.shape) + ones = np.ones_like(xyz[:,:,0:1]) + xyz1 = np.concatenate([xyz, ones], 2) + # print('xyz1', xyz1.shape) + xyz1_t = xyz1.transpose(0,2,1) + # print('xyz1_t', xyz1_t.shape) + # this is B x 4 x N + xyz2_t = np.matmul(RT, xyz1_t) + # print('xyz2_t', xyz2_t.shape) + xyz2 = xyz2_t.transpose(0,2,1) + # print('xyz2', xyz2.shape) + xyz2 = xyz2[:,:,:3] + return xyz2 + +def apply_3x3(RT, xy): + B, N, _ = list(xy.shape) + ones = torch.ones_like(xy[:,:,0:1]) + xy1 = torch.cat([xy, ones], 2) + xy1_t = torch.transpose(xy1, 1, 2) + # this is B x 4 x N + xy2_t = torch.matmul(RT, xy1_t) + xy2 = torch.transpose(xy2_t, 1, 2) + xy2 = xy2[:,:,:2] + return xy2 + +def generate_polygon(ctr_x, ctr_y, avg_r, irregularity, spikiness, num_verts): + ''' + Start with the center of the polygon at ctr_x, ctr_y, + Then creates the polygon by sampling points on a circle around the center. + Random noise is added by varying the angular spacing between sequential points, + and by varying the radial distance of each point from the centre. + + Params: + ctr_x, ctr_y - coordinates of the "centre" of the polygon + avg_r - in px, the average radius of this polygon, this roughly controls how large the polygon is, really only useful for order of magnitude. + irregularity - [0,1] indicating how much variance there is in the angular spacing of vertices. [0,1] will map to [0, 2pi/numberOfVerts] + spikiness - [0,1] indicating how much variance there is in each vertex from the circle of radius avg_r. [0,1] will map to [0, avg_r] +pp num_verts + + Returns: + np.array [num_verts, 2] - CCW order. + ''' + # spikiness + spikiness = np.clip(spikiness, 0, 1) * avg_r + + # generate n angle steps + irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / num_verts + lower = (2*np.pi / num_verts) - irregularity + upper = (2*np.pi / num_verts) + irregularity + + # angle steps + angle_steps = np.random.uniform(lower, upper, num_verts) + sc = (2 * np.pi) / angle_steps.sum() + angle_steps *= sc + + # get all radii + angle = np.random.uniform(0, 2*np.pi) + radii = np.clip(np.random.normal(avg_r, spikiness, num_verts), 0, 2 * avg_r) + + # compute all points + points = [] + for i in range(num_verts): + x = ctr_x + radii[i] * np.cos(angle) + y = ctr_y + radii[i] * np.sin(angle) + points.append([x, y]) + angle += angle_steps[i] + + return np.array(points).astype(int) + + +def get_random_affine_2d(B, rot_min=-5.0, rot_max=5.0, tx_min=-0.1, tx_max=0.1, ty_min=-0.1, ty_max=0.1, sx_min=-0.05, sx_max=0.05, sy_min=-0.05, sy_max=0.05, shx_min=-0.05, shx_max=0.05, shy_min=-0.05, shy_max=0.05): + ''' + Params: + rot_min: rotation amount min + rot_max: rotation amount max + + tx_min: translation x min + tx_max: translation x max + + ty_min: translation y min + ty_max: translation y max + + sx_min: scaling x min + sx_max: scaling x max + + sy_min: scaling y min + sy_max: scaling y max + + shx_min: shear x min + shx_max: shear x max + + shy_min: shear y min + shy_max: shear y max + + Returns: + transformation matrix: (B, 3, 3) + ''' + # rotation + if rot_max - rot_min != 0: + rot_amount = np.random.uniform(low=rot_min, high=rot_max, size=B) + rot_amount = np.pi/180.0*rot_amount + else: + rot_amount = rot_min + rotation = np.zeros((B, 3, 3)) # B, 3, 3 + rotation[:, 2, 2] = 1 + rotation[:, 0, 0] = np.cos(rot_amount) + rotation[:, 0, 1] = -np.sin(rot_amount) + rotation[:, 1, 0] = np.sin(rot_amount) + rotation[:, 1, 1] = np.cos(rot_amount) + + # translation + translation = np.zeros((B, 3, 3)) # B, 3, 3 + translation[:, [0,1,2], [0,1,2]] = 1 + if (tx_max - tx_min) > 0: + trans_x = np.random.uniform(low=tx_min, high=tx_max, size=B) + translation[:, 0, 2] = trans_x + # else: + # translation[:, 0, 2] = tx_max + if ty_max - ty_min != 0: + trans_y = np.random.uniform(low=ty_min, high=ty_max, size=B) + translation[:, 1, 2] = trans_y + # else: + # translation[:, 1, 2] = ty_max + + # scaling + scaling = np.zeros((B, 3, 3)) # B, 3, 3 + scaling[:, [0,1,2], [0,1,2]] = 1 + if (sx_max - sx_min) > 0: + scale_x = 1 + np.random.uniform(low=sx_min, high=sx_max, size=B) + scaling[:, 0, 0] = scale_x + # else: + # scaling[:, 0, 0] = sx_max + if (sy_max - sy_min) > 0: + scale_y = 1 + np.random.uniform(low=sy_min, high=sy_max, size=B) + scaling[:, 1, 1] = scale_y + # else: + # scaling[:, 1, 1] = sy_max + + # shear + shear = np.zeros((B, 3, 3)) # B, 3, 3 + shear[:, [0,1,2], [0,1,2]] = 1 + if (shx_max - shx_min) > 0: + shear_x = np.random.uniform(low=shx_min, high=shx_max, size=B) + shear[:, 0, 1] = shear_x + # else: + # shear[:, 0, 1] = shx_max + if (shy_max - shy_min) > 0: + shear_y = np.random.uniform(low=shy_min, high=shy_max, size=B) + shear[:, 1, 0] = shear_y + # else: + # shear[:, 1, 0] = shy_max + + # compose all those + rt = np.einsum("ijk,ikl->ijl", rotation, translation) + ss = np.einsum("ijk,ikl->ijl", scaling, shear) + trans = np.einsum("ijk,ikl->ijl", rt, ss) + + return trans + +def get_centroid_from_box2d(box2d): + ymin = box2d[:,0] + xmin = box2d[:,1] + ymax = box2d[:,2] + xmax = box2d[:,3] + x = (xmin+xmax)/2.0 + y = (ymin+ymax)/2.0 + return y, x + +def normalize_boxlist2d(boxlist2d, H, W): + boxlist2d = boxlist2d.clone() + ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) + ymin = ymin / float(H) + ymax = ymax / float(H) + xmin = xmin / float(W) + xmax = xmax / float(W) + boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) + return boxlist2d + +def unnormalize_boxlist2d(boxlist2d, H, W): + boxlist2d = boxlist2d.clone() + ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2) + ymin = ymin * float(H) + ymax = ymax * float(H) + xmin = xmin * float(W) + xmax = xmax * float(W) + boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2) + return boxlist2d + +def unnormalize_box2d(box2d, H, W): + return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) + +def normalize_box2d(box2d, H, W): + return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1) + +def get_size_from_box2d(box2d): + ymin = box2d[:,0] + xmin = box2d[:,1] + ymax = box2d[:,2] + xmax = box2d[:,3] + height = ymax-ymin + width = xmax-xmin + return height, width + +def crop_and_resize(im, boxlist, PH, PW, boxlist_is_normalized=False): + B, C, H, W = im.shape + B2, N, D = boxlist.shape + assert(B==B2) + assert(D==4) + # PH, PW is the size to resize to + + # output is B,N,C,PH,PW + + # pt wants xy xy, unnormalized + if boxlist_is_normalized: + boxlist_unnorm = unnormalize_boxlist2d(boxlist, H, W) + else: + boxlist_unnorm = boxlist + + ymin, xmin, ymax, xmax = boxlist_unnorm.unbind(2) + # boxlist_pt = torch.stack([boxlist_unnorm[:,1], boxlist_unnorm[:,0], boxlist_unnorm[:,3], boxlist_unnorm[:,2]], dim=1) + boxlist_pt = torch.stack([xmin, ymin, xmax, ymax], dim=2) + # we want a B-len list of K x 4 arrays + + # print('im', im.shape) + # print('boxlist', boxlist.shape) + # print('boxlist_pt', boxlist_pt.shape) + + # boxlist_pt = list(boxlist_pt.unbind(0)) + + crops = [] + for b in range(B): + crops_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW)) + crops.append(crops_b) + # # crops = im + + # print('crops', crops.shape) + # crops = crops.reshape(B,N,C,PH,PW) + + + # crops = [] + # for b in range(B): + # crop_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW)) + # print('crop_b', crop_b.shape) + # crops.append(crop_b) + crops = torch.stack(crops, dim=0) + + # print('crops', crops.shape) + # boxlist_list = boxlist_pt.unbind(0) + # print('rgb_crop', rgb_crop.shape) + + return crops + + +# def get_boxlist_from_centroid_and_size(cy, cx, h, w, clip=True): +# # cy,cx are both B,N +# ymin = cy - h/2 +# ymax = cy + h/2 +# xmin = cx - w/2 +# xmax = cx + w/2 + +# box = torch.stack([ymin, xmin, ymax, xmax], dim=-1) +# if clip: +# box = torch.clamp(box, 0, 1) +# return box + + +def get_boxlist_from_centroid_and_size(cy, cx, h, w):#, clip=False): + # cy,cx are the same shape + ymin = cy - h/2 + ymax = cy + h/2 + xmin = cx - w/2 + xmax = cx + w/2 + + # if clip: + # ymin = torch.clamp(ymin, 0, H-1) + # ymax = torch.clamp(ymax, 0, H-1) + # xmin = torch.clamp(xmin, 0, W-1) + # xmax = torch.clamp(xmax, 0, W-1) + + box = torch.stack([ymin, xmin, ymax, xmax], dim=-1) + return box + + +def get_box2d_from_mask(mask, normalize=False): + # mask is B, 1, H, W + + B, C, H, W = mask.shape + assert(C==1) + xy = utils.basic.gridcloud2d(B, H, W, norm=False, device=mask.device) # B, H*W, 2 + + box = torch.zeros((B, 4), dtype=torch.float32, device=mask.device) + for b in range(B): + xy_b = xy[b] # H*W, 2 + mask_b = mask[b].reshape(H*W) + xy_ = xy_b[mask_b > 0] + x_ = xy_[:,0] + y_ = xy_[:,1] + ymin = torch.min(y_) + ymax = torch.max(y_) + xmin = torch.min(x_) + xmax = torch.max(x_) + box[b] = torch.stack([ymin, xmin, ymax, xmax], dim=0) + if normalize: + box = normalize_boxlist2d(box.unsqueeze(1), H, W).squeeze(1) + return box + +def convert_box2d_to_intrinsics(box2d, pix_T_cam, H, W, use_image_aspect_ratio=True, mult_padding=1.0): + # box2d is B x 4, with ymin, xmin, ymax, xmax in normalized coords + # ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1) + # H, W is the original size of the image + # mult_padding is relative to object size in pixels + + # i assume we're rendering an image the same size as the original (H, W) + + if not mult_padding==1.0: + y, x = get_centroid_from_box2d(box2d) + h, w = get_size_from_box2d(box2d) + box2d = get_box2d_from_centroid_and_size( + y, x, h*mult_padding, w*mult_padding, clip=False) + + if use_image_aspect_ratio: + h, w = get_size_from_box2d(box2d) + y, x = get_centroid_from_box2d(box2d) + + # note h,w are relative right now + # we need to undo this, to see the real ratio + + h = h*float(H) + w = w*float(W) + box_ratio = h/w + im_ratio = H/float(W) + + # print('box_ratio:', box_ratio) + # print('im_ratio:', im_ratio) + + if box_ratio >= im_ratio: + w = h/im_ratio + # print('setting w:', h/im_ratio) + else: + h = w*im_ratio + # print('setting h:', w*im_ratio) + + box2d = get_box2d_from_centroid_and_size( + y, x, h/float(H), w/float(W), clip=False) + + assert(h > 1e-4) + assert(w > 1e-4) + + ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1) + + fx, fy, x0, y0 = split_intrinsics(pix_T_cam) + + # the topleft of the new image will now have a different offset from the center of projection + + new_x0 = x0 - xmin*W + new_y0 = y0 - ymin*H + + pix_T_cam = pack_intrinsics(fx, fy, new_x0, new_y0) + # this alone will give me an image in original resolution, + # with its topleft at the box corner + + box_h, box_w = get_size_from_box2d(box2d) + # these are normalized, and shaped B. (e.g., [0.4], [0.3]) + + # we are going to scale the image by the inverse of this, + # since we are zooming into this area + + sy = 1./box_h + sx = 1./box_w + + pix_T_cam = scale_intrinsics(pix_T_cam, sx, sy) + return pix_T_cam, box2d + +def pixels2camera(x,y,z,fx,fy,x0,y0): + # x and y are locations in pixel coordinates, z is a depth in meters + # they can be images or pointclouds + # fx, fy, x0, y0 are camera intrinsics + # returns xyz, sized B x N x 3 + + B = x.shape[0] + + fx = torch.reshape(fx, [B,1]) + fy = torch.reshape(fy, [B,1]) + x0 = torch.reshape(x0, [B,1]) + y0 = torch.reshape(y0, [B,1]) + + x = torch.reshape(x, [B,-1]) + y = torch.reshape(y, [B,-1]) + z = torch.reshape(z, [B,-1]) + + # unproject + x = (z/fx)*(x-x0) + y = (z/fy)*(y-y0) + + xyz = torch.stack([x,y,z], dim=2) + # B x N x 3 + return xyz + +def camera2pixels(xyz, pix_T_cam): + # xyz is shaped B x H*W x 3 + # returns xy, shaped B x H*W x 2 + + fx, fy, x0, y0 = split_intrinsics(pix_T_cam) + x, y, z = torch.unbind(xyz, dim=-1) + B = list(z.shape)[0] + + fx = torch.reshape(fx, [B,1]) + fy = torch.reshape(fy, [B,1]) + x0 = torch.reshape(x0, [B,1]) + y0 = torch.reshape(y0, [B,1]) + x = torch.reshape(x, [B,-1]) + y = torch.reshape(y, [B,-1]) + z = torch.reshape(z, [B,-1]) + + EPS = 1e-4 + z = torch.clamp(z, min=EPS) + x = (x*fx)/z + x0 + y = (y*fy)/z + y0 + xy = torch.stack([x, y], dim=-1) + return xy + +def depth2pointcloud(z, pix_T_cam): + B, C, H, W = list(z.shape) + device = z.device + y, x = utils.basic.meshgrid2d(B, H, W, device=device) + z = torch.reshape(z, [B, H, W]) + fx, fy, x0, y0 = split_intrinsics(pix_T_cam) + xyz = pixels2camera(x, y, z, fx, fy, x0, y0) + return xyz diff --git a/das/spatracker/utils/improc.py b/das/spatracker/utils/improc.py new file mode 100644 index 0000000..364daef --- /dev/null +++ b/das/spatracker/utils/improc.py @@ -0,0 +1,1447 @@ +import torch +import numpy as np +import models.spatracker.utils.basic +from sklearn.decomposition import PCA +from matplotlib import cm +import matplotlib.pyplot as plt +import cv2 +import torch.nn.functional as F +import torchvision +EPS = 1e-6 + +from skimage.color import ( + rgb2lab, rgb2yuv, rgb2ycbcr, lab2rgb, yuv2rgb, ycbcr2rgb, + rgb2hsv, hsv2rgb, rgb2xyz, xyz2rgb, rgb2hed, hed2rgb) + +def _convert(input_, type_): + return { + 'float': input_.float(), + 'double': input_.double(), + }.get(type_, input_) + +def _generic_transform_sk_3d(transform, in_type='', out_type=''): + def apply_transform_individual(input_): + device = input_.device + input_ = input_.cpu() + input_ = _convert(input_, in_type) + + input_ = input_.permute(1, 2, 0).detach().numpy() + transformed = transform(input_) + output = torch.from_numpy(transformed).float().permute(2, 0, 1) + output = _convert(output, out_type) + return output.to(device) + + def apply_transform(input_): + to_stack = [] + for image in input_: + to_stack.append(apply_transform_individual(image)) + return torch.stack(to_stack) + return apply_transform + +hsv_to_rgb = _generic_transform_sk_3d(hsv2rgb) + +def preprocess_color_tf(x): + import tensorflow as tf + return tf.cast(x,tf.float32) * 1./255 - 0.5 + +def preprocess_color(x): + if isinstance(x, np.ndarray): + return x.astype(np.float32) * 1./255 - 0.5 + else: + return x.float() * 1./255 - 0.5 + +def pca_embed(emb, keep, valid=None): + ## emb -- [S,H/2,W/2,C] + ## keep is the number of principal components to keep + ## Helper function for reduce_emb. + emb = emb + EPS + #emb is B x C x H x W + emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C + + if valid: + valid = valid.cpu().detach().numpy().reshape((H*W)) + + emb_reduced = list() + + B, H, W, C = np.shape(emb) + for img in emb: + if np.isnan(img).any(): + emb_reduced.append(np.zeros([H, W, keep])) + continue + + pixels_kd = np.reshape(img, (H*W, C)) + + if valid: + pixels_kd_pca = pixels_kd[valid] + else: + pixels_kd_pca = pixels_kd + + P = PCA(keep) + P.fit(pixels_kd_pca) + + if valid: + pixels3d = P.transform(pixels_kd)*valid + else: + pixels3d = P.transform(pixels_kd) + + out_img = np.reshape(pixels3d, [H,W,keep]).astype(np.float32) + if np.isnan(out_img).any(): + emb_reduced.append(np.zeros([H, W, keep])) + continue + + emb_reduced.append(out_img) + + emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32) + + return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2) + +def pca_embed_together(emb, keep): + ## emb -- [S,H/2,W/2,C] + ## keep is the number of principal components to keep + ## Helper function for reduce_emb. + emb = emb + EPS + #emb is B x C x H x W + emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy() #this is B x H x W x C + + B, H, W, C = np.shape(emb) + if np.isnan(emb).any(): + return torch.zeros(B, keep, H, W) + + pixelskd = np.reshape(emb, (B*H*W, C)) + P = PCA(keep) + P.fit(pixelskd) + pixels3d = P.transform(pixelskd) + out_img = np.reshape(pixels3d, [B,H,W,keep]).astype(np.float32) + + if np.isnan(out_img).any(): + return torch.zeros(B, keep, H, W) + + return torch.from_numpy(out_img).permute(0, 3, 1, 2) + +def reduce_emb(emb, valid=None, inbound=None, together=False): + ## emb -- [S,C,H/2,W/2], inbound -- [S,1,H/2,W/2] + ## Reduce number of chans to 3 with PCA. For vis. + # S,H,W,C = emb.shape.as_list() + S, C, H, W = list(emb.size()) + keep = 3 + + if together: + reduced_emb = pca_embed_together(emb, keep) + else: + reduced_emb = pca_embed(emb, keep, valid) #not im + + reduced_emb = utils.basic.normalize(reduced_emb) - 0.5 + if inbound is not None: + emb_inbound = emb*inbound + else: + emb_inbound = None + + return reduced_emb, emb_inbound + +def get_feat_pca(feat, valid=None): + B, C, D, W = list(feat.size()) + # feat is B x C x D x W. If 3D input, average it through Height dimension before passing into this function. + + pca, _ = reduce_emb(feat, valid=valid,inbound=None, together=True) + # pca is B x 3 x W x D + return pca + +def gif_and_tile(ims, just_gif=False): + S = len(ims) + # each im is B x H x W x C + # i want a gif in the left, and the tiled frames on the right + # for the gif tool, this means making a B x S x H x W tensor + # where the leftmost part is sequential and the rest is tiled + gif = torch.stack(ims, dim=1) + if just_gif: + return gif + til = torch.cat(ims, dim=2) + til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1) + im = torch.cat([gif, til], dim=3) + return im + +def back2color(i, blacken_zeros=False): + if blacken_zeros: + const = torch.tensor([-0.5]) + i = torch.where(i==0.0, const.cuda() if i.is_cuda else const, i) + return back2color(i) + else: + return ((i+0.5)*255).type(torch.ByteTensor) + +def convert_occ_to_height(occ, reduce_axis=3): + B, C, D, H, W = list(occ.shape) + assert(C==1) + # note that height increases DOWNWARD in the tensor + # (like pixel/camera coordinates) + + G = list(occ.shape)[reduce_axis] + values = torch.linspace(float(G), 1.0, steps=G, dtype=torch.float32, device=occ.device) + if reduce_axis==2: + # fro view + values = values.view(1, 1, G, 1, 1) + elif reduce_axis==3: + # top view + values = values.view(1, 1, 1, G, 1) + elif reduce_axis==4: + # lateral view + values = values.view(1, 1, 1, 1, G) + else: + assert(False) # you have to reduce one of the spatial dims (2-4) + values = torch.max(occ*values, dim=reduce_axis)[0]/float(G) + # values = values.view([B, C, D, W]) + return values + +def xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=False): + # xy is B x N x 2, containing float x and y coordinates of N things + # grid_xs and grid_ys are B x N x Y x X + + B, N, Y, X = list(grid_xs.shape) + + mu_x = xy[:,:,0].clone() + mu_y = xy[:,:,1].clone() + + x_valid = (mu_x>-0.5) & (mu_x-0.5) & (mu_y 0.5).float() + return prior + +def seq2color(im, norm=True, colormap='coolwarm'): + B, S, H, W = list(im.shape) + # S is sequential + + # prep a mask of the valid pixels, so we can blacken the invalids later + mask = torch.max(im, dim=1, keepdim=True)[0] + + # turn the S dim into an explicit sequence + coeffs = np.linspace(1.0, float(S), S).astype(np.float32)/float(S) + + # # increase the spacing from the center + # coeffs[:int(S/2)] -= 2.0 + # coeffs[int(S/2)+1:] += 2.0 + + coeffs = torch.from_numpy(coeffs).float().cuda() + coeffs = coeffs.reshape(1, S, 1, 1).repeat(B, 1, H, W) + # scale each channel by the right coeff + im = im * coeffs + # now im is in [1/S, 1], except for the invalid parts which are 0 + # keep the highest valid coeff at each pixel + im = torch.max(im, dim=1, keepdim=True)[0] + + out = [] + for b in range(B): + im_ = im[b] + # move channels out to last dim_ + im_ = im_.detach().cpu().numpy() + im_ = np.squeeze(im_) + # im_ is H x W + if colormap=='coolwarm': + im_ = cm.coolwarm(im_)[:, :, :3] + elif colormap=='PiYG': + im_ = cm.PiYG(im_)[:, :, :3] + elif colormap=='winter': + im_ = cm.winter(im_)[:, :, :3] + elif colormap=='spring': + im_ = cm.spring(im_)[:, :, :3] + elif colormap=='onediff': + im_ = np.reshape(im_, (-1)) + im0_ = cm.spring(im_)[:, :3] + im1_ = cm.winter(im_)[:, :3] + im1_[im_==1/float(S)] = im0_[im_==1/float(S)] + im_ = np.reshape(im1_, (H, W, 3)) + else: + assert(False) # invalid colormap + # move channels into dim 0 + im_ = np.transpose(im_, [2, 0, 1]) + im_ = torch.from_numpy(im_).float().cuda() + out.append(im_) + out = torch.stack(out, dim=0) + + # blacken the invalid pixels, instead of using the 0-color + out = out*mask + # out = out*255.0 + + # put it in [-0.5, 0.5] + out = out - 0.5 + + return out + +def colorize(d): + # this is actually just grayscale right now + + if d.ndim==2: + d = d.unsqueeze(dim=0) + else: + assert(d.ndim==3) + + # color_map = cm.get_cmap('plasma') + color_map = cm.get_cmap('inferno') + # S1, D = traj.shape + + # print('d1', d.shape) + C,H,W = d.shape + assert(C==1) + d = d.reshape(-1) + d = d.detach().cpu().numpy() + # print('d2', d.shape) + color = np.array(color_map(d)) * 255 # rgba + # print('color1', color.shape) + color = np.reshape(color[:,:3], [H*W, 3]) + # print('color2', color.shape) + color = torch.from_numpy(color).permute(1,0).reshape(3,H,W) + # # gather + # cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray') + # if cmap=='RdBu' or cmap=='RdYlGn': + # colors = cm(np.arange(256))[:, :3] + # else: + # colors = cm.colors + # colors = np.array(colors).astype(np.float32) + # colors = np.reshape(colors, [-1, 3]) + # colors = tf.constant(colors, dtype=tf.float32) + + # value = tf.gather(colors, indices) + # colorize(value, normalize=True, vmin=None, vmax=None, cmap=None, vals=255) + + # copy to the three chans + # d = d.repeat(3, 1, 1) + return color + + +def oned2inferno(d, norm=True, do_colorize=False): + # convert a 1chan input to a 3chan image output + + # if it's just B x H x W, add a C dim + if d.ndim==3: + d = d.unsqueeze(dim=1) + # d should be B x C x H x W, where C=1 + B, C, H, W = list(d.shape) + assert(C==1) + + if norm: + d = utils.basic.normalize(d) + + if do_colorize: + rgb = torch.zeros(B, 3, H, W) + for b in list(range(B)): + rgb[b] = colorize(d[b]) + else: + rgb = d.repeat(1, 3, 1, 1)*255.0 + # rgb = (255.0*rgb).type(torch.ByteTensor) + rgb = rgb.type(torch.ByteTensor) + + # rgb = tf.cast(255.0*rgb, tf.uint8) + # rgb = tf.reshape(rgb, [-1, hyp.H, hyp.W, 3]) + # rgb = tf.expand_dims(rgb, axis=0) + return rgb + +def oned2gray(d, norm=True): + # convert a 1chan input to a 3chan image output + + # if it's just B x H x W, add a C dim + if d.ndim==3: + d = d.unsqueeze(dim=1) + # d should be B x C x H x W, where C=1 + B, C, H, W = list(d.shape) + assert(C==1) + + if norm: + d = utils.basic.normalize(d) + + rgb = d.repeat(1,3,1,1) + rgb = (255.0*rgb).type(torch.ByteTensor) + return rgb + + +def draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20): + + rgb = vis.detach().cpu().numpy()[0] + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + color = (255, 255, 255) + # print('putting frame id', frame_id) + + frame_str = utils.basic.strnum(frame_id) + + text_color_bg = (0,0,0) + font = cv2.FONT_HERSHEY_SIMPLEX + text_size, _ = cv2.getTextSize(frame_str, font, scale, 1) + text_w, text_h = text_size + cv2.rectangle(rgb, (left, top-text_h), (left + text_w, top+1), text_color_bg, -1) + + cv2.putText( + rgb, + frame_str, + (left, top), # from left, from top + font, + scale, # font scale (float) + color, + 1) # font thickness (int) + rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB) + vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) + return vis + +COLORMAP_FILE = "./utils/bremm.png" +class ColorMap2d: + def __init__(self, filename=None): + self._colormap_file = filename or COLORMAP_FILE + self._img = plt.imread(self._colormap_file) + + self._height = self._img.shape[0] + self._width = self._img.shape[1] + + def __call__(self, X): + assert len(X.shape) == 2 + output = np.zeros((X.shape[0], 3)) + for i in range(X.shape[0]): + x, y = X[i, :] + xp = int((self._width-1) * x) + yp = int((self._height-1) * y) + xp = np.clip(xp, 0, self._width-1) + yp = np.clip(yp, 0, self._height-1) + output[i, :] = self._img[yp, xp] + return output + +def get_n_colors(N, sequential=False): + label_colors = [] + for ii in range(N): + if sequential: + rgb = cm.winter(ii/(N-1)) + rgb = (np.array(rgb) * 255).astype(np.uint8)[:3] + else: + rgb = np.zeros(3) + while np.sum(rgb) < 128: # ensure min brightness + rgb = np.random.randint(0,256,3) + label_colors.append(rgb) + return label_colors + +class Summ_writer(object): + def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False): + self.writer = writer + self.global_step = global_step + self.log_freq = log_freq + self.fps = fps + self.just_gif = just_gif + self.maxwidth = 10000 + self.save_this = (self.global_step % self.log_freq == 0) + self.scalar_freq = max(scalar_freq,1) + + + def summ_gif(self, name, tensor, blacken_zeros=False): + # tensor should be in B x S x C x H x W + + assert tensor.dtype in {torch.uint8,torch.float32} + shape = list(tensor.shape) + + if tensor.dtype == torch.float32: + tensor = back2color(tensor, blacken_zeros=blacken_zeros) + + video_to_write = tensor[0:1] + + S = video_to_write.shape[1] + if S==1: + # video_to_write is 1 x 1 x C x H x W + self.writer.add_image(name, video_to_write[0,0], global_step=self.global_step) + else: + self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step) + + return video_to_write + + def draw_boxlist2d_on_image(self, rgb, boxlist, scores=None, tids=None, linewidth=1): + B, C, H, W = list(rgb.shape) + assert(C==3) + B2, N, D = list(boxlist.shape) + assert(B2==B) + assert(D==4) # ymin, xmin, ymax, xmax + + rgb = back2color(rgb) + if scores is None: + scores = torch.ones(B2, N).float() + if tids is None: + tids = torch.arange(N).reshape(1,N).repeat(B2,N).long() + # tids = torch.zeros(B2, N).long() + out = self.draw_boxlist2d_on_image_py( + rgb[0].cpu().detach().numpy(), + boxlist[0].cpu().detach().numpy(), + scores[0].cpu().detach().numpy(), + tids[0].cpu().detach().numpy(), + linewidth=linewidth) + out = torch.from_numpy(out).type(torch.ByteTensor).permute(2, 0, 1) + out = torch.unsqueeze(out, dim=0) + out = preprocess_color(out) + out = torch.reshape(out, [1, C, H, W]) + return out + + def draw_boxlist2d_on_image_py(self, rgb, boxlist, scores, tids, linewidth=1): + # all inputs are numpy tensors + # rgb is H x W x 3 + # boxlist is N x 4 + # scores is N + # tids is N + + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + # rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + + rgb = rgb.astype(np.uint8).copy() + + + H, W, C = rgb.shape + assert(C==3) + N, D = boxlist.shape + assert(D==4) + + # color_map = cm.get_cmap('tab20') + # color_map = cm.get_cmap('set1') + color_map = cm.get_cmap('Accent') + color_map = color_map.colors + # print('color_map', color_map) + + # draw + for ind, box in enumerate(boxlist): + # box is 4 + if not np.isclose(scores[ind], 0.0): + # box = utils.geom.scale_box2d(box, H, W) + ymin, xmin, ymax, xmax = box + + # ymin, ymax = ymin*H, ymax*H + # xmin, xmax = xmin*W, xmax*W + + # print 'score = %.2f' % scores[ind] + # color_id = tids[ind] % 20 + color_id = tids[ind] + color = color_map[color_id] + color = np.array(color)*255.0 + color = color.round() + # color = color.astype(np.uint8) + # color = color[::-1] + # print('color', color) + + # print 'tid = %d; score = %.3f' % (tids[ind], scores[ind]) + + # if False: + if scores[ind] < 1.0: # not gt + cv2.putText(rgb, + # '%d (%.2f)' % (tids[ind], scores[ind]), + '%.2f' % (scores[ind]), + (int(xmin), int(ymin)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, # font size + color), + #1) # font weight + + xmin = np.clip(int(xmin), 0, W-1) + xmax = np.clip(int(xmax), 0, W-1) + ymin = np.clip(int(ymin), 0, H-1) + ymax = np.clip(int(ymax), 0, H-1) + + cv2.line(rgb, (xmin, ymin), (xmin, ymax), color, linewidth, cv2.LINE_AA) + cv2.line(rgb, (xmin, ymin), (xmax, ymin), color, linewidth, cv2.LINE_AA) + cv2.line(rgb, (xmax, ymin), (xmax, ymax), color, linewidth, cv2.LINE_AA) + cv2.line(rgb, (xmax, ymax), (xmin, ymax), color, linewidth, cv2.LINE_AA) + + # rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB) + return rgb + + def summ_boxlist2d(self, name, rgb, boxlist, scores=None, tids=None, frame_id=None, only_return=False, linewidth=2): + B, C, H, W = list(rgb.shape) + boxlist_vis = self.draw_boxlist2d_on_image(rgb, boxlist, scores=scores, tids=tids, linewidth=linewidth) + return self.summ_rgb(name, boxlist_vis, frame_id=frame_id, only_return=only_return) + + def summ_rgbs(self, name, ims, frame_ids=None, blacken_zeros=False, only_return=False): + if self.save_this: + + ims = gif_and_tile(ims, just_gif=self.just_gif) + vis = ims + + assert vis.dtype in {torch.uint8,torch.float32} + + if vis.dtype == torch.float32: + vis = back2color(vis, blacken_zeros) + + B, S, C, H, W = list(vis.shape) + + if frame_ids is not None: + assert(len(frame_ids)==S) + for s in range(S): + vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s]) + + if int(W) > self.maxwidth: + vis = vis[:,:,:,:self.maxwidth] + + if only_return: + return vis + else: + return self.summ_gif(name, vis, blacken_zeros) + + def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, only_return=False, halfres=False): + if self.save_this: + assert ims.dtype in {torch.uint8,torch.float32} + + if ims.dtype == torch.float32: + ims = back2color(ims, blacken_zeros) + + #ims is B x C x H x W + vis = ims[0:1] # just the first one + B, C, H, W = list(vis.shape) + + if halfres: + vis = F.interpolate(vis, scale_factor=0.5) + + if frame_id is not None: + vis = draw_frame_id_on_vis(vis, frame_id) + + if int(W) > self.maxwidth: + vis = vis[:,:,:,:self.maxwidth] + + if only_return: + return vis + else: + return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros) + + def flow2color(self, flow, clip=50.0): + """ + :param flow: Optical flow tensor. + :return: RGB image normalized between 0 and 1. + """ + + # flow is B x C x H x W + + B, C, H, W = list(flow.size()) + + flow = flow.clone().detach() + + abs_image = torch.abs(flow) + flow_mean = abs_image.mean(dim=[1,2,3]) + flow_std = abs_image.std(dim=[1,2,3]) + + if clip: + flow = torch.clamp(flow, -clip, clip)/clip + else: + # Apply some kind of normalization. Divide by the perceived maximum (mean + std*2) + flow_max = flow_mean + flow_std*2 + 1e-10 + for b in range(B): + flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1) + + radius = torch.sqrt(torch.sum(flow**2, dim=1, keepdim=True)) #B x 1 x H x W + radius_clipped = torch.clamp(radius, 0.0, 1.0) + + angle = torch.atan2(flow[:, 1:], flow[:, 0:1]) / np.pi #B x 1 x H x W + + hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0) + saturation = torch.ones_like(hue) * 0.75 + value = radius_clipped + hsv = torch.cat([hue, saturation, value], dim=1) #B x 3 x H x W + + #flow = tf.image.hsv_to_rgb(hsv) + flow = hsv_to_rgb(hsv) + flow = (flow*255.0).type(torch.ByteTensor) + return flow + + def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None): + # flow is B x C x D x W + if self.save_this: + return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id) + else: + return None + + def summ_oneds(self, name, ims, frame_ids=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0, norm=True, only_return=False, do_colorize=False): + if self.save_this: + if bev: + B, C, H, _, W = list(ims[0].shape) + if reduce_max: + ims = [torch.max(im, dim=3)[0] for im in ims] + else: + ims = [torch.mean(im, dim=3) for im in ims] + elif fro: + B, C, _, H, W = list(ims[0].shape) + if reduce_max: + ims = [torch.max(im, dim=2)[0] for im in ims] + else: + ims = [torch.mean(im, dim=2) for im in ims] + + + if len(ims) != 1: # sequence + im = gif_and_tile(ims, just_gif=self.just_gif) + else: + im = torch.stack(ims, dim=1) # single frame + + B, S, C, H, W = list(im.shape) + + if logvis and max_val: + max_val = np.log(max_val) + im = torch.log(torch.clamp(im, 0)+1.0) + im = torch.clamp(im, 0, max_val) + im = im/max_val + norm = False + elif max_val: + im = torch.clamp(im, 0, max_val) + im = im/max_val + norm = False + + if norm: + # normalize before oned2inferno, + # so that the ranges are similar within B across S + im = utils.basic.normalize(im) + + im = im.view(B*S, C, H, W) + vis = oned2inferno(im, norm=norm, do_colorize=do_colorize) + vis = vis.view(B, S, 3, H, W) + + if frame_ids is not None: + assert(len(frame_ids)==S) + for s in range(S): + vis[:,s] = draw_frame_id_on_vis(vis[:,s], frame_ids[s]) + + if W > self.maxwidth: + vis = vis[...,:self.maxwidth] + + if only_return: + return vis + else: + self.summ_gif(name, vis) + + def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True, frame_id=None, only_return=False): + if self.save_this: + + if bev: + B, C, H, _, W = list(im.shape) + if max_along_y: + im = torch.max(im, dim=3)[0] + else: + im = torch.mean(im, dim=3) + elif fro: + B, C, _, H, W = list(im.shape) + if max_along_y: + im = torch.max(im, dim=2)[0] + else: + im = torch.mean(im, dim=2) + else: + B, C, H, W = list(im.shape) + + im = im[0:1] # just the first one + assert(C==1) + + if logvis and max_val: + max_val = np.log(max_val) + im = torch.log(im) + im = torch.clamp(im, 0, max_val) + im = im/max_val + norm = False + elif max_val: + im = torch.clamp(im, 0, max_val)/max_val + norm = False + + vis = oned2inferno(im, norm=norm) + if W > self.maxwidth: + vis = vis[...,:self.maxwidth] + return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, only_return=only_return) + + def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None): + if self.save_this: + if valids is not None: + valids = torch.stack(valids, dim=1) + + feats = torch.stack(feats, dim=1) + # feats leads with B x S x C + + if feats.ndim==6: + + # feats is B x S x C x D x H x W + if fro: + reduce_dim = 3 + else: + reduce_dim = 4 + + if valids is None: + feats = torch.mean(feats, dim=reduce_dim) + else: + valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1) + feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim) + + B, S, C, D, W = list(feats.size()) + + if not pca: + # feats leads with B x S x C + feats = torch.mean(torch.abs(feats), dim=2, keepdims=True) + # feats leads with B x S x 1 + feats = torch.unbind(feats, dim=1) + return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids) + + else: + __p = lambda x: utils.basic.pack_seqdim(x, B) + __u = lambda x: utils.basic.unpack_seqdim(x, B) + + feats_ = __p(feats) + + if valids is None: + feats_pca_ = get_feat_pca(feats_) + else: + valids_ = __p(valids) + feats_pca_ = get_feat_pca(feats_, valids) + + feats_pca = __u(feats_pca_) + + return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return, frame_ids=frame_ids) + + def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None): + if self.save_this: + if feat.ndim==5: # B x C x D x H x W + + if bev: + reduce_axis = 3 + elif fro: + reduce_axis = 2 + else: + # default to bev + reduce_axis = 3 + + if valid is None: + feat = torch.mean(feat, dim=reduce_axis) + else: + valid = valid.repeat(1, feat.size()[1], 1, 1, 1) + feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis) + + B, C, D, W = list(feat.shape) + + if not pca: + feat = torch.mean(torch.abs(feat), dim=1, keepdims=True) + # feat is B x 1 x D x W + return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id) + else: + feat_pca = get_feat_pca(feat, valid) + return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id) + + def summ_scalar(self, name, value): + if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and ('torch' in value.type()): + value = value.detach().cpu().numpy() + if not np.isnan(value): + if (self.log_freq == 1): + self.writer.add_scalar(name, value, global_step=self.global_step) + elif self.save_this or np.mod(self.global_step, self.scalar_freq)==0: + self.writer.add_scalar(name, value, global_step=self.global_step) + + def summ_seg(self, name, seg, only_return=False, frame_id=None, colormap='tab20', label_colors=None): + if not self.save_this: + return + + B,H,W = seg.shape + + if label_colors is None: + custom_label_colors = False + # label_colors = get_n_colors(int(torch.max(seg).item()), sequential=True) + label_colors = cm.get_cmap(colormap).colors + label_colors = [[int(i*255) for i in l] for l in label_colors] + else: + custom_label_colors = True + # label_colors = matplotlib.cm.get_cmap(colormap).colors + # label_colors = [[int(i*255) for i in l] for l in label_colors] + # print('label_colors', label_colors) + + # label_colors = [ + # (0, 0, 0), # None + # (70, 70, 70), # Buildings + # (190, 153, 153), # Fences + # (72, 0, 90), # Other + # (220, 20, 60), # Pedestrians + # (153, 153, 153), # Poles + # (157, 234, 50), # RoadLines + # (128, 64, 128), # Roads + # (244, 35, 232), # Sidewalks + # (107, 142, 35), # Vegetation + # (0, 0, 255), # Vehicles + # (102, 102, 156), # Walls + # (220, 220, 0) # TrafficSigns + # ] + + r = torch.zeros_like(seg,dtype=torch.uint8) + g = torch.zeros_like(seg,dtype=torch.uint8) + b = torch.zeros_like(seg,dtype=torch.uint8) + + for label in range(0,len(label_colors)): + if (not custom_label_colors):# and (N > 20): + label_ = label % 20 + else: + label_ = label + + idx = (seg == label+1) + r[idx] = label_colors[label_][0] + g[idx] = label_colors[label_][1] + b[idx] = label_colors[label_][2] + + rgb = torch.stack([r,g,b],axis=1) + return self.summ_rgb(name,rgb,only_return=only_return, frame_id=frame_id) + + def summ_pts_on_rgb(self, name, trajs, rgb, valids=None, frame_id=None, only_return=False, show_dots=True, cmap='coolwarm', linewidth=1): + # trajs is B, S, N, 2 + # rgbs is B, S, C, H, W + B, C, H, W = rgb.shape + B, S, N, D = trajs.shape + + rgb = rgb[0] # C, H, W + trajs = trajs[0] # S, N, 2 + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) # S, N + else: + valids = valids[0] + # print('trajs', trajs.shape) + # print('valids', valids.shape) + + rgb = back2color(rgb).detach().cpu().numpy() + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + + trajs = trajs.long().detach().cpu().numpy() # S, N, 2 + valids = valids.long().detach().cpu().numpy() # S, N + + rgb = rgb.astype(np.uint8).copy() + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S,2 + valid = valids[:,i] # S + + color_map = cm.get_cmap(cmap) + color = np.array(color_map(i)[:3]) * 255 # rgb + for s in range(S): + if valid[s]: + cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, color, -1) + rgb = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0) + rgb = preprocess_color(rgb) + return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id) + + def summ_pts_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap='coolwarm', linewidth=1): + # trajs is B, S, N, 2 + # rgbs is B, S, C, H, W + B, S, C, H, W = rgbs.shape + B, S2, N, D = trajs.shape + assert(S==S2) + + rgbs = rgbs[0] # S, C, H, W + trajs = trajs[0] # S, N, 2 + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) # S, N + else: + valids = valids[0] + # print('trajs', trajs.shape) + # print('valids', valids.shape) + + rgbs_color = [] + for rgb in rgbs: + rgb = back2color(rgb).detach().cpu().numpy() + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + rgbs_color.append(rgb) # each element 3 x H x W + + trajs = trajs.long().detach().cpu().numpy() # S, N, 2 + valids = valids.long().detach().cpu().numpy() # S, N + + rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color] + + for i in range(N): + traj = trajs[:,i] # S,2 + valid = valids[:,i] # S + + color_map = cm.get_cmap(cmap) + color = np.array(color_map(0)[:3]) * 255 # rgb + for s in range(S): + if valid[s]: + cv2.circle(rgbs_color[s], (traj[s,0], traj[s,1]), linewidth, color, -1) + rgbs = [] + for rgb in rgbs_color: + rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) + rgbs.append(preprocess_color(rgb)) + + return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) + + + def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=False, cmap='coolwarm', vals=None, linewidth=1): + # trajs is B, S, N, 2 + # rgbs is B, S, C, H, W + B, S, C, H, W = rgbs.shape + B, S2, N, D = trajs.shape + assert(S==S2) + + rgbs = rgbs[0] # S, C, H, W + trajs = trajs[0] # S, N, 2 + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) # S, N + else: + valids = valids[0] + + # print('trajs', trajs.shape) + # print('valids', valids.shape) + + if vals is not None: + vals = vals[0] # N + # print('vals', vals.shape) + + rgbs_color = [] + for rgb in rgbs: + rgb = back2color(rgb).detach().cpu().numpy() + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + rgbs_color.append(rgb) # each element 3 x H x W + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i].long().detach().cpu().numpy() # S, 2 + valid = valids[:,i].long().detach().cpu().numpy() # S + + # print('traj', traj.shape) + # print('valid', valid.shape) + + if vals is not None: + # val = vals[:,i].float().detach().cpu().numpy() # [] + val = vals[i].float().detach().cpu().numpy() # [] + # print('val', val.shape) + else: + val = None + + for t in range(S): + # if valid[t]: + # traj_seq = traj[max(t-16,0):t+1] + traj_seq = traj[max(t-8,0):t+1] + val_seq = np.linspace(0,1,len(traj_seq)) + # if t<2: + # val_seq = np.zeros_like(val_seq) + # print('val_seq', val_seq) + # val_seq = 1.0 + # val_seq = np.arange(8)/8.0 + # val_seq = val_seq[-len(traj_seq):] + # rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth) + rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth) + # input() + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S,2 + # vis = visibles[:,i] # S + vis = torch.ones_like(traj[:,0]) # S + valid = valids[:,i] # S + rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=0, show_dots=show_dots, cmap=cmap_, linewidth=linewidth) + + rgbs = [] + for rgb in rgbs_color: + rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) + rgbs.append(preprocess_color(rgb)) + + return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) + + def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=True, cmap=None, linewidth=1): + # trajs is B, S, N, 2 + # rgbs is B, S, C, H, W + B, S, C, H, W = rgbs.shape + B, S2, N, D = trajs.shape + assert(S==S2) + + rgbs = rgbs[0] # S, C, H, W + trajs = trajs[0] # S, N, 2 + visibles = visibles[0] # S, N + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) # S, N + else: + valids = valids[0] + # print('trajs', trajs.shape) + # print('valids', valids.shape) + + rgbs_color = [] + for rgb in rgbs: + rgb = back2color(rgb).detach().cpu().numpy() + rgb = np.transpose(rgb, [1, 2, 0]) # put channels last + rgbs_color.append(rgb) # each element 3 x H x W + + trajs = trajs.long().detach().cpu().numpy() # S, N, 2 + visibles = visibles.float().detach().cpu().numpy() # S, N + valids = valids.long().detach().cpu().numpy() # S, N + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S,2 + vis = visibles[:,i] # S + valid = valids[:,i] # S + rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_, linewidth=linewidth) + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S,2 + vis = visibles[:,i] # S + valid = valids[:,i] # S + if valid[0]: + rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None, linewidth=linewidth) + + rgbs = [] + for rgb in rgbs_color: + rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) + rgbs.append(preprocess_color(rgb)) + + return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids) + + def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=False, show_lines=True, frame_id=None, only_return=False, cmap='coolwarm', linewidth=1): + # trajs is B, S, N, 2 + # rgb is B, C, H, W + B, C, H, W = rgb.shape + B, S, N, D = trajs.shape + + rgb = rgb[0] # S, C, H, W + trajs = trajs[0] # S, N, 2 + + if valids is None: + valids = torch.ones_like(trajs[:,:,0]) + else: + valids = valids[0] + + rgb_color = back2color(rgb).detach().cpu().numpy() + rgb_color = np.transpose(rgb_color, [1, 2, 0]) # put channels last + + # using maxdist will dampen the colors for short motions + norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0])**2, dim=1)) # N + maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy() + maxdist = None + trajs = trajs.long().detach().cpu().numpy() # S, N, 2 + valids = valids.long().detach().cpu().numpy() # S, N + + for i in range(N): + if cmap=='onediff' and i==0: + cmap_ = 'spring' + elif cmap=='onediff': + cmap_ = 'winter' + else: + cmap_ = cmap + traj = trajs[:,i] # S, 2 + valid = valids[:,i] # S + if valid[0]==1: + traj = traj[valid>0] + rgb_color = self.draw_traj_on_image_py( + rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist, linewidth=linewidth) + + rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0) + rgb = preprocess_color(rgb_color) + return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id) + + def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm', val=None, maxdist=None): + # all inputs are numpy tensors + # rgb is 3 x H x W + # traj is S x 2 + + H, W, C = rgb.shape + assert(C==3) + + rgb = rgb.astype(np.uint8).copy() + + S1, D = traj.shape + assert(D==2) + + color_map = cm.get_cmap(cmap) + S1, D = traj.shape + + for s in range(S1): + if val is not None: + # if len(val) == S1: + color = np.array(color_map(val[s])[:3]) * 255 # rgb + # else: + # color = np.array(color_map(val)[:3]) * 255 # rgb + else: + if maxdist is not None: + val = (np.sqrt(np.sum((traj[s]-traj[0])**2))/maxdist).clip(0,1) + color = np.array(color_map(val)[:3]) * 255 # rgb + else: + color = np.array(color_map((s)/max(1,float(S-2)))[:3]) * 255 # rgb + + if show_lines and s<(S1-1): + cv2.line(rgb, + (int(traj[s,0]), int(traj[s,1])), + (int(traj[s+1,0]), int(traj[s+1,1])), + color, + linewidth, + cv2.LINE_AA) + if show_dots: + cv2.circle(rgb, (int(traj[s,0]), int(traj[s,1])), linewidth, np.array(color_map(1)[:3])*255, -1) + + # if maxdist is not None: + # val = (np.sqrt(np.sum((traj[-1]-traj[0])**2))/maxdist).clip(0,1) + # color = np.array(color_map(val)[:3]) * 255 # rgb + # else: + # # draw the endpoint of traj, using the next color (which may be the last color) + # color = np.array(color_map((S1-1)/max(1,float(S-2)))[:3]) * 255 # rgb + + # # emphasize endpoint + # cv2.circle(rgb, (traj[-1,0], traj[-1,1]), linewidth*2, color, -1) + + return rgb + + + + def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None): + # all inputs are numpy tensors + # rgbs is a list of H,W,3 + # traj is S,2 + H, W, C = rgbs[0].shape + assert(C==3) + + rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs] + + S1, D = traj.shape + assert(D==2) + + x = int(np.clip(traj[0,0], 0, W-1)) + y = int(np.clip(traj[0,1], 0, H-1)) + color = rgbs[0][y,x] + color = (int(color[0]),int(color[1]),int(color[2])) + for s in range(S): + # bak_color = np.array(color_map(1.0)[:3]) * 255 # rgb + # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth*4, bak_color, -1) + cv2.polylines(rgbs[s], + [traj[:s+1]], + False, + color, + linewidth, + cv2.LINE_AA) + return rgbs + + def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None): + # all inputs are numpy tensors + # rgbs is a list of 3,H,W + # xy is N,2 + H, W, C = rgb.shape + assert(C==3) + + rgb = rgb.astype(np.uint8).copy() + + N, D = xy.shape + assert(D==2) + + + xy = xy.astype(np.float32) + xy[:,0] = np.clip(xy[:,0], 0, W-1) + xy[:,1] = np.clip(xy[:,1], 0, H-1) + xy = xy.astype(np.int32) + + + + if colors is None: + colors = get_n_colors(N) + + for n in range(N): + color = colors[n] + # print('color', color) + # color = (color[0]*255).astype(np.uint8) + color = (int(color[0]),int(color[1]),int(color[2])) + + # x = int(np.clip(xy[0,0], 0, W-1)) + # y = int(np.clip(xy[0,1], 0, H-1)) + # color_ = rgbs[0][y,x] + # color_ = (int(color_[0]),int(color_[1]),int(color_[2])) + # color_ = (int(color_[0]),int(color_[1]),int(color_[2])) + + cv2.circle(rgb, (xy[n,0], xy[n,1]), linewidth, color, 3) + # vis_color = int(np.squeeze(vis[s])*255) + # vis_color = (vis_color,vis_color,vis_color) + # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth+1, vis_color, -1) + return rgb + + def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None): + # all inputs are numpy tensors + # rgbs is a list of 3,H,W + # traj is S,2 + H, W, C = rgbs[0].shape + assert(C==3) + + rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs] + + S1, D = traj.shape + assert(D==2) + + if cmap is None: + bremm = ColorMap2d() + traj_ = traj[0:1].astype(np.float32) + traj_[:,0] /= float(W) + traj_[:,1] /= float(H) + color = bremm(traj_) + # print('color', color) + color = (color[0]*255).astype(np.uint8) + # color = (int(color[0]),int(color[1]),int(color[2])) + color = (int(color[2]),int(color[1]),int(color[0])) + + for s in range(S1): + if cmap is not None: + color_map = cm.get_cmap(cmap) + # color = np.array(color_map(s/(S-1))[:3]) * 255 # rgb + color = np.array(color_map((s+1)/max(1,float(S-1)))[:3]) * 255 # rgb + # color = color.astype(np.uint8) + # color = (color[0], color[1], color[2]) + # print('color', color) + # import ipdb; ipdb.set_trace() + + cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, color, -1) + # vis_color = int(np.squeeze(vis[s])*255) + # vis_color = (vis_color,vis_color,vis_color) + # cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, vis_color, -1) + + return rgbs + + def summ_traj_as_crops(self, name, trajs_e, rgbs, frame_id=None, only_return=False, show_circ=False, trajs_g=None, is_g=False): + B, S, N, D = trajs_e.shape + assert(N==1) + assert(D==2) + + rgbs_vis = [] + n = 0 + pad_amount = 100 + trajs_e_py = trajs_e[0].detach().cpu().numpy() + # trajs_e_py = np.clip(trajs_e_py, min=pad_amount/2, max=pad_amoun + trajs_e_py = trajs_e_py + pad_amount + + if trajs_g is not None: + trajs_g_py = trajs_g[0].detach().cpu().numpy() + trajs_g_py = trajs_g_py + pad_amount + + for s in range(S): + rgb = rgbs[0,s].detach().cpu().numpy() + # print('orig rgb', rgb.shape) + rgb = np.transpose(rgb,(1,2,0)) # H, W, 3 + + rgb = np.pad(rgb, ((pad_amount,pad_amount),(pad_amount,pad_amount),(0,0))) + # print('pad rgb', rgb.shape) + H, W, C = rgb.shape + + if trajs_g is not None: + xy_g = trajs_g_py[s,n] + xy_g[0] = np.clip(xy_g[0], pad_amount, W-pad_amount) + xy_g[1] = np.clip(xy_g[1], pad_amount, H-pad_amount) + rgb = self.draw_circs_on_image_py(rgb, xy_g.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3) + + xy_e = trajs_e_py[s,n] + xy_e[0] = np.clip(xy_e[0], pad_amount, W-pad_amount) + xy_e[1] = np.clip(xy_e[1], pad_amount, H-pad_amount) + + if show_circ: + if is_g: + rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(0,255,0)], linewidth=2, radius=3) + else: + rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1,2), colors=[(255,0,255)], linewidth=2, radius=3) + + + xmin = int(xy_e[0])-pad_amount//2 + xmax = int(xy_e[0])+pad_amount//2 + ymin = int(xy_e[1])-pad_amount//2 + ymax = int(xy_e[1])+pad_amount//2 + + rgb_ = rgb[ymin:ymax, xmin:xmax] + + H_, W_ = rgb_.shape[:2] + # if np.any(rgb_.shape==0): + # input() + if H_==0 or W_==0: + import ipdb; ipdb.set_trace() + + rgb_ = rgb_.transpose(2,0,1) + rgb_ = torch.from_numpy(rgb_) + + rgbs_vis.append(rgb_) + + # nrow = int(np.sqrt(S)*(16.0/9)/2.0) + nrow = int(np.sqrt(S)*1.5) + grid_img = torchvision.utils.make_grid(torch.stack(rgbs_vis, dim=0), nrow=nrow).unsqueeze(0) + # print('grid_img', grid_img.shape) + return self.summ_rgb(name, grid_img.byte(), frame_id=frame_id, only_return=only_return) + + def summ_occ(self, name, occ, reduce_axes=[3], bev=False, fro=False, pro=False, frame_id=None, only_return=False): + if self.save_this: + B, C, D, H, W = list(occ.shape) + if bev: + reduce_axes = [3] + elif fro: + reduce_axes = [2] + elif pro: + reduce_axes = [4] + for reduce_axis in reduce_axes: + height = convert_occ_to_height(occ, reduce_axis=reduce_axis) + if reduce_axis == reduce_axes[-1]: + return self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False, frame_id=frame_id, only_return=only_return) + else: + self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False, frame_id=frame_id, only_return=only_return) + +def erode2d(im, times=1, device='cuda'): + weights2d = torch.ones(1, 1, 3, 3, device=device) + for time in range(times): + im = 1.0 - F.conv2d(1.0 - im, weights2d, padding=1).clamp(0, 1) + return im + +def dilate2d(im, times=1, device='cuda', mode='square'): + weights2d = torch.ones(1, 1, 3, 3, device=device) + if mode=='cross': + weights2d[:,:,0,0] = 0.0 + weights2d[:,:,0,2] = 0.0 + weights2d[:,:,2,0] = 0.0 + weights2d[:,:,2,2] = 0.0 + for time in range(times): + im = F.conv2d(im, weights2d, padding=1).clamp(0, 1) + return im + + diff --git a/das/spatracker/utils/misc.py b/das/spatracker/utils/misc.py new file mode 100644 index 0000000..adc3196 --- /dev/null +++ b/das/spatracker/utils/misc.py @@ -0,0 +1,166 @@ +import torch +import numpy as np +import math +from prettytable import PrettyTable + +def count_parameters(model): + table = PrettyTable(["Modules", "Parameters"]) + total_params = 0 + for name, parameter in model.named_parameters(): + if not parameter.requires_grad: + continue + param = parameter.numel() + if param > 100000: + table.add_row([name, param]) + total_params+=param + print(table) + print('total params: %.2f M' % (total_params/1000000.0)) + return total_params + +def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False): + device = xy.device + dtype = xy.dtype + B, S, D = xy.shape + assert(D==2) + x = xy[:,:,0] + y = xy[:,:,1] + assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' + omega = torch.arange(C // 4, device=device) / (C // 4 - 1) + omega = 1. / (temperature ** omega) + + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + pe = pe.reshape(B,S,C).type(dtype) + if cat_coords: + pe = torch.cat([pe, xy], dim=2) # B,N,C+2 + return pe + +class SimplePool(): + def __init__(self, pool_size, version='pt'): + self.pool_size = pool_size + self.version = version + self.items = [] + + if not (version=='pt' or version=='np'): + print('version = %s; please choose pt or np') + assert(False) # please choose pt or np + + def __len__(self): + return len(self.items) + + def mean(self, min_size=1): + if min_size=='half': + pool_size_thresh = self.pool_size/2 + else: + pool_size_thresh = min_size + + if self.version=='np': + if len(self.items) >= pool_size_thresh: + return np.sum(self.items)/float(len(self.items)) + else: + return np.nan + if self.version=='pt': + if len(self.items) >= pool_size_thresh: + return torch.sum(self.items)/float(len(self.items)) + else: + return torch.from_numpy(np.nan) + + def sample(self, with_replacement=True): + idx = np.random.randint(len(self.items)) + if with_replacement: + return self.items[idx] + else: + return self.items.pop(idx) + + def fetch(self, num=None): + if self.version=='pt': + item_array = torch.stack(self.items) + elif self.version=='np': + item_array = np.stack(self.items) + if num is not None: + # there better be some items + assert(len(self.items) >= num) + + # if there are not that many elements just return however many there are + if len(self.items) < num: + return item_array + else: + idxs = np.random.randint(len(self.items), size=num) + return item_array[idxs] + else: + return item_array + + def is_full(self): + full = len(self.items)==self.pool_size + return full + + def empty(self): + self.items = [] + + def update(self, items): + for item in items: + if len(self.items) < self.pool_size: + # the pool is not full, so let's add this in + self.items.append(item) + else: + # the pool is full + # pop from the front + self.items.pop(0) + # add to the back + self.items.append(item) + return self.items + +def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False): + """ + Input: + xyz: pointcloud data, [B, N, C], where C is probably 3 + npoint: number of samples + Return: + inds: sampled pointcloud index, [B, npoint] + """ + device = xyz.device + B, N, C = xyz.shape + xyz = xyz.float() + inds = torch.zeros(B, npoint, dtype=torch.long).to(device) + distance = torch.ones(B, N).to(device) * 1e10 + if deterministic: + farthest = torch.randint(0, 1, (B,), dtype=torch.long).to(device) + else: + farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) + batch_indices = torch.arange(B, dtype=torch.long).to(device) + for i in range(npoint): + if include_ends: + if i==0: + farthest = 0 + elif i==1: + farthest = N-1 + inds[:, i] = farthest + centroid = xyz[batch_indices, farthest, :].view(B, 1, C) + dist = torch.sum((xyz - centroid) ** 2, -1) + mask = dist < distance + distance[mask] = dist[mask] + farthest = torch.max(distance, -1)[1] + + if npoint > N: + # if we need more samples, make them random + distance += torch.randn_like(distance) + return inds + +def farthest_point_sample_py(xyz, npoint): + N,C = xyz.shape + inds = np.zeros(npoint, dtype=np.int32) + distance = np.ones(N) * 1e10 + farthest = np.random.randint(0, N, dtype=np.int32) + for i in range(npoint): + inds[i] = farthest + centroid = xyz[farthest, :].reshape(1,C) + dist = np.sum((xyz - centroid) ** 2, -1) + mask = dist < distance + distance[mask] = dist[mask] + farthest = np.argmax(distance, -1) + if npoint > N: + # if we need more samples, make them random + distance += np.random.randn(*distance.shape) + return inds + diff --git a/das/spatracker/utils/samp.py b/das/spatracker/utils/samp.py new file mode 100644 index 0000000..3632c9c --- /dev/null +++ b/das/spatracker/utils/samp.py @@ -0,0 +1,152 @@ +import torch +import utils.basic +import torch.nn.functional as F + +def bilinear_sample2d(im, x, y, return_inbounds=False): + # x and y are each B, N + # output is B, C, N + B, C, H, W = list(im.shape) + N = list(x.shape)[1] + + x = x.float() + y = y.float() + H_f = torch.tensor(H, dtype=torch.float32) + W_f = torch.tensor(W, dtype=torch.float32) + + # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte() + y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() + inbounds = (x_valid & y_valid).float() + inbounds = inbounds.reshape(B, N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) + return output, inbounds + + return output # B, C, N + +def paste_crop_on_canvas(crop, box2d_unnorm, H, W, fast=True, mask=None, canvas=None): + # this is the inverse of crop_and_resize_box2d + B, C, Y, X = list(crop.shape) + B2, D = list(box2d_unnorm.shape) + assert(B == B2) + assert(D == 4) + + # here, we want to place the crop into a bigger image, + # at the location specified by the box2d. + + if canvas is None: + canvas = torch.zeros((B, C, H, W), device=crop.device) + else: + B2, C2, H2, W2 = canvas.shape + assert(B==B2) + assert(C==C2) + assert(H==H2) + assert(W==W2) + + # box2d_unnorm = utils.geom.unnormalize_box2d(box2d, H, W) + + if fast: + ymin = box2d_unnorm[:, 0].long() + xmin = box2d_unnorm[:, 1].long() + ymax = box2d_unnorm[:, 2].long() + xmax = box2d_unnorm[:, 3].long() + w = (xmax - xmin).float() + h = (ymax - ymin).float() + + grids = utils.basic.gridcloud2d(B, H, W) + grids_flat = grids.reshape(B, -1, 2) + # grids_flat[:, :, 0] = (grids_flat[:, :, 0] - xmin.float().unsqueeze(1)) / w.unsqueeze(1) * X + # grids_flat[:, :, 1] = (grids_flat[:, :, 1] - ymin.float().unsqueeze(1)) / h.unsqueeze(1) * Y + + # for each pixel in the main image, + # grids_flat tells us where to sample in the crop image + + # print('grids_flat', grids_flat.shape) + # print('crop', crop.shape) + + grids_flat[:, :, 0] = (grids_flat[:, :, 0] - xmin.float().unsqueeze(1)) / w.unsqueeze(1) * 2.0 - 1.0 + grids_flat[:, :, 1] = (grids_flat[:, :, 1] - ymin.float().unsqueeze(1)) / h.unsqueeze(1) * 2.0 - 1.0 + + grid = grids_flat.reshape(B,H,W,2) + + canvas = F.grid_sample(crop, grid, align_corners=False) + # print('canvas', canvas.shape) + + # if mask is None: + # crop_resamp, inb = bilinear_sample2d(crop, grids_flat[:, :, 0], grids_flat[:, :, 1], return_inbounds=True) + # crop_resamp = crop_resamp.reshape(B, C, H, W) + # inb = inb.reshape(B, 1, H, W) + # canvas = canvas * (1 - inb) + crop_resamp * inb + # else: + # full_resamp = bilinear_sample2d(torch.cat([crop, mask], dim=1), grids_flat[:, :, 0], grids_flat[:, :, 1]) + # full_resamp = full_resamp.reshape(B, C+1, H, W) + # crop_resamp = full_resamp[:,:3] + # mask_resamp = full_resamp[:,3:4] + # canvas = canvas * (1 - mask_resamp) + crop_resamp * mask_resamp + else: + for b in range(B): + ymin = box2d_unnorm[b, 0].long() + xmin = box2d_unnorm[b, 1].long() + ymax = box2d_unnorm[b, 2].long() + xmax = box2d_unnorm[b, 3].long() + + crop_b = F.interpolate(crop[b:b + 1], (ymax - ymin, xmax - xmin)).squeeze(0) + + # print('canvas[b,:,...', canvas[b,:,ymin:ymax,xmin:xmax].shape) + # print('crop_b', crop_b.shape) + + canvas[b, :, ymin:ymax, xmin:xmax] = crop_b + return canvas diff --git a/das/spatracker/utils/visualizer.py b/das/spatracker/utils/visualizer.py new file mode 100644 index 0000000..66736be --- /dev/null +++ b/das/spatracker/utils/visualizer.py @@ -0,0 +1,409 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import numpy as np +import cv2 +import torch +import flow_vis + +from matplotlib import cm +import torch.nn.functional as F +import torchvision.transforms as transforms +#from moviepy.editor import ImageSequenceClip +import matplotlib.pyplot as plt +from tqdm import tqdm + +def read_video_from_path(path): + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + print("Error opening video file") + else: + frames = [] + while cap.isOpened(): + ret, frame = cap.read() + if ret == True: + frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) + else: + break + cap.release() + return np.stack(frames) + + +class Visualizer: + def __init__( + self, + save_dir: str = "./results", + grayscale: bool = False, + pad_value: int = 0, + fps: int = 10, + mode: str = "rainbow", # 'cool', 'optical_flow' + linewidth: int = 1, + show_first_frame: int = 10, + tracks_leave_trace: int = 0, # -1 for infinite + ): + self.mode = mode + self.save_dir = save_dir + self.vtxt_path = os.path.join(save_dir, "videos.txt") + self.ttxt_path = os.path.join(save_dir, "trackings.txt") + if mode == "rainbow": + self.color_map = cm.get_cmap("gist_rainbow") + elif mode == "cool": + self.color_map = cm.get_cmap(mode) + self.show_first_frame = show_first_frame + self.grayscale = grayscale + self.tracks_leave_trace = tracks_leave_trace + self.pad_value = pad_value + self.linewidth = linewidth + self.fps = fps + + def visualize( + self, + video: torch.Tensor, # (B,T,C,H,W) + tracks: torch.Tensor, # (B,T,N,2) + visibility: torch.Tensor = None, # (B, T, N, 1) bool + gt_tracks: torch.Tensor = None, # (B,T,N,2) + segm_mask: torch.Tensor = None, # (B,1,H,W) + filename: str = "video", + writer=None, # tensorboard Summary Writer, used for visualization during training + step: int = 0, + query_frame: int = 0, + save_video: bool = True, + compensate_for_camera_motion: bool = False, + rigid_part = None, + video_depth = None # (B,T,C,H,W) + ): + if compensate_for_camera_motion: + assert segm_mask is not None + if segm_mask is not None: + coords = tracks[0, query_frame].round().long() + segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long() + + video = F.pad( + video, + (self.pad_value, self.pad_value, self.pad_value, self.pad_value), + "constant", + 255, + ) + + if video_depth is not None: + video_depth = (video_depth*255).cpu().numpy().astype(np.uint8) + video_depth = ([cv2.applyColorMap(video_depth[0,i,0], cv2.COLORMAP_INFERNO) + for i in range(video_depth.shape[1])]) + video_depth = np.stack(video_depth, axis=0) + video_depth = torch.from_numpy(video_depth).permute(0, 3, 1, 2)[None] + + tracks = tracks + self.pad_value + + if self.grayscale: + transform = transforms.Grayscale() + video = transform(video) + video = video.repeat(1, 1, 3, 1, 1) + + tracking_video = self.draw_tracks_on_video( + video=video, + tracks=tracks, + visibility=visibility, + segm_mask=segm_mask, + gt_tracks=gt_tracks, + query_frame=query_frame, + compensate_for_camera_motion=compensate_for_camera_motion, + rigid_part=rigid_part + ) + + if save_video: + # import ipdb; ipdb.set_trace() + tracking_dir = os.path.join(self.save_dir, "tracking") + if not os.path.exists(tracking_dir): + os.makedirs(tracking_dir) + self.save_video(tracking_video, filename=filename+"_tracking", + savedir=tracking_dir, writer=writer, step=step) + # with open(self.ttxt_path, 'a') as file: + # file.write(f"tracking/{filename}_tracking.mp4\n") + + videos_dir = os.path.join(self.save_dir, "videos") + if not os.path.exists(videos_dir): + os.makedirs(videos_dir) + self.save_video(video, filename=filename, + savedir=videos_dir, writer=writer, step=step) + # with open(self.vtxt_path, 'a') as file: + # file.write(f"videos/{filename}.mp4\n") + if video_depth is not None: + self.save_video(video_depth, filename=filename+"_depth", + savedir=os.path.join(self.save_dir, "depth"), writer=writer, step=step) + return tracking_video + + def save_video(self, video, filename, savedir=None, writer=None, step=0): + if writer is not None: + writer.add_video( + f"{filename}", + video.to(torch.uint8), + global_step=step, + fps=self.fps, + ) + else: + os.makedirs(self.save_dir, exist_ok=True) + wide_list = list(video.unbind(1)) + wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list] + # clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps) + clip = ImageSequenceClip(wide_list, fps=self.fps) + + # Write the video file + if savedir is None: + save_path = os.path.join(self.save_dir, f"{filename}.mp4") + else: + save_path = os.path.join(savedir, f"{filename}.mp4") + clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None) + + print(f"Video saved to {save_path}") + + def draw_tracks_on_video( + self, + video: torch.Tensor, + tracks: torch.Tensor, + visibility: torch.Tensor = None, + segm_mask: torch.Tensor = None, + gt_tracks=None, + query_frame: int = 0, + compensate_for_camera_motion=False, + rigid_part=None, + ): + B, T, C, H, W = video.shape + _, _, N, D = tracks.shape + + assert D == 3 + assert C == 3 + video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C + tracks = tracks[0].detach().cpu().numpy() # S, N, 2 + if gt_tracks is not None: + gt_tracks = gt_tracks[0].detach().cpu().numpy() + + res_video = [] + + # process input video + # for rgb in video: + # res_video.append(rgb.copy()) + + # create a blank tensor with the same shape as the video + for rgb in video: + black_frame = np.zeros_like(rgb.copy(), dtype=rgb.dtype) + res_video.append(black_frame) + + vector_colors = np.zeros((T, N, 3)) + + if self.mode == "optical_flow": + + vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None]) + + elif segm_mask is None: + if self.mode == "rainbow": + x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max() + y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max() + + z_inv = 1/tracks[0, :, 2] + z_min, z_max = np.percentile(z_inv, [2, 98]) + + norm_x = plt.Normalize(x_min, x_max) + norm_y = plt.Normalize(y_min, y_max) + norm_z = plt.Normalize(z_min, z_max) + + for n in range(N): + r = norm_x(tracks[0, n, 0]) + g = norm_y(tracks[0, n, 1]) + # r = 0 + # g = 0 + b = norm_z(1/tracks[0, n, 2]) + color = np.array([r, g, b])[None] * 255 + vector_colors[:, n] = np.repeat(color, T, axis=0) + else: + # color changes with time + for t in range(T): + color = np.array(self.color_map(t / T)[:3])[None] * 255 + vector_colors[t] = np.repeat(color, N, axis=0) + else: + if self.mode == "rainbow": + vector_colors[:, segm_mask <= 0, :] = 255 + + x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max() + y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max() + z_min, z_max = tracks[0, :, 2].min(), tracks[0, :, 2].max() + + norm_x = plt.Normalize(x_min, x_max) + norm_y = plt.Normalize(y_min, y_max) + norm_z = plt.Normalize(z_min, z_max) + + for n in range(N): + r = norm_x(tracks[0, n, 0]) + g = norm_y(tracks[0, n, 1]) + b = norm_z(tracks[0, n, 2]) + color = np.array([r, g, b])[None] * 255 + vector_colors[:, n] = np.repeat(color, T, axis=0) + + else: + # color changes with segm class + segm_mask = segm_mask.cpu() + color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32) + color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0 + color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0 + vector_colors = np.repeat(color[None], T, axis=0) + + # Draw tracks + if self.tracks_leave_trace != 0: + for t in range(1, T): + first_ind = ( + max(0, t - self.tracks_leave_trace) + if self.tracks_leave_trace >= 0 + else 0 + ) + curr_tracks = tracks[first_ind : t + 1] + curr_colors = vector_colors[first_ind : t + 1] + if compensate_for_camera_motion: + diff = ( + tracks[first_ind : t + 1, segm_mask <= 0] + - tracks[t : t + 1, segm_mask <= 0] + ).mean(1)[:, None] + + curr_tracks = curr_tracks - diff + curr_tracks = curr_tracks[:, segm_mask > 0] + curr_colors = curr_colors[:, segm_mask > 0] + + res_video[t] = self._draw_pred_tracks( + res_video[t], + curr_tracks, + curr_colors, + ) + if gt_tracks is not None: + res_video[t] = self._draw_gt_tracks( + res_video[t], gt_tracks[first_ind : t + 1] + ) + + if rigid_part is not None: + cls_label = torch.unique(rigid_part) + cls_num = len(torch.unique(rigid_part)) + # visualize the clustering results + cmap = plt.get_cmap('jet') # get the color mapping + colors = cmap(np.linspace(0, 1, cls_num)) + colors = (colors[:, :3] * 255) + color_map = {lable.item(): color for lable, color in zip(cls_label, colors)} + + # Draw points + for t in tqdm(range(T)): + # Create a list to store information for each point + points_info = [] + for i in range(N): + coord = (tracks[t, i, 0], tracks[t, i, 1]) + depth = tracks[t, i, 2] # assume the third dimension is depth + visibile = True + if visibility is not None: + visibile = visibility[0, t, i] + if coord[0] != 0 and coord[1] != 0: + if not compensate_for_camera_motion or ( + compensate_for_camera_motion and segm_mask[i] > 0 + ): + points_info.append((i, coord, depth, visibile)) + + # Sort points by depth, points with smaller depth (closer) will be drawn later + points_info.sort(key=lambda x: x[2], reverse=True) + + for i, coord, _, visibile in points_info: + if rigid_part is not None: + color = color_map[rigid_part.squeeze()[i].item()] + cv2.circle( + res_video[t], + coord, + int(self.linewidth * 2), + color.tolist(), + thickness=-1 if visibile else 2 + -1, + ) + else: + # Determine rectangle width based on the distance between adjacent tracks in the first frame + if t == 0: + distances = np.linalg.norm(tracks[0] - tracks[0, i], axis=1) + distances = distances[distances > 0] + rect_size = int(np.min(distances))/2 + + # Define coordinates for top-left and bottom-right corners of the rectangle + top_left = (int(coord[0] - rect_size), int(coord[1] - rect_size/1.5)) # Rectangle width is 1.5x (video aspect ratio is 1.5:1) + bottom_right = (int(coord[0] + rect_size), int(coord[1] + rect_size/1.5)) + + # Draw rectangle + cv2.rectangle( + res_video[t], + top_left, + bottom_right, + vector_colors[t, i].tolist(), + thickness=-1 if visibile else 0 + -1, + ) + + # Construct the final rgb sequence + return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte() + + def _draw_pred_tracks( + self, + rgb: np.ndarray, # H x W x 3 + tracks: np.ndarray, # T x 2 + vector_colors: np.ndarray, + alpha: float = 0.5, + ): + T, N, _ = tracks.shape + + for s in range(T - 1): + vector_color = vector_colors[s] + original = rgb.copy() + alpha = (s / T) ** 2 + for i in range(N): + coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1])) + coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1])) + if coord_y[0] != 0 and coord_y[1] != 0: + cv2.line( + rgb, + coord_y, + coord_x, + vector_color[i].tolist(), + self.linewidth, + cv2.LINE_AA, + ) + if self.tracks_leave_trace > 0: + rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0) + return rgb + + def _draw_gt_tracks( + self, + rgb: np.ndarray, # H x W x 3, + gt_tracks: np.ndarray, # T x 2 + ): + T, N, _ = gt_tracks.shape + color = np.array((211.0, 0.0, 0.0)) + + for t in range(T): + for i in range(N): + gt_tracks = gt_tracks[t][i] + # draw a red cross + if gt_tracks[0] > 0 and gt_tracks[1] > 0: + length = self.linewidth * 3 + coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length) + coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length) + cv2.line( + rgb, + coord_y, + coord_x, + color, + self.linewidth, + cv2.LINE_AA, + ) + coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length) + coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length) + cv2.line( + rgb, + coord_y, + coord_x, + color, + self.linewidth, + cv2.LINE_AA, + ) + return rgb diff --git a/das/spatracker/utils/vox.py b/das/spatracker/utils/vox.py new file mode 100644 index 0000000..203097b --- /dev/null +++ b/das/spatracker/utils/vox.py @@ -0,0 +1,500 @@ +import numpy as np +import torch +import torch.nn.functional as F + +import utils.geom + +class Vox_util(object): + def __init__(self, Z, Y, X, scene_centroid, bounds, pad=None, assert_cube=False): + self.XMIN, self.XMAX, self.YMIN, self.YMAX, self.ZMIN, self.ZMAX = bounds + B, D = list(scene_centroid.shape) + self.Z, self.Y, self.X = Z, Y, X + + scene_centroid = scene_centroid.detach().cpu().numpy() + x_centroid, y_centroid, z_centroid = scene_centroid[0] + self.XMIN += x_centroid + self.XMAX += x_centroid + self.YMIN += y_centroid + self.YMAX += y_centroid + self.ZMIN += z_centroid + self.ZMAX += z_centroid + + self.default_vox_size_X = (self.XMAX-self.XMIN)/float(X) + self.default_vox_size_Y = (self.YMAX-self.YMIN)/float(Y) + self.default_vox_size_Z = (self.ZMAX-self.ZMIN)/float(Z) + + if pad: + Z_pad, Y_pad, X_pad = pad + self.ZMIN -= self.default_vox_size_Z * Z_pad + self.ZMAX += self.default_vox_size_Z * Z_pad + self.YMIN -= self.default_vox_size_Y * Y_pad + self.YMAX += self.default_vox_size_Y * Y_pad + self.XMIN -= self.default_vox_size_X * X_pad + self.XMAX += self.default_vox_size_X * X_pad + + if assert_cube: + # we assume cube voxels + if (not np.isclose(self.default_vox_size_X, self.default_vox_size_Y)) or (not np.isclose(self.default_vox_size_X, self.default_vox_size_Z)): + print('Z, Y, X', Z, Y, X) + print('bounds for this iter:', + 'X = %.2f to %.2f' % (self.XMIN, self.XMAX), + 'Y = %.2f to %.2f' % (self.YMIN, self.YMAX), + 'Z = %.2f to %.2f' % (self.ZMIN, self.ZMAX), + ) + print('self.default_vox_size_X', self.default_vox_size_X) + print('self.default_vox_size_Y', self.default_vox_size_Y) + print('self.default_vox_size_Z', self.default_vox_size_Z) + assert(np.isclose(self.default_vox_size_X, self.default_vox_size_Y)) + assert(np.isclose(self.default_vox_size_X, self.default_vox_size_Z)) + + def Ref2Mem(self, xyz, Z, Y, X, assert_cube=False): + # xyz is B x N x 3, in ref coordinates + # transforms ref coordinates into mem coordinates + B, N, C = list(xyz.shape) + device = xyz.device + assert(C==3) + mem_T_ref = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube, device=device) + xyz = utils.geom.apply_4x4(mem_T_ref, xyz) + return xyz + + def Mem2Ref(self, xyz_mem, Z, Y, X, assert_cube=False): + # xyz is B x N x 3, in mem coordinates + # transforms mem coordinates into ref coordinates + B, N, C = list(xyz_mem.shape) + ref_T_mem = self.get_ref_T_mem(B, Z, Y, X, assert_cube=assert_cube, device=xyz_mem.device) + xyz_ref = utils.geom.apply_4x4(ref_T_mem, xyz_mem) + return xyz_ref + + def get_mem_T_ref(self, B, Z, Y, X, assert_cube=False, device='cuda'): + vox_size_X = (self.XMAX-self.XMIN)/float(X) + vox_size_Y = (self.YMAX-self.YMIN)/float(Y) + vox_size_Z = (self.ZMAX-self.ZMIN)/float(Z) + + if assert_cube: + if (not np.isclose(vox_size_X, vox_size_Y)) or (not np.isclose(vox_size_X, vox_size_Z)): + print('Z, Y, X', Z, Y, X) + print('bounds for this iter:', + 'X = %.2f to %.2f' % (self.XMIN, self.XMAX), + 'Y = %.2f to %.2f' % (self.YMIN, self.YMAX), + 'Z = %.2f to %.2f' % (self.ZMIN, self.ZMAX), + ) + print('vox_size_X', vox_size_X) + print('vox_size_Y', vox_size_Y) + print('vox_size_Z', vox_size_Z) + assert(np.isclose(vox_size_X, vox_size_Y)) + assert(np.isclose(vox_size_X, vox_size_Z)) + + # translation + # (this makes the left edge of the leftmost voxel correspond to XMIN) + center_T_ref = utils.geom.eye_4x4(B, device=device) + center_T_ref[:,0,3] = -self.XMIN-vox_size_X/2.0 + center_T_ref[:,1,3] = -self.YMIN-vox_size_Y/2.0 + center_T_ref[:,2,3] = -self.ZMIN-vox_size_Z/2.0 + + # scaling + # (this makes the right edge of the rightmost voxel correspond to XMAX) + mem_T_center = utils.geom.eye_4x4(B, device=device) + mem_T_center[:,0,0] = 1./vox_size_X + mem_T_center[:,1,1] = 1./vox_size_Y + mem_T_center[:,2,2] = 1./vox_size_Z + mem_T_ref = utils.geom.matmul2(mem_T_center, center_T_ref) + + return mem_T_ref + + def get_ref_T_mem(self, B, Z, Y, X, assert_cube=False, device='cuda'): + mem_T_ref = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube, device=device) + # note safe_inverse is inapplicable here, + # since the transform is nonrigid + ref_T_mem = mem_T_ref.inverse() + return ref_T_mem + + def get_inbounds(self, xyz, Z, Y, X, already_mem=False, padding=0.0, assert_cube=False): + # xyz is B x N x 3 + # padding should be 0 unless you are trying to account for some later cropping + if not already_mem: + xyz = self.Ref2Mem(xyz, Z, Y, X, assert_cube=assert_cube) + + x = xyz[:,:,0] + y = xyz[:,:,1] + z = xyz[:,:,2] + + x_valid = ((x-padding)>-0.5).byte() & ((x+padding)-0.5).byte() & ((y+padding)-0.5).byte() & ((z+padding) 0: + # only take points that are already near centers + xyz_round = torch.round(xyz) # B, N, 3 + dist = torch.norm(xyz_round - xyz, dim=2) + mask[dist > clean_eps] = 0 + + # set the invalid guys to zero + # we then need to zero out 0,0,0 + # (this method seems a bit clumsy) + x = x*mask + y = y*mask + z = z*mask + + x = torch.round(x) + y = torch.round(y) + z = torch.round(z) + x = torch.clamp(x, 0, X-1).int() + y = torch.clamp(y, 0, Y-1).int() + z = torch.clamp(z, 0, Z-1).int() + + x = x.view(B*N) + y = y.view(B*N) + z = z.view(B*N) + + dim3 = X + dim2 = X * Y + dim1 = X * Y * Z + + base = torch.arange(0, B, dtype=torch.int32, device=xyz.device)*dim1 + base = torch.reshape(base, [B, 1]).repeat([1, N]).view(B*N) + + vox_inds = base + z * dim2 + y * dim3 + x + voxels = torch.zeros(B*Z*Y*X, device=xyz.device).float() + voxels[vox_inds.long()] = 1.0 + # zero out the singularity + voxels[base.long()] = 0.0 + voxels = voxels.reshape(B, 1, Z, Y, X) + # B x 1 x Z x Y x X + return voxels + + def get_feat_occupancy(self, xyz, feat, Z, Y, X, clean_eps=0, xyz_zero=None): + # xyz is B x N x 3 and in mem coords + # feat is B x N x D + # we want to fill a voxel tensor with 1's at these inds + B, N, C = list(xyz.shape) + B2, N2, D2 = list(feat.shape) + assert(C==3) + assert(B==B2) + assert(N==N2) + + # these papers say simple 1/0 occupancy is ok: + # http://openaccess.thecvf.com/content_cvpr_2018/papers/Yang_PIXOR_Real-Time_3d_CVPR_2018_paper.pdf + # http://openaccess.thecvf.com/content_cvpr_2018/papers/Luo_Fast_and_Furious_CVPR_2018_paper.pdf + # cont fusion says they do 8-neighbor interp + # voxelnet does occupancy but with a bit of randomness in terms of the reflectance value i think + + inbounds = self.get_inbounds(xyz, Z, Y, X, already_mem=True) + x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2] + mask = torch.zeros_like(x) + mask[inbounds] = 1.0 + + if xyz_zero is not None: + # only take points that are beyond a thresh of zero + dist = torch.norm(xyz_zero-xyz, dim=2) + mask[dist < 0.1] = 0 + + if clean_eps > 0: + # only take points that are already near centers + xyz_round = torch.round(xyz) # B, N, 3 + dist = torch.norm(xyz_round - xyz, dim=2) + mask[dist > clean_eps] = 0 + + # set the invalid guys to zero + # we then need to zero out 0,0,0 + # (this method seems a bit clumsy) + x = x*mask # B, N + y = y*mask + z = z*mask + feat = feat*mask.unsqueeze(-1) # B, N, D + + x = torch.round(x) + y = torch.round(y) + z = torch.round(z) + x = torch.clamp(x, 0, X-1).int() + y = torch.clamp(y, 0, Y-1).int() + z = torch.clamp(z, 0, Z-1).int() + + # permute point orders + perm = torch.randperm(N) + x = x[:, perm] + y = y[:, perm] + z = z[:, perm] + feat = feat[:, perm] + + x = x.view(B*N) + y = y.view(B*N) + z = z.view(B*N) + feat = feat.view(B*N, -1) + + dim3 = X + dim2 = X * Y + dim1 = X * Y * Z + + base = torch.arange(0, B, dtype=torch.int32, device=xyz.device)*dim1 + base = torch.reshape(base, [B, 1]).repeat([1, N]).view(B*N) + + vox_inds = base + z * dim2 + y * dim3 + x + feat_voxels = torch.zeros((B*Z*Y*X, D2), device=xyz.device).float() + feat_voxels[vox_inds.long()] = feat + # zero out the singularity + feat_voxels[base.long()] = 0.0 + feat_voxels = feat_voxels.reshape(B, Z, Y, X, D2).permute(0, 4, 1, 2, 3) + # B x C x Z x Y x X + return feat_voxels + + def unproject_image_to_mem(self, rgb_camB, pixB_T_camA, camB_T_camA, Z, Y, X, assert_cube=False, xyz_camA=None): + # rgb_camB is B x C x H x W + # pixB_T_camA is B x 4 x 4 + + # rgb lives in B pixel coords + # we want everything in A memory coords + + # this puts each C-dim pixel in the rgb_camB + # along a ray in the voxelgrid + B, C, H, W = list(rgb_camB.shape) + + if xyz_camA is None: + xyz_memA = utils.basic.gridcloud3d(B, Z, Y, X, norm=False, device=pixB_T_camA.device) + xyz_camA = self.Mem2Ref(xyz_memA, Z, Y, X, assert_cube=assert_cube) + + xyz_camB = utils.geom.apply_4x4(camB_T_camA, xyz_camA) + z = xyz_camB[:,:,2] + + xyz_pixB = utils.geom.apply_4x4(pixB_T_camA, xyz_camA) + normalizer = torch.unsqueeze(xyz_pixB[:,:,2], 2) + EPS=1e-6 + # z = xyz_pixB[:,:,2] + xy_pixB = xyz_pixB[:,:,:2]/torch.clamp(normalizer, min=EPS) + # this is B x N x 2 + # this is the (floating point) pixel coordinate of each voxel + x, y = xy_pixB[:,:,0], xy_pixB[:,:,1] + # these are B x N + + x_valid = (x>-0.5).bool() & (x-0.5).bool() & (y0.0).bool() + valid_mem = (x_valid & y_valid & z_valid).reshape(B, 1, Z, Y, X).float() + + if (0): + # handwritten version + values = torch.zeros([B, C, Z*Y*X], dtype=torch.float32) + for b in list(range(B)): + values[b] = utils.samp.bilinear_sample_single(rgb_camB[b], x_pixB[b], y_pixB[b]) + else: + # native pytorch version + y_pixB, x_pixB = utils.basic.normalize_grid2d(y, x, H, W) + # since we want a 3d output, we need 5d tensors + z_pixB = torch.zeros_like(x) + xyz_pixB = torch.stack([x_pixB, y_pixB, z_pixB], axis=2) + rgb_camB = rgb_camB.unsqueeze(2) + xyz_pixB = torch.reshape(xyz_pixB, [B, Z, Y, X, 3]) + values = F.grid_sample(rgb_camB, xyz_pixB, align_corners=False) + + values = torch.reshape(values, (B, C, Z, Y, X)) + values = values * valid_mem + return values + + def warp_tiled_to_mem(self, rgb_tileB, pixB_T_camA, camB_T_camA, Z, Y, X, DMIN, DMAX, assert_cube=False): + # rgb_tileB is B,C,D,H,W + # pixB_T_camA is B,4,4 + # camB_T_camA is B,4,4 + + # rgb_tileB lives in B pixel coords but it has been tiled across the Z dimension + # we want everything in A memory coords + + # this resamples the so that each C-dim pixel in rgb_tilB + # is put into its correct place in the voxelgrid + # (using the pinhole camera model) + + B, C, D, H, W = list(rgb_tileB.shape) + + xyz_memA = utils.basic.gridcloud3d(B, Z, Y, X, norm=False, device=pixB_T_camA.device) + + xyz_camA = self.Mem2Ref(xyz_memA, Z, Y, X, assert_cube=assert_cube) + + xyz_camB = utils.geom.apply_4x4(camB_T_camA, xyz_camA) + z_camB = xyz_camB[:,:,2] + + # rgb_tileB has depth=DMIN in tile 0, and depth=DMAX in tile D-1 + z_tileB = (D-1.0) * (z_camB-float(DMIN)) / float(DMAX-DMIN) + + xyz_pixB = utils.geom.apply_4x4(pixB_T_camA, xyz_camA) + normalizer = torch.unsqueeze(xyz_pixB[:,:,2], 2) + EPS=1e-6 + # z = xyz_pixB[:,:,2] + xy_pixB = xyz_pixB[:,:,:2]/torch.clamp(normalizer, min=EPS) + # this is B x N x 2 + # this is the (floating point) pixel coordinate of each voxel + x, y = xy_pixB[:,:,0], xy_pixB[:,:,1] + # these are B x N + + x_valid = (x>-0.5).bool() & (x-0.5).bool() & (y0.0).bool() + valid_mem = (x_valid & y_valid & z_valid).reshape(B, 1, Z, Y, X).float() + + z_tileB, y_pixB, x_pixB = utils.basic.normalize_grid3d(z_tileB, y, x, D, H, W) + xyz_pixB = torch.stack([x_pixB, y_pixB, z_tileB], axis=2) + xyz_pixB = torch.reshape(xyz_pixB, [B, Z, Y, X, 3]) + values = F.grid_sample(rgb_tileB, xyz_pixB, align_corners=False) + + values = torch.reshape(values, (B, C, Z, Y, X)) + values = values * valid_mem + return values + + + def apply_mem_T_ref_to_lrtlist(self, lrtlist_cam, Z, Y, X, assert_cube=False): + # lrtlist is B x N x 19, in cam coordinates + # transforms them into mem coordinates, including a scale change for the lengths + B, N, C = list(lrtlist_cam.shape) + assert(C==19) + mem_T_cam = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube, device=lrtlist_cam.device) + + def xyz2circles(self, xyz, radius, Z, Y, X, soft=True, already_mem=True, also_offset=False, grid=None): + # xyz is B x N x 3 + # radius is B x N or broadcastably so + # output is B x N x Z x Y x X + B, N, D = list(xyz.shape) + assert(D==3) + if not already_mem: + xyz = self.Ref2Mem(xyz, Z, Y, X) + + if grid is None: + grid_z, grid_y, grid_x = utils.basic.meshgrid3d(B, Z, Y, X, stack=False, norm=False, device=xyz.device) + # note the default stack is on -1 + grid = torch.stack([grid_x, grid_y, grid_z], dim=1) + # this is B x 3 x Z x Y x X + + xyz = xyz.reshape(B, N, 3, 1, 1, 1) + grid = grid.reshape(B, 1, 3, Z, Y, X) + # this is B x N x Z x Y x X + + # round the xyzs, so that at least one value matches the grid perfectly, + # and we get a value of 1 there (since exp(0)==1) + xyz = xyz.round() + + if torch.is_tensor(radius): + radius = radius.clamp(min=0.01) + + if soft: + off = grid - xyz # B,N,3,Z,Y,X + # interpret radius as sigma + dist_grid = torch.sum(off**2, dim=2, keepdim=False) + # this is B x N x Z x Y x X + if torch.is_tensor(radius): + radius = radius.reshape(B, N, 1, 1, 1) + mask = torch.exp(-dist_grid/(2*radius*radius)) + # zero out near zero + mask[mask < 0.001] = 0.0 + # h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) + # h[h < np.finfo(h.dtype).eps * h.max()] = 0 + # return h + if also_offset: + return mask, off + else: + return mask + else: + assert(False) # something is wrong with this. come back later to debug + + dist_grid = torch.norm(grid - xyz, dim=2, keepdim=False) + # this is 0 at/near the xyz, and increases by 1 for each voxel away + + radius = radius.reshape(B, N, 1, 1, 1) + + within_radius_mask = (dist_grid < radius).float() + within_radius_mask = torch.sum(within_radius_mask, dim=1, keepdim=True).clamp(0, 1) + return within_radius_mask + + def xyz2circles_bev(self, xyz, radius, Z, Y, X, already_mem=True, also_offset=False): + # xyz is B x N x 3 + # radius is B x N or broadcastably so + # output is B x N x Z x Y x X + B, N, D = list(xyz.shape) + assert(D==3) + if not already_mem: + xyz = self.Ref2Mem(xyz, Z, Y, X) + + xz = torch.stack([xyz[:,:,0], xyz[:,:,2]], dim=2) + + grid_z, grid_x = utils.basic.meshgrid2d(B, Z, X, stack=False, norm=False, device=xyz.device) + # note the default stack is on -1 + grid = torch.stack([grid_x, grid_z], dim=1) + # this is B x 2 x Z x X + + xz = xz.reshape(B, N, 2, 1, 1) + grid = grid.reshape(B, 1, 2, Z, X) + # these are ready to broadcast to B x N x Z x X + + # round the points, so that at least one value matches the grid perfectly, + # and we get a value of 1 there (since exp(0)==1) + xz = xz.round() + + if torch.is_tensor(radius): + radius = radius.clamp(min=0.01) + + off = grid - xz # B,N,2,Z,X + # interpret radius as sigma + dist_grid = torch.sum(off**2, dim=2, keepdim=False) + # this is B x N x Z x X + if torch.is_tensor(radius): + radius = radius.reshape(B, N, 1, 1, 1) + mask = torch.exp(-dist_grid/(2*radius*radius)) + # zero out near zero + mask[mask < 0.001] = 0.0 + + # add a Y dim + mask = mask.unsqueeze(-2) + off = off.unsqueeze(-2) + # # B,N,2,Z,1,X + + if also_offset: + return mask, off + else: + return mask + diff --git a/model_loading.py b/model_loading.py index e9c9e4c..236b428 100644 --- a/model_loading.py +++ b/model_loading.py @@ -712,6 +712,7 @@ class CogVideoXModelLoader: device = mm.get_torch_device() offload_device = mm.unet_offload_device() manual_offloading = True + das_transformer = False transformer_load_device = device if load_device == "main_device" else offload_device mm.soft_empty_cache() @@ -719,9 +720,14 @@ class CogVideoXModelLoader: model_path = folder_paths.get_full_path_or_raise("diffusion_models", model) sd = load_torch_file(model_path, device=transformer_load_device) + first_key = next(iter(sd.keys())) model_type = "" - if sd["patch_embed.proj.weight"].shape == (3072, 33, 2, 2): + if first_key == "combine_linears.0.bias": + log.info("Detected 'Diffusion As Shader' model") + model_type = "I2V_5b" + das_transformer = True + elif sd["patch_embed.proj.weight"].shape == (3072, 33, 2, 2): model_type = "fun_5b" elif sd["patch_embed.proj.weight"].shape == (3072, 16, 2, 2): model_type = "5b" @@ -770,7 +776,7 @@ class CogVideoXModelLoader: transformer_config["sample_width"] = 300 with init_empty_weights(): - transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode) + transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode, das_transformer=das_transformer) #load weights #params_to_keep = {} diff --git a/nodes.py b/nodes.py index 338d9f5..96ffac8 100644 --- a/nodes.py +++ b/nodes.py @@ -631,8 +631,9 @@ class CogVideoSampler: "controlnet": ("COGVIDECONTROLNET",), "tora_trajectory": ("TORAFEATURES", ), "fastercache": ("FASTERCACHEARGS", ), - "feta_args": ("FETAARGS", ), - "teacache_args": ("TEACACHEARGS", ), + "feta_args": ("FETAARGS", {"tooltip": "Arguments for Enhance-a-video"} ), + "teacache_args": ("TEACACHEARGS",{"tooltip": "Arguments for TeaCache"} ), + "das_tracking": ("DASTRACKING", {"tooltip": "Enable tracking for Diffusion As Shader"} ), } } @@ -642,17 +643,14 @@ class CogVideoSampler: CATEGORY = "CogVideoWrapper" def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None, - denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None, teacache_args=None): + denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, + das_tracking=None, fastercache=None, feta_args=None, teacache_args=None): mm.unload_all_models() mm.soft_empty_cache() model_name = model.get("model_name", "") - supports_image_conds = True if ( - "I2V" in model_name or - "interpolation" in model_name.lower() or - "fun" in model_name.lower() or - "img2vid" in model_name.lower() - ) else False + supports_image_conds = True if model["pipe"].transformer.config.in_channels == 32 else False + if "fun" in model_name.lower() and not ("pose" in model_name.lower() or "control" in model_name.lower()) and image_cond_latents is not None: assert image_cond_latents["mask"] is not None, "For fun inpaint models use CogVideoImageEncodeFunInP" fun_mask = image_cond_latents["mask"] @@ -771,6 +769,7 @@ class CogVideoSampler: image_cond_start_percent=image_cond_start_percent if image_cond_latents is not None else 0.0, image_cond_end_percent=image_cond_end_percent if image_cond_latents is not None else 1.0, feta_args=feta_args, + das_tracking=das_tracking, ) if not model["cpu_offloading"] and model["manual_offloading"]: pipe.transformer.to(offload_device) @@ -1033,4 +1032,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP", "CogVideoEnhanceAVideo": "CogVideo Enhance-A-Video", "CogVideoXTeaCache": "CogVideoX TeaCache", + "CogVideoDASTrackingEncode": "CogVideo DAS Tracking Encode", } diff --git a/pipeline_cogvideox.py b/pipeline_cogvideox.py index 9ce8529..bf182b4 100644 --- a/pipeline_cogvideox.py +++ b/pipeline_cogvideox.py @@ -383,6 +383,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): image_cond_start_percent: float = 0.0, image_cond_end_percent: float = 1.0, feta_args: Optional[dict] = None, + das_tracking: Optional[dict] = None, ): """ @@ -625,6 +626,24 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): else: disable_enhance() + #das + if das_tracking is not None: + tracking_maps = das_tracking["tracking_maps"] + tracking_image_latents = das_tracking["tracking_image_latents"] + das_start_percent = das_tracking["start_percent"] + das_end_percent = das_tracking["end_percent"] + + padding_shape = ( + batch_size, + (latents.shape[1] - 1), + self.vae_latent_channels, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype) + tracking_image_latents = torch.cat([tracking_image_latents, latent_padding], dim=1) + # reset TeaCache if hasattr(self.transformer, 'accumulated_rel_l1_distance'): delattr(self.transformer, 'accumulated_rel_l1_distance') @@ -805,6 +824,15 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): fun_inpaint_latents = torch.cat([fun_inpaint_mask, fun_inpaint_masked_video_latents], dim=2).to(latents.dtype) latent_model_input = torch.cat([latent_model_input, fun_inpaint_latents], dim=2) + if das_tracking is not None and das_start_percent <= current_step_percentage <= das_end_percent: + logger.info("DAS tracking enabled") + latents_tracking_image = torch.cat([tracking_image_latents] * 2) if do_classifier_free_guidance else tracking_image_latents + tracking_maps_input = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps + tracking_maps_input = torch.cat([tracking_maps_input, latents_tracking_image], dim=2) + del latents_tracking_image + else: + tracking_maps_input = None + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) @@ -836,6 +864,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): controlnet_states=controlnet_states, controlnet_weights=control_weights, video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None, + tracking_maps=tracking_maps_input, )[0] noise_pred = noise_pred.float() if isinstance(self.scheduler, CogVideoXDPMScheduler):