This commit is contained in:
kijai 2025-02-12 00:50:38 +02:00
parent dbc63f622d
commit d3601e3fa3
41 changed files with 10724 additions and 18 deletions

View File

@ -1,7 +1,8 @@
from .nodes import NODE_CLASS_MAPPINGS as NODES_CLASS, NODE_DISPLAY_NAME_MAPPINGS as NODES_DISPLAY
from .model_loading import NODE_CLASS_MAPPINGS as MODEL_CLASS, NODE_DISPLAY_NAME_MAPPINGS as MODEL_DISPLAY
from .das.das_nodes import NODE_CLASS_MAPPINGS as DAS_CLASS, NODE_DISPLAY_NAME_MAPPINGS as DAS_DISPLAY
NODE_CLASS_MAPPINGS = {**NODES_CLASS, **MODEL_CLASS}
NODE_DISPLAY_NAME_MAPPINGS = {**NODES_DISPLAY, **MODEL_DISPLAY}
NODE_CLASS_MAPPINGS = {**NODES_CLASS, **MODEL_CLASS, **DAS_CLASS}
NODE_DISPLAY_NAME_MAPPINGS = {**NODES_DISPLAY, **MODEL_DISPLAY, **DAS_DISPLAY}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]

View File

@ -453,6 +453,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
use_learned_positional_embeddings: bool = False,
patch_bias: bool = True,
attention_mode: Optional[str] = "sdpa",
das_transformer: bool = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
@ -557,6 +558,41 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
else:
#CogVideoX-5B
self.teacache_coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02]
#das
# Create linear layers for combining hidden states and tracking maps
if das_transformer:
num_tracking_blocks = 18
self.combine_linears = nn.ModuleList(
[nn.Linear(inner_dim, inner_dim) for _ in range(num_tracking_blocks)]
)
# Initialize weights of combine_linears to zero
for linear in self.combine_linears:
linear.weight.data.zero_()
linear.bias.data.zero_()
# Create transformer blocks for processing tracking maps
self.transformer_blocks_copy = nn.ModuleList(
[
CogVideoXBlock(
dim=inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
time_embed_dim=self.config.time_embed_dim,
dropout=self.config.dropout,
activation_fn=self.config.activation_fn,
attention_bias=self.config.attention_bias,
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
)
for _ in range(num_tracking_blocks)
]
)
# For initial combination of hidden states and tracking maps
self.initial_combine_linear = nn.Linear(inner_dim, inner_dim)
self.initial_combine_linear.weight.data.zero_()
self.initial_combine_linear.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
@ -573,6 +609,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
controlnet_states: torch.Tensor = None,
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
video_flow_features: Optional[torch.Tensor] = None,
tracking_maps: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
@ -602,13 +639,27 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
#print("hidden_states before patch_embedding", hidden_states.shape) #torch.Size([2, 4, 16, 60, 90])
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
#print("hidden_states after patch_embedding", hidden_states.shape) #1.5: torch.Size([2, 2926, 3072]) #1.0: torch.Size([2, 5626, 3072])
hidden_states = self.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
#print("hidden_states after split", hidden_states.shape) #1.5: torch.Size([2, 2700, 3072]) #1.0: torch.Size([2, 5400, 3072])
if tracking_maps is not None:
# Process tracking maps
prompt_embed = encoder_hidden_states.clone()
tracking_maps_hidden_states = self.patch_embed(prompt_embed, tracking_maps)
tracking_maps_hidden_states = self.embedding_dropout(tracking_maps_hidden_states)
del prompt_embed
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
tracking_maps = tracking_maps_hidden_states[:, text_seq_length:]
# Combine hidden states and tracking maps initially
combined = hidden_states + tracking_maps
tracking_maps = self.initial_combine_linear(combined)
else:
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
if self.use_fastercache:
self.fastercache_counter+=1
@ -706,6 +757,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
if self.use_teacache:
ori_hidden_states = hidden_states.clone()
ori_encoder_hidden_states = encoder_hidden_states.clone()
for i, block in enumerate(self.transformer_blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
@ -731,6 +783,18 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
controlnet_block_weight = controlnet_weights
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
#das
if i < len(self.transformer_blocks_copy) and tracking_maps is not None:
tracking_maps, _ = self.transformer_blocks_copy[i](
hidden_states=tracking_maps,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)
# Combine hidden states and tracking maps
tracking_maps = self.combine_linears[i](tracking_maps)
hidden_states = hidden_states + tracking_maps
if self.use_teacache:
self.previous_residual = hidden_states - ori_hidden_states
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states

146
das/das_nodes.py Normal file
View File

@ -0,0 +1,146 @@
import torch
import comfy.model_management as mm
from ..utils import log
import os
import numpy as np
import folder_paths
class CogVideoDASTrackingEncode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"vae": ("VAE",),
"images": ("IMAGE", ),
},
"optional": {
"enable_tiling": ("BOOLEAN", {"default": True, "tooltip": "Enable tiling for the VAE to reduce memory usage"}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
},
}
RETURN_TYPES = ("DASTRACKING",)
RETURN_NAMES = ("das_tracking",)
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, vae, images, enable_tiling=False, strength=1.0, start_percent=0.0, end_percent=1.0):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
try:
vae.enable_slicing()
except:
pass
vae_scaling_factor = vae.config.scaling_factor
if enable_tiling:
from ..mz_enable_vae_encode_tiling import enable_vae_encode_tiling
enable_vae_encode_tiling(vae)
vae.to(device)
try:
vae._clear_fake_context_parallel_cache()
except:
pass
tracking_maps = images.to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
tracking_first_frame = tracking_maps[:, :, 0:1, :, :]
tracking_first_frame *= 2.0 - 1.0
print("tracking_first_frame shape: ", tracking_first_frame.shape)
tracking_first_frame_latent = vae.encode(tracking_first_frame).latent_dist.sample(generator).permute(0, 2, 1, 3, 4)
tracking_first_frame_latent = tracking_first_frame_latent * vae_scaling_factor * strength
log.info(f"Encoded tracking first frame latents shape: {tracking_first_frame_latent.shape}")
tracking_latents = vae.encode(tracking_maps).latent_dist.sample(generator).permute(0, 2, 1, 3, 4) # B, T, C, H, W
tracking_latents = tracking_latents * vae_scaling_factor * strength
log.info(f"Encoded tracking latents shape: {tracking_latents.shape}")
vae.to(offload_device)
return ({
"tracking_maps": tracking_latents,
"tracking_image_latents": tracking_first_frame_latent,
"start_percent": start_percent,
"end_percent": end_percent
}, )
class DAS_SpaTracker:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"images": ("IMAGE", ),
"depth_images": ("IMAGE", ),
"density": ("INT", {"default": 70, "min": 1, "max": 100, "step": 1}),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("tracking_video",)
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, images, depth_images, density):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
model_path = os.path.join(folder_paths.models_dir, 'spatracker', 'spaT_final.pth')
from .spatracker.predictor import SpaTrackerPredictor
from .spatracker.utils.visualizer import Visualizer
if not hasattr(self, "tracker"):
self.tracker = SpaTrackerPredictor(
checkpoint=model_path,
interp_shape=(384, 576),
seq_length=12
).to(device)
segm_mask = np.ones((480, 720), dtype=np.uint8)
video = images.permute(0, 3, 1, 2).to(device).unsqueeze(0)
video_depth = depth_images.permute(0, 3, 1, 2).to(device)
video_depth = video_depth[:, 0:1, :, :]
pred_tracks, pred_visibility, T_Firsts = self.tracker(
video * 255,
video_depth=video_depth,
grid_size=density,
backward_tracking=False,
depth_predictor=None,
grid_query_frame=0,
segm_mask=torch.from_numpy(segm_mask)[None, None].to(device),
wind_length=12,
progressive_tracking=False
)
vis = Visualizer(grayscale=False, fps=24, pad_value=0)
msk_query = (T_Firsts == 0)
pred_tracks = pred_tracks[:,:,msk_query.squeeze()]
pred_visibility = pred_visibility[:,:,msk_query.squeeze()]
tracking_video = vis.visualize(video=video, tracks=pred_tracks,
visibility=pred_visibility, save_video=False,
filename="temp")
tracking_video = tracking_video.squeeze(0).permute(0, 2, 3, 1) # [T, H, W, C]
tracking_video = (tracking_video / 255.0).float()
return (tracking_video,)
NODE_CLASS_MAPPINGS = {
"CogVideoDASTrackingEncode": CogVideoDASTrackingEncode,
"DAS_SpaTracker": DAS_SpaTracker,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoDASTrackingEncode": "CogVideo DAS Tracking Encode",
"DAS_SpaTracker": "DAS SpaTracker",
}

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .core.spatracker.spatracker import SpaTracker
def build_spatracker(
checkpoint: str,
seq_length: int = 8,
):
model_name = checkpoint.split("/")[-1].split(".")[0]
return build_spatracker_from_cfg(checkpoint=checkpoint, seq_length=seq_length)
# model used to produce the results in the paper
def build_spatracker_from_cfg(checkpoint=None, seq_length=8):
return _build_spatracker(
stride=4,
sequence_len=seq_length,
checkpoint=checkpoint,
)
def _build_spatracker(
stride,
sequence_len,
checkpoint=None,
):
spatracker = SpaTracker(
stride=stride,
S=sequence_len,
add_space_attn=True,
space_depth=6,
time_depth=6,
)
if checkpoint is not None:
with open(checkpoint, "rb") as f:
if "safetensors" in checkpoint:
from safetensors.torch import load_file
state_dict = load_file(checkpoint)
else:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
if "model" in state_dict:
model_paras = spatracker.state_dict()
paras_dict = {k: v for k,v in state_dict["model"].items() if k in spatracker.state_dict()}
model_paras.update(paras_dict)
state_dict = model_paras
spatracker.load_state_dict(state_dict)
return spatracker

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,250 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import numpy as np
def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, tuple):
grid_size_h, grid_size_w = grid_size
else:
grid_size_h = grid_size_w = grid_size
grid_h = np.arange(grid_size_h, dtype=np.float32)
grid_w = np.arange(grid_size_w, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate(
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 3 == 0
# use half of dimensions to encode grid_h
B, S, N, _ = grid.shape
gridx = grid[..., 0].view(B*S*N).detach().cpu().numpy()
gridy = grid[..., 1].view(B*S*N).detach().cpu().numpy()
gridz = grid[..., 2].view(B*S*N).detach().cpu().numpy()
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridx) # (N, D/3)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridy) # (N, D/3)
emb_z = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridz) # (N, D/3)
emb = np.concatenate([emb_h, emb_w, emb_z], axis=1) # (N, D)
emb = torch.from_numpy(emb).to(grid.device)
return emb.view(B, S, N, embed_dim)
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, tuple):
grid_size_h, grid_size_w = grid_size
else:
grid_size_h = grid_size_w = grid_size
grid_h = np.arange(grid_size_h, dtype=np.float32)
grid_w = np.arange(grid_size_w, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate(
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def get_2d_embedding(xy, C, cat_coords=True):
B, N, D = xy.shape
assert D == 2
x = xy[:, :, 0:1]
y = xy[:, :, 1:2]
div_term = (
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
).reshape(1, 1, int(C / 2))
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
pe_x[:, :, 0::2] = torch.sin(x * div_term)
pe_x[:, :, 1::2] = torch.cos(x * div_term)
pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3
if cat_coords:
pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3
return pe
def get_3d_embedding(xyz, C, cat_coords=True):
B, N, D = xyz.shape
assert D == 3
x = xyz[:, :, 0:1]
y = xyz[:, :, 1:2]
z = xyz[:, :, 2:3]
div_term = (
torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)
).reshape(1, 1, int(C / 2))
pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
pe_x[:, :, 0::2] = torch.sin(x * div_term)
pe_x[:, :, 1::2] = torch.cos(x * div_term)
pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe_z[:, :, 0::2] = torch.sin(z * div_term)
pe_z[:, :, 1::2] = torch.cos(z * div_term)
pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
if cat_coords:
pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
return pe
def get_4d_embedding(xyzw, C, cat_coords=True):
B, N, D = xyzw.shape
assert D == 4
x = xyzw[:, :, 0:1]
y = xyzw[:, :, 1:2]
z = xyzw[:, :, 2:3]
w = xyzw[:, :, 3:4]
div_term = (
torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C)
).reshape(1, 1, int(C / 2))
pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_x[:, :, 0::2] = torch.sin(x * div_term)
pe_x[:, :, 1::2] = torch.cos(x * div_term)
pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe_z[:, :, 0::2] = torch.sin(z * div_term)
pe_z[:, :, 1::2] = torch.cos(z * div_term)
pe_w[:, :, 0::2] = torch.sin(w * div_term)
pe_w[:, :, 1::2] = torch.cos(w * div_term)
pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3
if cat_coords:
pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3
return pe
import torch.nn as nn
class Embedder_Fourier(nn.Module):
def __init__(self, input_dim, max_freq_log2, N_freqs,
log_sampling=True, include_input=True,
periodic_fns=(torch.sin, torch.cos)):
'''
:param input_dim: dimension of input to be embedded
:param max_freq_log2: log2 of max freq; min freq is 1 by default
:param N_freqs: number of frequency bands
:param log_sampling: if True, frequency bands are linerly sampled in log-space
:param include_input: if True, raw input is included in the embedding
:param periodic_fns: periodic functions used to embed input
'''
super(Embedder_Fourier, self).__init__()
self.input_dim = input_dim
self.include_input = include_input
self.periodic_fns = periodic_fns
self.out_dim = 0
if self.include_input:
self.out_dim += self.input_dim
self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns)
if log_sampling:
self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
else:
self.freq_bands = torch.linspace(
2. ** 0., 2. ** max_freq_log2, N_freqs)
self.freq_bands = self.freq_bands.numpy().tolist()
def forward(self,
input: torch.Tensor,
rescale: float = 1.0):
'''
:param input: tensor of shape [..., self.input_dim]
:return: tensor of shape [..., self.out_dim]
'''
assert (input.shape[-1] == self.input_dim)
out = []
if self.include_input:
out.append(input/rescale)
for i in range(len(self.freq_bands)):
freq = self.freq_bands[i]
for p_fn in self.periodic_fns:
out.append(p_fn(input * freq))
out = torch.cat(out, dim=-1)
assert (out.shape[-1] == self.out_dim)
return out

View File

@ -0,0 +1,477 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from easydict import EasyDict as edict
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
EPS = 1e-6
def nearest_sample2d(im, x, y, return_inbounds=False):
# x and y are each B, N
# output is B, C, N
if len(im.shape) == 5:
B, N, C, H, W = list(im.shape)
else:
B, C, H, W = list(im.shape)
N = list(x.shape)[1]
x = x.float()
y = y.float()
H_f = torch.tensor(H, dtype=torch.float32)
W_f = torch.tensor(W, dtype=torch.float32)
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
max_y = (H_f - 1).int()
max_x = (W_f - 1).int()
x0 = torch.floor(x).int()
x1 = x0 + 1
y0 = torch.floor(y).int()
y1 = y0 + 1
x0_clip = torch.clamp(x0, 0, max_x)
x1_clip = torch.clamp(x1, 0, max_x)
y0_clip = torch.clamp(y0, 0, max_y)
y1_clip = torch.clamp(y1, 0, max_y)
dim2 = W
dim1 = W * H
base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
base = torch.reshape(base, [B, 1]).repeat([1, N])
base_y0 = base + y0_clip * dim2
base_y1 = base + y1_clip * dim2
idx_y0_x0 = base_y0 + x0_clip
idx_y0_x1 = base_y0 + x1_clip
idx_y1_x0 = base_y1 + x0_clip
idx_y1_x1 = base_y1 + x1_clip
# use the indices to lookup pixels in the flat image
# im is B x C x H x W
# move C out to last dim
if len(im.shape) == 5:
im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
else:
im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
i_y0_x0 = im_flat[idx_y0_x0.long()]
i_y0_x1 = im_flat[idx_y0_x1.long()]
i_y1_x0 = im_flat[idx_y1_x0.long()]
i_y1_x1 = im_flat[idx_y1_x1.long()]
# Finally calculate interpolated values.
x0_f = x0.float()
x1_f = x1.float()
y0_f = y0.float()
y1_f = y1.float()
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
# w_yi_xo is B * N * 1
max_idx = torch.cat([w_y0_x0, w_y0_x1, w_y1_x0, w_y1_x1], dim=-1).max(dim=-1)[1]
output = torch.stack([i_y0_x0, i_y0_x1, i_y1_x0, i_y1_x1], dim=-1).gather(-1, max_idx[...,None,None].repeat(1,1,C,1)).squeeze(-1)
# output is B*N x C
output = output.view(B, -1, C)
output = output.permute(0, 2, 1)
# output is B x C x N
if return_inbounds:
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
inbounds = (x_valid & y_valid).float()
inbounds = inbounds.reshape(
B, N
) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
return output, inbounds
return output # B, C, N
def smart_cat(tensor1, tensor2, dim):
if tensor1 is None:
return tensor2
return torch.cat([tensor1, tensor2], dim=dim)
def normalize_single(d):
# d is a whatever shape torch tensor
dmin = torch.min(d)
dmax = torch.max(d)
d = (d - dmin) / (EPS + (dmax - dmin))
return d
def normalize(d):
# d is B x whatever. normalize within each element of the batch
out = torch.zeros(d.size())
if d.is_cuda:
out = out.cuda()
B = list(d.size())[0]
for b in list(range(B)):
out[b] = normalize_single(d[b])
return out
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
# returns a meshgrid sized B x Y x X
grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
grid_y = torch.reshape(grid_y, [1, Y, 1])
grid_y = grid_y.repeat(B, 1, X)
grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
grid_x = torch.reshape(grid_x, [1, 1, X])
grid_x = grid_x.repeat(B, Y, 1)
if stack:
# note we stack in xy order
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
grid = torch.stack([grid_x, grid_y], dim=-1)
return grid
else:
return grid_y, grid_x
def reduce_masked_mean(x, mask, dim=None, keepdim=False):
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
# returns shape-1
# axis can be a list of axes
for (a, b) in zip(x.size(), mask.size()):
assert a == b # some shape mismatch!
prod = x * mask
if dim is None:
numer = torch.sum(prod)
denom = EPS + torch.sum(mask)
else:
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
mean = numer / denom
return mean
def bilinear_sample2d(im, x, y, return_inbounds=False):
# x and y are each B, N
# output is B, C, N
if len(im.shape) == 5:
B, N, C, H, W = list(im.shape)
else:
B, C, H, W = list(im.shape)
N = list(x.shape)[1]
x = x.float()
y = y.float()
H_f = torch.tensor(H, dtype=torch.float32)
W_f = torch.tensor(W, dtype=torch.float32)
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
max_y = (H_f - 1).int()
max_x = (W_f - 1).int()
x0 = torch.floor(x).int()
x1 = x0 + 1
y0 = torch.floor(y).int()
y1 = y0 + 1
x0_clip = torch.clamp(x0, 0, max_x)
x1_clip = torch.clamp(x1, 0, max_x)
y0_clip = torch.clamp(y0, 0, max_y)
y1_clip = torch.clamp(y1, 0, max_y)
dim2 = W
dim1 = W * H
base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
base = torch.reshape(base, [B, 1]).repeat([1, N])
base_y0 = base + y0_clip * dim2
base_y1 = base + y1_clip * dim2
idx_y0_x0 = base_y0 + x0_clip
idx_y0_x1 = base_y0 + x1_clip
idx_y1_x0 = base_y1 + x0_clip
idx_y1_x1 = base_y1 + x1_clip
# use the indices to lookup pixels in the flat image
# im is B x C x H x W
# move C out to last dim
if len(im.shape) == 5:
im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
else:
im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
i_y0_x0 = im_flat[idx_y0_x0.long()]
i_y0_x1 = im_flat[idx_y0_x1.long()]
i_y1_x0 = im_flat[idx_y1_x0.long()]
i_y1_x1 = im_flat[idx_y1_x1.long()]
# Finally calculate interpolated values.
x0_f = x0.float()
x1_f = x1.float()
y0_f = y0.float()
y1_f = y1.float()
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
output = (
w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
)
# output is B*N x C
output = output.view(B, -1, C)
output = output.permute(0, 2, 1)
# output is B x C x N
if return_inbounds:
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
inbounds = (x_valid & y_valid).float()
inbounds = inbounds.reshape(
B, N
) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
return output, inbounds
return output # B, C, N
def procrustes_analysis(X0,X1,Weight): # [B,N,3]
# translation
t0 = X0.mean(dim=1,keepdim=True)
t1 = X1.mean(dim=1,keepdim=True)
X0c = X0-t0
X1c = X1-t1
# scale
# s0 = (X0c**2).sum(dim=-1).mean().sqrt()
# s1 = (X1c**2).sum(dim=-1).mean().sqrt()
# X0cs = X0c/s0
# X1cs = X1c/s1
# rotation (use double for SVD, float loses precision)
U,_,V = (X0c.t()@X1c).double().svd(some=True)
R = (U@V.t()).float()
if R.det()<0: R[2] *= -1
# align X1 to X0: X1to0 = (X1-t1)/@R.t()+t0
se3 = edict(t0=t0[0],t1=t1[0],R=R)
return se3
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
r"""Sample a tensor using bilinear interpolation
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
coordinates :attr:`coords` using bilinear interpolation. It is the same
as `torch.nn.functional.grid_sample()` but with a different coordinate
convention.
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
:math:`B` is the batch size, :math:`C` is the number of channels,
:math:`H` is the height of the image, and :math:`W` is the width of the
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
that in this case the order of the components is slightly different
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
left-most image pixel :math:`W-1` to the center of the right-most
pixel.
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
the left-most pixel :math:`W` to the right edge of the right-most
pixel.
Similar conventions apply to the :math:`y` for the range
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
:math:`[0,T-1]` and :math:`[0,T]`.
Args:
input (Tensor): batch of input images.
coords (Tensor): batch of coordinates.
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
Returns:
Tensor: sampled points.
"""
sizes = input.shape[2:]
assert len(sizes) in [2, 3]
if len(sizes) == 3:
# t x y -> x y t to match dimensions T H W in grid_sample
coords = coords[..., [1, 2, 0]]
if align_corners:
coords = coords * torch.tensor(
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
)
else:
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
coords -= 1
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
def sample_features4d(input, coords):
r"""Sample spatial features
`sample_features4d(input, coords)` samples the spatial features
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
The field is sampled at coordinates :attr:`coords` using bilinear
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
same convention as :func:`bilinear_sampler` with `align_corners=True`.
The output tensor has one feature per point, and has shape :math:`(B,
R, C)`.
Args:
input (Tensor): spatial features.
coords (Tensor): points.
Returns:
Tensor: sampled features.
"""
B, _, _, _ = input.shape
# B R 2 -> B R 1 2
coords = coords.unsqueeze(2)
# B C R 1
feats = bilinear_sampler(input, coords)
return feats.permute(0, 2, 1, 3).view(
B, -1, feats.shape[1] * feats.shape[3]
) # B C R 1 -> B R C
def sample_features5d(input, coords):
r"""Sample spatio-temporal features
`sample_features5d(input, coords)` works in the same way as
:func:`sample_features4d` but for spatio-temporal features and points:
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
Args:
input (Tensor): spatio-temporal features.
coords (Tensor): spatio-temporal points.
Returns:
Tensor: sampled features.
"""
B, T, _, _, _ = input.shape
# B T C H W -> B C T H W
input = input.permute(0, 2, 1, 3, 4)
# B R1 R2 3 -> B R1 R2 1 3
coords = coords.unsqueeze(3)
# B C R1 R2 1
feats = bilinear_sampler(input, coords)
return feats.permute(0, 2, 3, 1, 4).view(
B, feats.shape[2], feats.shape[3], feats.shape[1]
) # B C R1 R2 1 -> B R1 R2 C
def vis_PCA(fmaps, save_dir):
"""
visualize the PCA of the feature maps
args:
fmaps: feature maps 1 C H W
save_dir: the directory to save the PCA visualization
"""
pca = PCA(n_components=3)
fmap_vis = fmaps[0,...]
fmap_vnorm = (
(fmap_vis-fmap_vis.min())/
(fmap_vis.max()-fmap_vis.min()))
H_vis, W_vis = fmap_vis.shape[1:]
fmap_vnorm = fmap_vnorm.reshape(fmap_vnorm.shape[0],
-1).permute(1,0)
fmap_pca = pca.fit_transform(fmap_vnorm.detach().cpu().numpy())
pca = fmap_pca.reshape(H_vis,W_vis,3)
plt.imsave(save_dir,
(
(pca-pca.min())/
(pca.max()-pca.min())
))
# debug=False
# if debug==True:
# pcd_idx = 60
# vis_PCA(fmapYZ[0,:1], "./yz.png")
# vis_PCA(fmapXZ[0,:1], "./xz.png")
# vis_PCA(fmaps[0,:1], "./xy.png")
# vis_PCA(fmaps[0,-1:], "./xy_.png")
# fxy_q = fxy[0,0,pcd_idx:pcd_idx+1, :, None, None]
# fyz_q = fyz[0,0,pcd_idx:pcd_idx+1, :, None, None]
# fxz_q = fxz[0,0,pcd_idx:pcd_idx+1, :, None, None]
# corr_map = (fxy_q*fmaps[0,-1:]).sum(dim=1)
# corr_map_yz = (fyz_q*fmapYZ[0,-1:]).sum(dim=1)
# corr_map_xz = (fxz_q*fmapXZ[0,-1:]).sum(dim=1)
# coord_last = coords[0,-1,pcd_idx:pcd_idx+1]
# coord_last_neigh = coords[0,-1, self.neigh_indx[pcd_idx]]
# depth_last = depths_dnG[-1,0]
# abs_res = (depth_last-coord_last[-1,-1]).abs()
# abs_res = (abs_res - abs_res.min())/(abs_res.max()-abs_res.min())
# res_dp = torch.exp(-abs_res)
# enhance_corr = res_dp*corr_map
# plt.imsave("./res.png", res_dp.detach().cpu().numpy())
# plt.imsave("./enhance_corr.png", enhance_corr[0].detach().cpu().numpy())
# plt.imsave("./corr_map.png", corr_map[0].detach().cpu().numpy())
# plt.imsave("./corr_map_yz.png", corr_map_yz[0].detach().cpu().numpy())
# plt.imsave("./corr_map_xz.png", corr_map_xz[0].detach().cpu().numpy())
# img_feat = cv2.imread("./xy.png")
# cv2.circle(img_feat, (int(coord_last[0,0]), int(coord_last[0,1])), 2, (0, 0, 255), -1)
# for p_i in coord_last_neigh:
# cv2.circle(img_feat, (int(p_i[0]), int(p_i[1])), 1, (0, 255, 0), -1)
# cv2.imwrite("./xy_coord.png", img_feat)
# import ipdb; ipdb.set_trace()

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,999 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
from einops import rearrange
import collections
from functools import partial
from itertools import repeat
import torchvision.models as tvm
from .vit.encoder import ImageEncoderViT as vitEnc
from .dpt.models import DPTEncoder
from .loftr import LocalFeatureTransformer
# from models.monoD.depth_anything.dpt import DPTHeadEnc, DPTHead
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
to_2tuple = _ntuple(2)
class Mlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class Attention(nn.Module):
def __init__(self, query_dim, context_dim=None,
num_heads=8, dim_head=48, qkv_bias=False, flash=False):
super().__init__()
inner_dim = self.inner_dim = dim_head * num_heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = num_heads
self.flash = flash
self.qkv = nn.Linear(query_dim, inner_dim*3, bias=qkv_bias)
self.proj = nn.Linear(inner_dim, query_dim)
def forward(self, x, context=None, attn_bias=None):
B, N1, _ = x.shape
C = self.inner_dim
h = self.heads
# q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
# k, v = self.to_kv(context).chunk(2, dim=-1)
# context = default(context, x)
qkv = self.qkv(x).reshape(B, N1, 3, h, C // h)
q, k, v = qkv[:,:, 0], qkv[:,:, 1], qkv[:,:, 2]
N2 = x.shape[1]
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
q = q.reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
if self.flash==False:
sim = (q @ k.transpose(-2, -1)) * self.scale
if attn_bias is not None:
sim = sim + attn_bias
attn = sim.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
else:
input_args = [x.half().contiguous() for x in [q, k, v]]
x = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).reshape(B,N1,-1) # type: ignore
# return self.to_out(x.float())
return self.proj(x.float())
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
padding=1,
stride=stride,
padding_mode="zeros",
)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Module):
def __init__(
self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0,
Embed3D=False
):
super(BasicEncoder, self).__init__()
self.stride = stride
self.norm_fn = norm_fn
self.in_planes = 64
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(self.in_planes)
self.norm2 = nn.BatchNorm2d(output_dim * 2)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(self.in_planes)
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(
input_dim,
self.in_planes,
kernel_size=7,
stride=2,
padding=3,
padding_mode="zeros",
)
self.relu1 = nn.ReLU(inplace=True)
self.shallow = False
if self.shallow:
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1)
else:
if Embed3D:
self.conv_fuse = nn.Conv2d(64+63,
self.in_planes, kernel_size=3, padding=1)
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
self.layer4 = self._make_layer(128, stride=2)
self.conv2 = nn.Conv2d(
128 + 128 + 96 + 64,
output_dim * 2,
kernel_size=3,
padding=1,
padding_mode="zeros",
)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out",
nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x, feat_PE=None):
_, _, H, W = x.shape
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
if self.shallow:
a = self.layer1(x)
b = self.layer2(a)
c = self.layer3(b)
a = F.interpolate(
a,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
b = F.interpolate(
b,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
c = F.interpolate(
c,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
x = self.conv2(torch.cat([a, b, c], dim=1))
else:
if feat_PE is not None:
x = self.conv_fuse(torch.cat([x, feat_PE], dim=1))
a = self.layer1(x)
else:
a = self.layer1(x)
b = self.layer2(a)
c = self.layer3(b)
d = self.layer4(c)
a = F.interpolate(
a,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
b = F.interpolate(
b,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
c = F.interpolate(
c,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
d = F.interpolate(
d,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
x = self.conv2(torch.cat([a, b, c, d], dim=1))
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
return x
class VitEncoder(nn.Module):
def __init__(self, input_dim=4, output_dim=128, stride=4):
super(VitEncoder, self).__init__()
self.vit = vitEnc(img_size=512,
depth=6, num_heads=8, in_chans=input_dim,
out_chans=output_dim,embed_dim=384).cuda()
self.stride = stride
def forward(self, x):
T, C, H, W = x.shape
x_resize = F.interpolate(x.view(-1, C, H, W), size=(512, 512),
mode='bilinear', align_corners=False)
x_resize = self.vit(x_resize)
x = F.interpolate(x_resize, size=(H//self.stride, W//self.stride),
mode='bilinear', align_corners=False)
return x
class DPTEnc(nn.Module):
def __init__(self, input_dim=3, output_dim=128, stride=2):
super(DPTEnc, self).__init__()
self.dpt = DPTEncoder()
self.stride = stride
def forward(self, x):
T, C, H, W = x.shape
x = (x-0.5)/0.5
x_resize = F.interpolate(x.view(-1, C, H, W), size=(384, 384),
mode='bilinear', align_corners=False)
x_resize = self.dpt(x_resize)
x = F.interpolate(x_resize, size=(H//self.stride, W//self.stride),
mode='bilinear', align_corners=False)
return x
# class DPT_DINOv2(nn.Module):
# def __init__(self, encoder='vits', features=64, out_channels=[48, 96, 192, 384],
# use_bn=True, use_clstoken=False, localhub=True, stride=2, enc_only=True):
# super(DPT_DINOv2, self).__init__()
# self.stride = stride
# self.enc_only = enc_only
# assert encoder in ['vits', 'vitb', 'vitl']
# if localhub:
# self.pretrained = torch.hub.load('models/torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False)
# else:
# self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder))
# state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vits14_pretrain.pth")
# self.pretrained.load_state_dict(state_dict, strict=True)
# self.pretrained.requires_grad_(False)
# dim = self.pretrained.blocks[0].attn.qkv.in_features
# if enc_only == True:
# out_channels=[128, 128, 128, 128]
# self.DPThead = DPTHeadEnc(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
# def forward(self, x):
# mean_ = torch.tensor([0.485, 0.456, 0.406],
# device=x.device).view(1, 3, 1, 1)
# std_ = torch.tensor([0.229, 0.224, 0.225],
# device=x.device).view(1, 3, 1, 1)
# x = (x+1)/2
# x = (x - mean_)/std_
# h, w = x.shape[-2:]
# h_re, w_re = 560, 560
# x_resize = F.interpolate(x, size=(h_re, w_re),
# mode='bilinear', align_corners=False)
# with torch.no_grad():
# features = self.pretrained.get_intermediate_layers(x_resize, 4, return_class_token=True)
# patch_h, patch_w = h_re // 14, w_re // 14
# feat = self.DPThead(features, patch_h, patch_w, self.enc_only)
# feat = F.interpolate(feat, size=(h//self.stride, w//self.stride), mode="bilinear", align_corners=True)
# return feat
class VGG19(nn.Module):
def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
super().__init__()
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
self.amp = amp
self.amp_dtype = amp_dtype
def forward(self, x, **kwargs):
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
feats = {}
scale = 1
for layer in self.layers:
if isinstance(layer, nn.MaxPool2d):
feats[scale] = x
scale = scale*2
x = layer(x)
return feats
class CNNandDinov2(nn.Module):
def __init__(self, cnn_kwargs = None, amp = True, amp_dtype = torch.float16):
super().__init__()
# in case the Internet connection is not stable, please load the DINOv2 locally
self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
self.cnn = VGG19(**cnn_kwargs)
self.amp = amp
self.amp_dtype = amp_dtype
if self.amp:
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
def train(self, mode: bool = True):
return self.cnn.train(mode)
def forward(self, x, upsample = False):
B,C,H,W = x.shape
feature_pyramid = self.cnn(x)
if not upsample:
with torch.no_grad():
if self.dinov2_vitl14[0].device != x.device:
self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
del dinov2_features_16
feature_pyramid[16] = features_16
return feature_pyramid
class Dinov2(nn.Module):
def __init__(self, amp = True, amp_dtype = torch.float16):
super().__init__()
# in case the Internet connection is not stable, please load the DINOv2 locally
self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
self.amp = amp
self.amp_dtype = amp_dtype
if self.amp:
self.dinov2_vitl14 = self.dinov2_vitl14.to(self.amp_dtype)
def forward(self, x, upsample = False):
B,C,H,W = x.shape
mean_ = torch.tensor([0.485, 0.456, 0.406],
device=x.device).view(1, 3, 1, 1)
std_ = torch.tensor([0.229, 0.224, 0.225],
device=x.device).view(1, 3, 1, 1)
x = (x+1)/2
x = (x - mean_)/std_
h_re, w_re = 560, 560
x_resize = F.interpolate(x, size=(h_re, w_re),
mode='bilinear', align_corners=True)
if not upsample:
with torch.no_grad():
dinov2_features_16 = self.dinov2_vitl14.forward_features(x_resize.to(self.amp_dtype))
features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,h_re//14, w_re//14)
del dinov2_features_16
features_16 = F.interpolate(features_16, size=(H//8, W//8), mode="bilinear", align_corners=True)
return features_16
class AttnBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0,
flash=False, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.flash=flash
self.attn = Attention(
hidden_size, num_heads=num_heads, qkv_bias=True, flash=flash,
**block_kwargs
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class CrossAttnBlock(nn.Module):
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0,
flash=True, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm_context = nn.LayerNorm(hidden_size)
self.cross_attn = Attention(
hidden_size, context_dim=context_dim,
num_heads=num_heads, qkv_bias=True, **block_kwargs, flash=flash
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
def forward(self, x, context):
with autocast():
x = x + self.cross_attn(
self.norm1(x), self.norm_context(context)
)
x = x + self.mlp(self.norm2(x))
return x
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
"""Wrapper for grid_sample, uses pixel coordinates"""
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1, 1], dim=-1)
# go to 0,1 then 0,2 then -1,1
xgrid = 2 * xgrid / (W - 1) - 1
ygrid = 2 * ygrid / (H - 1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
class CorrBlock:
def __init__(self, fmaps, num_levels=4, radius=4, depths_dnG=None):
B, S, C, H_prev, W_prev = fmaps.shape
self.S, self.C, self.H, self.W = S, C, H_prev, W_prev
self.num_levels = num_levels
self.radius = radius
self.fmaps_pyramid = []
self.depth_pyramid = []
self.fmaps_pyramid.append(fmaps)
if depths_dnG is not None:
self.depth_pyramid.append(depths_dnG)
for i in range(self.num_levels - 1):
if depths_dnG is not None:
depths_dnG_ = depths_dnG.reshape(B * S, 1, H_prev, W_prev)
depths_dnG_ = F.avg_pool2d(depths_dnG_, 2, stride=2)
_, _, H, W = depths_dnG_.shape
depths_dnG = depths_dnG_.reshape(B, S, 1, H, W)
self.depth_pyramid.append(depths_dnG)
fmaps_ = fmaps.reshape(B * S, C, H_prev, W_prev)
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
_, _, H, W = fmaps_.shape
fmaps = fmaps_.reshape(B, S, C, H, W)
H_prev = H
W_prev = W
self.fmaps_pyramid.append(fmaps)
def sample(self, coords):
r = self.radius
B, S, N, D = coords.shape
assert D == 2
H, W = self.H, self.W
out_pyramid = []
for i in range(self.num_levels):
corrs = self.corrs_pyramid[i] # B, S, N, H, W
_, _, _, H, W = corrs.shape
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
coords.device
)
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
corrs = corrs.view(B, S, N, -1)
out_pyramid.append(corrs)
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
return out.contiguous().float()
def corr(self, targets):
B, S, N, C = targets.shape
assert C == self.C
assert S == self.S
fmap1 = targets
self.corrs_pyramid = []
for fmaps in self.fmaps_pyramid:
_, _, _, H, W = fmaps.shape
fmap2s = fmaps.view(B, S, C, H * W)
corrs = torch.matmul(fmap1, fmap2s)
corrs = corrs.view(B, S, N, H, W)
corrs = corrs / torch.sqrt(torch.tensor(C).float())
self.corrs_pyramid.append(corrs)
def corr_sample(self, targets, coords, coords_dp=None):
B, S, N, C = targets.shape
r = self.radius
Dim_c = (2*r+1)**2
assert C == self.C
assert S == self.S
out_pyramid = []
out_pyramid_dp = []
for i in range(self.num_levels):
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
coords.device
)
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
fmaps = self.fmaps_pyramid[i]
_, _, _, H, W = fmaps.shape
fmap2s = fmaps.view(B*S, C, H, W)
if len(self.depth_pyramid)>0:
depths_dnG_i = self.depth_pyramid[i]
depths_dnG_i = depths_dnG_i.view(B*S, 1, H, W)
dnG_sample = bilinear_sampler(depths_dnG_i, coords_lvl.view(B*S,1,N*Dim_c,2))
dp_corrs = (dnG_sample.view(B*S,N,-1) - coords_dp[0]).abs()/coords_dp[0]
out_pyramid_dp.append(dp_corrs)
fmap2s_sample = bilinear_sampler(fmap2s, coords_lvl.view(B*S,1,N*Dim_c,2))
fmap2s_sample = fmap2s_sample.permute(0, 3, 1, 2) # B*S, N*Dim_c, C, -1
corrs = torch.matmul(targets.reshape(B*S*N, 1, -1), fmap2s_sample.reshape(B*S*N, Dim_c, -1).permute(0, 2, 1))
corrs = corrs / torch.sqrt(torch.tensor(C).float())
corrs = corrs.view(B, S, N, -1)
out_pyramid.append(corrs)
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
if len(self.depth_pyramid)>0:
out_dp = torch.cat(out_pyramid_dp, dim=-1)
self.fcorrD = out_dp.contiguous().float()
else:
self.fcorrD = torch.zeros_like(out).contiguous().float()
return out.contiguous().float()
class EUpdateFormer(nn.Module):
"""
Transformer model that updates track estimates.
"""
def __init__(
self,
space_depth=12,
time_depth=12,
input_dim=320,
hidden_size=384,
num_heads=8,
output_dim=130,
mlp_ratio=4.0,
vq_depth=3,
add_space_attn=True,
add_time_attn=True,
flash=True
):
super().__init__()
self.out_channels = 2
self.num_heads = num_heads
self.hidden_size = hidden_size
self.add_space_attn = add_space_attn
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
self.flash = flash
self.flow_head = nn.Sequential(
nn.Linear(hidden_size, output_dim, bias=True),
nn.ReLU(inplace=True),
nn.Linear(output_dim, output_dim, bias=True),
nn.ReLU(inplace=True),
nn.Linear(output_dim, output_dim, bias=True)
)
cross_attn_kwargs = {
"d_model": 384,
"nhead": 4,
"layer_names": ['self', 'cross'] * 3,
}
self.gnn = LocalFeatureTransformer(cross_attn_kwargs)
# Attention Modules in the temporal dimension
self.time_blocks = nn.ModuleList(
[
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, flash=flash) if add_time_attn else nn.Identity()
for _ in range(time_depth)
]
)
if add_space_attn:
self.space_blocks = nn.ModuleList(
[
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, flash=flash)
for _ in range(space_depth)
]
)
assert len(self.time_blocks) >= len(self.space_blocks)
# Placeholder for the rigid transformation
self.RigidProj = nn.Linear(self.hidden_size, 128, bias=True)
self.Proj = nn.Linear(self.hidden_size, 128, bias=True)
self.se3_dec = nn.Linear(384, 3, bias=True)
self.initialize_weights()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def forward(self, input_tensor, se3_feature):
""" Updating with Transformer
Args:
input_tensor: B, N, T, C
arap_embed: B, N, T, C
"""
B, N, T, C = input_tensor.shape
x = self.input_transform(input_tensor)
tokens = x
K = 0
j = 0
for i in range(len(self.time_blocks)):
tokens_time = rearrange(tokens, "b n t c -> (b n) t c", b=B, t=T, n=N+K)
tokens_time = self.time_blocks[i](tokens_time)
tokens = rearrange(tokens_time, "(b n) t c -> b n t c ", b=B, t=T, n=N+K)
if self.add_space_attn and (
i % (len(self.time_blocks) // len(self.space_blocks)) == 0
):
tokens_space = rearrange(tokens, "b n t c -> (b t) n c ", b=B, t=T, n=N)
tokens_space = self.space_blocks[j](tokens_space)
tokens = rearrange(tokens_space, "(b t) n c -> b n t c ", b=B, t=T, n=N)
j += 1
B, N, S, _ = tokens.shape
feat0, feat1 = self.gnn(tokens.view(B*N*S, -1)[None,...], se3_feature[None, ...])
so3 = F.tanh(self.se3_dec(feat0.view(B*N*S, -1)[None,...].view(B, N, S, -1))/100)
flow = self.flow_head(feat0.view(B,N,S,-1))
return flow, _, _, feat1, so3
class FusionFormer(nn.Module):
"""
Fuse the feature tracks info with the low rank motion tokens
"""
def __init__(
self,
d_model=64,
nhead=8,
attn_iters=4,
mlp_ratio=4.0,
flash=False,
input_dim=35,
output_dim=384+3,
):
super().__init__()
self.flash = flash
self.in_proj = nn.ModuleList(
[
nn.Linear(input_dim, d_model)
for _ in range(2)
]
)
self.out_proj = nn.Linear(d_model, output_dim, bias=True)
self.time_blocks = nn.ModuleList(
[
CrossAttnBlock(d_model, d_model, nhead, mlp_ratio=mlp_ratio)
for _ in range(attn_iters)
]
)
self.space_blocks = nn.ModuleList(
[
AttnBlock(d_model, nhead, mlp_ratio=mlp_ratio, flash=self.flash)
for _ in range(attn_iters)
]
)
self.initialize_weights()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
self.out_proj.weight.data.fill_(0)
self.out_proj.bias.data.fill_(0)
def forward(self, x, token_cls):
""" Fuse the feature tracks info with the low rank motion tokens
Args:
x: B, S, N, C
Traj_whole: B T N C
"""
B, S, N, C = x.shape
_, T, _, _ = token_cls.shape
x = self.in_proj[0](x)
token_cls = self.in_proj[1](token_cls)
token_cls = rearrange(token_cls, 'b t n c -> (b n) t c')
for i in range(len(self.space_blocks)):
x = rearrange(x, 'b s n c -> (b n) s c')
x = self.time_blocks[i](x, token_cls)
x = self.space_blocks[i](x.permute(1,0,2))
x = rearrange(x, '(b s) n c -> b s n c', b=B, s=S, n=N)
x = self.out_proj(x)
delta_xyz = x[..., :3]
feat_traj = x[..., 3:]
return delta_xyz, feat_traj
class Lie():
"""
Lie algebra for SO(3) and SE(3) operations in PyTorch
"""
def so3_to_SO3(self,w): # [...,3]
wx = self.skew_symmetric(w)
theta = w.norm(dim=-1)[...,None,None]
I = torch.eye(3,device=w.device,dtype=torch.float32)
A = self.taylor_A(theta)
B = self.taylor_B(theta)
R = I+A*wx+B*wx@wx
return R
def SO3_to_so3(self,R,eps=1e-7): # [...,3,3]
trace = R[...,0,0]+R[...,1,1]+R[...,2,2]
theta = ((trace-1)/2).clamp(-1+eps,1-eps).acos_()[...,None,None]%np.pi # ln(R) will explode if theta==pi
lnR = 1/(2*self.taylor_A(theta)+1e-8)*(R-R.transpose(-2,-1)) # FIXME: wei-chiu finds it weird
w0,w1,w2 = lnR[...,2,1],lnR[...,0,2],lnR[...,1,0]
w = torch.stack([w0,w1,w2],dim=-1)
return w
def se3_to_SE3(self,wu): # [...,3]
w,u = wu.split([3,3],dim=-1)
wx = self.skew_symmetric(w)
theta = w.norm(dim=-1)[...,None,None]
I = torch.eye(3,device=w.device,dtype=torch.float32)
A = self.taylor_A(theta)
B = self.taylor_B(theta)
C = self.taylor_C(theta)
R = I+A*wx+B*wx@wx
V = I+B*wx+C*wx@wx
Rt = torch.cat([R,(V@u[...,None])],dim=-1)
return Rt
def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4]
R,t = Rt.split([3,1],dim=-1)
w = self.SO3_to_so3(R)
wx = self.skew_symmetric(w)
theta = w.norm(dim=-1)[...,None,None]
I = torch.eye(3,device=w.device,dtype=torch.float32)
A = self.taylor_A(theta)
B = self.taylor_B(theta)
invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx
u = (invV@t)[...,0]
wu = torch.cat([w,u],dim=-1)
return wu
def skew_symmetric(self,w):
w0,w1,w2 = w.unbind(dim=-1)
O = torch.zeros_like(w0)
wx = torch.stack([torch.stack([O,-w2,w1],dim=-1),
torch.stack([w2,O,-w0],dim=-1),
torch.stack([-w1,w0,O],dim=-1)],dim=-2)
return wx
def taylor_A(self,x,nth=10):
# Taylor expansion of sin(x)/x
ans = torch.zeros_like(x)
denom = 1.
for i in range(nth+1):
if i>0: denom *= (2*i)*(2*i+1)
ans = ans+(-1)**i*x**(2*i)/denom
return ans
def taylor_B(self,x,nth=10):
# Taylor expansion of (1-cos(x))/x**2
ans = torch.zeros_like(x)
denom = 1.
for i in range(nth+1):
denom *= (2*i+1)*(2*i+2)
ans = ans+(-1)**i*x**(2*i)/denom
return ans
def taylor_C(self,x,nth=10):
# Taylor expansion of (x-sin(x))/x**3
ans = torch.zeros_like(x)
denom = 1.
for i in range(nth+1):
denom *= (2*i+2)*(2*i+3)
ans = ans+(-1)**i*x**(2*i)/denom
return ans
def pix2cam(coords,
intr):
"""
Args:
coords: [B, T, N, 3]
intr: [B, T, 3, 3]
"""
coords=coords.detach()
B, S, N, _, = coords.shape
xy_src = coords.reshape(B*S*N, 3)
intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B*S*N, 3, 3)
xy_src = torch.cat([xy_src[..., :2], torch.ones_like(xy_src[..., :1])], dim=-1)
xyz_src = (torch.inverse(intr)@xy_src[...,None])[...,0]
dp_pred = coords[..., 2]
xyz_src_ = (xyz_src*(dp_pred.reshape(S*N, 1)))
xyz_src_ = xyz_src_.reshape(B, S, N, 3)
return xyz_src_
def cam2pix(coords,
intr):
"""
Args:
coords: [B, T, N, 3]
intr: [B, T, 3, 3]
"""
coords=coords.detach()
B, S, N, _, = coords.shape
xy_src = coords.reshape(B*S*N, 3).clone()
intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B*S*N, 3, 3)
xy_src = xy_src / (xy_src[..., 2:]+1e-5)
xyz_src = (intr@xy_src[...,None])[...,0]
dp_pred = coords[..., 2]
xyz_src[...,2] *= dp_pred.reshape(S*N)
xyz_src = xyz_src.reshape(B, S, N, 3)
return xyz_src
def edgeMat(traj3d):
"""
Args:
traj3d: [B, T, N, 3]
"""
B, T, N, _ = traj3d.shape
traj3d = traj3d
traj3d = traj3d.view(B, T, N, 3)
traj3d = traj3d[..., None, :] - traj3d[..., None, :, :] # B, T, N, N, 3
edgeMat = traj3d.norm(dim=-1) # B, T, N, N
return edgeMat

View File

@ -0,0 +1,16 @@
import torch
class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device("cpu"))
if "optimizer" in parameters:
parameters = parameters["model"]
self.load_state_dict(parameters)

View File

@ -0,0 +1,394 @@
import torch
import torch.nn as nn
from .vit import (
_make_pretrained_vitb_rn50_384,
_make_pretrained_vitl16_384,
_make_pretrained_vitb16_384,
forward_vit,
_make_pretrained_vit_tiny
)
def _make_encoder(
backbone,
features,
use_pretrained,
groups=1,
expand=False,
exportable=True,
hooks=None,
use_vit_only=False,
use_readout="ignore",
enable_attention_hooks=False,
):
if backbone == "vitl16_384":
pretrained = _make_pretrained_vitl16_384(
use_pretrained,
hooks=hooks,
use_readout=use_readout,
enable_attention_hooks=enable_attention_hooks,
)
scratch = _make_scratch(
[256, 512, 1024, 1024], features, groups=groups, expand=expand
) # ViT-L/16 - 85.0% Top1 (backbone)
elif backbone == "vitb_rn50_384":
pretrained = _make_pretrained_vitb_rn50_384(
use_pretrained,
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
enable_attention_hooks=enable_attention_hooks,
)
scratch = _make_scratch(
[256, 512, 768, 768], features, groups=groups, expand=expand
) # ViT-H/16 - 85.0% Top1 (backbone)
elif backbone == "vitb16_384":
pretrained = _make_pretrained_vitb16_384(
use_pretrained,
hooks=hooks,
use_readout=use_readout,
enable_attention_hooks=enable_attention_hooks,
)
scratch = _make_scratch(
[96, 192, 384, 768], features, groups=groups, expand=expand
) # ViT-B/16 - 84.6% Top1 (backbone)
elif backbone == "resnext101_wsl":
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
scratch = _make_scratch(
[256, 512, 1024, 2048], features, groups=groups, expand=expand
) # efficientnet_lite3
elif backbone == "vit_tiny_r_s16_p8_384":
pretrained = _make_pretrained_vit_tiny(
use_pretrained,
hooks=hooks,
use_readout=use_readout,
enable_attention_hooks=enable_attention_hooks,
)
scratch = _make_scratch(
[96, 192, 384, 768], features, groups=groups, expand=expand
)
else:
print(f"Backbone '{backbone}' not implemented")
assert False
return pretrained, scratch
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand == True:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(
in_shape[0],
out_shape1,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1],
out_shape2,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2],
out_shape3,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3],
out_shape4,
kernel_size=3,
stride=1,
padding=1,
bias=False,
groups=groups,
)
return scratch
def _make_resnet_backbone(resnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
)
pretrained.layer2 = resnet.layer2
pretrained.layer3 = resnet.layer3
pretrained.layer4 = resnet.layer4
return pretrained
def _make_pretrained_resnext101_wsl(use_pretrained):
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
return _make_resnet_backbone(resnet)
class Interpolate(nn.Module):
"""Interpolation module."""
def __init__(self, scale_factor, mode, align_corners=False):
"""Init.
Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: interpolated data
"""
x = self.interp(
x,
scale_factor=self.scale_factor,
mode=self.mode,
align_corners=self.align_corners,
)
return x
class ResidualConvUnit(nn.Module):
"""Residual convolution module."""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Module):
"""Feature fusion block."""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
output += self.resConfUnit1(xs[1])
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=True
)
return output
class ResidualConvUnit_custom(nn.Module):
"""Residual convolution module."""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups = 1
self.conv1 = nn.Conv2d(
features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=not self.bn,
groups=self.groups,
)
self.conv2 = nn.Conv2d(
features,
features,
kernel_size=3,
stride=1,
padding=1,
bias=not self.bn,
groups=self.groups,
)
if self.bn == True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn == True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn == True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
# return out + x
class FeatureFusionBlock_custom(nn.Module):
"""Feature fusion block."""
def __init__(
self,
features,
activation,
deconv=False,
bn=False,
expand=False,
align_corners=True,
):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_custom, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups = 1
self.expand = expand
out_features = features
if self.expand == True:
out_features = features // 2
self.out_conv = nn.Conv2d(
features,
out_features,
kernel_size=1,
stride=1,
padding=0,
bias=True,
groups=1,
)
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
# output += res
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
)
output = self.out_conv(output)
return output

View File

@ -0,0 +1,77 @@
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn
from .base_model import BaseModel
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
class MidasNet_large(BaseModel):
"""Network for monocular depth estimation."""
def __init__(self, path=None, features=256, non_negative=True):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print("Loading weights: ", path)
super(MidasNet_large, self).__init__()
use_pretrained = False if path is None else True
self.pretrained, self.scratch = _make_encoder(
backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained
)
self.scratch.refinenet4 = FeatureFusionBlock(features)
self.scratch.refinenet3 = FeatureFusionBlock(features)
self.scratch.refinenet2 = FeatureFusionBlock(features)
self.scratch.refinenet1 = FeatureFusionBlock(features)
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear"),
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
)
if path:
self.load(path)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
layer_1 = self.pretrained.layer1(x)
layer_2 = self.pretrained.layer2(layer_1)
layer_3 = self.pretrained.layer3(layer_2)
layer_4 = self.pretrained.layer4(layer_3)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return torch.squeeze(out, dim=1)

View File

@ -0,0 +1,231 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_model import BaseModel
from .blocks import (
FeatureFusionBlock,
FeatureFusionBlock_custom,
Interpolate,
_make_encoder,
forward_vit,
)
def _make_fusion_block(features, use_bn):
return FeatureFusionBlock_custom(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
)
class DPT(BaseModel):
def __init__(
self,
head,
features=256,
backbone="vitb_rn50_384",
readout="project",
channels_last=False,
use_bn=True,
enable_attention_hooks=False,
):
super(DPT, self).__init__()
self.channels_last = channels_last
hooks = {
"vitb_rn50_384": [0, 1, 8, 11],
"vitb16_384": [2, 5, 8, 11],
"vitl16_384": [5, 11, 17, 23],
"vit_tiny_r_s16_p8_384": [0, 1, 2, 3],
}
# Instantiate backbone and reassemble blocks
self.pretrained, self.scratch = _make_encoder(
backbone,
features,
False, # Set to true of you want to train from scratch, uses ImageNet weights
groups=1,
expand=False,
exportable=False,
hooks=hooks[backbone],
use_readout=readout,
enable_attention_hooks=enable_attention_hooks,
)
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
self.scratch.output_conv = head
self.proj_out = nn.Sequential(
nn.Conv2d(
256+512+384+384,
256,
kernel_size=3,
padding=1,
padding_mode="zeros",
),
nn.BatchNorm2d(128 * 2),
nn.ReLU(True),
nn.Conv2d(
128 * 2,
128,
kernel_size=3,
padding=1,
padding_mode="zeros",
)
)
def forward(self, x, only_enc=False):
if self.channels_last == True:
x.contiguous(memory_format=torch.channels_last)
if only_enc:
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
a = (layer_1)
b = (
F.interpolate(
layer_2,
scale_factor=2,
mode="bilinear",
align_corners=True,
)
)
c = (
F.interpolate(
layer_3,
scale_factor=8,
mode="bilinear",
align_corners=True,
)
)
d = (
F.interpolate(
layer_4,
scale_factor=16,
mode="bilinear",
align_corners=True,
)
)
x = self.proj_out(torch.cat([a, b, c, d], dim=1))
return x
else:
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
_,_,H_out,W_out = path_1.size()
path_2_up = F.interpolate(path_2, size=(H_out,W_out), mode="bilinear", align_corners=True)
path_3_up = F.interpolate(path_3, size=(H_out,W_out), mode="bilinear", align_corners=True)
path_4_up = F.interpolate(path_4, size=(H_out,W_out), mode="bilinear", align_corners=True)
out = self.scratch.output_conv(path_1+path_2_up+path_3_up+path_4_up)
return out
class DPTDepthModel(DPT):
def __init__(
self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs
):
features = kwargs["features"] if "features" in kwargs else 256
self.scale = scale
self.shift = shift
self.invert = invert
head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
nn.Identity(),
)
super().__init__(head, **kwargs)
if path is not None:
self.load(path)
def forward(self, x):
inv_depth = super().forward(x).squeeze(dim=1)
if self.invert:
depth = self.scale * inv_depth + self.shift
depth[depth < 1e-8] = 1e-8
depth = 1.0 / depth
return depth
else:
return inv_depth
class DPTEncoder(DPT):
def __init__(
self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs
):
features = kwargs["features"] if "features" in kwargs else 256
self.scale = scale
self.shift = shift
head = nn.Sequential(
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
)
super().__init__(head, **kwargs)
if path is not None:
self.load(path)
def forward(self, x):
features = super().forward(x, only_enc=True).squeeze(dim=1)
return features
class DPTSegmentationModel(DPT):
def __init__(self, num_classes, path=None, **kwargs):
features = kwargs["features"] if "features" in kwargs else 256
kwargs["use_bn"] = True
head = nn.Sequential(
nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(features),
nn.ReLU(True),
nn.Dropout(0.1, False),
nn.Conv2d(features, num_classes, kernel_size=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
)
super().__init__(head, **kwargs)
self.auxlayer = nn.Sequential(
nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(features),
nn.ReLU(True),
nn.Dropout(0.1, False),
nn.Conv2d(features, num_classes, kernel_size=1),
)
if path is not None:
self.load(path)

View File

@ -0,0 +1,231 @@
import numpy as np
import cv2
import math
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
Args:
sample (dict): sample
size (tuple): image size
Returns:
tuple: new size
"""
shape = list(sample["disparity"].shape)
if shape[0] >= size[0] and shape[1] >= size[1]:
return sample
scale = [0, 0]
scale[0] = size[0] / shape[0]
scale[1] = size[1] / shape[1]
scale = max(scale)
shape[0] = math.ceil(scale * shape[0])
shape[1] = math.ceil(scale * shape[1])
# resize
sample["image"] = cv2.resize(
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
)
sample["disparity"] = cv2.resize(
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
tuple(shape[::-1]),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return tuple(shape)
class Resize(object):
"""Resize sample to given size (width, height)."""
def __init__(
self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_AREA,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == "lower_bound":
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "upper_bound":
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "minimal":
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(
f"resize_method {self.__resize_method} not implemented"
)
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(
scale_height * height, min_val=self.__height
)
new_width = self.constrain_to_multiple_of(
scale_width * width, min_val=self.__width
)
elif self.__resize_method == "upper_bound":
new_height = self.constrain_to_multiple_of(
scale_height * height, max_val=self.__height
)
new_width = self.constrain_to_multiple_of(
scale_width * width, max_val=self.__width
)
elif self.__resize_method == "minimal":
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(
sample["image"].shape[1], sample["image"].shape[0]
)
# resize sample
sample["image"] = cv2.resize(
sample["image"],
(width, height),
interpolation=self.__image_interpolation_method,
)
if self.__resize_target:
if "disparity" in sample:
sample["disparity"] = cv2.resize(
sample["disparity"],
(width, height),
interpolation=cv2.INTER_NEAREST,
)
if "depth" in sample:
sample["depth"] = cv2.resize(
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
(width, height),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return sample
class NormalizeImage(object):
"""Normlize image by given mean and std."""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample["image"] = (sample["image"] - self.__mean) / self.__std
return sample
class PrepareForNet(object):
"""Prepare sample for usage as network input."""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample["image"], (2, 0, 1))
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
if "mask" in sample:
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])
if "disparity" in sample:
disparity = sample["disparity"].astype(np.float32)
sample["disparity"] = np.ascontiguousarray(disparity)
if "depth" in sample:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)
return sample

View File

@ -0,0 +1,596 @@
import torch
import torch.nn as nn
import timm
import types
import math
import torch.nn.functional as F
activations = {}
def get_activation(name):
def hook(model, input, output):
activations[name] = output
return hook
attention = {}
def get_attention(name):
def hook(module, input, output):
x = input[0]
B, N, C = x.shape
qkv = (
module.qkv(x)
.reshape(B, N, 3, module.num_heads, C // module.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * module.scale
attn = attn.softmax(dim=-1) # [:,:,1,1:]
attention[name] = attn
return hook
def get_mean_attention_map(attn, token, shape):
attn = attn[:, :, token, 1:]
attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float()
attn = torch.nn.functional.interpolate(
attn, size=shape[2:], mode="bicubic", align_corners=False
).squeeze(0)
all_attn = torch.mean(attn, 0)
return all_attn
class Slice(nn.Module):
def __init__(self, start_index=1):
super(Slice, self).__init__()
self.start_index = start_index
def forward(self, x):
return x[:, self.start_index :]
class AddReadout(nn.Module):
def __init__(self, start_index=1):
super(AddReadout, self).__init__()
self.start_index = start_index
def forward(self, x):
if self.start_index == 2:
readout = (x[:, 0] + x[:, 1]) / 2
else:
readout = x[:, 0]
return x[:, self.start_index :] + readout.unsqueeze(1)
class ProjectReadout(nn.Module):
def __init__(self, in_features, start_index=1):
super(ProjectReadout, self).__init__()
self.start_index = start_index
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
def forward(self, x):
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
features = torch.cat((x[:, self.start_index :], readout), -1)
return self.project(features)
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
x = x.transpose(self.dim0, self.dim1)
return x
def forward_vit(pretrained, x):
b, c, h, w = x.shape
glob = pretrained.model.forward_flex(x)
layer_1 = pretrained.activations["1"]
layer_2 = pretrained.activations["2"]
layer_3 = pretrained.activations["3"]
layer_4 = pretrained.activations["4"]
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
unflatten = nn.Sequential(
nn.Unflatten(
2,
torch.Size(
[
h // pretrained.model.patch_size[1],
w // pretrained.model.patch_size[0],
]
),
)
)
if layer_1.ndim == 3:
layer_1 = unflatten(layer_1)
if layer_2.ndim == 3:
layer_2 = unflatten(layer_2)
if layer_3.ndim == 3:
layer_3 = unflatten(layer_3)
if layer_4.ndim == 3:
layer_4 = unflatten(layer_4)
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
return layer_1, layer_2, layer_3, layer_4
def _resize_pos_embed(self, posemb, gs_h, gs_w):
posemb_tok, posemb_grid = (
posemb[:, : self.start_index],
posemb[0, self.start_index :],
)
gs_old = int(math.sqrt(len(posemb_grid)))
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def forward_flex(self, x):
b, c, h, w = x.shape
pos_embed = self._resize_pos_embed(
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
)
B = x.shape[0]
if hasattr(self.patch_embed, "backbone"):
x = self.patch_embed.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
if getattr(self, "dist_token", None) is not None:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
else:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
def get_readout_oper(vit_features, features, use_readout, start_index=1):
if use_readout == "ignore":
readout_oper = [Slice(start_index)] * len(features)
elif use_readout == "add":
readout_oper = [AddReadout(start_index)] * len(features)
elif use_readout == "project":
readout_oper = [
ProjectReadout(vit_features, start_index) for out_feat in features
]
else:
assert (
False
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
return readout_oper
def _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
size=[384, 384],
hooks=[2, 5, 8, 11],
vit_features=768,
use_readout="ignore",
start_index=1,
enable_attention_hooks=False,
):
pretrained = nn.Module()
pretrained.model = model
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
if enable_attention_hooks:
pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
get_attention("attn_1")
)
pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
get_attention("attn_2")
)
pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
get_attention("attn_3")
)
pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
get_attention("attn_4")
)
pretrained.attention = attention
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
# 32, 48, 136, 384
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=[0, 1, 8, 11],
vit_features=384,
use_vit_only=False,
use_readout="ignore",
start_index=1,
enable_attention_hooks=False,
):
pretrained = nn.Module()
pretrained.model = model
pretrained.model.patch_size = [32, 32]
ps = pretrained.model.patch_size[0]
if use_vit_only == True:
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
else:
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
get_activation("1")
)
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
get_activation("2")
)
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
if enable_attention_hooks:
pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1"))
pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2"))
pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3"))
pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4"))
pretrained.attention = attention
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
if use_vit_only == True:
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
else:
pretrained.act_postprocess1 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess2 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [32, 32]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitb_rn50_384(
pretrained,
use_readout="ignore",
hooks=None,
use_vit_only=False,
enable_attention_hooks=False,
):
# model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
# model = timm.create_model("vit_tiny_r_s16_p8_384", pretrained=pretrained)
model = timm.create_model("vit_small_r26_s32_384", pretrained=pretrained)
hooks = [0, 1, 8, 11] if hooks == None else hooks
return _make_vit_b_rn50_backbone(
model,
features=[128, 256, 384, 384],
size=[384, 384],
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
enable_attention_hooks=enable_attention_hooks,
)
def _make_pretrained_vit_tiny(
pretrained,
use_readout="ignore",
hooks=None,
use_vit_only=False,
enable_attention_hooks=False,
):
# model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
model = timm.create_model("vit_tiny_r_s16_p8_384", pretrained=pretrained)
import ipdb; ipdb.set_trace()
hooks = [0, 1, 8, 11] if hooks == None else hooks
return _make_vit_tiny_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
enable_attention_hooks=enable_attention_hooks,
)
def _make_pretrained_vitl16_384(
pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
):
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
hooks = [5, 11, 17, 23] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[256, 512, 1024, 1024],
hooks=hooks,
vit_features=1024,
use_readout=use_readout,
enable_attention_hooks=enable_attention_hooks,
)
def _make_pretrained_vitb16_384(
pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
):
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
hooks=hooks,
use_readout=use_readout,
enable_attention_hooks=enable_attention_hooks,
)
def _make_pretrained_deitb16_384(
pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
):
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
hooks=hooks,
use_readout=use_readout,
enable_attention_hooks=enable_attention_hooks,
)
def _make_pretrained_deitb16_distil_384(
pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
):
model = timm.create_model(
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
hooks=hooks,
use_readout=use_readout,
start_index=2,
enable_attention_hooks=enable_attention_hooks,
)

View File

@ -0,0 +1,916 @@
"""
Adapted from ConvONet
https://github.com/autonomousvision/convolutional_occupancy_networks/blob/838bea5b2f1314f2edbb68d05ebb0db49f1f3bd2/src/encoder/pointnet.py#L1
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# from torch_scatter import scatter_mean, scatter_max
from .unet import UNet
from ..model_utils import (
vis_PCA
)
from einops import rearrange
import numpy as np
def compute_iou(occ1, occ2):
''' Computes the Intersection over Union (IoU) value for two sets of
occupancy values.
Args:
occ1 (tensor): first set of occupancy values
occ2 (tensor): second set of occupancy values
'''
occ1 = np.asarray(occ1)
occ2 = np.asarray(occ2)
# Put all data in second dimension
# Also works for 1-dimensional data
if occ1.ndim >= 2:
occ1 = occ1.reshape(occ1.shape[0], -1)
if occ2.ndim >= 2:
occ2 = occ2.reshape(occ2.shape[0], -1)
# Convert to boolean values
occ1 = (occ1 >= 0.5)
occ2 = (occ2 >= 0.5)
# Compute IOU
area_union = (occ1 | occ2).astype(np.float32).sum(axis=-1)
area_intersect = (occ1 & occ2).astype(np.float32).sum(axis=-1)
iou = (area_intersect / area_union)
return iou
def chamfer_distance(points1, points2, use_kdtree=True, give_id=False):
''' Returns the chamfer distance for the sets of points.
Args:
points1 (numpy array): first point set
points2 (numpy array): second point set
use_kdtree (bool): whether to use a kdtree
give_id (bool): whether to return the IDs of nearest points
'''
if use_kdtree:
return chamfer_distance_kdtree(points1, points2, give_id=give_id)
else:
return chamfer_distance_naive(points1, points2)
def chamfer_distance_naive(points1, points2):
''' Naive implementation of the Chamfer distance.
Args:
points1 (numpy array): first point set
points2 (numpy array): second point set
'''
assert(points1.size() == points2.size())
batch_size, T, _ = points1.size()
points1 = points1.view(batch_size, T, 1, 3)
points2 = points2.view(batch_size, 1, T, 3)
distances = (points1 - points2).pow(2).sum(-1)
chamfer1 = distances.min(dim=1)[0].mean(dim=1)
chamfer2 = distances.min(dim=2)[0].mean(dim=1)
chamfer = chamfer1 + chamfer2
return chamfer
def chamfer_distance_kdtree(points1, points2, give_id=False):
''' KD-tree based implementation of the Chamfer distance.
Args:
points1 (numpy array): first point set
points2 (numpy array): second point set
give_id (bool): whether to return the IDs of the nearest points
'''
# Points have size batch_size x T x 3
batch_size = points1.size(0)
# First convert points to numpy
points1_np = points1.detach().cpu().numpy()
points2_np = points2.detach().cpu().numpy()
# Get list of nearest neighbors indieces
idx_nn_12, _ = get_nearest_neighbors_indices_batch(points1_np, points2_np)
idx_nn_12 = torch.LongTensor(idx_nn_12).to(points1.device)
# Expands it as batch_size x 1 x 3
idx_nn_12_expand = idx_nn_12.view(batch_size, -1, 1).expand_as(points1)
# Get list of nearest neighbors indieces
idx_nn_21, _ = get_nearest_neighbors_indices_batch(points2_np, points1_np)
idx_nn_21 = torch.LongTensor(idx_nn_21).to(points1.device)
# Expands it as batch_size x T x 3
idx_nn_21_expand = idx_nn_21.view(batch_size, -1, 1).expand_as(points2)
# Compute nearest neighbors in points2 to points in points1
# points_12[i, j, k] = points2[i, idx_nn_12_expand[i, j, k], k]
points_12 = torch.gather(points2, dim=1, index=idx_nn_12_expand)
# Compute nearest neighbors in points1 to points in points2
# points_21[i, j, k] = points2[i, idx_nn_21_expand[i, j, k], k]
points_21 = torch.gather(points1, dim=1, index=idx_nn_21_expand)
# Compute chamfer distance
chamfer1 = (points1 - points_12).pow(2).sum(2).mean(1)
chamfer2 = (points2 - points_21).pow(2).sum(2).mean(1)
# Take sum
chamfer = chamfer1 + chamfer2
# If required, also return nearest neighbors
if give_id:
return chamfer1, chamfer2, idx_nn_12, idx_nn_21
return chamfer
def get_nearest_neighbors_indices_batch(points_src, points_tgt, k=1):
''' Returns the nearest neighbors for point sets batchwise.
Args:
points_src (numpy array): source points
points_tgt (numpy array): target points
k (int): number of nearest neighbors to return
'''
indices = []
distances = []
for (p1, p2) in zip(points_src, points_tgt):
raise NotImplementedError()
# kdtree = KDTree(p2)
dist, idx = kdtree.query(p1, k=k)
indices.append(idx)
distances.append(dist)
return indices, distances
def make_3d_grid(bb_min, bb_max, shape):
''' Makes a 3D grid.
Args:
bb_min (tuple): bounding box minimum
bb_max (tuple): bounding box maximum
shape (tuple): output shape
'''
size = shape[0] * shape[1] * shape[2]
pxs = torch.linspace(bb_min[0], bb_max[0], shape[0])
pys = torch.linspace(bb_min[1], bb_max[1], shape[1])
pzs = torch.linspace(bb_min[2], bb_max[2], shape[2])
pxs = pxs.view(-1, 1, 1).expand(*shape).contiguous().view(size)
pys = pys.view(1, -1, 1).expand(*shape).contiguous().view(size)
pzs = pzs.view(1, 1, -1).expand(*shape).contiguous().view(size)
p = torch.stack([pxs, pys, pzs], dim=1)
return p
def transform_points(points, transform):
''' Transforms points with regard to passed camera information.
Args:
points (tensor): points tensor
transform (tensor): transformation matrices
'''
assert(points.size(2) == 3)
assert(transform.size(1) == 3)
assert(points.size(0) == transform.size(0))
if transform.size(2) == 4:
R = transform[:, :, :3]
t = transform[:, :, 3:]
points_out = points @ R.transpose(1, 2) + t.transpose(1, 2)
elif transform.size(2) == 3:
K = transform
points_out = points @ K.transpose(1, 2)
return points_out
def b_inv(b_mat):
''' Performs batch matrix inversion.
Arguments:
b_mat: the batch of matrices that should be inverted
'''
eye = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat)
b_inv, _ = torch.gesv(eye, b_mat)
return b_inv
def project_to_camera(points, transform):
''' Projects points to the camera plane.
Args:
points (tensor): points tensor
transform (tensor): transformation matrices
'''
p_camera = transform_points(points, transform)
p_camera = p_camera[..., :2] / p_camera[..., 2:]
return p_camera
def fix_Rt_camera(Rt, loc, scale):
''' Fixes Rt camera matrix.
Args:
Rt (tensor): Rt camera matrix
loc (tensor): location
scale (float): scale
'''
# Rt is B x 3 x 4
# loc is B x 3 and scale is B
batch_size = Rt.size(0)
R = Rt[:, :, :3]
t = Rt[:, :, 3:]
scale = scale.view(batch_size, 1, 1)
R_new = R * scale
t_new = t + R @ loc.unsqueeze(2)
Rt_new = torch.cat([R_new, t_new], dim=2)
assert(Rt_new.size() == (batch_size, 3, 4))
return Rt_new
def normalize_coordinate(p, padding=0.1, plane='xz'):
''' Normalize coordinate to [0, 1] for unit cube experiments
Args:
p (tensor): point
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
plane (str): plane feature type, ['xz', 'xy', 'yz']
'''
# breakpoint()
if plane == 'xz':
xy = p[:, :, [0, 2]]
elif plane =='xy':
xy = p[:, :, [0, 1]]
else:
xy = p[:, :, [1, 2]]
xy = torch.clamp(xy, min=1e-6, max=1. - 1e-6)
# xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
# xy_new = xy_new + 0.5 # range (0, 1)
# # f there are outliers out of the range
# if xy_new.max() >= 1:
# xy_new[xy_new >= 1] = 1 - 10e-6
# if xy_new.min() < 0:
# xy_new[xy_new < 0] = 0.0
# xy_new = (xy + 1.) / 2.
return xy
def normalize_3d_coordinate(p, padding=0.1):
''' Normalize coordinate to [0, 1] for unit cube experiments.
Corresponds to our 3D model
Args:
p (tensor): point
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
'''
p_nor = p / (1 + padding + 10e-4) # (-0.5, 0.5)
p_nor = p_nor + 0.5 # range (0, 1)
# f there are outliers out of the range
if p_nor.max() >= 1:
p_nor[p_nor >= 1] = 1 - 10e-4
if p_nor.min() < 0:
p_nor[p_nor < 0] = 0.0
return p_nor
def normalize_coord(p, vol_range, plane='xz'):
''' Normalize coordinate to [0, 1] for sliding-window experiments
Args:
p (tensor): point
vol_range (numpy array): volume boundary
plane (str): feature type, ['xz', 'xy', 'yz'] - canonical planes; ['grid'] - grid volume
'''
p[:, 0] = (p[:, 0] - vol_range[0][0]) / (vol_range[1][0] - vol_range[0][0])
p[:, 1] = (p[:, 1] - vol_range[0][1]) / (vol_range[1][1] - vol_range[0][1])
p[:, 2] = (p[:, 2] - vol_range[0][2]) / (vol_range[1][2] - vol_range[0][2])
if plane == 'xz':
x = p[:, [0, 2]]
elif plane =='xy':
x = p[:, [0, 1]]
elif plane =='yz':
x = p[:, [1, 2]]
else:
x = p
return x
def coordinate2index(x, reso, coord_type='2d'):
''' Normalize coordinate to [0, 1] for unit cube experiments.
Corresponds to our 3D model
Args:
x (tensor): coordinate
reso (int): defined resolution
coord_type (str): coordinate type
'''
x = (x * reso).long()
if coord_type == '2d': # plane
index = x[:, :, 0] + reso * x[:, :, 1]
elif coord_type == '3d': # grid
index = x[:, :, 0] + reso * (x[:, :, 1] + reso * x[:, :, 2])
index = index[:, None, :]
return index
def coord2index(p, vol_range, reso=None, plane='xz'):
''' Normalize coordinate to [0, 1] for sliding-window experiments.
Corresponds to our 3D model
Args:
p (tensor): points
vol_range (numpy array): volume boundary
reso (int): defined resolution
plane (str): feature type, ['xz', 'xy', 'yz'] - canonical planes; ['grid'] - grid volume
'''
# normalize to [0, 1]
x = normalize_coord(p, vol_range, plane=plane)
if isinstance(x, np.ndarray):
x = np.floor(x * reso).astype(int)
else: #* pytorch tensor
x = (x * reso).long()
if x.shape[1] == 2:
index = x[:, 0] + reso * x[:, 1]
index[index > reso**2] = reso**2
elif x.shape[1] == 3:
index = x[:, 0] + reso * (x[:, 1] + reso * x[:, 2])
index[index > reso**3] = reso**3
return index[None]
def update_reso(reso, depth):
''' Update the defined resolution so that UNet can process.
Args:
reso (int): defined resolution
depth (int): U-Net number of layers
'''
base = 2**(int(depth) - 1)
if ~(reso / base).is_integer(): # when this is not integer, U-Net dimension error
for i in range(base):
if ((reso + i) / base).is_integer():
reso = reso + i
break
return reso
def decide_total_volume_range(query_vol_metric, recep_field, unit_size, unet_depth):
''' Update the defined resolution so that UNet can process.
Args:
query_vol_metric (numpy array): query volume size
recep_field (int): defined the receptive field for U-Net
unit_size (float): the defined voxel size
unet_depth (int): U-Net number of layers
'''
reso = query_vol_metric / unit_size + recep_field - 1
reso = update_reso(int(reso), unet_depth) # make sure input reso can be processed by UNet
input_vol_metric = reso * unit_size
p_c = np.array([0.0, 0.0, 0.0]).astype(np.float32)
lb_input_vol, ub_input_vol = p_c - input_vol_metric/2, p_c + input_vol_metric/2
lb_query_vol, ub_query_vol = p_c - query_vol_metric/2, p_c + query_vol_metric/2
input_vol = [lb_input_vol, ub_input_vol]
query_vol = [lb_query_vol, ub_query_vol]
# handle the case when resolution is too large
if reso > 10000:
reso = 1
return input_vol, query_vol, reso
def add_key(base, new, base_name, new_name, device=None):
''' Add new keys to the given input
Args:
base (tensor): inputs
new (tensor): new info for the inputs
base_name (str): name for the input
new_name (str): name for the new info
device (device): pytorch device
'''
if (new is not None) and (isinstance(new, dict)):
if device is not None:
for key in new.keys():
new[key] = new[key].to(device)
base = {base_name: base,
new_name: new}
return base
class map2local(object):
''' Add new keys to the given input
Args:
s (float): the defined voxel size
pos_encoding (str): method for the positional encoding, linear|sin_cos
'''
def __init__(self, s, pos_encoding='linear'):
super().__init__()
self.s = s
self.pe = positional_encoding(basis_function=pos_encoding)
def __call__(self, p):
p = torch.remainder(p, self.s) / self.s # always possitive
# p = torch.fmod(p, self.s) / self.s # same sign as input p!
p = self.pe(p)
return p
class positional_encoding(object):
''' Positional Encoding (presented in NeRF)
Args:
basis_function (str): basis function
'''
def __init__(self, basis_function='sin_cos'):
super().__init__()
self.func = basis_function
L = 10
freq_bands = 2.**(np.linspace(0, L-1, L))
self.freq_bands = freq_bands * math.pi
def __call__(self, p):
if self.func == 'sin_cos':
out = []
p = 2.0 * p - 1.0 # chagne to the range [-1, 1]
for freq in self.freq_bands:
out.append(torch.sin(freq * p))
out.append(torch.cos(freq * p))
p = torch.cat(out, dim=2)
return p
# Resnet Blocks
class ResnetBlockFC(nn.Module):
''' Fully connected ResNet Block class.
Args:
size_in (int): input dimension
size_out (int): output dimension
size_h (int): hidden dimension
'''
def __init__(self, size_in, size_out=None, size_h=None):
super().__init__()
# Attributes
if size_out is None:
size_out = size_in
if size_h is None:
size_h = min(size_in, size_out)
self.size_in = size_in
self.size_h = size_h
self.size_out = size_out
# Submodules
self.fc_0 = nn.Linear(size_in, size_h)
self.fc_1 = nn.Linear(size_h, size_out)
self.actvn = nn.ReLU()
if size_in == size_out:
self.shortcut = None
else:
self.shortcut = nn.Linear(size_in, size_out, bias=False)
# Initialization
nn.init.zeros_(self.fc_1.weight)
def forward(self, x):
net = self.fc_0(self.actvn(x))
dx = self.fc_1(self.actvn(net))
if self.shortcut is not None:
x_s = self.shortcut(x)
else:
x_s = x
return x_s + dx
'''
------------------ the key model for Pointnet ----------------------------
'''
class LocalSoftSplat(nn.Module):
def __init__(self, ch=128, dim=3, hidden_dim=128, scatter_type='max',
unet=True, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
hw=None, grid_resolution=None, plane_type='xz', padding=0.1,
n_blocks=4, splat_func=None):
super().__init__()
c_dim = ch
self.c_dim = c_dim
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
self.blocks = nn.ModuleList([
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
])
self.fc_c = nn.Linear(hidden_dim, c_dim)
self.actvn = nn.ReLU()
self.hidden_dim = hidden_dim
if unet:
self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
else:
self.unet = None
# get splat func
self.splat_func = splat_func
def forward(self, img_feat,
Fxy2xz, Fxy2yz, Dz, gridxy=None):
"""
Args:
img_feat (tensor): image features
Fxy2xz (tensor): transformation matrix from xy to xz
Fxy2yz (tensor): transformation matrix from xy to yz
"""
B, T, _, H, W = img_feat.shape
fea_reshp = rearrange(img_feat, 'b t c h w -> (b h w) t c',
c=img_feat.shape[2], h=H, w=W)
gridyz = gridxy + Fxy2yz
gridxz = gridxy + Fxy2xz
# normalize
gridyz[:, 0, ...] = (gridyz[:, 0, ...] / (H - 1) - 0.5) * 2
gridyz[:, 1, ...] = (gridyz[:, 1, ...] / (Dz - 1) - 0.5) * 2
gridxz[:, 0, ...] = (gridxz[:, 0, ...] / (W - 1) - 0.5) * 2
gridxz[:, 1, ...] = (gridxz[:, 1, ...] / (Dz - 1) - 0.5) * 2
if len(self.blocks) > 0:
net = self.fc_pos(fea_reshp)
net = self.blocks[0](net)
for block in self.blocks[1:]:
# splat and fusion
net_plane = rearrange(net, '(b h w) t c -> (b t) c h w', b=B, h=H, w=W)
net_planeYZ = self.splat_func(net_plane, Fxy2yz, None,
strMode="avg", tenoutH=Dz, tenoutW=H)
net_planeXZ = self.splat_func(net_plane, Fxy2xz, None,
strMode="avg", tenoutH=Dz, tenoutW=W)
net_plane = net_plane + (
F.grid_sample(
net_planeYZ, gridyz.permute(0,2,3,1), mode='bilinear', padding_mode='border') +
F.grid_sample(
net_planeXZ, gridxz.permute(0,2,3,1), mode='bilinear', padding_mode='border')
)
pooled = rearrange(net_plane, 't c h w -> (h w) t c',
c=net_plane.shape[1], h=H, w=W)
net = torch.cat([net, pooled], dim=2)
net = block(net)
c = self.fc_c(net)
net_plane = rearrange(c, '(b h w) t c -> (b t) c h w', b=B, h=H, w=W)
else:
net_plane = rearrange(img_feat, 'b t c h w -> (b t) c h w',
c=img_feat.shape[2], h=H, w=W)
net_planeYZ = self.splat_func(net_plane, Fxy2yz, None,
strMode="avg", tenoutH=Dz, tenoutW=H)
net_planeXZ = self.splat_func(net_plane, Fxy2xz, None,
strMode="avg", tenoutH=Dz, tenoutW=W)
return net_plane[None], net_planeYZ[None], net_planeXZ[None]
class LocalPoolPointnet(nn.Module):
''' PointNet-based encoder network with ResNet blocks for each point.
Number of input points are fixed.
Args:
c_dim (int): dimension of latent code c
dim (int): input points dimension
hidden_dim (int): hidden dimension of the network
scatter_type (str): feature aggregation when doing local pooling
unet (bool): weather to use U-Net
unet_kwargs (str): U-Net parameters
unet3d (bool): weather to use 3D U-Net
unet3d_kwargs (str): 3D U-Net parameters
plane_resolution (int): defined resolution for plane feature
grid_resolution (int): defined resolution for grid feature
plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
n_blocks (int): number of blocks ResNetBlockFC layers
'''
def __init__(self, ch=128, dim=3, hidden_dim=128, scatter_type='max',
unet=True, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
hw=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5):
super().__init__()
c_dim = ch
unet3d = False
plane_type = ['xy', 'xz', 'yz']
plane_resolution = hw
self.c_dim = c_dim
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
self.blocks = nn.ModuleList([
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
])
self.fc_c = nn.Linear(hidden_dim, c_dim)
self.actvn = nn.ReLU()
self.hidden_dim = hidden_dim
if unet:
self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
else:
self.unet = None
if unet3d:
# self.unet3d = UNet3D(**unet3d_kwargs)
raise NotImplementedError()
else:
self.unet3d = None
self.reso_plane = plane_resolution
self.reso_grid = grid_resolution
self.plane_type = plane_type
self.padding = padding
if scatter_type == 'max':
self.scatter = scatter_max
elif scatter_type == 'mean':
self.scatter = scatter_mean
else:
raise ValueError('incorrect scatter type')
def generate_plane_features(self, p, c, plane='xz'):
# acquire indices of features in plane
xy = normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
index = coordinate2index(xy, self.reso_plane)
# scatter plane features from points
fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
c = c.permute(0, 2, 1) # B x 512 x T
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)
# process the plane features with UNet
if self.unet is not None:
fea_plane = self.unet(fea_plane)
return fea_plane
def generate_grid_features(self, p, c):
p_nor = normalize_3d_coordinate(p.clone(), padding=self.padding)
index = coordinate2index(p_nor, self.reso_grid, coord_type='3d')
# scatter grid features from points
fea_grid = c.new_zeros(p.size(0), self.c_dim, self.reso_grid**3)
c = c.permute(0, 2, 1)
fea_grid = scatter_mean(c, index, out=fea_grid) # B x C x reso^3
fea_grid = fea_grid.reshape(p.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid) # sparce matrix (B x 512 x reso x reso)
if self.unet3d is not None:
fea_grid = self.unet3d(fea_grid)
return fea_grid
def pool_local(self, xy, index, c):
bs, fea_dim = c.size(0), c.size(2)
keys = xy.keys()
c_out = 0
for key in keys:
# scatter plane features from points
if key == 'grid':
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_grid**3)
else:
c_permute = c.permute(0, 2, 1)
fea = self.scatter(c_permute, index[key], dim_size=self.reso_plane**2)
if self.scatter == scatter_max:
fea = fea[0]
# gather feature back to points
fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
c_out = c_out + fea
return c_out.permute(0, 2, 1)
def forward(self, p_input, img_feats=None):
"""
Args:
p_input (tensor): input points T 3 H W
img_feats (tensor): image features T C H W
"""
T, _, H, W = img_feats.size()
p = rearrange(p_input, 't c h w -> (h w) t c', c=3, h=H, w=W)
fea_reshp = rearrange(img_feats, 't c h w -> (h w) t c',
c=img_feats.shape[1], h=H, w=W)
# acquire the index for each point
coord = {}
index = {}
if 'xz' in self.plane_type:
coord['xz'] = normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
index['xz'] = coordinate2index(coord['xz'], self.reso_plane)
if 'xy' in self.plane_type:
coord['xy'] = normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
index['xy'] = coordinate2index(coord['xy'], self.reso_plane)
if 'yz' in self.plane_type:
coord['yz'] = normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
index['yz'] = coordinate2index(coord['yz'], self.reso_plane)
if 'grid' in self.plane_type:
coord['grid'] = normalize_3d_coordinate(p.clone(), padding=self.padding)
index['grid'] = coordinate2index(coord['grid'], self.reso_grid, coord_type='3d')
net = self.fc_pos(p) + fea_reshp
net = self.blocks[0](net)
for block in self.blocks[1:]:
pooled = self.pool_local(coord, index, net)
net = torch.cat([net, pooled], dim=2)
net = block(net)
c = self.fc_c(net)
fea = {}
if 'grid' in self.plane_type:
fea['grid'] = self.generate_grid_features(p, c)
if 'xz' in self.plane_type:
fea['xz'] = self.generate_plane_features(p, c, plane='xz')
if 'xy' in self.plane_type:
fea['xy'] = self.generate_plane_features(p, c, plane='xy')
if 'yz' in self.plane_type:
fea['yz'] = self.generate_plane_features(p, c, plane='yz')
ret = torch.stack([fea['xy'], fea['xz'], fea['yz']]).permute((1, 0, 2, 3, 4))
return ret
class PatchLocalPoolPointnet(nn.Module):
''' PointNet-based encoder network with ResNet blocks.
First transform input points to local system based on the given voxel size.
Support non-fixed number of point cloud, but need to precompute the index
Args:
c_dim (int): dimension of latent code c
dim (int): input points dimension
hidden_dim (int): hidden dimension of the network
scatter_type (str): feature aggregation when doing local pooling
unet (bool): weather to use U-Net
unet_kwargs (str): U-Net parameters
unet3d (bool): weather to use 3D U-Net
unet3d_kwargs (str): 3D U-Net parameters
plane_resolution (int): defined resolution for plane feature
grid_resolution (int): defined resolution for grid feature
plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
n_blocks (int): number of blocks ResNetBlockFC layers
local_coord (bool): whether to use local coordinate
pos_encoding (str): method for the positional encoding, linear|sin_cos
unit_size (float): defined voxel unit size for local system
'''
def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max',
unet=False, unet_kwargs=None, unet3d=False, unet3d_kwargs=None,
plane_resolution=None, grid_resolution=None, plane_type='xz', padding=0.1, n_blocks=5,
local_coord=False, pos_encoding='linear', unit_size=0.1):
super().__init__()
self.c_dim = c_dim
self.blocks = nn.ModuleList([
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
])
self.fc_c = nn.Linear(hidden_dim, c_dim)
self.actvn = nn.ReLU()
self.hidden_dim = hidden_dim
self.reso_plane = plane_resolution
self.reso_grid = grid_resolution
self.plane_type = plane_type
self.padding = padding
if unet:
self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
else:
self.unet = None
if unet3d:
# self.unet3d = UNet3D(**unet3d_kwargs)
raise NotImplementedError()
else:
self.unet3d = None
if scatter_type == 'max':
self.scatter = scatter_max
elif scatter_type == 'mean':
self.scatter = scatter_mean
else:
raise ValueError('incorrect scatter type')
if local_coord:
self.map2local = map2local(unit_size, pos_encoding=pos_encoding)
else:
self.map2local = None
if pos_encoding == 'sin_cos':
self.fc_pos = nn.Linear(60, 2*hidden_dim)
else:
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
def generate_plane_features(self, index, c):
c = c.permute(0, 2, 1)
# scatter plane features from points
if index.max() < self.reso_plane**2:
fea_plane = c.new_zeros(c.size(0), self.c_dim, self.reso_plane**2)
fea_plane = scatter_mean(c, index, out=fea_plane) # B x c_dim x reso^2
else:
fea_plane = scatter_mean(c, index) # B x c_dim x reso^2
if fea_plane.shape[-1] > self.reso_plane**2: # deal with outliers
fea_plane = fea_plane[:, :, :-1]
fea_plane = fea_plane.reshape(c.size(0), self.c_dim, self.reso_plane, self.reso_plane)
# process the plane features with UNet
if self.unet is not None:
fea_plane = self.unet(fea_plane)
return fea_plane
def generate_grid_features(self, index, c):
# scatter grid features from points
c = c.permute(0, 2, 1)
if index.max() < self.reso_grid**3:
fea_grid = c.new_zeros(c.size(0), self.c_dim, self.reso_grid**3)
fea_grid = scatter_mean(c, index, out=fea_grid) # B x c_dim x reso^3
else:
fea_grid = scatter_mean(c, index) # B x c_dim x reso^3
if fea_grid.shape[-1] > self.reso_grid**3: # deal with outliers
fea_grid = fea_grid[:, :, :-1]
fea_grid = fea_grid.reshape(c.size(0), self.c_dim, self.reso_grid, self.reso_grid, self.reso_grid)
if self.unet3d is not None:
fea_grid = self.unet3d(fea_grid)
return fea_grid
def pool_local(self, index, c):
bs, fea_dim = c.size(0), c.size(2)
keys = index.keys()
c_out = 0
for key in keys:
# scatter plane features from points
if key == 'grid':
fea = self.scatter(c.permute(0, 2, 1), index[key])
else:
fea = self.scatter(c.permute(0, 2, 1), index[key])
if self.scatter == scatter_max:
fea = fea[0]
# gather feature back to points
fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
c_out += fea
return c_out.permute(0, 2, 1)
def forward(self, inputs):
p = inputs['points']
index = inputs['index']
batch_size, T, D = p.size()
if self.map2local:
pp = self.map2local(p)
net = self.fc_pos(pp)
else:
net = self.fc_pos(p)
net = self.blocks[0](net)
for block in self.blocks[1:]:
pooled = self.pool_local(index, net)
net = torch.cat([net, pooled], dim=2)
net = block(net)
c = self.fc_c(net)
fea = {}
if 'grid' in self.plane_type:
fea['grid'] = self.generate_grid_features(index['grid'], c)
if 'xz' in self.plane_type:
fea['xz'] = self.generate_plane_features(index['xz'], c)
if 'xy' in self.plane_type:
fea['xy'] = self.generate_plane_features(index['xy'], c)
if 'yz' in self.plane_type:
fea['yz'] = self.generate_plane_features(index['yz'], c)
return fea

View File

@ -0,0 +1 @@
from .transformer import LocalFeatureTransformer

View File

@ -0,0 +1,81 @@
"""
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
"""
import torch
from torch.nn import Module, Dropout
def elu_feature_map(x):
return torch.nn.functional.elu(x) + 1
class LinearAttention(Module):
def __init__(self, eps=1e-6):
super().__init__()
self.feature_map = elu_feature_map
self.eps = eps
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
""" Multi-Head linear attention proposed in "Transformers are RNNs"
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""
Q = self.feature_map(queries)
K = self.feature_map(keys)
# set padded position to zero
if q_mask is not None:
Q = Q * q_mask[:, :, None, None]
if kv_mask is not None:
K = K * kv_mask[:, :, None, None]
values = values * kv_mask[:, :, None, None]
v_length = values.size(1)
values = values / v_length # prevent fp16 overflow
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
return queried_values.contiguous()
class FullAttention(Module):
def __init__(self, use_dropout=False, attention_dropout=0.1):
super().__init__()
self.use_dropout = use_dropout
self.dropout = Dropout(attention_dropout)
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
""" Multi-head scaled dot-product attention, a.k.a full attention.
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""
# Compute the unnormalized attention and apply the masks
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
if kv_mask is not None:
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
# Compute the attention and the weighted average
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
A = torch.softmax(softmax_temp * QK, dim=2)
if self.use_dropout:
A = self.dropout(A)
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
return queried_values.contiguous()

View File

@ -0,0 +1,142 @@
'''
modified from
https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
'''
import torch
from torch.nn import Module, Dropout
import copy
import torch.nn as nn
import torch.nn.functional as F
def elu_feature_map(x):
return torch.nn.functional.elu(x) + 1
class FullAttention(Module):
def __init__(self, use_dropout=False, attention_dropout=0.1):
super().__init__()
self.use_dropout = use_dropout
self.dropout = Dropout(attention_dropout)
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
""" Multi-head scaled dot-product attention, a.k.a full attention.
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""
# Compute the unnormalized attention and apply the masks
# QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
# if kv_mask is not None:
# QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float(-1e12))
# softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
# A = torch.softmax(softmax_temp * QK, dim=2)
# if self.use_dropout:
# A = self.dropout(A)
# queried_values_ = torch.einsum("nlsh,nshd->nlhd", A, values)
# Compute the attention and the weighted average
input_args = [x.half().contiguous() for x in [queries.permute(0,2,1,3), keys.permute(0,2,1,3), values.permute(0,2,1,3)]]
queried_values = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).float() # type: ignore
return queried_values.contiguous()
class TransformerEncoderLayer(nn.Module):
def __init__(self,
d_model,
nhead,):
super(TransformerEncoderLayer, self).__init__()
self.dim = d_model // nhead
self.nhead = nhead
# multi-head attention
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.attention = FullAttention()
self.merge = nn.Linear(d_model, d_model, bias=False)
# feed-forward network
self.mlp = nn.Sequential(
nn.Linear(d_model*2, d_model*2, bias=False),
nn.ReLU(True),
nn.Linear(d_model*2, d_model, bias=False),
)
# norm and dropout
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, source, x_mask=None, source_mask=None):
"""
Args:
x (torch.Tensor): [N, L, C]
source (torch.Tensor): [N, S, C]
x_mask (torch.Tensor): [N, L] (optional)
source_mask (torch.Tensor): [N, S] (optional)
"""
bs = x.size(0)
query, key, value = x, source, source
# multi-head attention
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
message = self.norm1(message)
# feed-forward network
message = self.mlp(torch.cat([x, message], dim=2))
message = self.norm2(message)
return x + message
class LocalFeatureTransformer(nn.Module):
"""A Local Feature Transformer module."""
def __init__(self, config):
super(LocalFeatureTransformer, self).__init__()
self.config = config
self.d_model = config['d_model']
self.nhead = config['nhead']
self.layer_names = config['layer_names']
encoder_layer = TransformerEncoderLayer(config['d_model'], config['nhead'])
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feat0, feat1, mask0=None, mask1=None):
"""
Args:
feat0 (torch.Tensor): [N, L, C]
feat1 (torch.Tensor): [N, S, C]
mask0 (torch.Tensor): [N, L] (optional)
mask1 (torch.Tensor): [N, S] (optional)
"""
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
for layer, name in zip(self.layers, self.layer_names):
if name == 'self':
feat0 = layer(feat0, feat0, mask0, mask0)
feat1 = layer(feat1, feat1, mask1, mask1)
elif name == 'cross':
feat0 = layer(feat0, feat1, mask0, mask1)
feat1 = layer(feat1, feat0, mask1, mask0)
else:
raise KeyError
return feat0, feat1

View File

@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from ..model_utils import reduce_masked_mean
from .blocks import (
pix2cam
)
from ..model_utils import (
bilinear_sample2d
)
EPS = 1e-6
import torchvision.transforms.functional as TF
sigma = 3
x_grid = torch.arange(-7,8,1)
y_grid = torch.arange(-7,8,1)
x_grid, y_grid = torch.meshgrid(x_grid, y_grid)
gridxy = torch.stack([x_grid, y_grid], dim=-1).float()
gs_kernel = torch.exp(-torch.sum(gridxy**2, dim=-1)/(2*sigma**2))
def balanced_ce_loss(pred, gt, valid=None):
total_balanced_loss = 0.0
for j in range(len(gt)):
B, S, N = gt[j].shape
# pred and gt are the same shape
for (a, b) in zip(pred[j].size(), gt[j].size()):
assert a == b # some shape mismatch!
# if valid is not None:
for (a, b) in zip(pred[j].size(), valid[j].size()):
assert a == b # some shape mismatch!
pos = (gt[j] > 0.95).float()
neg = (gt[j] < 0.05).float()
label = pos * 2.0 - 1.0
a = -label * pred[j]
b = F.relu(a)
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
pos_loss = reduce_masked_mean(loss, pos * valid[j])
neg_loss = reduce_masked_mean(loss, neg * valid[j])
balanced_loss = pos_loss + neg_loss
total_balanced_loss += balanced_loss / float(N)
import ipdb; ipdb.set_trace()
return total_balanced_loss
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8,
intr=None, trajs_g_all=None):
"""Loss function defined over sequence of flow predictions"""
total_flow_loss = 0.0
for j in range(len(flow_gt)):
B, S, N, D = flow_gt[j].shape
# assert D == 3
B, S1, N = vis[j].shape
B, S2, N = valids[j].shape
assert S == S1
assert S == S2
n_predictions = len(flow_preds[j])
if intr is not None:
intr_i = intr[j]
flow_loss = 0.0
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
flow_pred = flow_preds[j][i][..., -N:, :D]
flow_gt_j = flow_gt[j].clone()
if intr is not None:
xyz_j_gt = pix2cam(flow_gt_j, intr_i)
try:
i_loss = (flow_pred - flow_gt_j).abs() # B, S, N, 3
except:
import ipdb; ipdb.set_trace()
if D==3:
i_loss[...,2]*=30
i_loss = torch.mean(i_loss, dim=3) # B, S, N
flow_loss += i_weight * (reduce_masked_mean(i_loss, valids[j]))
flow_loss = flow_loss / n_predictions
total_flow_loss += flow_loss / float(N)
return total_flow_loss

View File

@ -0,0 +1,539 @@
#!/usr/bin/env python
"""The code of softsplat function is modified from:
https://github.com/sniklaus/softmax-splatting/blob/master/softsplat.py
"""
import collections
import cupy
import os
import re
import torch
import typing
##########################################################
objCudacache = {}
def cuda_int32(intIn:int):
return cupy.int32(intIn)
# end
def cuda_float32(fltIn:float):
return cupy.float32(fltIn)
# end
def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
if 'device' not in objCudacache:
objCudacache['device'] = torch.cuda.get_device_name()
# end
strKey = strFunction
for strVariable in objVariables:
objValue = objVariables[strVariable]
strKey += strVariable
if objValue is None:
continue
elif type(objValue) == int:
strKey += str(objValue)
elif type(objValue) == float:
strKey += str(objValue)
elif type(objValue) == bool:
strKey += str(objValue)
elif type(objValue) == str:
strKey += objValue
elif type(objValue) == torch.Tensor:
strKey += str(objValue.dtype)
strKey += str(objValue.shape)
strKey += str(objValue.stride())
elif True:
print(strVariable, type(objValue))
assert(False)
# end
# end
strKey += objCudacache['device']
if strKey not in objCudacache:
for strVariable in objVariables:
objValue = objVariables[strVariable]
if objValue is None:
continue
elif type(objValue) == int:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == float:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == bool:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == str:
strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
strKernel = strKernel.replace('{{type}}', 'unsigned char')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
strKernel = strKernel.replace('{{type}}', 'half')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
strKernel = strKernel.replace('{{type}}', 'float')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
strKernel = strKernel.replace('{{type}}', 'double')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
strKernel = strKernel.replace('{{type}}', 'int')
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
strKernel = strKernel.replace('{{type}}', 'long')
elif type(objValue) == torch.Tensor:
print(strVariable, objValue.dtype)
assert(False)
elif True:
print(strVariable, type(objValue))
assert(False)
# end
# end
while True:
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
if objMatch is None:
break
# end
intArg = int(objMatch.group(2))
strTensor = objMatch.group(4)
intSizes = objVariables[strTensor].size()
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
# end
while True:
objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel)
if objMatch is None:
break
# end
intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1
while True:
intParentheses += 1 if strKernel[intStop] == '(' else 0
intParentheses -= 1 if strKernel[intStop] == ')' else 0
if intParentheses == 0:
break
# end
intStop += 1
# end
intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(',')
assert(intArgs == len(strArgs) - 1)
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = []
for intArg in range(intArgs):
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
# end
strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')')
# end
while True:
objMatch = re.search('(VALUE_)([0-4])(\()', strKernel)
if objMatch is None:
break
# end
intStart = objMatch.span()[1]
intStop = objMatch.span()[1]
intParentheses = 1
while True:
intParentheses += 1 if strKernel[intStop] == '(' else 0
intParentheses -= 1 if strKernel[intStop] == ')' else 0
if intParentheses == 0:
break
# end
intStop += 1
# end
intArgs = int(objMatch.group(2))
strArgs = strKernel[intStart:intStop].split(',')
assert(intArgs == len(strArgs) - 1)
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = []
for intArg in range(intArgs):
strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
# end
strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']')
# end
objCudacache[strKey] = {
'strFunction': strFunction,
'strKernel': strKernel
}
# end
return strKey
# end
@cupy.memoize(for_each_device=True)
def cuda_launch(strKey:str):
if 'CUDA_HOME' not in os.environ:
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
# end
return cupy.RawKernel(objCudacache[strKey]['strKernel'], objCudacache[strKey]['strFunction'])
# end
##########################################################
def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor,
tenMetric:torch.Tensor, strMode:str, tenoutH=None, tenoutW=None):
assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft'])
if strMode == 'sum': assert(tenMetric is None)
if strMode == 'avg': assert(tenMetric is None)
if strMode.split('-')[0] == 'linear': assert(tenMetric is not None)
if strMode.split('-')[0] == 'soft': assert(tenMetric is not None)
if strMode == 'avg':
tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1)
elif strMode.split('-')[0] == 'linear':
tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
elif strMode.split('-')[0] == 'soft':
tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1)
# end
tenOut = softsplat_func.apply(tenIn, tenFlow, tenoutH, tenoutW)
if strMode.split('-')[0] in ['avg', 'linear', 'soft']:
tenNormalize = tenOut[:, -1:, :, :]
if len(strMode.split('-')) == 1:
tenNormalize = tenNormalize + 0.0000001
elif strMode.split('-')[1] == 'addeps':
tenNormalize = tenNormalize + 0.0000001
elif strMode.split('-')[1] == 'zeroeps':
tenNormalize[tenNormalize == 0.0] = 1.0
elif strMode.split('-')[1] == 'clipeps':
tenNormalize = tenNormalize.clip(0.0000001, None)
# end
tenOut = tenOut[:, :-1, :, :] / tenNormalize
# end
return tenOut
# end
class softsplat_func(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(self, tenIn, tenFlow, H=None, W=None):
if H is None:
tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])
else:
tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], H, W])
if tenIn.is_cuda == True:
cuda_launch(cuda_kernel('softsplat_out', '''
extern "C" __global__ void __launch_bounds__(512) softsplat_out(
const int n,
const {{type}}* __restrict__ tenIn,
const {{type}}* __restrict__ tenFlow,
{{type}}* __restrict__ tenOut
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) / SIZE_1(tenIn) ) % SIZE_0(tenIn);
const int intC = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) ) % SIZE_1(tenIn);
const int intY = ( intIndex / SIZE_3(tenIn) ) % SIZE_2(tenIn);
const int intX = ( intIndex ) % SIZE_3(tenIn);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
{{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest);
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast);
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest);
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast);
}
} }
''', {
'tenIn': tenIn,
'tenFlow': tenFlow,
'tenOut': tenOut
}))(
grid=tuple([int((tenIn.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
elif tenIn.is_cuda != True:
assert(False)
# end
self.save_for_backward(tenIn, tenFlow)
return tenOut
# end
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(self, tenOutgrad):
tenIn, tenFlow = self.saved_tensors
tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True)
tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None
tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None
Hgrad = None
Wgrad = None
if tenIngrad is not None:
cuda_launch(cuda_kernel('softsplat_ingrad', '''
extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad(
const int n,
const {{type}}* __restrict__ tenIn,
const {{type}}* __restrict__ tenFlow,
const {{type}}* __restrict__ tenOutgrad,
{{type}}* __restrict__ tenIngrad,
{{type}}* __restrict__ tenFlowgrad
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad);
const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad);
const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad);
const int intX = ( intIndex ) % SIZE_3(tenIngrad);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltIngrad = 0.0f;
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
}
tenIngrad[intIndex] = fltIngrad;
} }
''', {
'tenIn': tenIn,
'tenFlow': tenFlow,
'tenOutgrad': tenOutgrad,
'tenIngrad': tenIngrad,
'tenFlowgrad': tenFlowgrad
}))(
grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
# end
if tenFlowgrad is not None:
cuda_launch(cuda_kernel('softsplat_flowgrad', '''
extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad(
const int n,
const {{type}}* __restrict__ tenIn,
const {{type}}* __restrict__ tenFlow,
const {{type}}* __restrict__ tenOutgrad,
{{type}}* __restrict__ tenIngrad,
{{type}}* __restrict__ tenFlowgrad
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad);
const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad);
const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad);
const int intX = ( intIndex ) % SIZE_3(tenFlowgrad);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltFlowgrad = 0.0f;
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
{{type}} fltNorthwest = 0.0f;
{{type}} fltNortheast = 0.0f;
{{type}} fltSouthwest = 0.0f;
{{type}} fltSoutheast = 0.0f;
if (intC == 0) {
fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY);
fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY);
fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY));
fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY));
} else if (intC == 1) {
fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f));
fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f));
fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f));
fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f));
}
for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) {
{{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX);
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest;
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast;
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest;
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast;
}
}
tenFlowgrad[intIndex] = fltFlowgrad;
} }
''', {
'tenIn': tenIn,
'tenFlow': tenFlow,
'tenOutgrad': tenOutgrad,
'tenIngrad': tenIngrad,
'tenFlowgrad': tenFlowgrad
}))(
grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()],
stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
)
# end
return tenIngrad, tenFlowgrad, Hgrad, Wgrad
# end
# end

View File

@ -0,0 +1,736 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from easydict import EasyDict as edict
from einops import rearrange
from sklearn.cluster import SpectralClustering
from .blocks import Lie
import matplotlib.pyplot as plt
import cv2
import torch.nn.functional as F
from .blocks import (
BasicEncoder,
CorrBlock,
EUpdateFormer,
FusionFormer,
pix2cam,
cam2pix,
edgeMat,
VitEncoder,
DPTEnc,
Dinov2
)
from .feature_net import (
LocalSoftSplat
)
from ..model_utils import (
meshgrid2d, bilinear_sample2d, smart_cat, sample_features5d, vis_PCA
)
from ..embeddings import (
get_2d_embedding,
get_3d_embedding,
get_1d_sincos_pos_embed_from_grid,
get_2d_sincos_pos_embed,
get_3d_sincos_pos_embed_from_grid,
Embedder_Fourier,
)
import numpy as np
from .softsplat import softsplat
torch.manual_seed(0)
from comfy.utils import ProgressBar
from tqdm import tqdm
def get_points_on_a_grid(grid_size, interp_shape,
grid_center=(0, 0), device="cuda"):
if grid_size == 1:
return torch.tensor([interp_shape[1] / 2,
interp_shape[0] / 2], device=device)[
None, None
]
grid_y, grid_x = meshgrid2d(
1, grid_size, grid_size, stack=False, norm=False, device=device
)
step = interp_shape[1] // 64
if grid_center[0] != 0 or grid_center[1] != 0:
grid_y = grid_y - grid_size / 2.0
grid_x = grid_x - grid_size / 2.0
grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
interp_shape[0] - step * 2
)
grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
interp_shape[1] - step * 2
)
grid_y = grid_y + grid_center[0]
grid_x = grid_x + grid_center[1]
xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
return xy
def sample_pos_embed(grid_size, embed_dim, coords):
if coords.shape[-1] == 2:
pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim,
grid_size=grid_size)
pos_embed = (
torch.from_numpy(pos_embed)
.reshape(grid_size[0], grid_size[1], embed_dim)
.float()
.unsqueeze(0)
.to(coords.device)
)
sampled_pos_embed = bilinear_sample2d(
pos_embed.permute(0, 3, 1, 2),
coords[:, 0, :, 0], coords[:, 0, :, 1]
)
elif coords.shape[-1] == 3:
sampled_pos_embed = get_3d_sincos_pos_embed_from_grid(
embed_dim, coords[:, :1, ...]
).float()[:,0,...].permute(0, 2, 1)
return sampled_pos_embed
class SpaTracker(nn.Module):
def __init__(
self,
S=8,
stride=8,
add_space_attn=True,
num_heads=8,
hidden_size=384,
space_depth=12,
time_depth=12,
args=edict({})
):
super(SpaTracker, self).__init__()
# step1: config the arch of the model
self.args=args
# step1.1: config the default value of the model
if getattr(args, "depth_color", None) == None:
self.args.depth_color = False
if getattr(args, "if_ARAP", None) == None:
self.args.if_ARAP = True
if getattr(args, "flash_attn", None) == None:
self.args.flash_attn = True
if getattr(args, "backbone", None) == None:
self.args.backbone = "CNN"
if getattr(args, "Nblock", None) == None:
self.args.Nblock = 0
if getattr(args, "Embed3D", None) == None:
self.args.Embed3D = True
# step1.2: config the model parameters
self.S = S
self.stride = stride
self.hidden_dim = 256
self.latent_dim = latent_dim = 128
self.b_latent_dim = self.latent_dim//3
self.corr_levels = 4
self.corr_radius = 3
self.add_space_attn = add_space_attn
self.lie = Lie()
# step2: config the model components
# @Encoder
self.fnet = BasicEncoder(input_dim=3,
output_dim=self.latent_dim, norm_fn="instance", dropout=0,
stride=stride, Embed3D=False
)
# conv head for the tri-plane features
self.headyz = nn.Sequential(
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1))
self.headxz = nn.Sequential(
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1))
# @UpdateFormer
self.updateformer = EUpdateFormer(
space_depth=space_depth,
time_depth=time_depth,
input_dim=456,
hidden_size=hidden_size,
num_heads=num_heads,
output_dim=latent_dim + 3,
mlp_ratio=4.0,
add_space_attn=add_space_attn,
flash=getattr(self.args, "flash_attn", True)
)
self.support_features = torch.zeros(100, 384).to("cuda") + 0.1
self.norm = nn.GroupNorm(1, self.latent_dim)
self.ffeat_updater = nn.Sequential(
nn.Linear(self.latent_dim, self.latent_dim),
nn.GELU(),
)
self.ffeatyz_updater = nn.Sequential(
nn.Linear(self.latent_dim, self.latent_dim),
nn.GELU(),
)
self.ffeatxz_updater = nn.Sequential(
nn.Linear(self.latent_dim, self.latent_dim),
nn.GELU(),
)
#TODO @NeuralArap: optimize the arap
self.embed_traj = Embedder_Fourier(
input_dim=5, max_freq_log2=5.0, N_freqs=3, include_input=True
)
self.embed3d = Embedder_Fourier(
input_dim=3, max_freq_log2=10.0, N_freqs=10, include_input=True
)
self.embedConv = nn.Conv2d(self.latent_dim+63,
self.latent_dim, 3, padding=1)
# @Vis_predictor
self.vis_predictor = nn.Sequential(
nn.Linear(128, 1),
)
self.embedProj = nn.Linear(63, 456)
self.zeroMLPflow = nn.Linear(195, 130)
def prepare_track(self, rgbds, queries):
"""
NOTE:
Normalized the rgbs and sorted the queries via their first appeared time
Args:
rgbds: the input rgbd images (B T 4 H W)
queries: the input queries (B N 4)
Return:
rgbds: the normalized rgbds (B T 4 H W)
queries: the sorted queries (B N 4)
track_mask:
"""
assert (rgbds.shape[2]==4) and (queries.shape[2]==4)
#Step1: normalize the rgbs input
device = rgbds.device
rgbds[:, :, :3, ...] = 2 * (rgbds[:, :, :3, ...] / 255.0) - 1.0
B, T, C, H, W = rgbds.shape
B, N, __ = queries.shape
self.traj_e = torch.zeros((B, T, N, 3), device=device)
self.vis_e = torch.zeros((B, T, N), device=device)
#Step2: sort the points via their first appeared time
first_positive_inds = queries[0, :, 0].long()
__, sort_inds = torch.sort(first_positive_inds, dim=0, descending=False)
inv_sort_inds = torch.argsort(sort_inds, dim=0)
first_positive_sorted_inds = first_positive_inds[sort_inds]
# check if can be inverse
assert torch.allclose(
first_positive_inds, first_positive_inds[sort_inds][inv_sort_inds]
)
# filter those points never appear points during 1 - T
ind_array = torch.arange(T, device=device)
ind_array = ind_array[None, :, None].repeat(B, 1, N)
track_mask = (ind_array >=
first_positive_inds[None, None, :]).unsqueeze(-1)
# scale the coords_init
coords_init = queries[:, :, 1:].reshape(B, 1, N, 3).repeat(
1, self.S, 1, 1
)
coords_init[..., :2] /= float(self.stride)
#Step3: initial the regular grid
gridx = torch.linspace(0, W//self.stride - 1, W//self.stride)
gridy = torch.linspace(0, H//self.stride - 1, H//self.stride)
gridx, gridy = torch.meshgrid(gridx, gridy)
gridxy = torch.stack([gridx, gridy], dim=-1).to(rgbds.device).permute(
2, 1, 0
)
vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10
# Step4: initial traj for neural arap
T_series = torch.linspace(0, 5, T).reshape(1, T, 1 , 1).cuda() # 1 T 1 1
T_series = T_series.repeat(B, 1, N, 1)
# get the 3d traj in the camera coordinates
intr_init = self.intrs[:,queries[0,:,0].long()]
Traj_series = pix2cam(queries[:,:,None,1:].double(), intr_init.double())
#torch.inverse(intr_init.double())@queries[:,:,1:,None].double() # B N 3 1
Traj_series = Traj_series.repeat(1, 1, T, 1).permute(0, 2, 1, 3).float()
Traj_series = torch.cat([T_series, Traj_series], dim=-1)
# get the indicator for the neural arap
Traj_mask = -1e2*torch.ones_like(T_series)
Traj_series = torch.cat([Traj_series, Traj_mask], dim=-1)
return (
rgbds,
first_positive_inds,
first_positive_sorted_inds,
sort_inds, inv_sort_inds,
track_mask, gridxy, coords_init[..., sort_inds, :].clone(),
vis_init, Traj_series[..., sort_inds, :].clone()
)
def sample_trifeat(self, t,
coords,
featMapxy,
featMapyz,
featMapxz):
"""
Sample the features from the 5D triplane feature map 3*(B S C H W)
Args:
t: the time index
coords: the coordinates of the points B S N 3
featMapxy: the feature map B S C Hx Wy
featMapyz: the feature map B S C Hy Wz
featMapxz: the feature map B S C Hx Wz
"""
# get xy_t yz_t xz_t
queried_t = t.reshape(1, 1, -1, 1)
xy_t = torch.cat(
[queried_t, coords[..., [0,1]]],
dim=-1
)
yz_t = torch.cat(
[queried_t, coords[..., [1, 2]]],
dim=-1
)
xz_t = torch.cat(
[queried_t, coords[..., [0, 2]]],
dim=-1
)
featxy_init = sample_features5d(featMapxy, xy_t)
featyz_init = sample_features5d(featMapyz, yz_t)
featxz_init = sample_features5d(featMapxz, xz_t)
featxy_init = featxy_init.repeat(1, self.S, 1, 1)
featyz_init = featyz_init.repeat(1, self.S, 1, 1)
featxz_init = featxz_init.repeat(1, self.S, 1, 1)
return featxy_init, featyz_init, featxz_init
def neural_arap(self, coords, Traj_arap, intrs_S, T_mark):
""" calculate the ARAP embedding and offset
Args:
coords: the coordinates of the current points 1 S N' 3
Traj_arap: the trajectory of the points 1 T N' 5
intrs_S: the camera intrinsics B S 3 3
"""
coords_out = coords.clone()
coords_out[..., :2] *= float(self.stride)
coords_out[..., 2] = coords_out[..., 2]/self.Dz
coords_out[..., 2] = coords_out[..., 2]*(self.d_far-self.d_near) + self.d_near
intrs_S = intrs_S[:, :, None, ...].repeat(1, 1, coords_out.shape[2], 1, 1)
B, S, N, D = coords_out.shape
if S != intrs_S.shape[1]:
intrs_S = torch.cat(
[intrs_S, intrs_S[:, -1:].repeat(1, S - intrs_S.shape[1],1,1,1)], dim=1
)
T_mark = torch.cat(
[T_mark, T_mark[:, -1:].repeat(1, S - T_mark.shape[1],1)], dim=1
)
xyz_ = pix2cam(coords_out.double(), intrs_S.double()[:,:,0])
xyz_ = xyz_.float()
xyz_embed = torch.cat([T_mark[...,None], xyz_,
torch.zeros_like(T_mark[...,None])], dim=-1)
xyz_embed = self.embed_traj(xyz_embed)
Traj_arap_embed = self.embed_traj(Traj_arap)
d_xyz,traj_feat = self.arapFormer(xyz_embed, Traj_arap_embed)
# update in camera coordinate
xyz_ = xyz_ + d_xyz.clamp(-5, 5)
# project back to the image plane
coords_out = cam2pix(xyz_.double(), intrs_S[:,:,0].double()).float()
# resize back
coords_out[..., :2] /= float(self.stride)
coords_out[..., 2] = (coords_out[..., 2] - self.d_near)/(self.d_far-self.d_near)
coords_out[..., 2] *= self.Dz
return xyz_, coords_out, traj_feat
def gradient_arap(self, coords, aff_avg=None, aff_std=None, aff_f_sg=None,
iter=0, iter_num=4, neigh_idx=None, intr=None, msk_track=None):
with torch.enable_grad():
coords.requires_grad_(True)
y = self.ARAP_ln(coords, aff_f_sg=aff_f_sg, neigh_idx=neigh_idx,
iter=iter, iter_num=iter_num, intr=intr,msk_track=msk_track)
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=coords,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True, allow_unused=True)[0]
return gradients.detach()
def forward_iteration(
self,
fmapXY,
fmapYZ,
fmapXZ,
coords_init,
feat_init=None,
vis_init=None,
track_mask=None,
iters=4,
intrs_S=None,
):
B, S_init, N, D = coords_init.shape
assert D == 3
assert B == 1
B, S, __, H8, W8 = fmapXY.shape
device = fmapXY.device
if S_init < S:
coords = torch.cat(
[coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)],
dim=1
)
vis_init = torch.cat(
[vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
)
intrs_S = torch.cat(
[intrs_S, intrs_S[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
)
else:
coords = coords_init.clone()
fcorr_fnXY = CorrBlock(
fmapXY, num_levels=self.corr_levels, radius=self.corr_radius
)
fcorr_fnYZ = CorrBlock(
fmapYZ, num_levels=self.corr_levels, radius=self.corr_radius
)
fcorr_fnXZ = CorrBlock(
fmapXZ, num_levels=self.corr_levels, radius=self.corr_radius
)
ffeats = torch.split(feat_init.clone(), dim=-1, split_size_or_sections=1)
ffeats = [f.squeeze(-1) for f in ffeats]
times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1)
pos_embed = sample_pos_embed(
grid_size=(H8, W8),
embed_dim=456,
coords=coords[..., :2],
)
pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1)
times_embed = (
torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None]
.repeat(B, 1, 1)
.float()
.to(device)
)
coord_predictions = []
attn_predictions = []
Rot_ln = 0
support_feat = self.support_features
comfy_pbar = ProgressBar(iters)
for __ in tqdm(range(iters), desc="Processing iterations", leave=True):
coords = coords.detach()
# if self.args.if_ARAP == True:
# # refine the track with arap
# xyz_pred, coords, flows_cat0 = self.neural_arap(coords.detach(),
# Traj_arap.detach(),
# intrs_S, T_mark)
with torch.no_grad():
fcorrsXY = fcorr_fnXY.corr_sample(ffeats[0], coords[..., :2])
fcorrsYZ = fcorr_fnYZ.corr_sample(ffeats[1], coords[..., [1,2]])
fcorrsXZ = fcorr_fnXZ.corr_sample(ffeats[2], coords[..., [0,2]])
# fcorrs = fcorrsXY
fcorrs = fcorrsXY + fcorrsYZ + fcorrsXZ
LRR = fcorrs.shape[3]
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR)
flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 3)
flows_cat = get_3d_embedding(flows_, 64, cat_coords=True)
flows_cat = self.zeroMLPflow(flows_cat)
ffeats_xy = ffeats[0].permute(0,
2, 1, 3).reshape(B * N, S, self.latent_dim)
ffeats_yz = ffeats[1].permute(0,
2, 1, 3).reshape(B * N, S, self.latent_dim)
ffeats_xz = ffeats[2].permute(0,
2, 1, 3).reshape(B * N, S, self.latent_dim)
ffeats_ = ffeats_xy + ffeats_yz + ffeats_xz
if track_mask.shape[1] < vis_init.shape[1]:
track_mask = torch.cat(
[
track_mask,
torch.zeros_like(track_mask[:, 0]).repeat(
1, vis_init.shape[1] - track_mask.shape[1], 1, 1
),
],
dim=1,
)
concat = (
torch.cat([track_mask, vis_init], dim=2)
.permute(0, 2, 1, 3)
.reshape(B * N, S, 2)
)
transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2)
if transformer_input.shape[-1] < pos_embed.shape[-1]:
# padding the transformer_input to the same dimension as pos_embed
transformer_input = F.pad(
transformer_input, (0, pos_embed.shape[-1] - transformer_input.shape[-1]),
"constant", 0
)
x = transformer_input + pos_embed + times_embed
x = rearrange(x, "(b n) t d -> b n t d", b=B)
delta, AttnMap, so3_dist, delta_se3F, so3 = self.updateformer(x, support_feat)
support_feat = support_feat + delta_se3F[0]/100
delta = rearrange(delta, " b n t d -> (b n) t d")
d_coord = delta[:, :, :3]
d_feats = delta[:, :, 3:]
ffeats_xy = self.ffeat_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_xy.reshape(-1, self.latent_dim)
ffeats_yz = self.ffeatyz_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_yz.reshape(-1, self.latent_dim)
ffeats_xz = self.ffeatxz_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_xz.reshape(-1, self.latent_dim)
ffeats[0] = ffeats_xy.reshape(B, N, S, self.latent_dim).permute(
0, 2, 1, 3
) # B,S,N,C
ffeats[1] = ffeats_yz.reshape(B, N, S, self.latent_dim).permute(
0, 2, 1, 3
) # B,S,N,C
ffeats[2] = ffeats_xz.reshape(B, N, S, self.latent_dim).permute(
0, 2, 1, 3
) # B,S,N,C
coords = coords + d_coord.reshape(B, N, S, 3).permute(0, 2, 1, 3)
if torch.isnan(coords).any():
import ipdb; ipdb.set_trace()
coords_out = coords.clone()
coords_out[..., :2] *= float(self.stride)
coords_out[..., 2] = coords_out[..., 2]/self.Dz
coords_out[..., 2] = coords_out[..., 2]*(self.d_far-self.d_near) + self.d_near
coord_predictions.append(coords_out)
attn_predictions.append(AttnMap)
comfy_pbar.update(1)
ffeats_f = ffeats[0] + ffeats[1] + ffeats[2]
vis_e = self.vis_predictor(ffeats_f.reshape(B * S * N, self.latent_dim)).reshape(
B, S, N
)
self.support_features = support_feat.detach()
return coord_predictions, attn_predictions, vis_e, feat_init, Rot_ln
def forward(self, rgbds, queries, iters=4, feat_init=None,
is_train=False, intrs=None, wind_S=None):
self.support_features = torch.zeros(100, 384).to("cuda") + 0.1
self.is_train=is_train
B, T, C, H, W = rgbds.shape
# set the intrinsic or simply initialized
if intrs is None:
intrs = torch.from_numpy(np.array([[W, 0.0, W//2],
[0.0, W, H//2],
[0.0, 0.0, 1.0]]))
intrs = intrs[None,
None,...].repeat(B, T, 1, 1).float().to(rgbds.device)
self.intrs = intrs
# prepare the input for tracking
(
rgbds,
first_positive_inds,
first_positive_sorted_inds, sort_inds,
inv_sort_inds, track_mask, gridxy,
coords_init, vis_init, Traj_arap
) = self.prepare_track(rgbds.clone(), queries)
coords_init_ = coords_init.clone()
vis_init_ = vis_init[:, :, sort_inds].clone()
depth_all = rgbds[:, :, 3,...]
d_near = self.d_near = depth_all[depth_all>0.01].min().item()
d_far = self.d_far = depth_all[depth_all>0.01].max().item()
if wind_S is not None:
self.S = wind_S
B, N, __ = queries.shape
self.Dz = Dz = W//self.stride
w_idx_start = 0
p_idx_end = 0
p_idx_start = 0
fmaps_ = None
vis_predictions = []
coord_predictions = []
attn_predictions = []
p_idx_end_list = []
Rigid_ln_total = 0
while w_idx_start < T - self.S // 2:
curr_wind_points = torch.nonzero(
first_positive_sorted_inds < w_idx_start + self.S)
if curr_wind_points.shape[0] == 0:
w_idx_start = w_idx_start + self.S // 2
continue
p_idx_end = curr_wind_points[-1] + 1
p_idx_end_list.append(p_idx_end)
# the T may not be divided by self.S
rgbds_seq = rgbds[:, w_idx_start:w_idx_start + self.S].clone()
S = S_local = rgbds_seq.shape[1]
if S < self.S:
rgbds_seq = torch.cat(
[rgbds_seq,
rgbds_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)],
dim=1,
)
S = rgbds_seq.shape[1]
rgbs_ = rgbds_seq.reshape(B * S, C, H, W)[:, :3]
depths = rgbds_seq.reshape(B * S, C, H, W)[:, 3:].clone()
# open the mask
# Traj_arap[:, w_idx_start:w_idx_start + self.S, :p_idx_end, -1] = 0
#step1: normalize the depth map
depths = (depths - d_near)/(d_far-d_near)
depths_dn = nn.functional.interpolate(
depths, scale_factor=1.0 / self.stride, mode="nearest")
depths_dnG = depths_dn*Dz
#step2: normalize the coordinate
coords_init_[:, :, p_idx_start:p_idx_end, 2] = (
coords_init[:, :, p_idx_start:p_idx_end, 2] - d_near
)/(d_far-d_near)
coords_init_[:, :, p_idx_start:p_idx_end, 2] *= Dz
# efficient triplane splatting
gridxyz = torch.cat([gridxy[None,...].repeat(
depths_dn.shape[0],1,1,1), depths_dnG], dim=1)
Fxy2yz = gridxyz[:,[1, 2], ...] - gridxyz[:,:2]
Fxy2xz = gridxyz[:,[0, 2], ...] - gridxyz[:,:2]
if getattr(self.args, "Embed3D", None) == True:
gridxyz_nm = gridxyz.clone()
gridxyz_nm[:,0,...] = (gridxyz_nm[:,0,...]-gridxyz_nm[:,0,...].min())/(gridxyz_nm[:,0,...].max()-gridxyz_nm[:,0,...].min())
gridxyz_nm[:,1,...] = (gridxyz_nm[:,1,...]-gridxyz_nm[:,1,...].min())/(gridxyz_nm[:,1,...].max()-gridxyz_nm[:,1,...].min())
gridxyz_nm[:,2,...] = (gridxyz_nm[:,2,...]-gridxyz_nm[:,2,...].min())/(gridxyz_nm[:,2,...].max()-gridxyz_nm[:,2,...].min())
gridxyz_nm = 2*(gridxyz_nm-0.5)
_,_,h4,w4 = gridxyz_nm.shape
gridxyz_nm = gridxyz_nm.permute(0,2,3,1).reshape(S*h4*w4, 3)
featPE = self.embed3d(gridxyz_nm).view(S, h4, w4, -1).permute(0,3,1,2)
if fmaps_ is None:
fmaps_ = torch.cat([self.fnet(rgbs_),featPE], dim=1)
fmaps_ = self.embedConv(fmaps_)
else:
fmaps_new = torch.cat([self.fnet(rgbs_[self.S // 2 :]),featPE[self.S // 2 :]], dim=1)
fmaps_new = self.embedConv(fmaps_new)
fmaps_ = torch.cat(
[fmaps_[self.S // 2 :], fmaps_new], dim=0
)
else:
if fmaps_ is None:
fmaps_ = self.fnet(rgbs_)
else:
fmaps_ = torch.cat(
[fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0
)
fmapXY = fmaps_[:, :self.latent_dim].reshape(
B, S, self.latent_dim, H // self.stride, W // self.stride
)
fmapYZ = softsplat(fmapXY[0], Fxy2yz, None,
strMode="avg", tenoutH=self.Dz, tenoutW=H//self.stride)
fmapXZ = softsplat(fmapXY[0], Fxy2xz, None,
strMode="avg", tenoutH=self.Dz, tenoutW=W//self.stride)
fmapYZ = self.headyz(fmapYZ)[None, ...]
fmapXZ = self.headxz(fmapXZ)[None, ...]
if p_idx_end - p_idx_start > 0:
queried_t = (first_positive_sorted_inds[p_idx_start:p_idx_end]
- w_idx_start)
(featxy_init,
featyz_init,
featxz_init) = self.sample_trifeat(
t=queried_t,featMapxy=fmapXY,
featMapyz=fmapYZ,featMapxz=fmapXZ,
coords=coords_init_[:, :1, p_idx_start:p_idx_end]
)
# T, S, N, C, 3
feat_init_curr = torch.stack([featxy_init,
featyz_init, featxz_init], dim=-1)
feat_init = smart_cat(feat_init, feat_init_curr, dim=2)
if p_idx_start > 0:
# preprocess the coordinates of last windows
last_coords = coords[-1][:, self.S // 2 :].clone()
last_coords[..., :2] /= float(self.stride)
last_coords[..., 2:] = (last_coords[..., 2:]-d_near)/(d_far-d_near)
last_coords[..., 2:] = last_coords[..., 2:]*Dz
coords_init_[:, : self.S // 2, :p_idx_start] = last_coords
coords_init_[:, self.S // 2 :, :p_idx_start] = last_coords[
:, -1
].repeat(1, self.S // 2, 1, 1)
last_vis = vis[:, self.S // 2 :].unsqueeze(-1)
vis_init_[:, : self.S // 2, :p_idx_start] = last_vis
vis_init_[:, self.S // 2 :, :p_idx_start] = last_vis[:, -1].repeat(
1, self.S // 2, 1, 1
)
coords, attns, vis, __, Rigid_ln = self.forward_iteration(
fmapXY=fmapXY,
fmapYZ=fmapYZ,
fmapXZ=fmapXZ,
coords_init=coords_init_[:, :, :p_idx_end],
feat_init=feat_init[:, :, :p_idx_end],
vis_init=vis_init_[:, :, :p_idx_end],
track_mask=track_mask[:, w_idx_start : w_idx_start + self.S, :p_idx_end],
iters=iters,
intrs_S=self.intrs[:, w_idx_start : w_idx_start + self.S],
)
Rigid_ln_total+=Rigid_ln
if is_train:
vis_predictions.append(torch.sigmoid(vis[:, :S_local]))
coord_predictions.append([coord[:, :S_local] for coord in coords])
attn_predictions.append(attns)
self.traj_e[:, w_idx_start:w_idx_start+self.S, :p_idx_end] = coords[-1][:, :S_local]
self.vis_e[:, w_idx_start:w_idx_start+self.S, :p_idx_end] = vis[:, :S_local]
track_mask[:, : w_idx_start + self.S, :p_idx_end] = 0.0
w_idx_start = w_idx_start + self.S // 2
p_idx_start = p_idx_end
self.traj_e = self.traj_e[:, :, inv_sort_inds]
self.vis_e = self.vis_e[:, :, inv_sort_inds]
self.vis_e = torch.sigmoid(self.vis_e)
train_data = (
(vis_predictions, coord_predictions, attn_predictions,
p_idx_end_list, sort_inds, Rigid_ln_total)
)
if self.is_train:
return self.traj_e, feat_init, self.vis_e, train_data
else:
return self.traj_e, feat_init, self.vis_e

View File

@ -0,0 +1,258 @@
'''
Codes are from:
https://github.com/jaxony/unet-pytorch/blob/master/model.py
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict
from torch.nn import init
import numpy as np
def conv3x3(in_channels, out_channels, stride=1,
padding=1, bias=True, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def upconv2x2(in_channels, out_channels, mode='transpose'):
if mode == 'transpose':
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=2,
stride=2)
else:
# out_channels is always going to be the same
# as in_channels
return nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
conv1x1(in_channels, out_channels))
def conv1x1(in_channels, out_channels, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
groups=groups,
stride=1)
class DownConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 MaxPool.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels, pooling=True):
super(DownConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.pooling = pooling
self.conv1 = conv3x3(self.in_channels, self.out_channels)
self.conv2 = conv3x3(self.out_channels, self.out_channels)
if self.pooling:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
before_pool = x
if self.pooling:
x = self.pool(x)
return x, before_pool
class UpConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 UpConvolution.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels,
merge_mode='concat', up_mode='transpose'):
super(UpConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.merge_mode = merge_mode
self.up_mode = up_mode
self.upconv = upconv2x2(self.in_channels, self.out_channels,
mode=self.up_mode)
if self.merge_mode == 'concat':
self.conv1 = conv3x3(
2*self.out_channels, self.out_channels)
else:
# num of input channels to conv2 is same
self.conv1 = conv3x3(self.out_channels, self.out_channels)
self.conv2 = conv3x3(self.out_channels, self.out_channels)
def forward(self, from_down, from_up):
""" Forward pass
Arguments:
from_down: tensor from the encoder pathway
from_up: upconv'd tensor from the decoder pathway
"""
from_up = self.upconv(from_up)
if self.merge_mode == 'concat':
x = torch.cat((from_up, from_down), 1)
else:
x = from_up + from_down
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
return x
class UNet(nn.Module):
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
The U-Net is a convolutional encoder-decoder neural network.
Contextual spatial information (from the decoding,
expansive pathway) about an input tensor is merged with
information representing the localization of details
(from the encoding, compressive pathway).
Modifications to the original paper:
(1) padding is used in 3x3 convolutions to prevent loss
of border pixels
(2) merging outputs does not require cropping due to (1)
(3) residual connections can be used by specifying
UNet(merge_mode='add')
(4) if non-parametric upsampling is used in the decoder
pathway (specified by upmode='upsample'), then an
additional 1x1 2d convolution occurs after upsampling
to reduce channel dimensionality by a factor of 2.
This channel halving happens with the convolution in
the tranpose convolution (specified by upmode='transpose')
"""
def __init__(self, num_classes, in_channels=3, depth=5,
start_filts=64, up_mode='transpose',
merge_mode='concat', **kwargs):
"""
Arguments:
in_channels: int, number of channels in the input tensor.
Default is 3 for RGB images.
depth: int, number of MaxPools in the U-Net.
start_filts: int, number of convolutional filters for the
first conv.
up_mode: string, type of upconvolution. Choices: 'transpose'
for transpose convolution or 'upsample' for nearest neighbour
upsampling.
"""
super(UNet, self).__init__()
if up_mode in ('transpose', 'upsample'):
self.up_mode = up_mode
else:
raise ValueError("\"{}\" is not a valid mode for "
"upsampling. Only \"transpose\" and "
"\"upsample\" are allowed.".format(up_mode))
if merge_mode in ('concat', 'add'):
self.merge_mode = merge_mode
else:
raise ValueError("\"{}\" is not a valid mode for"
"merging up and down paths. "
"Only \"concat\" and "
"\"add\" are allowed.".format(up_mode))
# NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
if self.up_mode == 'upsample' and self.merge_mode == 'add':
raise ValueError("up_mode \"upsample\" is incompatible "
"with merge_mode \"add\" at the moment "
"because it doesn't make sense to use "
"nearest neighbour to reduce "
"depth channels (by half).")
self.num_classes = num_classes
self.in_channels = in_channels
self.start_filts = start_filts
self.depth = depth
self.down_convs = []
self.up_convs = []
# create the encoder pathway and add to a list
for i in range(depth):
ins = self.in_channels if i == 0 else outs
outs = self.start_filts*(2**i)
pooling = True if i < depth-1 else False
down_conv = DownConv(ins, outs, pooling=pooling)
self.down_convs.append(down_conv)
# create the decoder pathway and add to a list
# - careful! decoding only requires depth-1 blocks
for i in range(depth-1):
ins = outs
outs = ins // 2
up_conv = UpConv(ins, outs, up_mode=up_mode,
merge_mode=merge_mode)
self.up_convs.append(up_conv)
# add the list of modules to current module
self.down_convs = nn.ModuleList(self.down_convs)
self.up_convs = nn.ModuleList(self.up_convs)
self.conv_final = conv1x1(outs, self.num_classes)
self.reset_params()
@staticmethod
def weight_init(m):
if isinstance(m, nn.Conv2d):
init.xavier_normal_(m.weight)
init.constant_(m.bias, 0)
def reset_params(self):
for i, m in enumerate(self.modules()):
self.weight_init(m)
def forward(self, x):
encoder_outs = []
# encoder pathway, save outputs for merging
for i, module in enumerate(self.down_convs):
x, before_pool = module(x)
encoder_outs.append(before_pool)
for i, module in enumerate(self.up_convs):
before_pool = encoder_outs[-(i+2)]
x = module(before_pool, x)
# No softmax is used. This means you need to use
# nn.CrossEntropyLoss is your training script,
# as this module includes a softmax already.
x = self.conv_final(x)
return x
if __name__ == "__main__":
"""
testing
"""
model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32)
print(model)
print(sum(p.numel() for p in model.parameters()))
reso = 176
x = np.zeros((1, 1, reso, reso))
x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan
x = torch.FloatTensor(x)
out = model(x)
print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso)))
# loss = torch.sum(out)
# loss.backward()

View File

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from typing import Type
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x

View File

@ -0,0 +1,397 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Type
from .common import (
LayerNorm2d, MLPBlock
)
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
return x
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x

288
das/spatracker/predictor.py Normal file
View File

@ -0,0 +1,288 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
import time
from tqdm import tqdm
from .models.core.spatracker.spatracker import get_points_on_a_grid
from .models.core.model_utils import smart_cat
from .models.build_spatracker import (
build_spatracker,
)
from .models.core.model_utils import (
meshgrid2d, bilinear_sample2d, smart_cat
)
from comfy.utils import ProgressBar
class SpaTrackerPredictor(torch.nn.Module):
def __init__(
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth",
interp_shape=(384, 512),
seq_length=16
):
super().__init__()
self.interp_shape = interp_shape
self.support_grid_size = 6
model = build_spatracker(checkpoint, seq_length=seq_length)
self.model = model
self.model.eval()
@torch.no_grad()
def forward(
self,
video, # (1, T, 3, H, W)
video_depth = None, # (T, 1, H, W)
# input prompt types:
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
# *backward_tracking=True* will compute tracks in both directions.
# - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates.
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
queries: torch.Tensor = None,
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
grid_size: int = 0,
grid_query_frame: int = 0, # only for dense and regular grid tracks
backward_tracking: bool = False,
depth_predictor=None,
wind_length: int = 8,
progressive_tracking: bool = False,
):
if queries is None and grid_size == 0:
tracks, visibilities, T_Firsts = self._compute_dense_tracks(
video,
grid_query_frame=grid_query_frame,
backward_tracking=backward_tracking,
video_depth=video_depth,
depth_predictor=depth_predictor,
wind_length=wind_length,
)
else:
tracks, visibilities, T_Firsts = self._compute_sparse_tracks(
video,
queries,
segm_mask,
grid_size,
add_support_grid=False, #(grid_size == 0 or segm_mask is not None),
grid_query_frame=grid_query_frame,
backward_tracking=backward_tracking,
video_depth=video_depth,
depth_predictor=depth_predictor,
wind_length=wind_length,
)
return tracks, visibilities, T_Firsts
def _compute_dense_tracks(
self, video, grid_query_frame, grid_size=30, backward_tracking=False,
depth_predictor=None, video_depth=None, wind_length=8
):
*_, H, W = video.shape
grid_step = W // grid_size
grid_width = W // grid_step
grid_height = H // grid_step
tracks = visibilities = T_Firsts = None
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
grid_pts[0, :, 0] = grid_query_frame
for offset in tqdm(range(grid_step * grid_step)):
ox = offset % grid_step
oy = offset // grid_step
grid_pts[0, :, 1] = (
torch.arange(grid_width).repeat(grid_height) * grid_step + ox
)
grid_pts[0, :, 2] = (
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
)
tracks_step, visibilities_step, T_First_step = self._compute_sparse_tracks(
video=video,
queries=grid_pts,
backward_tracking=backward_tracking,
wind_length=wind_length,
video_depth=video_depth,
depth_predictor=depth_predictor,
)
tracks = smart_cat(tracks, tracks_step, dim=2)
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
T_Firsts = smart_cat(T_Firsts, T_First_step, dim=1)
return tracks, visibilities, T_Firsts
def _compute_sparse_tracks(
self,
video,
queries,
segm_mask=None,
grid_size=0,
add_support_grid=False,
grid_query_frame=0,
backward_tracking=False,
depth_predictor=None,
video_depth=None,
wind_length=8,
):
B, T, C, H, W = video.shape
assert B == 1
video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
if queries is not None:
queries = queries.clone()
B, N, D = queries.shape
assert D == 3
queries[:, :, 1] *= self.interp_shape[1] / W
queries[:, :, 2] *= self.interp_shape[0] / H
elif grid_size > 0:
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
if segm_mask is not None:
segm_mask = F.interpolate(
segm_mask, tuple(self.interp_shape), mode="nearest"
)
point_mask = segm_mask[0, 0][
(grid_pts[0, :, 1]).round().long().cpu(),
(grid_pts[0, :, 0]).round().long().cpu(),
].bool()
grid_pts_extra = grid_pts[:, point_mask]
else:
grid_pts_extra = None
if grid_pts_extra is not None:
total_num = int(grid_pts_extra.shape[1])
total_num = min(800, total_num)
pick_idx = torch.randperm(grid_pts_extra.shape[1])[:total_num]
grid_pts_extra = grid_pts_extra[:, pick_idx]
queries_extra = torch.cat(
[
torch.ones_like(grid_pts_extra[:, :, :1]) * grid_query_frame,
grid_pts_extra,
],
dim=2,
)
queries = torch.cat(
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts],
dim=2,
)
if add_support_grid:
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=video.device)
grid_pts = torch.cat(
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
)
queries = torch.cat([queries, grid_pts], dim=1)
## ----------- estimate the video depth -----------##
if video_depth is None:
with torch.no_grad():
if video[0].shape[0]>30:
vidDepths = []
for i in range(video[0].shape[0]//30+1):
if (i+1)*30 > video[0].shape[0]:
end_idx = video[0].shape[0]
else:
end_idx = (i+1)*30
if end_idx == i*30:
break
video_ = video[0][i*30:end_idx]
vidDepths.append(depth_predictor.infer(video_/255))
video_depth = torch.cat(vidDepths, dim=0)
else:
video_depth = depth_predictor.infer(video[0]/255)
video_depth = F.interpolate(video_depth,
tuple(self.interp_shape), mode="nearest")
# from PIL import Image
# import numpy
# depth_frame = video_depth[0].detach().cpu()
# depth_frame = depth_frame.squeeze(0)
# print(depth_frame)
# print(depth_frame.min(), depth_frame.max())
# depth_img = (depth_frame * 255).numpy().astype(numpy.uint8)
# depth_img = Image.fromarray(depth_img, mode='L')
# depth_img.save('outputs/depth_map.png')
# frame = video[0, 0].detach().cpu()
# frame = frame.permute(1, 2, 0)
# frame = (frame * 255).numpy().astype(numpy.uint8)
# frame = Image.fromarray(frame, mode='RGB')
# frame.save('outputs/frame.png')
depths = video_depth
rgbds = torch.cat([video, depths[None,...]], dim=2)
# get the 3D queries
comfy_pbar = ProgressBar(queries.shape[1])
depth_interp=[]
for i in tqdm(range(queries.shape[1]), desc="Processing queries"):
depth_interp_i = bilinear_sample2d(video_depth[queries[:, i:i+1, 0].long()],
queries[:, i:i+1, 1], queries[:, i:i+1, 2])
depth_interp.append(depth_interp_i)
comfy_pbar.update(1)
depth_interp = torch.cat(depth_interp, dim=1)
queries = smart_cat(queries, depth_interp,dim=-1)
#NOTE: free the memory of depth_predictor
del depth_predictor
torch.cuda.empty_cache()
t0 = time.time()
tracks, __, visibilities = self.model(rgbds=rgbds, queries=queries, iters=6, wind_S=wind_length)
print("Time taken for inference: ", time.time()-t0)
if backward_tracking:
tracks, visibilities = self._compute_backward_tracks(
rgbds, queries, tracks, visibilities
)
if add_support_grid:
queries[:, -self.support_grid_size ** 2 :, 0] = T - 1
if add_support_grid:
tracks = tracks[:, :, : -self.support_grid_size ** 2]
visibilities = visibilities[:, :, : -self.support_grid_size ** 2]
thr = 0.9
visibilities = visibilities > thr
# correct query-point predictions
# see https://github.com/facebookresearch/co-tracker/issues/28
# TODO: batchify
for i in tqdm(range(len(queries)), desc="Processing queries", leave=False):
queries_t = queries[i, :tracks.size(2), 0].to(torch.int64)
arange = torch.arange(0, len(queries_t))
# overwrite the predictions with the query points
tracks[i, queries_t, arange] = queries[i, :tracks.size(2), 1:]
# correct visibilities, the query points should be visible
visibilities[i, queries_t, arange] = True
T_First = queries[..., :tracks.size(2), 0].to(torch.uint8)
tracks[:, :, :, 0] *= W / float(self.interp_shape[1])
tracks[:, :, :, 1] *= H / float(self.interp_shape[0])
return tracks, visibilities, T_First
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
inv_video = video.flip(1).clone()
inv_queries = queries.clone()
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
inv_tracks, __, inv_visibilities = self.model(
rgbds=inv_video, queries=queries, iters=6
)
inv_tracks = inv_tracks.flip(1)
inv_visibilities = inv_visibilities.flip(1)
mask = tracks == 0
tracks[mask] = inv_tracks[mask]
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
return tracks, visibilities

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,397 @@
import os
import numpy as np
from os.path import isfile
import torch
import torch.nn.functional as F
EPS = 1e-6
import copy
def sub2ind(height, width, y, x):
return y*width + x
def ind2sub(height, width, ind):
y = ind // width
x = ind % width
return y, x
def get_lr_str(lr):
lrn = "%.1e" % lr # e.g., 5.0e-04
lrn = lrn[0] + lrn[3:5] + lrn[-1] # e.g., 5e-4
return lrn
def strnum(x):
s = '%g' % x
if '.' in s:
if x < 1.0:
s = s[s.index('.'):]
s = s[:min(len(s),4)]
return s
def assert_same_shape(t1, t2):
for (x, y) in zip(list(t1.shape), list(t2.shape)):
assert(x==y)
def print_stats(name, tensor):
shape = tensor.shape
tensor = tensor.detach().cpu().numpy()
print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)
def print_stats_py(name, tensor):
shape = tensor.shape
print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)
def print_(name, tensor):
tensor = tensor.detach().cpu().numpy()
print(name, tensor, tensor.shape)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def normalize_single(d):
# d is a whatever shape torch tensor
dmin = torch.min(d)
dmax = torch.max(d)
d = (d-dmin)/(EPS+(dmax-dmin))
return d
def normalize(d):
# d is B x whatever. normalize within each element of the batch
out = torch.zeros(d.size())
if d.is_cuda:
out = out.cuda()
B = list(d.size())[0]
for b in list(range(B)):
out[b] = normalize_single(d[b])
return out
def hard_argmax2d(tensor):
B, C, Y, X = list(tensor.shape)
assert(C==1)
# flatten the Tensor along the height and width axes
flat_tensor = tensor.reshape(B, -1)
# argmax of the flat tensor
argmax = torch.argmax(flat_tensor, dim=1)
# convert the indices into 2d coordinates
argmax_y = torch.floor(argmax / X) # row
argmax_x = argmax % X # col
argmax_y = argmax_y.reshape(B)
argmax_x = argmax_x.reshape(B)
return argmax_y, argmax_x
def argmax2d(heat, hard=True):
B, C, Y, X = list(heat.shape)
assert(C==1)
if hard:
# hard argmax
loc_y, loc_x = hard_argmax2d(heat)
loc_y = loc_y.float()
loc_x = loc_x.float()
else:
heat = heat.reshape(B, Y*X)
prob = torch.nn.functional.softmax(heat, dim=1)
grid_y, grid_x = meshgrid2d(B, Y, X)
grid_y = grid_y.reshape(B, -1)
grid_x = grid_x.reshape(B, -1)
loc_y = torch.sum(grid_y*prob, dim=1)
loc_x = torch.sum(grid_x*prob, dim=1)
# these are B
return loc_y, loc_x
def reduce_masked_mean(x, mask, dim=None, keepdim=False):
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
# returns shape-1
# axis can be a list of axes
for (a,b) in zip(x.size(), mask.size()):
# if not b==1:
assert(a==b) # some shape mismatch!
# assert(x.size() == mask.size())
prod = x*mask
if dim is None:
numer = torch.sum(prod)
denom = EPS+torch.sum(mask)
else:
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
denom = EPS+torch.sum(mask, dim=dim, keepdim=keepdim)
mean = numer/denom
return mean
def reduce_masked_median(x, mask, keep_batch=False):
# x and mask are the same shape
assert(x.size() == mask.size())
device = x.device
B = list(x.shape)[0]
x = x.detach().cpu().numpy()
mask = mask.detach().cpu().numpy()
if keep_batch:
x = np.reshape(x, [B, -1])
mask = np.reshape(mask, [B, -1])
meds = np.zeros([B], np.float32)
for b in list(range(B)):
xb = x[b]
mb = mask[b]
if np.sum(mb) > 0:
xb = xb[mb > 0]
meds[b] = np.median(xb)
else:
meds[b] = np.nan
meds = torch.from_numpy(meds).to(device)
return meds.float()
else:
x = np.reshape(x, [-1])
mask = np.reshape(mask, [-1])
if np.sum(mask) > 0:
x = x[mask > 0]
med = np.median(x)
else:
med = np.nan
med = np.array([med], np.float32)
med = torch.from_numpy(med).to(device)
return med.float()
def pack_seqdim(tensor, B):
shapelist = list(tensor.shape)
B_, S = shapelist[:2]
assert(B==B_)
otherdims = shapelist[2:]
tensor = torch.reshape(tensor, [B*S]+otherdims)
return tensor
def unpack_seqdim(tensor, B):
shapelist = list(tensor.shape)
BS = shapelist[0]
assert(BS%B==0)
otherdims = shapelist[1:]
S = int(BS/B)
tensor = torch.reshape(tensor, [B,S]+otherdims)
return tensor
def meshgrid2d(B, Y, X, stack=False, norm=False, device='cuda', on_chans=False):
# returns a meshgrid sized B x Y x X
grid_y = torch.linspace(0.0, Y-1, Y, device=torch.device(device))
grid_y = torch.reshape(grid_y, [1, Y, 1])
grid_y = grid_y.repeat(B, 1, X)
grid_x = torch.linspace(0.0, X-1, X, device=torch.device(device))
grid_x = torch.reshape(grid_x, [1, 1, X])
grid_x = grid_x.repeat(B, Y, 1)
if norm:
grid_y, grid_x = normalize_grid2d(
grid_y, grid_x, Y, X)
if stack:
# note we stack in xy order
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
if on_chans:
grid = torch.stack([grid_x, grid_y], dim=1)
else:
grid = torch.stack([grid_x, grid_y], dim=-1)
return grid
else:
return grid_y, grid_x
def meshgrid3d(B, Z, Y, X, stack=False, norm=False, device='cuda'):
# returns a meshgrid sized B x Z x Y x X
grid_z = torch.linspace(0.0, Z-1, Z, device=device)
grid_z = torch.reshape(grid_z, [1, Z, 1, 1])
grid_z = grid_z.repeat(B, 1, Y, X)
grid_y = torch.linspace(0.0, Y-1, Y, device=device)
grid_y = torch.reshape(grid_y, [1, 1, Y, 1])
grid_y = grid_y.repeat(B, Z, 1, X)
grid_x = torch.linspace(0.0, X-1, X, device=device)
grid_x = torch.reshape(grid_x, [1, 1, 1, X])
grid_x = grid_x.repeat(B, Z, Y, 1)
# if cuda:
# grid_z = grid_z.cuda()
# grid_y = grid_y.cuda()
# grid_x = grid_x.cuda()
if norm:
grid_z, grid_y, grid_x = normalize_grid3d(
grid_z, grid_y, grid_x, Z, Y, X)
if stack:
# note we stack in xyz order
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
grid = torch.stack([grid_x, grid_y, grid_z], dim=-1)
return grid
else:
return grid_z, grid_y, grid_x
def normalize_grid2d(grid_y, grid_x, Y, X, clamp_extreme=True):
# make things in [-1,1]
grid_y = 2.0*(grid_y / float(Y-1)) - 1.0
grid_x = 2.0*(grid_x / float(X-1)) - 1.0
if clamp_extreme:
grid_y = torch.clamp(grid_y, min=-2.0, max=2.0)
grid_x = torch.clamp(grid_x, min=-2.0, max=2.0)
return grid_y, grid_x
def normalize_grid3d(grid_z, grid_y, grid_x, Z, Y, X, clamp_extreme=True):
# make things in [-1,1]
grid_z = 2.0*(grid_z / float(Z-1)) - 1.0
grid_y = 2.0*(grid_y / float(Y-1)) - 1.0
grid_x = 2.0*(grid_x / float(X-1)) - 1.0
if clamp_extreme:
grid_z = torch.clamp(grid_z, min=-2.0, max=2.0)
grid_y = torch.clamp(grid_y, min=-2.0, max=2.0)
grid_x = torch.clamp(grid_x, min=-2.0, max=2.0)
return grid_z, grid_y, grid_x
def gridcloud2d(B, Y, X, norm=False, device='cuda'):
# we want to sample for each location in the grid
grid_y, grid_x = meshgrid2d(B, Y, X, norm=norm, device=device)
x = torch.reshape(grid_x, [B, -1])
y = torch.reshape(grid_y, [B, -1])
# these are B x N
xy = torch.stack([x, y], dim=2)
# this is B x N x 2
return xy
def gridcloud3d(B, Z, Y, X, norm=False, device='cuda'):
# we want to sample for each location in the grid
grid_z, grid_y, grid_x = meshgrid3d(B, Z, Y, X, norm=norm, device=device)
x = torch.reshape(grid_x, [B, -1])
y = torch.reshape(grid_y, [B, -1])
z = torch.reshape(grid_z, [B, -1])
# these are B x N
xyz = torch.stack([x, y, z], dim=2)
# this is B x N x 3
return xyz
import re
def readPFM(file):
file = open(file, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header == b'PF':
color = True
elif header == b'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data
def normalize_boxlist2d(boxlist2d, H, W):
boxlist2d = boxlist2d.clone()
ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
ymin = ymin / float(H)
ymax = ymax / float(H)
xmin = xmin / float(W)
xmax = xmax / float(W)
boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
return boxlist2d
def unnormalize_boxlist2d(boxlist2d, H, W):
boxlist2d = boxlist2d.clone()
ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
ymin = ymin * float(H)
ymax = ymax * float(H)
xmin = xmin * float(W)
xmax = xmax * float(W)
boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
return boxlist2d
def unnormalize_box2d(box2d, H, W):
return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
def normalize_box2d(box2d, H, W):
return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
def get_gaussian_kernel_2d(channels, kernel_size=3, sigma=2.0, mid_one=False):
C = channels
xy_grid = gridcloud2d(C, kernel_size, kernel_size) # C x N x 2
mean = (kernel_size - 1)/2.0
variance = sigma**2.0
gaussian_kernel = (1.0/(2.0*np.pi*variance)**1.5) * torch.exp(-torch.sum((xy_grid - mean)**2.0, dim=-1) / (2.0*variance)) # C X N
gaussian_kernel = gaussian_kernel.view(C, 1, kernel_size, kernel_size) # C x 1 x 3 x 3
kernel_sum = torch.sum(gaussian_kernel, dim=(2,3), keepdim=True)
gaussian_kernel = gaussian_kernel / kernel_sum # normalize
if mid_one:
# normalize so that the middle element is 1
maxval = gaussian_kernel[:,:,(kernel_size//2),(kernel_size//2)].reshape(C, 1, 1, 1)
gaussian_kernel = gaussian_kernel / maxval
return gaussian_kernel
def gaussian_blur_2d(input, kernel_size=3, sigma=2.0, reflect_pad=False, mid_one=False):
B, C, Z, X = input.shape
kernel = get_gaussian_kernel_2d(C, kernel_size, sigma, mid_one=mid_one)
if reflect_pad:
pad = (kernel_size - 1)//2
out = F.pad(input, (pad, pad, pad, pad), mode='reflect')
out = F.conv2d(out, kernel, padding=0, groups=C)
else:
out = F.conv2d(input, kernel, padding=(kernel_size - 1)//2, groups=C)
return out
def gradient2d(x, absolute=False, square=False, return_sum=False):
# x should be B x C x H x W
dh = x[:, :, 1:, :] - x[:, :, :-1, :]
dw = x[:, :, :, 1:] - x[:, :, :, :-1]
zeros = torch.zeros_like(x)
zero_h = zeros[:, :, 0:1, :]
zero_w = zeros[:, :, :, 0:1]
dh = torch.cat([dh, zero_h], axis=2)
dw = torch.cat([dw, zero_w], axis=3)
if absolute:
dh = torch.abs(dh)
dw = torch.abs(dw)
if square:
dh = dh ** 2
dw = dw ** 2
if return_sum:
return dh+dw
else:
return dh, dw

View File

@ -0,0 +1,547 @@
import torch
from . import basic as utils
import numpy as np
import torchvision.ops as ops
from .basic import print_
def matmul2(mat1, mat2):
return torch.matmul(mat1, mat2)
def matmul3(mat1, mat2, mat3):
return torch.matmul(mat1, torch.matmul(mat2, mat3))
def eye_3x3(B, device='cuda'):
rt = torch.eye(3, device=torch.device(device)).view(1,3,3).repeat([B, 1, 1])
return rt
def eye_4x4(B, device='cuda'):
rt = torch.eye(4, device=torch.device(device)).view(1,4,4).repeat([B, 1, 1])
return rt
def safe_inverse(a): #parallel version
B, _, _ = list(a.shape)
inv = a.clone()
r_transpose = a[:, :3, :3].transpose(1,2) #inverse of rotation matrix
inv[:, :3, :3] = r_transpose
inv[:, :3, 3:4] = -torch.matmul(r_transpose, a[:, :3, 3:4])
return inv
def safe_inverse_single(a):
r, t = split_rt_single(a)
t = t.view(3,1)
r_transpose = r.t()
inv = torch.cat([r_transpose, -torch.matmul(r_transpose, t)], 1)
bottom_row = a[3:4, :] # this is [0, 0, 0, 1]
# bottom_row = torch.tensor([0.,0.,0.,1.]).view(1,4)
inv = torch.cat([inv, bottom_row], 0)
return inv
def split_intrinsics(K):
# K is B x 3 x 3 or B x 4 x 4
fx = K[:,0,0]
fy = K[:,1,1]
x0 = K[:,0,2]
y0 = K[:,1,2]
return fx, fy, x0, y0
def apply_pix_T_cam(pix_T_cam, xyz):
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
# xyz is shaped B x H*W x 3
# returns xy, shaped B x H*W x 2
B, N, C = list(xyz.shape)
assert(C==3)
x, y, z = torch.unbind(xyz, axis=-1)
fx = torch.reshape(fx, [B, 1])
fy = torch.reshape(fy, [B, 1])
x0 = torch.reshape(x0, [B, 1])
y0 = torch.reshape(y0, [B, 1])
EPS = 1e-4
z = torch.clamp(z, min=EPS)
x = (x*fx)/(z)+x0
y = (y*fy)/(z)+y0
xy = torch.stack([x, y], axis=-1)
return xy
def apply_pix_T_cam_py(pix_T_cam, xyz):
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
# xyz is shaped B x H*W x 3
# returns xy, shaped B x H*W x 2
B, N, C = list(xyz.shape)
assert(C==3)
x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2]
fx = np.reshape(fx, [B, 1])
fy = np.reshape(fy, [B, 1])
x0 = np.reshape(x0, [B, 1])
y0 = np.reshape(y0, [B, 1])
EPS = 1e-4
z = np.clip(z, EPS, None)
x = (x*fx)/(z)+x0
y = (y*fy)/(z)+y0
xy = np.stack([x, y], axis=-1)
return xy
def get_camM_T_camXs(origin_T_camXs, ind=0):
B, S = list(origin_T_camXs.shape)[0:2]
camM_T_camXs = torch.zeros_like(origin_T_camXs)
for b in list(range(B)):
camM_T_origin = safe_inverse_single(origin_T_camXs[b,ind])
for s in list(range(S)):
camM_T_camXs[b,s] = torch.matmul(camM_T_origin, origin_T_camXs[b,s])
return camM_T_camXs
def apply_4x4(RT, xyz):
B, N, _ = list(xyz.shape)
ones = torch.ones_like(xyz[:,:,0:1])
xyz1 = torch.cat([xyz, ones], 2)
xyz1_t = torch.transpose(xyz1, 1, 2)
# this is B x 4 x N
xyz2_t = torch.matmul(RT, xyz1_t)
xyz2 = torch.transpose(xyz2_t, 1, 2)
xyz2 = xyz2[:,:,:3]
return xyz2
def apply_4x4_py(RT, xyz):
# print('RT', RT.shape)
B, N, _ = list(xyz.shape)
ones = np.ones_like(xyz[:,:,0:1])
xyz1 = np.concatenate([xyz, ones], 2)
# print('xyz1', xyz1.shape)
xyz1_t = xyz1.transpose(0,2,1)
# print('xyz1_t', xyz1_t.shape)
# this is B x 4 x N
xyz2_t = np.matmul(RT, xyz1_t)
# print('xyz2_t', xyz2_t.shape)
xyz2 = xyz2_t.transpose(0,2,1)
# print('xyz2', xyz2.shape)
xyz2 = xyz2[:,:,:3]
return xyz2
def apply_3x3(RT, xy):
B, N, _ = list(xy.shape)
ones = torch.ones_like(xy[:,:,0:1])
xy1 = torch.cat([xy, ones], 2)
xy1_t = torch.transpose(xy1, 1, 2)
# this is B x 4 x N
xy2_t = torch.matmul(RT, xy1_t)
xy2 = torch.transpose(xy2_t, 1, 2)
xy2 = xy2[:,:,:2]
return xy2
def generate_polygon(ctr_x, ctr_y, avg_r, irregularity, spikiness, num_verts):
'''
Start with the center of the polygon at ctr_x, ctr_y,
Then creates the polygon by sampling points on a circle around the center.
Random noise is added by varying the angular spacing between sequential points,
and by varying the radial distance of each point from the centre.
Params:
ctr_x, ctr_y - coordinates of the "centre" of the polygon
avg_r - in px, the average radius of this polygon, this roughly controls how large the polygon is, really only useful for order of magnitude.
irregularity - [0,1] indicating how much variance there is in the angular spacing of vertices. [0,1] will map to [0, 2pi/numberOfVerts]
spikiness - [0,1] indicating how much variance there is in each vertex from the circle of radius avg_r. [0,1] will map to [0, avg_r]
pp num_verts
Returns:
np.array [num_verts, 2] - CCW order.
'''
# spikiness
spikiness = np.clip(spikiness, 0, 1) * avg_r
# generate n angle steps
irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / num_verts
lower = (2*np.pi / num_verts) - irregularity
upper = (2*np.pi / num_verts) + irregularity
# angle steps
angle_steps = np.random.uniform(lower, upper, num_verts)
sc = (2 * np.pi) / angle_steps.sum()
angle_steps *= sc
# get all radii
angle = np.random.uniform(0, 2*np.pi)
radii = np.clip(np.random.normal(avg_r, spikiness, num_verts), 0, 2 * avg_r)
# compute all points
points = []
for i in range(num_verts):
x = ctr_x + radii[i] * np.cos(angle)
y = ctr_y + radii[i] * np.sin(angle)
points.append([x, y])
angle += angle_steps[i]
return np.array(points).astype(int)
def get_random_affine_2d(B, rot_min=-5.0, rot_max=5.0, tx_min=-0.1, tx_max=0.1, ty_min=-0.1, ty_max=0.1, sx_min=-0.05, sx_max=0.05, sy_min=-0.05, sy_max=0.05, shx_min=-0.05, shx_max=0.05, shy_min=-0.05, shy_max=0.05):
'''
Params:
rot_min: rotation amount min
rot_max: rotation amount max
tx_min: translation x min
tx_max: translation x max
ty_min: translation y min
ty_max: translation y max
sx_min: scaling x min
sx_max: scaling x max
sy_min: scaling y min
sy_max: scaling y max
shx_min: shear x min
shx_max: shear x max
shy_min: shear y min
shy_max: shear y max
Returns:
transformation matrix: (B, 3, 3)
'''
# rotation
if rot_max - rot_min != 0:
rot_amount = np.random.uniform(low=rot_min, high=rot_max, size=B)
rot_amount = np.pi/180.0*rot_amount
else:
rot_amount = rot_min
rotation = np.zeros((B, 3, 3)) # B, 3, 3
rotation[:, 2, 2] = 1
rotation[:, 0, 0] = np.cos(rot_amount)
rotation[:, 0, 1] = -np.sin(rot_amount)
rotation[:, 1, 0] = np.sin(rot_amount)
rotation[:, 1, 1] = np.cos(rot_amount)
# translation
translation = np.zeros((B, 3, 3)) # B, 3, 3
translation[:, [0,1,2], [0,1,2]] = 1
if (tx_max - tx_min) > 0:
trans_x = np.random.uniform(low=tx_min, high=tx_max, size=B)
translation[:, 0, 2] = trans_x
# else:
# translation[:, 0, 2] = tx_max
if ty_max - ty_min != 0:
trans_y = np.random.uniform(low=ty_min, high=ty_max, size=B)
translation[:, 1, 2] = trans_y
# else:
# translation[:, 1, 2] = ty_max
# scaling
scaling = np.zeros((B, 3, 3)) # B, 3, 3
scaling[:, [0,1,2], [0,1,2]] = 1
if (sx_max - sx_min) > 0:
scale_x = 1 + np.random.uniform(low=sx_min, high=sx_max, size=B)
scaling[:, 0, 0] = scale_x
# else:
# scaling[:, 0, 0] = sx_max
if (sy_max - sy_min) > 0:
scale_y = 1 + np.random.uniform(low=sy_min, high=sy_max, size=B)
scaling[:, 1, 1] = scale_y
# else:
# scaling[:, 1, 1] = sy_max
# shear
shear = np.zeros((B, 3, 3)) # B, 3, 3
shear[:, [0,1,2], [0,1,2]] = 1
if (shx_max - shx_min) > 0:
shear_x = np.random.uniform(low=shx_min, high=shx_max, size=B)
shear[:, 0, 1] = shear_x
# else:
# shear[:, 0, 1] = shx_max
if (shy_max - shy_min) > 0:
shear_y = np.random.uniform(low=shy_min, high=shy_max, size=B)
shear[:, 1, 0] = shear_y
# else:
# shear[:, 1, 0] = shy_max
# compose all those
rt = np.einsum("ijk,ikl->ijl", rotation, translation)
ss = np.einsum("ijk,ikl->ijl", scaling, shear)
trans = np.einsum("ijk,ikl->ijl", rt, ss)
return trans
def get_centroid_from_box2d(box2d):
ymin = box2d[:,0]
xmin = box2d[:,1]
ymax = box2d[:,2]
xmax = box2d[:,3]
x = (xmin+xmax)/2.0
y = (ymin+ymax)/2.0
return y, x
def normalize_boxlist2d(boxlist2d, H, W):
boxlist2d = boxlist2d.clone()
ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
ymin = ymin / float(H)
ymax = ymax / float(H)
xmin = xmin / float(W)
xmax = xmax / float(W)
boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
return boxlist2d
def unnormalize_boxlist2d(boxlist2d, H, W):
boxlist2d = boxlist2d.clone()
ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)
ymin = ymin * float(H)
ymax = ymax * float(H)
xmin = xmin * float(W)
xmax = xmax * float(W)
boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)
return boxlist2d
def unnormalize_box2d(box2d, H, W):
return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
def normalize_box2d(box2d, H, W):
return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)
def get_size_from_box2d(box2d):
ymin = box2d[:,0]
xmin = box2d[:,1]
ymax = box2d[:,2]
xmax = box2d[:,3]
height = ymax-ymin
width = xmax-xmin
return height, width
def crop_and_resize(im, boxlist, PH, PW, boxlist_is_normalized=False):
B, C, H, W = im.shape
B2, N, D = boxlist.shape
assert(B==B2)
assert(D==4)
# PH, PW is the size to resize to
# output is B,N,C,PH,PW
# pt wants xy xy, unnormalized
if boxlist_is_normalized:
boxlist_unnorm = unnormalize_boxlist2d(boxlist, H, W)
else:
boxlist_unnorm = boxlist
ymin, xmin, ymax, xmax = boxlist_unnorm.unbind(2)
# boxlist_pt = torch.stack([boxlist_unnorm[:,1], boxlist_unnorm[:,0], boxlist_unnorm[:,3], boxlist_unnorm[:,2]], dim=1)
boxlist_pt = torch.stack([xmin, ymin, xmax, ymax], dim=2)
# we want a B-len list of K x 4 arrays
# print('im', im.shape)
# print('boxlist', boxlist.shape)
# print('boxlist_pt', boxlist_pt.shape)
# boxlist_pt = list(boxlist_pt.unbind(0))
crops = []
for b in range(B):
crops_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW))
crops.append(crops_b)
# # crops = im
# print('crops', crops.shape)
# crops = crops.reshape(B,N,C,PH,PW)
# crops = []
# for b in range(B):
# crop_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW))
# print('crop_b', crop_b.shape)
# crops.append(crop_b)
crops = torch.stack(crops, dim=0)
# print('crops', crops.shape)
# boxlist_list = boxlist_pt.unbind(0)
# print('rgb_crop', rgb_crop.shape)
return crops
# def get_boxlist_from_centroid_and_size(cy, cx, h, w, clip=True):
# # cy,cx are both B,N
# ymin = cy - h/2
# ymax = cy + h/2
# xmin = cx - w/2
# xmax = cx + w/2
# box = torch.stack([ymin, xmin, ymax, xmax], dim=-1)
# if clip:
# box = torch.clamp(box, 0, 1)
# return box
def get_boxlist_from_centroid_and_size(cy, cx, h, w):#, clip=False):
# cy,cx are the same shape
ymin = cy - h/2
ymax = cy + h/2
xmin = cx - w/2
xmax = cx + w/2
# if clip:
# ymin = torch.clamp(ymin, 0, H-1)
# ymax = torch.clamp(ymax, 0, H-1)
# xmin = torch.clamp(xmin, 0, W-1)
# xmax = torch.clamp(xmax, 0, W-1)
box = torch.stack([ymin, xmin, ymax, xmax], dim=-1)
return box
def get_box2d_from_mask(mask, normalize=False):
# mask is B, 1, H, W
B, C, H, W = mask.shape
assert(C==1)
xy = utils.basic.gridcloud2d(B, H, W, norm=False, device=mask.device) # B, H*W, 2
box = torch.zeros((B, 4), dtype=torch.float32, device=mask.device)
for b in range(B):
xy_b = xy[b] # H*W, 2
mask_b = mask[b].reshape(H*W)
xy_ = xy_b[mask_b > 0]
x_ = xy_[:,0]
y_ = xy_[:,1]
ymin = torch.min(y_)
ymax = torch.max(y_)
xmin = torch.min(x_)
xmax = torch.max(x_)
box[b] = torch.stack([ymin, xmin, ymax, xmax], dim=0)
if normalize:
box = normalize_boxlist2d(box.unsqueeze(1), H, W).squeeze(1)
return box
def convert_box2d_to_intrinsics(box2d, pix_T_cam, H, W, use_image_aspect_ratio=True, mult_padding=1.0):
# box2d is B x 4, with ymin, xmin, ymax, xmax in normalized coords
# ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1)
# H, W is the original size of the image
# mult_padding is relative to object size in pixels
# i assume we're rendering an image the same size as the original (H, W)
if not mult_padding==1.0:
y, x = get_centroid_from_box2d(box2d)
h, w = get_size_from_box2d(box2d)
box2d = get_box2d_from_centroid_and_size(
y, x, h*mult_padding, w*mult_padding, clip=False)
if use_image_aspect_ratio:
h, w = get_size_from_box2d(box2d)
y, x = get_centroid_from_box2d(box2d)
# note h,w are relative right now
# we need to undo this, to see the real ratio
h = h*float(H)
w = w*float(W)
box_ratio = h/w
im_ratio = H/float(W)
# print('box_ratio:', box_ratio)
# print('im_ratio:', im_ratio)
if box_ratio >= im_ratio:
w = h/im_ratio
# print('setting w:', h/im_ratio)
else:
h = w*im_ratio
# print('setting h:', w*im_ratio)
box2d = get_box2d_from_centroid_and_size(
y, x, h/float(H), w/float(W), clip=False)
assert(h > 1e-4)
assert(w > 1e-4)
ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1)
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
# the topleft of the new image will now have a different offset from the center of projection
new_x0 = x0 - xmin*W
new_y0 = y0 - ymin*H
pix_T_cam = pack_intrinsics(fx, fy, new_x0, new_y0)
# this alone will give me an image in original resolution,
# with its topleft at the box corner
box_h, box_w = get_size_from_box2d(box2d)
# these are normalized, and shaped B. (e.g., [0.4], [0.3])
# we are going to scale the image by the inverse of this,
# since we are zooming into this area
sy = 1./box_h
sx = 1./box_w
pix_T_cam = scale_intrinsics(pix_T_cam, sx, sy)
return pix_T_cam, box2d
def pixels2camera(x,y,z,fx,fy,x0,y0):
# x and y are locations in pixel coordinates, z is a depth in meters
# they can be images or pointclouds
# fx, fy, x0, y0 are camera intrinsics
# returns xyz, sized B x N x 3
B = x.shape[0]
fx = torch.reshape(fx, [B,1])
fy = torch.reshape(fy, [B,1])
x0 = torch.reshape(x0, [B,1])
y0 = torch.reshape(y0, [B,1])
x = torch.reshape(x, [B,-1])
y = torch.reshape(y, [B,-1])
z = torch.reshape(z, [B,-1])
# unproject
x = (z/fx)*(x-x0)
y = (z/fy)*(y-y0)
xyz = torch.stack([x,y,z], dim=2)
# B x N x 3
return xyz
def camera2pixels(xyz, pix_T_cam):
# xyz is shaped B x H*W x 3
# returns xy, shaped B x H*W x 2
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
x, y, z = torch.unbind(xyz, dim=-1)
B = list(z.shape)[0]
fx = torch.reshape(fx, [B,1])
fy = torch.reshape(fy, [B,1])
x0 = torch.reshape(x0, [B,1])
y0 = torch.reshape(y0, [B,1])
x = torch.reshape(x, [B,-1])
y = torch.reshape(y, [B,-1])
z = torch.reshape(z, [B,-1])
EPS = 1e-4
z = torch.clamp(z, min=EPS)
x = (x*fx)/z + x0
y = (y*fy)/z + y0
xy = torch.stack([x, y], dim=-1)
return xy
def depth2pointcloud(z, pix_T_cam):
B, C, H, W = list(z.shape)
device = z.device
y, x = utils.basic.meshgrid2d(B, H, W, device=device)
z = torch.reshape(z, [B, H, W])
fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
xyz = pixels2camera(x, y, z, fx, fy, x0, y0)
return xyz

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,166 @@
import torch
import numpy as np
import math
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad:
continue
param = parameter.numel()
if param > 100000:
table.add_row([name, param])
total_params+=param
print(table)
print('total params: %.2f M' % (total_params/1000000.0))
return total_params
def posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False):
device = xy.device
dtype = xy.dtype
B, S, D = xy.shape
assert(D==2)
x = xy[:,:,0]
y = xy[:,:,1]
assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(C // 4, device=device) / (C // 4 - 1)
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
pe = pe.reshape(B,S,C).type(dtype)
if cat_coords:
pe = torch.cat([pe, xy], dim=2) # B,N,C+2
return pe
class SimplePool():
def __init__(self, pool_size, version='pt'):
self.pool_size = pool_size
self.version = version
self.items = []
if not (version=='pt' or version=='np'):
print('version = %s; please choose pt or np')
assert(False) # please choose pt or np
def __len__(self):
return len(self.items)
def mean(self, min_size=1):
if min_size=='half':
pool_size_thresh = self.pool_size/2
else:
pool_size_thresh = min_size
if self.version=='np':
if len(self.items) >= pool_size_thresh:
return np.sum(self.items)/float(len(self.items))
else:
return np.nan
if self.version=='pt':
if len(self.items) >= pool_size_thresh:
return torch.sum(self.items)/float(len(self.items))
else:
return torch.from_numpy(np.nan)
def sample(self, with_replacement=True):
idx = np.random.randint(len(self.items))
if with_replacement:
return self.items[idx]
else:
return self.items.pop(idx)
def fetch(self, num=None):
if self.version=='pt':
item_array = torch.stack(self.items)
elif self.version=='np':
item_array = np.stack(self.items)
if num is not None:
# there better be some items
assert(len(self.items) >= num)
# if there are not that many elements just return however many there are
if len(self.items) < num:
return item_array
else:
idxs = np.random.randint(len(self.items), size=num)
return item_array[idxs]
else:
return item_array
def is_full(self):
full = len(self.items)==self.pool_size
return full
def empty(self):
self.items = []
def update(self, items):
for item in items:
if len(self.items) < self.pool_size:
# the pool is not full, so let's add this in
self.items.append(item)
else:
# the pool is full
# pop from the front
self.items.pop(0)
# add to the back
self.items.append(item)
return self.items
def farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False):
"""
Input:
xyz: pointcloud data, [B, N, C], where C is probably 3
npoint: number of samples
Return:
inds: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
xyz = xyz.float()
inds = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
if deterministic:
farthest = torch.randint(0, 1, (B,), dtype=torch.long).to(device)
else:
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
if include_ends:
if i==0:
farthest = 0
elif i==1:
farthest = N-1
inds[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, C)
dist = torch.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
if npoint > N:
# if we need more samples, make them random
distance += torch.randn_like(distance)
return inds
def farthest_point_sample_py(xyz, npoint):
N,C = xyz.shape
inds = np.zeros(npoint, dtype=np.int32)
distance = np.ones(N) * 1e10
farthest = np.random.randint(0, N, dtype=np.int32)
for i in range(npoint):
inds[i] = farthest
centroid = xyz[farthest, :].reshape(1,C)
dist = np.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
if npoint > N:
# if we need more samples, make them random
distance += np.random.randn(*distance.shape)
return inds

View File

@ -0,0 +1,152 @@
import torch
import utils.basic
import torch.nn.functional as F
def bilinear_sample2d(im, x, y, return_inbounds=False):
# x and y are each B, N
# output is B, C, N
B, C, H, W = list(im.shape)
N = list(x.shape)[1]
x = x.float()
y = y.float()
H_f = torch.tensor(H, dtype=torch.float32)
W_f = torch.tensor(W, dtype=torch.float32)
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
max_y = (H_f - 1).int()
max_x = (W_f - 1).int()
x0 = torch.floor(x).int()
x1 = x0 + 1
y0 = torch.floor(y).int()
y1 = y0 + 1
x0_clip = torch.clamp(x0, 0, max_x)
x1_clip = torch.clamp(x1, 0, max_x)
y0_clip = torch.clamp(y0, 0, max_y)
y1_clip = torch.clamp(y1, 0, max_y)
dim2 = W
dim1 = W * H
base = torch.arange(0, B, dtype=torch.int64, device=x.device)*dim1
base = torch.reshape(base, [B, 1]).repeat([1, N])
base_y0 = base + y0_clip * dim2
base_y1 = base + y1_clip * dim2
idx_y0_x0 = base_y0 + x0_clip
idx_y0_x1 = base_y0 + x1_clip
idx_y1_x0 = base_y1 + x0_clip
idx_y1_x1 = base_y1 + x1_clip
# use the indices to lookup pixels in the flat image
# im is B x C x H x W
# move C out to last dim
im_flat = (im.permute(0, 2, 3, 1)).reshape(B*H*W, C)
i_y0_x0 = im_flat[idx_y0_x0.long()]
i_y0_x1 = im_flat[idx_y0_x1.long()]
i_y1_x0 = im_flat[idx_y1_x0.long()]
i_y1_x1 = im_flat[idx_y1_x1.long()]
# Finally calculate interpolated values.
x0_f = x0.float()
x1_f = x1.float()
y0_f = y0.float()
y1_f = y1.float()
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
output = w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + \
w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
# output is B*N x C
output = output.view(B, -1, C)
output = output.permute(0, 2, 1)
# output is B x C x N
if return_inbounds:
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
inbounds = (x_valid & y_valid).float()
inbounds = inbounds.reshape(B, N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
return output, inbounds
return output # B, C, N
def paste_crop_on_canvas(crop, box2d_unnorm, H, W, fast=True, mask=None, canvas=None):
# this is the inverse of crop_and_resize_box2d
B, C, Y, X = list(crop.shape)
B2, D = list(box2d_unnorm.shape)
assert(B == B2)
assert(D == 4)
# here, we want to place the crop into a bigger image,
# at the location specified by the box2d.
if canvas is None:
canvas = torch.zeros((B, C, H, W), device=crop.device)
else:
B2, C2, H2, W2 = canvas.shape
assert(B==B2)
assert(C==C2)
assert(H==H2)
assert(W==W2)
# box2d_unnorm = utils.geom.unnormalize_box2d(box2d, H, W)
if fast:
ymin = box2d_unnorm[:, 0].long()
xmin = box2d_unnorm[:, 1].long()
ymax = box2d_unnorm[:, 2].long()
xmax = box2d_unnorm[:, 3].long()
w = (xmax - xmin).float()
h = (ymax - ymin).float()
grids = utils.basic.gridcloud2d(B, H, W)
grids_flat = grids.reshape(B, -1, 2)
# grids_flat[:, :, 0] = (grids_flat[:, :, 0] - xmin.float().unsqueeze(1)) / w.unsqueeze(1) * X
# grids_flat[:, :, 1] = (grids_flat[:, :, 1] - ymin.float().unsqueeze(1)) / h.unsqueeze(1) * Y
# for each pixel in the main image,
# grids_flat tells us where to sample in the crop image
# print('grids_flat', grids_flat.shape)
# print('crop', crop.shape)
grids_flat[:, :, 0] = (grids_flat[:, :, 0] - xmin.float().unsqueeze(1)) / w.unsqueeze(1) * 2.0 - 1.0
grids_flat[:, :, 1] = (grids_flat[:, :, 1] - ymin.float().unsqueeze(1)) / h.unsqueeze(1) * 2.0 - 1.0
grid = grids_flat.reshape(B,H,W,2)
canvas = F.grid_sample(crop, grid, align_corners=False)
# print('canvas', canvas.shape)
# if mask is None:
# crop_resamp, inb = bilinear_sample2d(crop, grids_flat[:, :, 0], grids_flat[:, :, 1], return_inbounds=True)
# crop_resamp = crop_resamp.reshape(B, C, H, W)
# inb = inb.reshape(B, 1, H, W)
# canvas = canvas * (1 - inb) + crop_resamp * inb
# else:
# full_resamp = bilinear_sample2d(torch.cat([crop, mask], dim=1), grids_flat[:, :, 0], grids_flat[:, :, 1])
# full_resamp = full_resamp.reshape(B, C+1, H, W)
# crop_resamp = full_resamp[:,:3]
# mask_resamp = full_resamp[:,3:4]
# canvas = canvas * (1 - mask_resamp) + crop_resamp * mask_resamp
else:
for b in range(B):
ymin = box2d_unnorm[b, 0].long()
xmin = box2d_unnorm[b, 1].long()
ymax = box2d_unnorm[b, 2].long()
xmax = box2d_unnorm[b, 3].long()
crop_b = F.interpolate(crop[b:b + 1], (ymax - ymin, xmax - xmin)).squeeze(0)
# print('canvas[b,:,...', canvas[b,:,ymin:ymax,xmin:xmax].shape)
# print('crop_b', crop_b.shape)
canvas[b, :, ymin:ymax, xmin:xmax] = crop_b
return canvas

View File

@ -0,0 +1,409 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import cv2
import torch
import flow_vis
from matplotlib import cm
import torch.nn.functional as F
import torchvision.transforms as transforms
#from moviepy.editor import ImageSequenceClip
import matplotlib.pyplot as plt
from tqdm import tqdm
def read_video_from_path(path):
cap = cv2.VideoCapture(path)
if not cap.isOpened():
print("Error opening video file")
else:
frames = []
while cap.isOpened():
ret, frame = cap.read()
if ret == True:
frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
else:
break
cap.release()
return np.stack(frames)
class Visualizer:
def __init__(
self,
save_dir: str = "./results",
grayscale: bool = False,
pad_value: int = 0,
fps: int = 10,
mode: str = "rainbow", # 'cool', 'optical_flow'
linewidth: int = 1,
show_first_frame: int = 10,
tracks_leave_trace: int = 0, # -1 for infinite
):
self.mode = mode
self.save_dir = save_dir
self.vtxt_path = os.path.join(save_dir, "videos.txt")
self.ttxt_path = os.path.join(save_dir, "trackings.txt")
if mode == "rainbow":
self.color_map = cm.get_cmap("gist_rainbow")
elif mode == "cool":
self.color_map = cm.get_cmap(mode)
self.show_first_frame = show_first_frame
self.grayscale = grayscale
self.tracks_leave_trace = tracks_leave_trace
self.pad_value = pad_value
self.linewidth = linewidth
self.fps = fps
def visualize(
self,
video: torch.Tensor, # (B,T,C,H,W)
tracks: torch.Tensor, # (B,T,N,2)
visibility: torch.Tensor = None, # (B, T, N, 1) bool
gt_tracks: torch.Tensor = None, # (B,T,N,2)
segm_mask: torch.Tensor = None, # (B,1,H,W)
filename: str = "video",
writer=None, # tensorboard Summary Writer, used for visualization during training
step: int = 0,
query_frame: int = 0,
save_video: bool = True,
compensate_for_camera_motion: bool = False,
rigid_part = None,
video_depth = None # (B,T,C,H,W)
):
if compensate_for_camera_motion:
assert segm_mask is not None
if segm_mask is not None:
coords = tracks[0, query_frame].round().long()
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
video = F.pad(
video,
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
"constant",
255,
)
if video_depth is not None:
video_depth = (video_depth*255).cpu().numpy().astype(np.uint8)
video_depth = ([cv2.applyColorMap(video_depth[0,i,0], cv2.COLORMAP_INFERNO)
for i in range(video_depth.shape[1])])
video_depth = np.stack(video_depth, axis=0)
video_depth = torch.from_numpy(video_depth).permute(0, 3, 1, 2)[None]
tracks = tracks + self.pad_value
if self.grayscale:
transform = transforms.Grayscale()
video = transform(video)
video = video.repeat(1, 1, 3, 1, 1)
tracking_video = self.draw_tracks_on_video(
video=video,
tracks=tracks,
visibility=visibility,
segm_mask=segm_mask,
gt_tracks=gt_tracks,
query_frame=query_frame,
compensate_for_camera_motion=compensate_for_camera_motion,
rigid_part=rigid_part
)
if save_video:
# import ipdb; ipdb.set_trace()
tracking_dir = os.path.join(self.save_dir, "tracking")
if not os.path.exists(tracking_dir):
os.makedirs(tracking_dir)
self.save_video(tracking_video, filename=filename+"_tracking",
savedir=tracking_dir, writer=writer, step=step)
# with open(self.ttxt_path, 'a') as file:
# file.write(f"tracking/{filename}_tracking.mp4\n")
videos_dir = os.path.join(self.save_dir, "videos")
if not os.path.exists(videos_dir):
os.makedirs(videos_dir)
self.save_video(video, filename=filename,
savedir=videos_dir, writer=writer, step=step)
# with open(self.vtxt_path, 'a') as file:
# file.write(f"videos/{filename}.mp4\n")
if video_depth is not None:
self.save_video(video_depth, filename=filename+"_depth",
savedir=os.path.join(self.save_dir, "depth"), writer=writer, step=step)
return tracking_video
def save_video(self, video, filename, savedir=None, writer=None, step=0):
if writer is not None:
writer.add_video(
f"{filename}",
video.to(torch.uint8),
global_step=step,
fps=self.fps,
)
else:
os.makedirs(self.save_dir, exist_ok=True)
wide_list = list(video.unbind(1))
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
# clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps)
clip = ImageSequenceClip(wide_list, fps=self.fps)
# Write the video file
if savedir is None:
save_path = os.path.join(self.save_dir, f"{filename}.mp4")
else:
save_path = os.path.join(savedir, f"{filename}.mp4")
clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
print(f"Video saved to {save_path}")
def draw_tracks_on_video(
self,
video: torch.Tensor,
tracks: torch.Tensor,
visibility: torch.Tensor = None,
segm_mask: torch.Tensor = None,
gt_tracks=None,
query_frame: int = 0,
compensate_for_camera_motion=False,
rigid_part=None,
):
B, T, C, H, W = video.shape
_, _, N, D = tracks.shape
assert D == 3
assert C == 3
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
tracks = tracks[0].detach().cpu().numpy() # S, N, 2
if gt_tracks is not None:
gt_tracks = gt_tracks[0].detach().cpu().numpy()
res_video = []
# process input video
# for rgb in video:
# res_video.append(rgb.copy())
# create a blank tensor with the same shape as the video
for rgb in video:
black_frame = np.zeros_like(rgb.copy(), dtype=rgb.dtype)
res_video.append(black_frame)
vector_colors = np.zeros((T, N, 3))
if self.mode == "optical_flow":
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
elif segm_mask is None:
if self.mode == "rainbow":
x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max()
y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max()
z_inv = 1/tracks[0, :, 2]
z_min, z_max = np.percentile(z_inv, [2, 98])
norm_x = plt.Normalize(x_min, x_max)
norm_y = plt.Normalize(y_min, y_max)
norm_z = plt.Normalize(z_min, z_max)
for n in range(N):
r = norm_x(tracks[0, n, 0])
g = norm_y(tracks[0, n, 1])
# r = 0
# g = 0
b = norm_z(1/tracks[0, n, 2])
color = np.array([r, g, b])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with time
for t in range(T):
color = np.array(self.color_map(t / T)[:3])[None] * 255
vector_colors[t] = np.repeat(color, N, axis=0)
else:
if self.mode == "rainbow":
vector_colors[:, segm_mask <= 0, :] = 255
x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max()
y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max()
z_min, z_max = tracks[0, :, 2].min(), tracks[0, :, 2].max()
norm_x = plt.Normalize(x_min, x_max)
norm_y = plt.Normalize(y_min, y_max)
norm_z = plt.Normalize(z_min, z_max)
for n in range(N):
r = norm_x(tracks[0, n, 0])
g = norm_y(tracks[0, n, 1])
b = norm_z(tracks[0, n, 2])
color = np.array([r, g, b])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with segm class
segm_mask = segm_mask.cpu()
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
vector_colors = np.repeat(color[None], T, axis=0)
# Draw tracks
if self.tracks_leave_trace != 0:
for t in range(1, T):
first_ind = (
max(0, t - self.tracks_leave_trace)
if self.tracks_leave_trace >= 0
else 0
)
curr_tracks = tracks[first_ind : t + 1]
curr_colors = vector_colors[first_ind : t + 1]
if compensate_for_camera_motion:
diff = (
tracks[first_ind : t + 1, segm_mask <= 0]
- tracks[t : t + 1, segm_mask <= 0]
).mean(1)[:, None]
curr_tracks = curr_tracks - diff
curr_tracks = curr_tracks[:, segm_mask > 0]
curr_colors = curr_colors[:, segm_mask > 0]
res_video[t] = self._draw_pred_tracks(
res_video[t],
curr_tracks,
curr_colors,
)
if gt_tracks is not None:
res_video[t] = self._draw_gt_tracks(
res_video[t], gt_tracks[first_ind : t + 1]
)
if rigid_part is not None:
cls_label = torch.unique(rigid_part)
cls_num = len(torch.unique(rigid_part))
# visualize the clustering results
cmap = plt.get_cmap('jet') # get the color mapping
colors = cmap(np.linspace(0, 1, cls_num))
colors = (colors[:, :3] * 255)
color_map = {lable.item(): color for lable, color in zip(cls_label, colors)}
# Draw points
for t in tqdm(range(T)):
# Create a list to store information for each point
points_info = []
for i in range(N):
coord = (tracks[t, i, 0], tracks[t, i, 1])
depth = tracks[t, i, 2] # assume the third dimension is depth
visibile = True
if visibility is not None:
visibile = visibility[0, t, i]
if coord[0] != 0 and coord[1] != 0:
if not compensate_for_camera_motion or (
compensate_for_camera_motion and segm_mask[i] > 0
):
points_info.append((i, coord, depth, visibile))
# Sort points by depth, points with smaller depth (closer) will be drawn later
points_info.sort(key=lambda x: x[2], reverse=True)
for i, coord, _, visibile in points_info:
if rigid_part is not None:
color = color_map[rigid_part.squeeze()[i].item()]
cv2.circle(
res_video[t],
coord,
int(self.linewidth * 2),
color.tolist(),
thickness=-1 if visibile else 2
-1,
)
else:
# Determine rectangle width based on the distance between adjacent tracks in the first frame
if t == 0:
distances = np.linalg.norm(tracks[0] - tracks[0, i], axis=1)
distances = distances[distances > 0]
rect_size = int(np.min(distances))/2
# Define coordinates for top-left and bottom-right corners of the rectangle
top_left = (int(coord[0] - rect_size), int(coord[1] - rect_size/1.5)) # Rectangle width is 1.5x (video aspect ratio is 1.5:1)
bottom_right = (int(coord[0] + rect_size), int(coord[1] + rect_size/1.5))
# Draw rectangle
cv2.rectangle(
res_video[t],
top_left,
bottom_right,
vector_colors[t, i].tolist(),
thickness=-1 if visibile else 0
-1,
)
# Construct the final rgb sequence
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
def _draw_pred_tracks(
self,
rgb: np.ndarray, # H x W x 3
tracks: np.ndarray, # T x 2
vector_colors: np.ndarray,
alpha: float = 0.5,
):
T, N, _ = tracks.shape
for s in range(T - 1):
vector_color = vector_colors[s]
original = rgb.copy()
alpha = (s / T) ** 2
for i in range(N):
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
if coord_y[0] != 0 and coord_y[1] != 0:
cv2.line(
rgb,
coord_y,
coord_x,
vector_color[i].tolist(),
self.linewidth,
cv2.LINE_AA,
)
if self.tracks_leave_trace > 0:
rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
return rgb
def _draw_gt_tracks(
self,
rgb: np.ndarray, # H x W x 3,
gt_tracks: np.ndarray, # T x 2
):
T, N, _ = gt_tracks.shape
color = np.array((211.0, 0.0, 0.0))
for t in range(T):
for i in range(N):
gt_tracks = gt_tracks[t][i]
# draw a red cross
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
length = self.linewidth * 3
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
cv2.line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
cv2.LINE_AA,
)
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
cv2.line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
cv2.LINE_AA,
)
return rgb

500
das/spatracker/utils/vox.py Normal file
View File

@ -0,0 +1,500 @@
import numpy as np
import torch
import torch.nn.functional as F
import utils.geom
class Vox_util(object):
def __init__(self, Z, Y, X, scene_centroid, bounds, pad=None, assert_cube=False):
self.XMIN, self.XMAX, self.YMIN, self.YMAX, self.ZMIN, self.ZMAX = bounds
B, D = list(scene_centroid.shape)
self.Z, self.Y, self.X = Z, Y, X
scene_centroid = scene_centroid.detach().cpu().numpy()
x_centroid, y_centroid, z_centroid = scene_centroid[0]
self.XMIN += x_centroid
self.XMAX += x_centroid
self.YMIN += y_centroid
self.YMAX += y_centroid
self.ZMIN += z_centroid
self.ZMAX += z_centroid
self.default_vox_size_X = (self.XMAX-self.XMIN)/float(X)
self.default_vox_size_Y = (self.YMAX-self.YMIN)/float(Y)
self.default_vox_size_Z = (self.ZMAX-self.ZMIN)/float(Z)
if pad:
Z_pad, Y_pad, X_pad = pad
self.ZMIN -= self.default_vox_size_Z * Z_pad
self.ZMAX += self.default_vox_size_Z * Z_pad
self.YMIN -= self.default_vox_size_Y * Y_pad
self.YMAX += self.default_vox_size_Y * Y_pad
self.XMIN -= self.default_vox_size_X * X_pad
self.XMAX += self.default_vox_size_X * X_pad
if assert_cube:
# we assume cube voxels
if (not np.isclose(self.default_vox_size_X, self.default_vox_size_Y)) or (not np.isclose(self.default_vox_size_X, self.default_vox_size_Z)):
print('Z, Y, X', Z, Y, X)
print('bounds for this iter:',
'X = %.2f to %.2f' % (self.XMIN, self.XMAX),
'Y = %.2f to %.2f' % (self.YMIN, self.YMAX),
'Z = %.2f to %.2f' % (self.ZMIN, self.ZMAX),
)
print('self.default_vox_size_X', self.default_vox_size_X)
print('self.default_vox_size_Y', self.default_vox_size_Y)
print('self.default_vox_size_Z', self.default_vox_size_Z)
assert(np.isclose(self.default_vox_size_X, self.default_vox_size_Y))
assert(np.isclose(self.default_vox_size_X, self.default_vox_size_Z))
def Ref2Mem(self, xyz, Z, Y, X, assert_cube=False):
# xyz is B x N x 3, in ref coordinates
# transforms ref coordinates into mem coordinates
B, N, C = list(xyz.shape)
device = xyz.device
assert(C==3)
mem_T_ref = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube, device=device)
xyz = utils.geom.apply_4x4(mem_T_ref, xyz)
return xyz
def Mem2Ref(self, xyz_mem, Z, Y, X, assert_cube=False):
# xyz is B x N x 3, in mem coordinates
# transforms mem coordinates into ref coordinates
B, N, C = list(xyz_mem.shape)
ref_T_mem = self.get_ref_T_mem(B, Z, Y, X, assert_cube=assert_cube, device=xyz_mem.device)
xyz_ref = utils.geom.apply_4x4(ref_T_mem, xyz_mem)
return xyz_ref
def get_mem_T_ref(self, B, Z, Y, X, assert_cube=False, device='cuda'):
vox_size_X = (self.XMAX-self.XMIN)/float(X)
vox_size_Y = (self.YMAX-self.YMIN)/float(Y)
vox_size_Z = (self.ZMAX-self.ZMIN)/float(Z)
if assert_cube:
if (not np.isclose(vox_size_X, vox_size_Y)) or (not np.isclose(vox_size_X, vox_size_Z)):
print('Z, Y, X', Z, Y, X)
print('bounds for this iter:',
'X = %.2f to %.2f' % (self.XMIN, self.XMAX),
'Y = %.2f to %.2f' % (self.YMIN, self.YMAX),
'Z = %.2f to %.2f' % (self.ZMIN, self.ZMAX),
)
print('vox_size_X', vox_size_X)
print('vox_size_Y', vox_size_Y)
print('vox_size_Z', vox_size_Z)
assert(np.isclose(vox_size_X, vox_size_Y))
assert(np.isclose(vox_size_X, vox_size_Z))
# translation
# (this makes the left edge of the leftmost voxel correspond to XMIN)
center_T_ref = utils.geom.eye_4x4(B, device=device)
center_T_ref[:,0,3] = -self.XMIN-vox_size_X/2.0
center_T_ref[:,1,3] = -self.YMIN-vox_size_Y/2.0
center_T_ref[:,2,3] = -self.ZMIN-vox_size_Z/2.0
# scaling
# (this makes the right edge of the rightmost voxel correspond to XMAX)
mem_T_center = utils.geom.eye_4x4(B, device=device)
mem_T_center[:,0,0] = 1./vox_size_X
mem_T_center[:,1,1] = 1./vox_size_Y
mem_T_center[:,2,2] = 1./vox_size_Z
mem_T_ref = utils.geom.matmul2(mem_T_center, center_T_ref)
return mem_T_ref
def get_ref_T_mem(self, B, Z, Y, X, assert_cube=False, device='cuda'):
mem_T_ref = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube, device=device)
# note safe_inverse is inapplicable here,
# since the transform is nonrigid
ref_T_mem = mem_T_ref.inverse()
return ref_T_mem
def get_inbounds(self, xyz, Z, Y, X, already_mem=False, padding=0.0, assert_cube=False):
# xyz is B x N x 3
# padding should be 0 unless you are trying to account for some later cropping
if not already_mem:
xyz = self.Ref2Mem(xyz, Z, Y, X, assert_cube=assert_cube)
x = xyz[:,:,0]
y = xyz[:,:,1]
z = xyz[:,:,2]
x_valid = ((x-padding)>-0.5).byte() & ((x+padding)<float(X-0.5)).byte()
y_valid = ((y-padding)>-0.5).byte() & ((y+padding)<float(Y-0.5)).byte()
z_valid = ((z-padding)>-0.5).byte() & ((z+padding)<float(Z-0.5)).byte()
nonzero = (~(z==0.0)).byte()
inbounds = x_valid & y_valid & z_valid & nonzero
return inbounds.bool()
def voxelize_xyz(self, xyz_ref, Z, Y, X, already_mem=False, assert_cube=False, clean_eps=0):
B, N, D = list(xyz_ref.shape)
assert(D==3)
if already_mem:
xyz_mem = xyz_ref
else:
xyz_mem = self.Ref2Mem(xyz_ref, Z, Y, X, assert_cube=assert_cube)
xyz_zero = self.Ref2Mem(xyz_ref[:,0:1]*0, Z, Y, X, assert_cube=assert_cube)
vox = self.get_occupancy(xyz_mem, Z, Y, X, clean_eps=clean_eps, xyz_zero=xyz_zero)
return vox
def voxelize_xyz_and_feats(self, xyz_ref, feats, Z, Y, X, already_mem=False, assert_cube=False, clean_eps=0):
B, N, D = list(xyz_ref.shape)
B2, N2, D2 = list(feats.shape)
assert(D==3)
assert(B==B2)
assert(N==N2)
if already_mem:
xyz_mem = xyz_ref
else:
xyz_mem = self.Ref2Mem(xyz_ref, Z, Y, X, assert_cube=assert_cube)
xyz_zero = self.Ref2Mem(xyz_ref[:,0:1]*0, Z, Y, X, assert_cube=assert_cube)
feats = self.get_feat_occupancy(xyz_mem, feats, Z, Y, X, clean_eps=clean_eps, xyz_zero=xyz_zero)
return feats
def get_occupancy(self, xyz, Z, Y, X, clean_eps=0, xyz_zero=None):
# xyz is B x N x 3 and in mem coords
# we want to fill a voxel tensor with 1's at these inds
B, N, C = list(xyz.shape)
assert(C==3)
# these papers say simple 1/0 occupancy is ok:
# http://openaccess.thecvf.com/content_cvpr_2018/papers/Yang_PIXOR_Real-Time_3d_CVPR_2018_paper.pdf
# http://openaccess.thecvf.com/content_cvpr_2018/papers/Luo_Fast_and_Furious_CVPR_2018_paper.pdf
# cont fusion says they do 8-neighbor interp
# voxelnet does occupancy but with a bit of randomness in terms of the reflectance value i think
inbounds = self.get_inbounds(xyz, Z, Y, X, already_mem=True)
x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2]
mask = torch.zeros_like(x)
mask[inbounds] = 1.0
if xyz_zero is not None:
# only take points that are beyond a thresh of zero
dist = torch.norm(xyz_zero-xyz, dim=2)
mask[dist < 0.1] = 0
if clean_eps > 0:
# only take points that are already near centers
xyz_round = torch.round(xyz) # B, N, 3
dist = torch.norm(xyz_round - xyz, dim=2)
mask[dist > clean_eps] = 0
# set the invalid guys to zero
# we then need to zero out 0,0,0
# (this method seems a bit clumsy)
x = x*mask
y = y*mask
z = z*mask
x = torch.round(x)
y = torch.round(y)
z = torch.round(z)
x = torch.clamp(x, 0, X-1).int()
y = torch.clamp(y, 0, Y-1).int()
z = torch.clamp(z, 0, Z-1).int()
x = x.view(B*N)
y = y.view(B*N)
z = z.view(B*N)
dim3 = X
dim2 = X * Y
dim1 = X * Y * Z
base = torch.arange(0, B, dtype=torch.int32, device=xyz.device)*dim1
base = torch.reshape(base, [B, 1]).repeat([1, N]).view(B*N)
vox_inds = base + z * dim2 + y * dim3 + x
voxels = torch.zeros(B*Z*Y*X, device=xyz.device).float()
voxels[vox_inds.long()] = 1.0
# zero out the singularity
voxels[base.long()] = 0.0
voxels = voxels.reshape(B, 1, Z, Y, X)
# B x 1 x Z x Y x X
return voxels
def get_feat_occupancy(self, xyz, feat, Z, Y, X, clean_eps=0, xyz_zero=None):
# xyz is B x N x 3 and in mem coords
# feat is B x N x D
# we want to fill a voxel tensor with 1's at these inds
B, N, C = list(xyz.shape)
B2, N2, D2 = list(feat.shape)
assert(C==3)
assert(B==B2)
assert(N==N2)
# these papers say simple 1/0 occupancy is ok:
# http://openaccess.thecvf.com/content_cvpr_2018/papers/Yang_PIXOR_Real-Time_3d_CVPR_2018_paper.pdf
# http://openaccess.thecvf.com/content_cvpr_2018/papers/Luo_Fast_and_Furious_CVPR_2018_paper.pdf
# cont fusion says they do 8-neighbor interp
# voxelnet does occupancy but with a bit of randomness in terms of the reflectance value i think
inbounds = self.get_inbounds(xyz, Z, Y, X, already_mem=True)
x, y, z = xyz[:,:,0], xyz[:,:,1], xyz[:,:,2]
mask = torch.zeros_like(x)
mask[inbounds] = 1.0
if xyz_zero is not None:
# only take points that are beyond a thresh of zero
dist = torch.norm(xyz_zero-xyz, dim=2)
mask[dist < 0.1] = 0
if clean_eps > 0:
# only take points that are already near centers
xyz_round = torch.round(xyz) # B, N, 3
dist = torch.norm(xyz_round - xyz, dim=2)
mask[dist > clean_eps] = 0
# set the invalid guys to zero
# we then need to zero out 0,0,0
# (this method seems a bit clumsy)
x = x*mask # B, N
y = y*mask
z = z*mask
feat = feat*mask.unsqueeze(-1) # B, N, D
x = torch.round(x)
y = torch.round(y)
z = torch.round(z)
x = torch.clamp(x, 0, X-1).int()
y = torch.clamp(y, 0, Y-1).int()
z = torch.clamp(z, 0, Z-1).int()
# permute point orders
perm = torch.randperm(N)
x = x[:, perm]
y = y[:, perm]
z = z[:, perm]
feat = feat[:, perm]
x = x.view(B*N)
y = y.view(B*N)
z = z.view(B*N)
feat = feat.view(B*N, -1)
dim3 = X
dim2 = X * Y
dim1 = X * Y * Z
base = torch.arange(0, B, dtype=torch.int32, device=xyz.device)*dim1
base = torch.reshape(base, [B, 1]).repeat([1, N]).view(B*N)
vox_inds = base + z * dim2 + y * dim3 + x
feat_voxels = torch.zeros((B*Z*Y*X, D2), device=xyz.device).float()
feat_voxels[vox_inds.long()] = feat
# zero out the singularity
feat_voxels[base.long()] = 0.0
feat_voxels = feat_voxels.reshape(B, Z, Y, X, D2).permute(0, 4, 1, 2, 3)
# B x C x Z x Y x X
return feat_voxels
def unproject_image_to_mem(self, rgb_camB, pixB_T_camA, camB_T_camA, Z, Y, X, assert_cube=False, xyz_camA=None):
# rgb_camB is B x C x H x W
# pixB_T_camA is B x 4 x 4
# rgb lives in B pixel coords
# we want everything in A memory coords
# this puts each C-dim pixel in the rgb_camB
# along a ray in the voxelgrid
B, C, H, W = list(rgb_camB.shape)
if xyz_camA is None:
xyz_memA = utils.basic.gridcloud3d(B, Z, Y, X, norm=False, device=pixB_T_camA.device)
xyz_camA = self.Mem2Ref(xyz_memA, Z, Y, X, assert_cube=assert_cube)
xyz_camB = utils.geom.apply_4x4(camB_T_camA, xyz_camA)
z = xyz_camB[:,:,2]
xyz_pixB = utils.geom.apply_4x4(pixB_T_camA, xyz_camA)
normalizer = torch.unsqueeze(xyz_pixB[:,:,2], 2)
EPS=1e-6
# z = xyz_pixB[:,:,2]
xy_pixB = xyz_pixB[:,:,:2]/torch.clamp(normalizer, min=EPS)
# this is B x N x 2
# this is the (floating point) pixel coordinate of each voxel
x, y = xy_pixB[:,:,0], xy_pixB[:,:,1]
# these are B x N
x_valid = (x>-0.5).bool() & (x<float(W-0.5)).bool()
y_valid = (y>-0.5).bool() & (y<float(H-0.5)).bool()
z_valid = (z>0.0).bool()
valid_mem = (x_valid & y_valid & z_valid).reshape(B, 1, Z, Y, X).float()
if (0):
# handwritten version
values = torch.zeros([B, C, Z*Y*X], dtype=torch.float32)
for b in list(range(B)):
values[b] = utils.samp.bilinear_sample_single(rgb_camB[b], x_pixB[b], y_pixB[b])
else:
# native pytorch version
y_pixB, x_pixB = utils.basic.normalize_grid2d(y, x, H, W)
# since we want a 3d output, we need 5d tensors
z_pixB = torch.zeros_like(x)
xyz_pixB = torch.stack([x_pixB, y_pixB, z_pixB], axis=2)
rgb_camB = rgb_camB.unsqueeze(2)
xyz_pixB = torch.reshape(xyz_pixB, [B, Z, Y, X, 3])
values = F.grid_sample(rgb_camB, xyz_pixB, align_corners=False)
values = torch.reshape(values, (B, C, Z, Y, X))
values = values * valid_mem
return values
def warp_tiled_to_mem(self, rgb_tileB, pixB_T_camA, camB_T_camA, Z, Y, X, DMIN, DMAX, assert_cube=False):
# rgb_tileB is B,C,D,H,W
# pixB_T_camA is B,4,4
# camB_T_camA is B,4,4
# rgb_tileB lives in B pixel coords but it has been tiled across the Z dimension
# we want everything in A memory coords
# this resamples the so that each C-dim pixel in rgb_tilB
# is put into its correct place in the voxelgrid
# (using the pinhole camera model)
B, C, D, H, W = list(rgb_tileB.shape)
xyz_memA = utils.basic.gridcloud3d(B, Z, Y, X, norm=False, device=pixB_T_camA.device)
xyz_camA = self.Mem2Ref(xyz_memA, Z, Y, X, assert_cube=assert_cube)
xyz_camB = utils.geom.apply_4x4(camB_T_camA, xyz_camA)
z_camB = xyz_camB[:,:,2]
# rgb_tileB has depth=DMIN in tile 0, and depth=DMAX in tile D-1
z_tileB = (D-1.0) * (z_camB-float(DMIN)) / float(DMAX-DMIN)
xyz_pixB = utils.geom.apply_4x4(pixB_T_camA, xyz_camA)
normalizer = torch.unsqueeze(xyz_pixB[:,:,2], 2)
EPS=1e-6
# z = xyz_pixB[:,:,2]
xy_pixB = xyz_pixB[:,:,:2]/torch.clamp(normalizer, min=EPS)
# this is B x N x 2
# this is the (floating point) pixel coordinate of each voxel
x, y = xy_pixB[:,:,0], xy_pixB[:,:,1]
# these are B x N
x_valid = (x>-0.5).bool() & (x<float(W-0.5)).bool()
y_valid = (y>-0.5).bool() & (y<float(H-0.5)).bool()
z_valid = (z_camB>0.0).bool()
valid_mem = (x_valid & y_valid & z_valid).reshape(B, 1, Z, Y, X).float()
z_tileB, y_pixB, x_pixB = utils.basic.normalize_grid3d(z_tileB, y, x, D, H, W)
xyz_pixB = torch.stack([x_pixB, y_pixB, z_tileB], axis=2)
xyz_pixB = torch.reshape(xyz_pixB, [B, Z, Y, X, 3])
values = F.grid_sample(rgb_tileB, xyz_pixB, align_corners=False)
values = torch.reshape(values, (B, C, Z, Y, X))
values = values * valid_mem
return values
def apply_mem_T_ref_to_lrtlist(self, lrtlist_cam, Z, Y, X, assert_cube=False):
# lrtlist is B x N x 19, in cam coordinates
# transforms them into mem coordinates, including a scale change for the lengths
B, N, C = list(lrtlist_cam.shape)
assert(C==19)
mem_T_cam = self.get_mem_T_ref(B, Z, Y, X, assert_cube=assert_cube, device=lrtlist_cam.device)
def xyz2circles(self, xyz, radius, Z, Y, X, soft=True, already_mem=True, also_offset=False, grid=None):
# xyz is B x N x 3
# radius is B x N or broadcastably so
# output is B x N x Z x Y x X
B, N, D = list(xyz.shape)
assert(D==3)
if not already_mem:
xyz = self.Ref2Mem(xyz, Z, Y, X)
if grid is None:
grid_z, grid_y, grid_x = utils.basic.meshgrid3d(B, Z, Y, X, stack=False, norm=False, device=xyz.device)
# note the default stack is on -1
grid = torch.stack([grid_x, grid_y, grid_z], dim=1)
# this is B x 3 x Z x Y x X
xyz = xyz.reshape(B, N, 3, 1, 1, 1)
grid = grid.reshape(B, 1, 3, Z, Y, X)
# this is B x N x Z x Y x X
# round the xyzs, so that at least one value matches the grid perfectly,
# and we get a value of 1 there (since exp(0)==1)
xyz = xyz.round()
if torch.is_tensor(radius):
radius = radius.clamp(min=0.01)
if soft:
off = grid - xyz # B,N,3,Z,Y,X
# interpret radius as sigma
dist_grid = torch.sum(off**2, dim=2, keepdim=False)
# this is B x N x Z x Y x X
if torch.is_tensor(radius):
radius = radius.reshape(B, N, 1, 1, 1)
mask = torch.exp(-dist_grid/(2*radius*radius))
# zero out near zero
mask[mask < 0.001] = 0.0
# h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
# h[h < np.finfo(h.dtype).eps * h.max()] = 0
# return h
if also_offset:
return mask, off
else:
return mask
else:
assert(False) # something is wrong with this. come back later to debug
dist_grid = torch.norm(grid - xyz, dim=2, keepdim=False)
# this is 0 at/near the xyz, and increases by 1 for each voxel away
radius = radius.reshape(B, N, 1, 1, 1)
within_radius_mask = (dist_grid < radius).float()
within_radius_mask = torch.sum(within_radius_mask, dim=1, keepdim=True).clamp(0, 1)
return within_radius_mask
def xyz2circles_bev(self, xyz, radius, Z, Y, X, already_mem=True, also_offset=False):
# xyz is B x N x 3
# radius is B x N or broadcastably so
# output is B x N x Z x Y x X
B, N, D = list(xyz.shape)
assert(D==3)
if not already_mem:
xyz = self.Ref2Mem(xyz, Z, Y, X)
xz = torch.stack([xyz[:,:,0], xyz[:,:,2]], dim=2)
grid_z, grid_x = utils.basic.meshgrid2d(B, Z, X, stack=False, norm=False, device=xyz.device)
# note the default stack is on -1
grid = torch.stack([grid_x, grid_z], dim=1)
# this is B x 2 x Z x X
xz = xz.reshape(B, N, 2, 1, 1)
grid = grid.reshape(B, 1, 2, Z, X)
# these are ready to broadcast to B x N x Z x X
# round the points, so that at least one value matches the grid perfectly,
# and we get a value of 1 there (since exp(0)==1)
xz = xz.round()
if torch.is_tensor(radius):
radius = radius.clamp(min=0.01)
off = grid - xz # B,N,2,Z,X
# interpret radius as sigma
dist_grid = torch.sum(off**2, dim=2, keepdim=False)
# this is B x N x Z x X
if torch.is_tensor(radius):
radius = radius.reshape(B, N, 1, 1, 1)
mask = torch.exp(-dist_grid/(2*radius*radius))
# zero out near zero
mask[mask < 0.001] = 0.0
# add a Y dim
mask = mask.unsqueeze(-2)
off = off.unsqueeze(-2)
# # B,N,2,Z,1,X
if also_offset:
return mask, off
else:
return mask

View File

@ -712,6 +712,7 @@ class CogVideoXModelLoader:
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
manual_offloading = True
das_transformer = False
transformer_load_device = device if load_device == "main_device" else offload_device
mm.soft_empty_cache()
@ -719,9 +720,14 @@ class CogVideoXModelLoader:
model_path = folder_paths.get_full_path_or_raise("diffusion_models", model)
sd = load_torch_file(model_path, device=transformer_load_device)
first_key = next(iter(sd.keys()))
model_type = ""
if sd["patch_embed.proj.weight"].shape == (3072, 33, 2, 2):
if first_key == "combine_linears.0.bias":
log.info("Detected 'Diffusion As Shader' model")
model_type = "I2V_5b"
das_transformer = True
elif sd["patch_embed.proj.weight"].shape == (3072, 33, 2, 2):
model_type = "fun_5b"
elif sd["patch_embed.proj.weight"].shape == (3072, 16, 2, 2):
model_type = "5b"
@ -770,7 +776,7 @@ class CogVideoXModelLoader:
transformer_config["sample_width"] = 300
with init_empty_weights():
transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode)
transformer = CogVideoXTransformer3DModel.from_config(transformer_config, attention_mode=attention_mode, das_transformer=das_transformer)
#load weights
#params_to_keep = {}

View File

@ -631,8 +631,9 @@ class CogVideoSampler:
"controlnet": ("COGVIDECONTROLNET",),
"tora_trajectory": ("TORAFEATURES", ),
"fastercache": ("FASTERCACHEARGS", ),
"feta_args": ("FETAARGS", ),
"teacache_args": ("TEACACHEARGS", ),
"feta_args": ("FETAARGS", {"tooltip": "Arguments for Enhance-a-video"} ),
"teacache_args": ("TEACACHEARGS",{"tooltip": "Arguments for TeaCache"} ),
"das_tracking": ("DASTRACKING", {"tooltip": "Enable tracking for Diffusion As Shader"} ),
}
}
@ -642,17 +643,14 @@ class CogVideoSampler:
CATEGORY = "CogVideoWrapper"
def process(self, model, positive, negative, steps, cfg, seed, scheduler, num_frames, samples=None,
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None, fastercache=None, feta_args=None, teacache_args=None):
denoise_strength=1.0, image_cond_latents=None, context_options=None, controlnet=None, tora_trajectory=None,
das_tracking=None, fastercache=None, feta_args=None, teacache_args=None):
mm.unload_all_models()
mm.soft_empty_cache()
model_name = model.get("model_name", "")
supports_image_conds = True if (
"I2V" in model_name or
"interpolation" in model_name.lower() or
"fun" in model_name.lower() or
"img2vid" in model_name.lower()
) else False
supports_image_conds = True if model["pipe"].transformer.config.in_channels == 32 else False
if "fun" in model_name.lower() and not ("pose" in model_name.lower() or "control" in model_name.lower()) and image_cond_latents is not None:
assert image_cond_latents["mask"] is not None, "For fun inpaint models use CogVideoImageEncodeFunInP"
fun_mask = image_cond_latents["mask"]
@ -771,6 +769,7 @@ class CogVideoSampler:
image_cond_start_percent=image_cond_start_percent if image_cond_latents is not None else 0.0,
image_cond_end_percent=image_cond_end_percent if image_cond_latents is not None else 1.0,
feta_args=feta_args,
das_tracking=das_tracking,
)
if not model["cpu_offloading"] and model["manual_offloading"]:
pipe.transformer.to(offload_device)
@ -1033,4 +1032,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoImageEncodeFunInP": "CogVideo ImageEncode FunInP",
"CogVideoEnhanceAVideo": "CogVideo Enhance-A-Video",
"CogVideoXTeaCache": "CogVideoX TeaCache",
"CogVideoDASTrackingEncode": "CogVideo DAS Tracking Encode",
}

View File

@ -383,6 +383,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
image_cond_start_percent: float = 0.0,
image_cond_end_percent: float = 1.0,
feta_args: Optional[dict] = None,
das_tracking: Optional[dict] = None,
):
"""
@ -625,6 +626,24 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
else:
disable_enhance()
#das
if das_tracking is not None:
tracking_maps = das_tracking["tracking_maps"]
tracking_image_latents = das_tracking["tracking_image_latents"]
das_start_percent = das_tracking["start_percent"]
das_end_percent = das_tracking["end_percent"]
padding_shape = (
batch_size,
(latents.shape[1] - 1),
self.vae_latent_channels,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
latent_padding = torch.zeros(padding_shape, device=device, dtype=self.vae_dtype)
tracking_image_latents = torch.cat([tracking_image_latents, latent_padding], dim=1)
# reset TeaCache
if hasattr(self.transformer, 'accumulated_rel_l1_distance'):
delattr(self.transformer, 'accumulated_rel_l1_distance')
@ -805,6 +824,15 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
fun_inpaint_latents = torch.cat([fun_inpaint_mask, fun_inpaint_masked_video_latents], dim=2).to(latents.dtype)
latent_model_input = torch.cat([latent_model_input, fun_inpaint_latents], dim=2)
if das_tracking is not None and das_start_percent <= current_step_percentage <= das_end_percent:
logger.info("DAS tracking enabled")
latents_tracking_image = torch.cat([tracking_image_latents] * 2) if do_classifier_free_guidance else tracking_image_latents
tracking_maps_input = torch.cat([tracking_maps] * 2) if do_classifier_free_guidance else tracking_maps
tracking_maps_input = torch.cat([tracking_maps_input, latents_tracking_image], dim=2)
del latents_tracking_image
else:
tracking_maps_input = None
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
@ -836,6 +864,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
controlnet_states=controlnet_states,
controlnet_weights=control_weights,
video_flow_features=video_flow_features if (tora is not None and tora["start_percent"] <= current_step_percentage <= tora["end_percent"]) else None,
tracking_maps=tracking_maps_input,
)[0]
noise_pred = noise_pred.float()
if isinstance(self.scheduler, CogVideoXDPMScheduler):