mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
add moge
This commit is contained in:
parent
ee7d04d342
commit
1124c77d56
259
das/das_nodes.py
259
das/das_nodes.py
@ -1,9 +1,14 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management as mm
|
import comfy.model_management as mm
|
||||||
|
from comfy.utils import ProgressBar, common_upscale
|
||||||
from ..utils import log
|
from ..utils import log
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
from tqdm import tqdm
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
from .motion import CameraMotionGenerator, ObjectMotionGenerator
|
||||||
|
|
||||||
class CogVideoDASTrackingEncode:
|
class CogVideoDASTrackingEncode:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -139,32 +144,278 @@ class DAS_SpaTracker:
|
|||||||
progressive_tracking=False
|
progressive_tracking=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# cam_motion = CameraMotionGenerator(
|
||||||
|
# motion_type="trans",
|
||||||
|
# frame_num=49,
|
||||||
|
# W=720,
|
||||||
|
# H=480,
|
||||||
|
# fx=None,
|
||||||
|
# fy=None,
|
||||||
|
# fov=55,
|
||||||
|
# device=device,
|
||||||
|
# )
|
||||||
|
# poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
|
||||||
|
# pred_tracks = cam_motion.apply_motion_on_pts(pred_tracks, poses)
|
||||||
|
# print("Camera motion applied")
|
||||||
|
|
||||||
spatracker.to(offload_device)
|
spatracker.to(offload_device)
|
||||||
|
|
||||||
from .spatracker.utils.visualizer import Visualizer
|
from .spatracker.utils.visualizer import Visualizer
|
||||||
vis = Visualizer(grayscale=False, fps=24, pad_value=0)
|
vis = Visualizer(
|
||||||
|
grayscale=False,
|
||||||
|
fps=24,
|
||||||
|
pad_value=0,
|
||||||
|
#tracks_leave_trace=-1
|
||||||
|
)
|
||||||
|
|
||||||
msk_query = (T_Firsts == 0)
|
msk_query = (T_Firsts == 0)
|
||||||
pred_tracks = pred_tracks[:,:,msk_query.squeeze()]
|
pred_tracks = pred_tracks[:,:,msk_query.squeeze()]
|
||||||
pred_visibility = pred_visibility[:,:,msk_query.squeeze()]
|
pred_visibility = pred_visibility[:,:,msk_query.squeeze()]
|
||||||
|
|
||||||
tracking_video = vis.visualize(video=video, tracks=pred_tracks,
|
tracking_video = vis.visualize(
|
||||||
visibility=pred_visibility, save_video=False,
|
video=video,
|
||||||
filename="temp")
|
tracks=pred_tracks,
|
||||||
|
visibility=pred_visibility,
|
||||||
|
save_video=False,
|
||||||
|
)
|
||||||
|
|
||||||
tracking_video = tracking_video.squeeze(0).permute(0, 2, 3, 1) # [T, H, W, C]
|
tracking_video = tracking_video.squeeze(0).permute(0, 2, 3, 1) # [T, H, W, C]
|
||||||
tracking_video = (tracking_video / 255.0).float()
|
tracking_video = (tracking_video / 255.0).float()
|
||||||
|
|
||||||
return (tracking_video,)
|
return (tracking_video,)
|
||||||
|
|
||||||
|
class DAS_MoGeTracker:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"model": ("MOGEMODEL",),
|
||||||
|
"image": ("IMAGE", ),
|
||||||
|
"num_frames": ("INT", {"default": 49, "min": 1, "max": 100, "step": 1}),
|
||||||
|
"width": ("INT", {"default": 720, "min": 1, "max": 10000, "step": 1}),
|
||||||
|
"height": ("INT", {"default": 480, "min": 1, "max": 10000, "step": 1}),
|
||||||
|
"fov": ("FLOAT", {"default": 55.0, "min": 1.0, "max": 180.0, "step": 1.0}),
|
||||||
|
"object_motion_type": (["none", "up", "down", "left", "right", "front", "back"],),
|
||||||
|
"object_motion_distance": ("INT", {"default": 50, "min": 1, "max": 1000, "step": 1}),
|
||||||
|
"camera_motion_type": (["none","translation", "rotation", "spiral"],),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"mask": ("MASK", ),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
RETURN_NAMES = ("tracking_video",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
|
||||||
|
def encode(self, model, image, num_frames, width, height, fov, object_motion_type, object_motion_distance, camera_motion_type, mask=None):
|
||||||
|
device = mm.get_torch_device()
|
||||||
|
offload_device = mm.unet_offload_device()
|
||||||
|
B, H, W, C = image.shape
|
||||||
|
|
||||||
|
image_resized = common_upscale(image.movedim(-1,1), width, height, "lanczos", "disabled").movedim(1,-1)
|
||||||
|
|
||||||
|
# Use the first frame from previously loaded video_tensor
|
||||||
|
infer_result = model.infer(image_resized.permute(0, 3, 1, 2).to(device)[0].to(device)) # [C, H, W] in range [0,1]
|
||||||
|
H, W = infer_result["points"].shape[0:2]
|
||||||
|
|
||||||
|
motion_generator = ObjectMotionGenerator(num_frames, device=device)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask[0].bool()
|
||||||
|
mask = torch.nn.functional.interpolate(
|
||||||
|
mask[None, None].float(),
|
||||||
|
size=(H, W),
|
||||||
|
mode='nearest'
|
||||||
|
)[0, 0].bool()
|
||||||
|
else:
|
||||||
|
mask = torch.ones(H, W, dtype=torch.bool)
|
||||||
|
|
||||||
|
# Generate motion dictionary
|
||||||
|
motion_dict = motion_generator.generate_motion(
|
||||||
|
mask=mask,
|
||||||
|
motion_type=object_motion_type,
|
||||||
|
distance=object_motion_distance,
|
||||||
|
num_frames=num_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_tracks = motion_generator.apply_motion(
|
||||||
|
infer_result["points"],
|
||||||
|
motion_dict,
|
||||||
|
tracking_method="moge"
|
||||||
|
)
|
||||||
|
print("pred_tracks shape: ", pred_tracks.shape)
|
||||||
|
print("Object motion applied")
|
||||||
|
|
||||||
|
camera_motion_type_mapping = {
|
||||||
|
"none": "none",
|
||||||
|
"translation": "trans",
|
||||||
|
"rotation": "rot",
|
||||||
|
"spiral": "spiral"
|
||||||
|
}
|
||||||
|
cam_motion = CameraMotionGenerator(
|
||||||
|
motion_type=camera_motion_type_mapping[camera_motion_type],
|
||||||
|
frame_num=num_frames,
|
||||||
|
W=width,
|
||||||
|
H=height,
|
||||||
|
fx=None,
|
||||||
|
fy=None,
|
||||||
|
fov=fov,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# Apply camera motion if specified
|
||||||
|
cam_motion.set_intr(infer_result["intrinsics"])
|
||||||
|
poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
|
||||||
|
|
||||||
|
pred_tracks_flatten = pred_tracks.reshape(num_frames, H*W, 3)
|
||||||
|
pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([num_frames, H, W, 3]) # [T, H, W, 3]
|
||||||
|
print("Camera motion applied")
|
||||||
|
|
||||||
|
|
||||||
|
points = pred_tracks.cpu().numpy()
|
||||||
|
mask = infer_result["mask"].cpu().numpy()
|
||||||
|
# Create color array
|
||||||
|
T, H, W, _ = pred_tracks.shape
|
||||||
|
|
||||||
|
print("points shape: ", points.shape)
|
||||||
|
|
||||||
|
print("mask shape: ", mask.shape)
|
||||||
|
colors = np.zeros((H, W, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Set R channel - based on x coordinates (smaller on the left)
|
||||||
|
colors[:, :, 0] = np.tile(np.linspace(0, 255, W), (H, 1))
|
||||||
|
|
||||||
|
# Set G channel - based on y coordinates (smaller on the top)
|
||||||
|
colors[:, :, 1] = np.tile(np.linspace(0, 255, H), (W, 1)).T
|
||||||
|
|
||||||
|
# Set B channel - based on depth
|
||||||
|
z_values = points[0, :, :, 2] # get z values
|
||||||
|
inv_z = 1 / z_values # calculate 1/z
|
||||||
|
# Calculate 2% and 98% percentiles
|
||||||
|
p2 = np.percentile(inv_z, 2)
|
||||||
|
p98 = np.percentile(inv_z, 98)
|
||||||
|
# Normalize to [0,1] range
|
||||||
|
normalized_z = np.clip((inv_z - p2) / (p98 - p2), 0, 1)
|
||||||
|
colors[:, :, 2] = (normalized_z * 255).astype(np.uint8)
|
||||||
|
colors = colors.astype(np.uint8)
|
||||||
|
|
||||||
|
# First reshape points and colors
|
||||||
|
points = points.reshape(T, -1, 3) # (T, H*W, 3)
|
||||||
|
colors = colors.reshape(-1, 3) # (H*W, 3)
|
||||||
|
|
||||||
|
# Create mask for each frame
|
||||||
|
mask = mask.reshape(-1) # Flatten mask to (H*W,)
|
||||||
|
|
||||||
|
# Apply mask
|
||||||
|
points = points[:, mask, :] # (T, masked_points, 3)
|
||||||
|
colors = colors[mask] # (masked_points, 3)
|
||||||
|
|
||||||
|
# Repeat colors for each frame
|
||||||
|
colors = colors.reshape(1, -1, 3).repeat(T, axis=0) # (T, masked_points, 3)
|
||||||
|
|
||||||
|
# Initialize list to store frames
|
||||||
|
frames = []
|
||||||
|
pbar = ProgressBar(len(points))
|
||||||
|
|
||||||
|
for i, pts_i in enumerate(tqdm(points)):
|
||||||
|
pixels, depths = pts_i[..., :2], pts_i[..., 2]
|
||||||
|
pixels[..., 0] = pixels[..., 0] * W
|
||||||
|
pixels[..., 1] = pixels[..., 1] * H
|
||||||
|
pixels = pixels.astype(int)
|
||||||
|
|
||||||
|
valid = self.valid_mask(pixels, W, H)
|
||||||
|
|
||||||
|
frame_rgb = colors[i][valid]
|
||||||
|
pixels = pixels[valid]
|
||||||
|
depths = depths[valid]
|
||||||
|
|
||||||
|
img = Image.fromarray(np.uint8(np.zeros([H, W, 3])), mode="RGB")
|
||||||
|
sorted_pixels, _, sort_index = self.sort_points_by_depth(pixels, depths)
|
||||||
|
step = 1
|
||||||
|
sorted_pixels = sorted_pixels[::step]
|
||||||
|
sorted_rgb = frame_rgb[sort_index][::step]
|
||||||
|
|
||||||
|
for j in range(sorted_pixels.shape[0]):
|
||||||
|
self.draw_rectangle(
|
||||||
|
img,
|
||||||
|
coord=(sorted_pixels[j, 0], sorted_pixels[j, 1]),
|
||||||
|
side_length=2,
|
||||||
|
color=sorted_rgb[j],
|
||||||
|
)
|
||||||
|
frames.append(np.array(img))
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Convert frames to video tensor in range [0,1]
|
||||||
|
tracking_video = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2).float() / 255.0
|
||||||
|
tracking_video = tracking_video.permute(0, 2, 3, 1) # [B, H, W, C]
|
||||||
|
print("tracking_video shape: ", tracking_video.shape)
|
||||||
|
return (tracking_video,)
|
||||||
|
|
||||||
|
def valid_mask(self, pixels, W, H):
|
||||||
|
"""Check if pixels are within valid image bounds
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixels (numpy.ndarray): Pixel coordinates of shape [N, 2]
|
||||||
|
W (int): Image width
|
||||||
|
H (int): Image height
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: Boolean mask of valid pixels
|
||||||
|
"""
|
||||||
|
return ((pixels[:, 0] >= 0) & (pixels[:, 0] < W) & (pixels[:, 1] > 0) & \
|
||||||
|
(pixels[:, 1] < H))
|
||||||
|
|
||||||
|
def sort_points_by_depth(self, points, depths):
|
||||||
|
"""Sort points by depth values
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points (numpy.ndarray): Points array of shape [N, 2]
|
||||||
|
depths (numpy.ndarray): Depth values of shape [N]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (sorted_points, sorted_depths, sort_index)
|
||||||
|
"""
|
||||||
|
# Combine points and depths into a single array for sorting
|
||||||
|
combined = np.hstack((points, depths[:, None])) # Nx3 (points + depth)
|
||||||
|
# Sort by depth (last column) in descending order
|
||||||
|
sort_index = combined[:, -1].argsort()[::-1]
|
||||||
|
sorted_combined = combined[sort_index]
|
||||||
|
# Split back into points and depths
|
||||||
|
sorted_points = sorted_combined[:, :-1]
|
||||||
|
sorted_depths = sorted_combined[:, -1]
|
||||||
|
return sorted_points, sorted_depths, sort_index
|
||||||
|
|
||||||
|
def draw_rectangle(self, rgb, coord, side_length, color=(255, 0, 0)):
|
||||||
|
"""Draw a rectangle on the image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rgb (PIL.Image): Image to draw on
|
||||||
|
coord (tuple): Center coordinates (x, y)
|
||||||
|
side_length (int): Length of rectangle sides
|
||||||
|
color (tuple): RGB color tuple
|
||||||
|
"""
|
||||||
|
draw = ImageDraw.Draw(rgb)
|
||||||
|
# Calculate the bounding box of the rectangle
|
||||||
|
left_up_point = (coord[0] - side_length//2, coord[1] - side_length//2)
|
||||||
|
right_down_point = (coord[0] + side_length//2, coord[1] + side_length//2)
|
||||||
|
color = tuple(list(color))
|
||||||
|
|
||||||
|
draw.rectangle(
|
||||||
|
[left_up_point, right_down_point],
|
||||||
|
fill=tuple(color),
|
||||||
|
outline=tuple(color),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CogVideoDASTrackingEncode": CogVideoDASTrackingEncode,
|
"CogVideoDASTrackingEncode": CogVideoDASTrackingEncode,
|
||||||
"DAS_SpaTracker": DAS_SpaTracker,
|
"DAS_SpaTracker": DAS_SpaTracker,
|
||||||
"DAS_SpaTrackerModelLoader": DAS_SpaTrackerModelLoader,
|
"DAS_SpaTrackerModelLoader": DAS_SpaTrackerModelLoader,
|
||||||
|
"DAS_MoGeTracker": DAS_MoGeTracker,
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"CogVideoDASTrackingEncode": "CogVideo DAS Tracking Encode",
|
"CogVideoDASTrackingEncode": "CogVideo DAS Tracking Encode",
|
||||||
"DAS_SpaTracker": "DAS SpaTracker",
|
"DAS_SpaTracker": "DAS SpaTracker",
|
||||||
"DAS_SpaTrackerModelLoader": "DAS SpaTracker Model Loader",
|
"DAS_SpaTrackerModelLoader": "DAS SpaTracker Model Loader",
|
||||||
|
"DAS_MoGeTracker": "DAS MoGe Tracker",
|
||||||
}
|
}
|
||||||
|
|||||||
382
das/motion.py
Normal file
382
das/motion.py
Normal file
@ -0,0 +1,382 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
|
||||||
|
class CameraMotionGenerator:
|
||||||
|
def __init__(self, motion_type, frame_num=49, H=480, W=720, fx=None, fy=None, fov=55, device='cuda'):
|
||||||
|
self.motion_type = motion_type
|
||||||
|
self.frame_num = frame_num
|
||||||
|
self.fov = fov
|
||||||
|
self.device = device
|
||||||
|
self.W = W
|
||||||
|
self.H = H
|
||||||
|
self.intr = torch.tensor([
|
||||||
|
[0, 0, W / 2],
|
||||||
|
[0, 0, H / 2],
|
||||||
|
[0, 0, 1]
|
||||||
|
], dtype=torch.float32, device=device)
|
||||||
|
# if fx, fy not provided
|
||||||
|
if not fx or not fy:
|
||||||
|
fov_rad = math.radians(fov)
|
||||||
|
fx = fy = (W / 2) / math.tan(fov_rad / 2)
|
||||||
|
|
||||||
|
self.intr[0, 0] = fx
|
||||||
|
self.intr[1, 1] = fy
|
||||||
|
|
||||||
|
def _apply_poses(self, pts, poses):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pts (torch.Tensor): pointclouds coordinates [T, N, 3]
|
||||||
|
intr (torch.Tensor): camera intrinsics [T, 3, 3]
|
||||||
|
poses (numpy.ndarray): camera poses [T, 4, 4]
|
||||||
|
"""
|
||||||
|
if isinstance(poses, np.ndarray):
|
||||||
|
poses = torch.from_numpy(poses)
|
||||||
|
|
||||||
|
intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1).to(torch.float)
|
||||||
|
T, N, _ = pts.shape
|
||||||
|
ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float)
|
||||||
|
pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3)
|
||||||
|
pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3)
|
||||||
|
pts_cam[:,:, :3] *= pts[:, :, 2:3]
|
||||||
|
|
||||||
|
# to homogeneous
|
||||||
|
pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4)
|
||||||
|
|
||||||
|
if poses.shape[0] == 1:
|
||||||
|
poses = poses.repeat(T, 1, 1)
|
||||||
|
elif poses.shape[0] != T:
|
||||||
|
raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})")
|
||||||
|
|
||||||
|
poses = poses.to(torch.float).to(self.device)
|
||||||
|
pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3)
|
||||||
|
pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3)
|
||||||
|
pts_proj[:, :, :2] /= pts_proj[:, :, 2:3]
|
||||||
|
|
||||||
|
return pts_proj
|
||||||
|
|
||||||
|
def w2s(self, pts, poses):
|
||||||
|
if isinstance(poses, np.ndarray):
|
||||||
|
poses = torch.from_numpy(poses)
|
||||||
|
assert poses.shape[0] == self.frame_num
|
||||||
|
poses = poses.to(torch.float32).to(self.device)
|
||||||
|
T, N, _ = pts.shape # (T, N, 3)
|
||||||
|
intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1)
|
||||||
|
# Step 1: 扩展点的维度,使其变成 (T, N, 4),最后一维填充1 (齐次坐标)
|
||||||
|
ones = torch.ones((T, N, 1), device=self.device, dtype=pts.dtype)
|
||||||
|
points_world_h = torch.cat([pts, ones], dim=-1)
|
||||||
|
points_camera_h = torch.bmm(poses, points_world_h.permute(0, 2, 1))
|
||||||
|
points_camera = points_camera_h[:, :3, :].permute(0, 2, 1)
|
||||||
|
|
||||||
|
points_image_h = torch.bmm(points_camera, intr.permute(0, 2, 1))
|
||||||
|
|
||||||
|
uv = points_image_h[:, :, :2] / points_image_h[:, :, 2:3]
|
||||||
|
|
||||||
|
# Step 5: 提取深度 (Z) 并拼接
|
||||||
|
depth = points_camera[:, :, 2:3] # (T, N, 1)
|
||||||
|
uvd = torch.cat([uv, depth], dim=-1) # (T, N, 3)
|
||||||
|
|
||||||
|
return uvd # 屏幕坐标 + 深度 (T, N, 3)
|
||||||
|
|
||||||
|
def apply_motion_on_pts(self, pts, camera_motion):
|
||||||
|
tracking_pts = self._apply_poses(pts.squeeze(), camera_motion).unsqueeze(0)
|
||||||
|
return tracking_pts
|
||||||
|
|
||||||
|
def set_intr(self, K):
|
||||||
|
if isinstance(K, np.ndarray):
|
||||||
|
K = torch.from_numpy(K)
|
||||||
|
self.intr = K.to(self.device)
|
||||||
|
|
||||||
|
def rot_poses(self, angle, axis='y'):
|
||||||
|
"""
|
||||||
|
pts (torch.Tensor): [T, N, 3]
|
||||||
|
angle (int): angle of rotation (degree)
|
||||||
|
"""
|
||||||
|
angle_rad = math.radians(angle)
|
||||||
|
angles = torch.linspace(0, angle_rad, self.frame_num)
|
||||||
|
rot_mats = torch.zeros(self.frame_num, 4, 4)
|
||||||
|
|
||||||
|
for i, theta in enumerate(angles):
|
||||||
|
cos_theta = torch.cos(theta)
|
||||||
|
sin_theta = torch.sin(theta)
|
||||||
|
if axis == 'x':
|
||||||
|
rot_mats[i] = torch.tensor([
|
||||||
|
[1, 0, 0, 0],
|
||||||
|
[0, cos_theta, -sin_theta, 0],
|
||||||
|
[0, sin_theta, cos_theta, 0],
|
||||||
|
[0, 0, 0, 1]
|
||||||
|
], dtype=torch.float32)
|
||||||
|
elif axis == 'y':
|
||||||
|
rot_mats[i] = torch.tensor([
|
||||||
|
[cos_theta, 0, sin_theta, 0],
|
||||||
|
[0, 1, 0, 0],
|
||||||
|
[-sin_theta, 0, cos_theta, 0],
|
||||||
|
[0, 0, 0, 1]
|
||||||
|
], dtype=torch.float32)
|
||||||
|
|
||||||
|
elif axis == 'z':
|
||||||
|
rot_mats[i] = torch.tensor([
|
||||||
|
[cos_theta, -sin_theta, 0, 0],
|
||||||
|
[sin_theta, cos_theta, 0, 0],
|
||||||
|
[0, 0, 1, 0],
|
||||||
|
[0, 0, 0, 1]
|
||||||
|
], dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid axis value. Choose 'x', 'y', or 'z'.")
|
||||||
|
|
||||||
|
return rot_mats.to(self.device)
|
||||||
|
|
||||||
|
def trans_poses(self, dx, dy, dz):
|
||||||
|
"""
|
||||||
|
params:
|
||||||
|
- dx: float, displacement along x axis。
|
||||||
|
- dy: float, displacement along y axis。
|
||||||
|
- dz: float, displacement along z axis。
|
||||||
|
|
||||||
|
ret:
|
||||||
|
- matrices: torch.Tensor
|
||||||
|
"""
|
||||||
|
trans_mats = torch.eye(4).unsqueeze(0).repeat(self.frame_num, 1, 1) # (n, 4, 4)
|
||||||
|
|
||||||
|
delta_x = dx / (self.frame_num - 1)
|
||||||
|
delta_y = dy / (self.frame_num - 1)
|
||||||
|
delta_z = dz / (self.frame_num - 1)
|
||||||
|
|
||||||
|
for i in range(self.frame_num):
|
||||||
|
trans_mats[i, 0, 3] = i * delta_x
|
||||||
|
trans_mats[i, 1, 3] = i * delta_y
|
||||||
|
trans_mats[i, 2, 3] = i * delta_z
|
||||||
|
|
||||||
|
return trans_mats.to(self.device)
|
||||||
|
|
||||||
|
|
||||||
|
def _look_at(self, camera_position, target_position):
|
||||||
|
# look at direction
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
|
direction = target_position - camera_position
|
||||||
|
direction /= np.linalg.norm(direction)
|
||||||
|
# calculate rotation matrix
|
||||||
|
up = np.array([0, 1, 0])
|
||||||
|
right = np.cross(up, direction)
|
||||||
|
right /= np.linalg.norm(right)
|
||||||
|
up = np.cross(direction, right)
|
||||||
|
rotation_matrix = np.vstack([right, up, direction])
|
||||||
|
rotation_matrix = np.linalg.inv(rotation_matrix)
|
||||||
|
return rotation_matrix
|
||||||
|
|
||||||
|
def spiral_poses(self, radius, forward_ratio = 0.5, backward_ratio = 0.5, rotation_times = 0.1, look_at_times = 0.5):
|
||||||
|
"""Generate spiral camera poses
|
||||||
|
|
||||||
|
Args:
|
||||||
|
radius (float): Base radius of the spiral
|
||||||
|
forward_ratio (float): Scale factor for forward motion
|
||||||
|
backward_ratio (float): Scale factor for backward motion
|
||||||
|
rotation_times (float): Number of rotations to complete
|
||||||
|
look_at_times (float): Scale factor for look-at point distance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Camera poses of shape [num_frames, 4, 4]
|
||||||
|
"""
|
||||||
|
# Generate spiral trajectory
|
||||||
|
t = np.linspace(0, 1, self.frame_num)
|
||||||
|
r = np.sin(np.pi * t) * radius * rotation_times
|
||||||
|
theta = 2 * np.pi * t
|
||||||
|
|
||||||
|
# Calculate camera positions
|
||||||
|
# Limit y motion for better floor/sky view
|
||||||
|
y = r * np.cos(theta) * 0.3
|
||||||
|
x = r * np.sin(theta)
|
||||||
|
z = -r
|
||||||
|
z[z < 0] *= forward_ratio
|
||||||
|
z[z > 0] *= backward_ratio
|
||||||
|
|
||||||
|
# Set look-at target
|
||||||
|
target_pos = np.array([0, 0, radius * look_at_times])
|
||||||
|
cam_pos = np.vstack([x, y, z]).T
|
||||||
|
cam_poses = []
|
||||||
|
|
||||||
|
for pos in cam_pos:
|
||||||
|
rot_mat = self._look_at(pos, target_pos)
|
||||||
|
trans_mat = np.eye(4)
|
||||||
|
trans_mat[:3, :3] = rot_mat
|
||||||
|
trans_mat[:3, 3] = pos
|
||||||
|
cam_poses.append(trans_mat[None])
|
||||||
|
|
||||||
|
camera_poses = np.concatenate(cam_poses, axis=0)
|
||||||
|
return torch.from_numpy(camera_poses).to(self.device)
|
||||||
|
|
||||||
|
def rot(self, pts, angle, axis):
|
||||||
|
"""
|
||||||
|
pts: torch.Tensor, (T, N, 2)
|
||||||
|
"""
|
||||||
|
rot_mats = self.rot_poses(angle, axis)
|
||||||
|
pts = self.apply_motion_on_pts(pts, rot_mats)
|
||||||
|
return pts
|
||||||
|
|
||||||
|
def trans(self, pts, dx, dy, dz):
|
||||||
|
if pts.shape[-1] != 3:
|
||||||
|
raise ValueError("points should be in the 3d coordinate.")
|
||||||
|
trans_mats = self.trans_poses(dx, dy, dz)
|
||||||
|
pts = self.apply_motion_on_pts(pts, trans_mats)
|
||||||
|
return pts
|
||||||
|
|
||||||
|
def spiral(self, pts, radius):
|
||||||
|
spiral_poses = self.spiral_poses(radius)
|
||||||
|
pts = self.apply_motion_on_pts(pts, spiral_poses)
|
||||||
|
return pts
|
||||||
|
|
||||||
|
def get_default_motion(self):
|
||||||
|
if self.motion_type == 'none':
|
||||||
|
motion = torch.eye(4).unsqueeze(0).repeat(self.frame_num, 1, 1).to(self.device)
|
||||||
|
elif self.motion_type == 'trans':
|
||||||
|
motion = self.trans_poses(0.02, 0, 0)
|
||||||
|
elif self.motion_type == 'spiral':
|
||||||
|
motion = self.spiral_poses(1)
|
||||||
|
elif self.motion_type == 'rot':
|
||||||
|
motion = self.rot_poses(-25, 'y')
|
||||||
|
else:
|
||||||
|
raise ValueError(f'camera_motion must be in [trans, spiral, rot], but get {self.motion_type}.')
|
||||||
|
|
||||||
|
return motion
|
||||||
|
|
||||||
|
class ObjectMotionGenerator:
|
||||||
|
def __init__(self, num_frames=49, device="cuda:0"):
|
||||||
|
"""Initialize ObjectMotionGenerator
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (str): Device to run on
|
||||||
|
"""
|
||||||
|
self.device = device
|
||||||
|
self.num_frames = num_frames
|
||||||
|
|
||||||
|
def _get_points_in_mask(self, pred_tracks, mask):
|
||||||
|
"""Get points that fall within the mask in first frame
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred_tracks (torch.Tensor): [num_frames, num_points, 3]
|
||||||
|
mask (torch.Tensor): [H, W] binary mask
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Boolean mask of selected points [num_points]
|
||||||
|
"""
|
||||||
|
first_frame_points = pred_tracks[0] # [num_points, 3]
|
||||||
|
xy_points = first_frame_points[:, :2] # [num_points, 2]
|
||||||
|
|
||||||
|
# Convert xy coordinates to pixel indices
|
||||||
|
xy_pixels = xy_points.round().long() # Convert to integer pixel coordinates
|
||||||
|
|
||||||
|
# Clamp coordinates to valid range
|
||||||
|
xy_pixels[:, 0].clamp_(0, mask.shape[1] - 1) # x coordinates
|
||||||
|
xy_pixels[:, 1].clamp_(0, mask.shape[0] - 1) # y coordinates
|
||||||
|
|
||||||
|
# Get mask values at point locations
|
||||||
|
points_in_mask = mask[xy_pixels[:, 1], xy_pixels[:, 0]] # Index using y, x order
|
||||||
|
|
||||||
|
return points_in_mask
|
||||||
|
|
||||||
|
def generate_motion(self, mask, motion_type, distance, num_frames=49):
|
||||||
|
"""Generate motion dictionary for the given parameters
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask (torch.Tensor): [H, W] binary mask
|
||||||
|
motion_type (str): Motion direction ('up', 'down', 'left', 'right')
|
||||||
|
distance (float): Total distance to move
|
||||||
|
num_frames (int): Number of frames
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Motion dictionary containing:
|
||||||
|
- mask (torch.Tensor): Binary mask
|
||||||
|
- motions (torch.Tensor): Per-frame motion vectors [num_frames, 4, 4]
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.num_frames = num_frames
|
||||||
|
# Define motion template vectors
|
||||||
|
template = {
|
||||||
|
"none": torch.tensor([0, 0, 0]),
|
||||||
|
'up': torch.tensor([0, -1, 0]),
|
||||||
|
'down': torch.tensor([0, 1, 0]),
|
||||||
|
'left': torch.tensor([-1, 0, 0]),
|
||||||
|
'right': torch.tensor([1, 0, 0]),
|
||||||
|
'front': torch.tensor([0, 0, 1]),
|
||||||
|
'back': torch.tensor([0, 0, -1])
|
||||||
|
}
|
||||||
|
|
||||||
|
if motion_type not in template:
|
||||||
|
raise ValueError(f"Unknown motion type: {motion_type}")
|
||||||
|
|
||||||
|
# Move mask to device
|
||||||
|
mask = mask.to(self.device)
|
||||||
|
|
||||||
|
# Generate per-frame motion matrices
|
||||||
|
motions = []
|
||||||
|
base_vec = template[motion_type].to(self.device) * distance
|
||||||
|
|
||||||
|
for frame_idx in range(num_frames):
|
||||||
|
# Calculate interpolation factor (0 to 1)
|
||||||
|
t = frame_idx / (num_frames - 1)
|
||||||
|
|
||||||
|
# Create motion matrix for current frame
|
||||||
|
current_motion = torch.eye(4, device=self.device)
|
||||||
|
current_motion[:3, 3] = base_vec * t
|
||||||
|
motions.append(current_motion)
|
||||||
|
|
||||||
|
motions = torch.stack(motions) # [num_frames, 4, 4]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'mask': mask,
|
||||||
|
'motions': motions
|
||||||
|
}
|
||||||
|
|
||||||
|
def apply_motion(self, pred_tracks, motion_dict, tracking_method="spatracker"):
|
||||||
|
"""Apply motion to selected points
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred_tracks (torch.Tensor): [num_frames, num_points, 3] for spatracker
|
||||||
|
or [T, H, W, 3] for moge
|
||||||
|
motion_dict (dict): Motion dictionary containing mask and motions
|
||||||
|
tracking_method (str): "spatracker" or "moge"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Modified pred_tracks with same shape as input
|
||||||
|
"""
|
||||||
|
pred_tracks = pred_tracks.to(self.device).float()
|
||||||
|
|
||||||
|
if tracking_method == "moge":
|
||||||
|
|
||||||
|
H = pred_tracks.shape[0]
|
||||||
|
W = pred_tracks.shape[1]
|
||||||
|
|
||||||
|
initial_points = pred_tracks # [H, W, 3]
|
||||||
|
selected_mask = motion_dict['mask']
|
||||||
|
valid_selected = ~torch.any(torch.isnan(initial_points), dim=2) & selected_mask
|
||||||
|
valid_selected = valid_selected.reshape([-1])
|
||||||
|
modified_tracks = pred_tracks.clone().reshape(-1, 3).unsqueeze(0).repeat(self.num_frames, 1, 1)
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
|
for frame_idx in range(self.num_frames):
|
||||||
|
# Get current frame motion
|
||||||
|
motion_mat = motion_dict['motions'][frame_idx]
|
||||||
|
# Moge's pointcloud is scale-invairant
|
||||||
|
motion_mat[0, 3] /= W
|
||||||
|
motion_mat[1, 3] /= H
|
||||||
|
# Apply motion to selected points
|
||||||
|
points = modified_tracks[frame_idx, valid_selected]
|
||||||
|
# Convert to homogeneous coordinates
|
||||||
|
points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1)
|
||||||
|
# Apply transformation
|
||||||
|
transformed_points = torch.matmul(points_homo, motion_mat.T)
|
||||||
|
# Convert back to 3D coordinates
|
||||||
|
modified_tracks[frame_idx, valid_selected] = transformed_points[:, :3]
|
||||||
|
return modified_tracks
|
||||||
|
|
||||||
|
else:
|
||||||
|
points_in_mask = self._get_points_in_mask(pred_tracks, motion_dict['mask'])
|
||||||
|
modified_tracks = pred_tracks.clone()
|
||||||
|
|
||||||
|
for frame_idx in range(pred_tracks.shape[0]):
|
||||||
|
motion_mat = motion_dict['motions'][frame_idx]
|
||||||
|
points = modified_tracks[frame_idx, points_in_mask]
|
||||||
|
points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1)
|
||||||
|
transformed_points = torch.matmul(points_homo, motion_mat.T)
|
||||||
|
modified_tracks[frame_idx, points_in_mask] = transformed_points[:, :3]
|
||||||
|
|
||||||
|
return modified_tracks
|
||||||
Loading…
x
Reference in New Issue
Block a user