mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
add moge
This commit is contained in:
parent
ee7d04d342
commit
1124c77d56
259
das/das_nodes.py
259
das/das_nodes.py
@ -1,9 +1,14 @@
|
||||
import torch
|
||||
import comfy.model_management as mm
|
||||
from comfy.utils import ProgressBar, common_upscale
|
||||
from ..utils import log
|
||||
import os
|
||||
import numpy as np
|
||||
import folder_paths
|
||||
from tqdm import tqdm
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from .motion import CameraMotionGenerator, ObjectMotionGenerator
|
||||
|
||||
class CogVideoDASTrackingEncode:
|
||||
@classmethod
|
||||
@ -139,32 +144,278 @@ class DAS_SpaTracker:
|
||||
progressive_tracking=False
|
||||
)
|
||||
|
||||
# cam_motion = CameraMotionGenerator(
|
||||
# motion_type="trans",
|
||||
# frame_num=49,
|
||||
# W=720,
|
||||
# H=480,
|
||||
# fx=None,
|
||||
# fy=None,
|
||||
# fov=55,
|
||||
# device=device,
|
||||
# )
|
||||
# poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
|
||||
# pred_tracks = cam_motion.apply_motion_on_pts(pred_tracks, poses)
|
||||
# print("Camera motion applied")
|
||||
|
||||
spatracker.to(offload_device)
|
||||
|
||||
from .spatracker.utils.visualizer import Visualizer
|
||||
vis = Visualizer(grayscale=False, fps=24, pad_value=0)
|
||||
vis = Visualizer(
|
||||
grayscale=False,
|
||||
fps=24,
|
||||
pad_value=0,
|
||||
#tracks_leave_trace=-1
|
||||
)
|
||||
|
||||
msk_query = (T_Firsts == 0)
|
||||
pred_tracks = pred_tracks[:,:,msk_query.squeeze()]
|
||||
pred_visibility = pred_visibility[:,:,msk_query.squeeze()]
|
||||
|
||||
tracking_video = vis.visualize(video=video, tracks=pred_tracks,
|
||||
visibility=pred_visibility, save_video=False,
|
||||
filename="temp")
|
||||
tracking_video = vis.visualize(
|
||||
video=video,
|
||||
tracks=pred_tracks,
|
||||
visibility=pred_visibility,
|
||||
save_video=False,
|
||||
)
|
||||
|
||||
tracking_video = tracking_video.squeeze(0).permute(0, 2, 3, 1) # [T, H, W, C]
|
||||
tracking_video = (tracking_video / 255.0).float()
|
||||
|
||||
return (tracking_video,)
|
||||
|
||||
class DAS_MoGeTracker:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model": ("MOGEMODEL",),
|
||||
"image": ("IMAGE", ),
|
||||
"num_frames": ("INT", {"default": 49, "min": 1, "max": 100, "step": 1}),
|
||||
"width": ("INT", {"default": 720, "min": 1, "max": 10000, "step": 1}),
|
||||
"height": ("INT", {"default": 480, "min": 1, "max": 10000, "step": 1}),
|
||||
"fov": ("FLOAT", {"default": 55.0, "min": 1.0, "max": 180.0, "step": 1.0}),
|
||||
"object_motion_type": (["none", "up", "down", "left", "right", "front", "back"],),
|
||||
"object_motion_distance": ("INT", {"default": 50, "min": 1, "max": 1000, "step": 1}),
|
||||
"camera_motion_type": (["none","translation", "rotation", "spiral"],),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", ),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("tracking_video",)
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, model, image, num_frames, width, height, fov, object_motion_type, object_motion_distance, camera_motion_type, mask=None):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
B, H, W, C = image.shape
|
||||
|
||||
image_resized = common_upscale(image.movedim(-1,1), width, height, "lanczos", "disabled").movedim(1,-1)
|
||||
|
||||
# Use the first frame from previously loaded video_tensor
|
||||
infer_result = model.infer(image_resized.permute(0, 3, 1, 2).to(device)[0].to(device)) # [C, H, W] in range [0,1]
|
||||
H, W = infer_result["points"].shape[0:2]
|
||||
|
||||
motion_generator = ObjectMotionGenerator(num_frames, device=device)
|
||||
|
||||
if mask is not None:
|
||||
mask = mask[0].bool()
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask[None, None].float(),
|
||||
size=(H, W),
|
||||
mode='nearest'
|
||||
)[0, 0].bool()
|
||||
else:
|
||||
mask = torch.ones(H, W, dtype=torch.bool)
|
||||
|
||||
# Generate motion dictionary
|
||||
motion_dict = motion_generator.generate_motion(
|
||||
mask=mask,
|
||||
motion_type=object_motion_type,
|
||||
distance=object_motion_distance,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
|
||||
pred_tracks = motion_generator.apply_motion(
|
||||
infer_result["points"],
|
||||
motion_dict,
|
||||
tracking_method="moge"
|
||||
)
|
||||
print("pred_tracks shape: ", pred_tracks.shape)
|
||||
print("Object motion applied")
|
||||
|
||||
camera_motion_type_mapping = {
|
||||
"none": "none",
|
||||
"translation": "trans",
|
||||
"rotation": "rot",
|
||||
"spiral": "spiral"
|
||||
}
|
||||
cam_motion = CameraMotionGenerator(
|
||||
motion_type=camera_motion_type_mapping[camera_motion_type],
|
||||
frame_num=num_frames,
|
||||
W=width,
|
||||
H=height,
|
||||
fx=None,
|
||||
fy=None,
|
||||
fov=fov,
|
||||
device=device,
|
||||
)
|
||||
# Apply camera motion if specified
|
||||
cam_motion.set_intr(infer_result["intrinsics"])
|
||||
poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
|
||||
|
||||
pred_tracks_flatten = pred_tracks.reshape(num_frames, H*W, 3)
|
||||
pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([num_frames, H, W, 3]) # [T, H, W, 3]
|
||||
print("Camera motion applied")
|
||||
|
||||
|
||||
points = pred_tracks.cpu().numpy()
|
||||
mask = infer_result["mask"].cpu().numpy()
|
||||
# Create color array
|
||||
T, H, W, _ = pred_tracks.shape
|
||||
|
||||
print("points shape: ", points.shape)
|
||||
|
||||
print("mask shape: ", mask.shape)
|
||||
colors = np.zeros((H, W, 3), dtype=np.uint8)
|
||||
|
||||
# Set R channel - based on x coordinates (smaller on the left)
|
||||
colors[:, :, 0] = np.tile(np.linspace(0, 255, W), (H, 1))
|
||||
|
||||
# Set G channel - based on y coordinates (smaller on the top)
|
||||
colors[:, :, 1] = np.tile(np.linspace(0, 255, H), (W, 1)).T
|
||||
|
||||
# Set B channel - based on depth
|
||||
z_values = points[0, :, :, 2] # get z values
|
||||
inv_z = 1 / z_values # calculate 1/z
|
||||
# Calculate 2% and 98% percentiles
|
||||
p2 = np.percentile(inv_z, 2)
|
||||
p98 = np.percentile(inv_z, 98)
|
||||
# Normalize to [0,1] range
|
||||
normalized_z = np.clip((inv_z - p2) / (p98 - p2), 0, 1)
|
||||
colors[:, :, 2] = (normalized_z * 255).astype(np.uint8)
|
||||
colors = colors.astype(np.uint8)
|
||||
|
||||
# First reshape points and colors
|
||||
points = points.reshape(T, -1, 3) # (T, H*W, 3)
|
||||
colors = colors.reshape(-1, 3) # (H*W, 3)
|
||||
|
||||
# Create mask for each frame
|
||||
mask = mask.reshape(-1) # Flatten mask to (H*W,)
|
||||
|
||||
# Apply mask
|
||||
points = points[:, mask, :] # (T, masked_points, 3)
|
||||
colors = colors[mask] # (masked_points, 3)
|
||||
|
||||
# Repeat colors for each frame
|
||||
colors = colors.reshape(1, -1, 3).repeat(T, axis=0) # (T, masked_points, 3)
|
||||
|
||||
# Initialize list to store frames
|
||||
frames = []
|
||||
pbar = ProgressBar(len(points))
|
||||
|
||||
for i, pts_i in enumerate(tqdm(points)):
|
||||
pixels, depths = pts_i[..., :2], pts_i[..., 2]
|
||||
pixels[..., 0] = pixels[..., 0] * W
|
||||
pixels[..., 1] = pixels[..., 1] * H
|
||||
pixels = pixels.astype(int)
|
||||
|
||||
valid = self.valid_mask(pixels, W, H)
|
||||
|
||||
frame_rgb = colors[i][valid]
|
||||
pixels = pixels[valid]
|
||||
depths = depths[valid]
|
||||
|
||||
img = Image.fromarray(np.uint8(np.zeros([H, W, 3])), mode="RGB")
|
||||
sorted_pixels, _, sort_index = self.sort_points_by_depth(pixels, depths)
|
||||
step = 1
|
||||
sorted_pixels = sorted_pixels[::step]
|
||||
sorted_rgb = frame_rgb[sort_index][::step]
|
||||
|
||||
for j in range(sorted_pixels.shape[0]):
|
||||
self.draw_rectangle(
|
||||
img,
|
||||
coord=(sorted_pixels[j, 0], sorted_pixels[j, 1]),
|
||||
side_length=2,
|
||||
color=sorted_rgb[j],
|
||||
)
|
||||
frames.append(np.array(img))
|
||||
pbar.update(1)
|
||||
|
||||
# Convert frames to video tensor in range [0,1]
|
||||
tracking_video = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2).float() / 255.0
|
||||
tracking_video = tracking_video.permute(0, 2, 3, 1) # [B, H, W, C]
|
||||
print("tracking_video shape: ", tracking_video.shape)
|
||||
return (tracking_video,)
|
||||
|
||||
def valid_mask(self, pixels, W, H):
|
||||
"""Check if pixels are within valid image bounds
|
||||
|
||||
Args:
|
||||
pixels (numpy.ndarray): Pixel coordinates of shape [N, 2]
|
||||
W (int): Image width
|
||||
H (int): Image height
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Boolean mask of valid pixels
|
||||
"""
|
||||
return ((pixels[:, 0] >= 0) & (pixels[:, 0] < W) & (pixels[:, 1] > 0) & \
|
||||
(pixels[:, 1] < H))
|
||||
|
||||
def sort_points_by_depth(self, points, depths):
|
||||
"""Sort points by depth values
|
||||
|
||||
Args:
|
||||
points (numpy.ndarray): Points array of shape [N, 2]
|
||||
depths (numpy.ndarray): Depth values of shape [N]
|
||||
|
||||
Returns:
|
||||
tuple: (sorted_points, sorted_depths, sort_index)
|
||||
"""
|
||||
# Combine points and depths into a single array for sorting
|
||||
combined = np.hstack((points, depths[:, None])) # Nx3 (points + depth)
|
||||
# Sort by depth (last column) in descending order
|
||||
sort_index = combined[:, -1].argsort()[::-1]
|
||||
sorted_combined = combined[sort_index]
|
||||
# Split back into points and depths
|
||||
sorted_points = sorted_combined[:, :-1]
|
||||
sorted_depths = sorted_combined[:, -1]
|
||||
return sorted_points, sorted_depths, sort_index
|
||||
|
||||
def draw_rectangle(self, rgb, coord, side_length, color=(255, 0, 0)):
|
||||
"""Draw a rectangle on the image
|
||||
|
||||
Args:
|
||||
rgb (PIL.Image): Image to draw on
|
||||
coord (tuple): Center coordinates (x, y)
|
||||
side_length (int): Length of rectangle sides
|
||||
color (tuple): RGB color tuple
|
||||
"""
|
||||
draw = ImageDraw.Draw(rgb)
|
||||
# Calculate the bounding box of the rectangle
|
||||
left_up_point = (coord[0] - side_length//2, coord[1] - side_length//2)
|
||||
right_down_point = (coord[0] + side_length//2, coord[1] + side_length//2)
|
||||
color = tuple(list(color))
|
||||
|
||||
draw.rectangle(
|
||||
[left_up_point, right_down_point],
|
||||
fill=tuple(color),
|
||||
outline=tuple(color),
|
||||
)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CogVideoDASTrackingEncode": CogVideoDASTrackingEncode,
|
||||
"DAS_SpaTracker": DAS_SpaTracker,
|
||||
"DAS_SpaTrackerModelLoader": DAS_SpaTrackerModelLoader,
|
||||
"DAS_MoGeTracker": DAS_MoGeTracker,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoDASTrackingEncode": "CogVideo DAS Tracking Encode",
|
||||
"DAS_SpaTracker": "DAS SpaTracker",
|
||||
"DAS_SpaTrackerModelLoader": "DAS SpaTracker Model Loader",
|
||||
"DAS_MoGeTracker": "DAS MoGe Tracker",
|
||||
}
|
||||
|
||||
382
das/motion.py
Normal file
382
das/motion.py
Normal file
@ -0,0 +1,382 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
class CameraMotionGenerator:
|
||||
def __init__(self, motion_type, frame_num=49, H=480, W=720, fx=None, fy=None, fov=55, device='cuda'):
|
||||
self.motion_type = motion_type
|
||||
self.frame_num = frame_num
|
||||
self.fov = fov
|
||||
self.device = device
|
||||
self.W = W
|
||||
self.H = H
|
||||
self.intr = torch.tensor([
|
||||
[0, 0, W / 2],
|
||||
[0, 0, H / 2],
|
||||
[0, 0, 1]
|
||||
], dtype=torch.float32, device=device)
|
||||
# if fx, fy not provided
|
||||
if not fx or not fy:
|
||||
fov_rad = math.radians(fov)
|
||||
fx = fy = (W / 2) / math.tan(fov_rad / 2)
|
||||
|
||||
self.intr[0, 0] = fx
|
||||
self.intr[1, 1] = fy
|
||||
|
||||
def _apply_poses(self, pts, poses):
|
||||
"""
|
||||
Args:
|
||||
pts (torch.Tensor): pointclouds coordinates [T, N, 3]
|
||||
intr (torch.Tensor): camera intrinsics [T, 3, 3]
|
||||
poses (numpy.ndarray): camera poses [T, 4, 4]
|
||||
"""
|
||||
if isinstance(poses, np.ndarray):
|
||||
poses = torch.from_numpy(poses)
|
||||
|
||||
intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1).to(torch.float)
|
||||
T, N, _ = pts.shape
|
||||
ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float)
|
||||
pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3)
|
||||
pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3)
|
||||
pts_cam[:,:, :3] *= pts[:, :, 2:3]
|
||||
|
||||
# to homogeneous
|
||||
pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4)
|
||||
|
||||
if poses.shape[0] == 1:
|
||||
poses = poses.repeat(T, 1, 1)
|
||||
elif poses.shape[0] != T:
|
||||
raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})")
|
||||
|
||||
poses = poses.to(torch.float).to(self.device)
|
||||
pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3)
|
||||
pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3)
|
||||
pts_proj[:, :, :2] /= pts_proj[:, :, 2:3]
|
||||
|
||||
return pts_proj
|
||||
|
||||
def w2s(self, pts, poses):
|
||||
if isinstance(poses, np.ndarray):
|
||||
poses = torch.from_numpy(poses)
|
||||
assert poses.shape[0] == self.frame_num
|
||||
poses = poses.to(torch.float32).to(self.device)
|
||||
T, N, _ = pts.shape # (T, N, 3)
|
||||
intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1)
|
||||
# Step 1: 扩展点的维度,使其变成 (T, N, 4),最后一维填充1 (齐次坐标)
|
||||
ones = torch.ones((T, N, 1), device=self.device, dtype=pts.dtype)
|
||||
points_world_h = torch.cat([pts, ones], dim=-1)
|
||||
points_camera_h = torch.bmm(poses, points_world_h.permute(0, 2, 1))
|
||||
points_camera = points_camera_h[:, :3, :].permute(0, 2, 1)
|
||||
|
||||
points_image_h = torch.bmm(points_camera, intr.permute(0, 2, 1))
|
||||
|
||||
uv = points_image_h[:, :, :2] / points_image_h[:, :, 2:3]
|
||||
|
||||
# Step 5: 提取深度 (Z) 并拼接
|
||||
depth = points_camera[:, :, 2:3] # (T, N, 1)
|
||||
uvd = torch.cat([uv, depth], dim=-1) # (T, N, 3)
|
||||
|
||||
return uvd # 屏幕坐标 + 深度 (T, N, 3)
|
||||
|
||||
def apply_motion_on_pts(self, pts, camera_motion):
|
||||
tracking_pts = self._apply_poses(pts.squeeze(), camera_motion).unsqueeze(0)
|
||||
return tracking_pts
|
||||
|
||||
def set_intr(self, K):
|
||||
if isinstance(K, np.ndarray):
|
||||
K = torch.from_numpy(K)
|
||||
self.intr = K.to(self.device)
|
||||
|
||||
def rot_poses(self, angle, axis='y'):
|
||||
"""
|
||||
pts (torch.Tensor): [T, N, 3]
|
||||
angle (int): angle of rotation (degree)
|
||||
"""
|
||||
angle_rad = math.radians(angle)
|
||||
angles = torch.linspace(0, angle_rad, self.frame_num)
|
||||
rot_mats = torch.zeros(self.frame_num, 4, 4)
|
||||
|
||||
for i, theta in enumerate(angles):
|
||||
cos_theta = torch.cos(theta)
|
||||
sin_theta = torch.sin(theta)
|
||||
if axis == 'x':
|
||||
rot_mats[i] = torch.tensor([
|
||||
[1, 0, 0, 0],
|
||||
[0, cos_theta, -sin_theta, 0],
|
||||
[0, sin_theta, cos_theta, 0],
|
||||
[0, 0, 0, 1]
|
||||
], dtype=torch.float32)
|
||||
elif axis == 'y':
|
||||
rot_mats[i] = torch.tensor([
|
||||
[cos_theta, 0, sin_theta, 0],
|
||||
[0, 1, 0, 0],
|
||||
[-sin_theta, 0, cos_theta, 0],
|
||||
[0, 0, 0, 1]
|
||||
], dtype=torch.float32)
|
||||
|
||||
elif axis == 'z':
|
||||
rot_mats[i] = torch.tensor([
|
||||
[cos_theta, -sin_theta, 0, 0],
|
||||
[sin_theta, cos_theta, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1]
|
||||
], dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError("Invalid axis value. Choose 'x', 'y', or 'z'.")
|
||||
|
||||
return rot_mats.to(self.device)
|
||||
|
||||
def trans_poses(self, dx, dy, dz):
|
||||
"""
|
||||
params:
|
||||
- dx: float, displacement along x axis。
|
||||
- dy: float, displacement along y axis。
|
||||
- dz: float, displacement along z axis。
|
||||
|
||||
ret:
|
||||
- matrices: torch.Tensor
|
||||
"""
|
||||
trans_mats = torch.eye(4).unsqueeze(0).repeat(self.frame_num, 1, 1) # (n, 4, 4)
|
||||
|
||||
delta_x = dx / (self.frame_num - 1)
|
||||
delta_y = dy / (self.frame_num - 1)
|
||||
delta_z = dz / (self.frame_num - 1)
|
||||
|
||||
for i in range(self.frame_num):
|
||||
trans_mats[i, 0, 3] = i * delta_x
|
||||
trans_mats[i, 1, 3] = i * delta_y
|
||||
trans_mats[i, 2, 3] = i * delta_z
|
||||
|
||||
return trans_mats.to(self.device)
|
||||
|
||||
|
||||
def _look_at(self, camera_position, target_position):
|
||||
# look at direction
|
||||
# import ipdb;ipdb.set_trace()
|
||||
direction = target_position - camera_position
|
||||
direction /= np.linalg.norm(direction)
|
||||
# calculate rotation matrix
|
||||
up = np.array([0, 1, 0])
|
||||
right = np.cross(up, direction)
|
||||
right /= np.linalg.norm(right)
|
||||
up = np.cross(direction, right)
|
||||
rotation_matrix = np.vstack([right, up, direction])
|
||||
rotation_matrix = np.linalg.inv(rotation_matrix)
|
||||
return rotation_matrix
|
||||
|
||||
def spiral_poses(self, radius, forward_ratio = 0.5, backward_ratio = 0.5, rotation_times = 0.1, look_at_times = 0.5):
|
||||
"""Generate spiral camera poses
|
||||
|
||||
Args:
|
||||
radius (float): Base radius of the spiral
|
||||
forward_ratio (float): Scale factor for forward motion
|
||||
backward_ratio (float): Scale factor for backward motion
|
||||
rotation_times (float): Number of rotations to complete
|
||||
look_at_times (float): Scale factor for look-at point distance
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Camera poses of shape [num_frames, 4, 4]
|
||||
"""
|
||||
# Generate spiral trajectory
|
||||
t = np.linspace(0, 1, self.frame_num)
|
||||
r = np.sin(np.pi * t) * radius * rotation_times
|
||||
theta = 2 * np.pi * t
|
||||
|
||||
# Calculate camera positions
|
||||
# Limit y motion for better floor/sky view
|
||||
y = r * np.cos(theta) * 0.3
|
||||
x = r * np.sin(theta)
|
||||
z = -r
|
||||
z[z < 0] *= forward_ratio
|
||||
z[z > 0] *= backward_ratio
|
||||
|
||||
# Set look-at target
|
||||
target_pos = np.array([0, 0, radius * look_at_times])
|
||||
cam_pos = np.vstack([x, y, z]).T
|
||||
cam_poses = []
|
||||
|
||||
for pos in cam_pos:
|
||||
rot_mat = self._look_at(pos, target_pos)
|
||||
trans_mat = np.eye(4)
|
||||
trans_mat[:3, :3] = rot_mat
|
||||
trans_mat[:3, 3] = pos
|
||||
cam_poses.append(trans_mat[None])
|
||||
|
||||
camera_poses = np.concatenate(cam_poses, axis=0)
|
||||
return torch.from_numpy(camera_poses).to(self.device)
|
||||
|
||||
def rot(self, pts, angle, axis):
|
||||
"""
|
||||
pts: torch.Tensor, (T, N, 2)
|
||||
"""
|
||||
rot_mats = self.rot_poses(angle, axis)
|
||||
pts = self.apply_motion_on_pts(pts, rot_mats)
|
||||
return pts
|
||||
|
||||
def trans(self, pts, dx, dy, dz):
|
||||
if pts.shape[-1] != 3:
|
||||
raise ValueError("points should be in the 3d coordinate.")
|
||||
trans_mats = self.trans_poses(dx, dy, dz)
|
||||
pts = self.apply_motion_on_pts(pts, trans_mats)
|
||||
return pts
|
||||
|
||||
def spiral(self, pts, radius):
|
||||
spiral_poses = self.spiral_poses(radius)
|
||||
pts = self.apply_motion_on_pts(pts, spiral_poses)
|
||||
return pts
|
||||
|
||||
def get_default_motion(self):
|
||||
if self.motion_type == 'none':
|
||||
motion = torch.eye(4).unsqueeze(0).repeat(self.frame_num, 1, 1).to(self.device)
|
||||
elif self.motion_type == 'trans':
|
||||
motion = self.trans_poses(0.02, 0, 0)
|
||||
elif self.motion_type == 'spiral':
|
||||
motion = self.spiral_poses(1)
|
||||
elif self.motion_type == 'rot':
|
||||
motion = self.rot_poses(-25, 'y')
|
||||
else:
|
||||
raise ValueError(f'camera_motion must be in [trans, spiral, rot], but get {self.motion_type}.')
|
||||
|
||||
return motion
|
||||
|
||||
class ObjectMotionGenerator:
|
||||
def __init__(self, num_frames=49, device="cuda:0"):
|
||||
"""Initialize ObjectMotionGenerator
|
||||
|
||||
Args:
|
||||
device (str): Device to run on
|
||||
"""
|
||||
self.device = device
|
||||
self.num_frames = num_frames
|
||||
|
||||
def _get_points_in_mask(self, pred_tracks, mask):
|
||||
"""Get points that fall within the mask in first frame
|
||||
|
||||
Args:
|
||||
pred_tracks (torch.Tensor): [num_frames, num_points, 3]
|
||||
mask (torch.Tensor): [H, W] binary mask
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Boolean mask of selected points [num_points]
|
||||
"""
|
||||
first_frame_points = pred_tracks[0] # [num_points, 3]
|
||||
xy_points = first_frame_points[:, :2] # [num_points, 2]
|
||||
|
||||
# Convert xy coordinates to pixel indices
|
||||
xy_pixels = xy_points.round().long() # Convert to integer pixel coordinates
|
||||
|
||||
# Clamp coordinates to valid range
|
||||
xy_pixels[:, 0].clamp_(0, mask.shape[1] - 1) # x coordinates
|
||||
xy_pixels[:, 1].clamp_(0, mask.shape[0] - 1) # y coordinates
|
||||
|
||||
# Get mask values at point locations
|
||||
points_in_mask = mask[xy_pixels[:, 1], xy_pixels[:, 0]] # Index using y, x order
|
||||
|
||||
return points_in_mask
|
||||
|
||||
def generate_motion(self, mask, motion_type, distance, num_frames=49):
|
||||
"""Generate motion dictionary for the given parameters
|
||||
|
||||
Args:
|
||||
mask (torch.Tensor): [H, W] binary mask
|
||||
motion_type (str): Motion direction ('up', 'down', 'left', 'right')
|
||||
distance (float): Total distance to move
|
||||
num_frames (int): Number of frames
|
||||
|
||||
Returns:
|
||||
dict: Motion dictionary containing:
|
||||
- mask (torch.Tensor): Binary mask
|
||||
- motions (torch.Tensor): Per-frame motion vectors [num_frames, 4, 4]
|
||||
"""
|
||||
|
||||
self.num_frames = num_frames
|
||||
# Define motion template vectors
|
||||
template = {
|
||||
"none": torch.tensor([0, 0, 0]),
|
||||
'up': torch.tensor([0, -1, 0]),
|
||||
'down': torch.tensor([0, 1, 0]),
|
||||
'left': torch.tensor([-1, 0, 0]),
|
||||
'right': torch.tensor([1, 0, 0]),
|
||||
'front': torch.tensor([0, 0, 1]),
|
||||
'back': torch.tensor([0, 0, -1])
|
||||
}
|
||||
|
||||
if motion_type not in template:
|
||||
raise ValueError(f"Unknown motion type: {motion_type}")
|
||||
|
||||
# Move mask to device
|
||||
mask = mask.to(self.device)
|
||||
|
||||
# Generate per-frame motion matrices
|
||||
motions = []
|
||||
base_vec = template[motion_type].to(self.device) * distance
|
||||
|
||||
for frame_idx in range(num_frames):
|
||||
# Calculate interpolation factor (0 to 1)
|
||||
t = frame_idx / (num_frames - 1)
|
||||
|
||||
# Create motion matrix for current frame
|
||||
current_motion = torch.eye(4, device=self.device)
|
||||
current_motion[:3, 3] = base_vec * t
|
||||
motions.append(current_motion)
|
||||
|
||||
motions = torch.stack(motions) # [num_frames, 4, 4]
|
||||
|
||||
return {
|
||||
'mask': mask,
|
||||
'motions': motions
|
||||
}
|
||||
|
||||
def apply_motion(self, pred_tracks, motion_dict, tracking_method="spatracker"):
|
||||
"""Apply motion to selected points
|
||||
|
||||
Args:
|
||||
pred_tracks (torch.Tensor): [num_frames, num_points, 3] for spatracker
|
||||
or [T, H, W, 3] for moge
|
||||
motion_dict (dict): Motion dictionary containing mask and motions
|
||||
tracking_method (str): "spatracker" or "moge"
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Modified pred_tracks with same shape as input
|
||||
"""
|
||||
pred_tracks = pred_tracks.to(self.device).float()
|
||||
|
||||
if tracking_method == "moge":
|
||||
|
||||
H = pred_tracks.shape[0]
|
||||
W = pred_tracks.shape[1]
|
||||
|
||||
initial_points = pred_tracks # [H, W, 3]
|
||||
selected_mask = motion_dict['mask']
|
||||
valid_selected = ~torch.any(torch.isnan(initial_points), dim=2) & selected_mask
|
||||
valid_selected = valid_selected.reshape([-1])
|
||||
modified_tracks = pred_tracks.clone().reshape(-1, 3).unsqueeze(0).repeat(self.num_frames, 1, 1)
|
||||
# import ipdb;ipdb.set_trace()
|
||||
for frame_idx in range(self.num_frames):
|
||||
# Get current frame motion
|
||||
motion_mat = motion_dict['motions'][frame_idx]
|
||||
# Moge's pointcloud is scale-invairant
|
||||
motion_mat[0, 3] /= W
|
||||
motion_mat[1, 3] /= H
|
||||
# Apply motion to selected points
|
||||
points = modified_tracks[frame_idx, valid_selected]
|
||||
# Convert to homogeneous coordinates
|
||||
points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1)
|
||||
# Apply transformation
|
||||
transformed_points = torch.matmul(points_homo, motion_mat.T)
|
||||
# Convert back to 3D coordinates
|
||||
modified_tracks[frame_idx, valid_selected] = transformed_points[:, :3]
|
||||
return modified_tracks
|
||||
|
||||
else:
|
||||
points_in_mask = self._get_points_in_mask(pred_tracks, motion_dict['mask'])
|
||||
modified_tracks = pred_tracks.clone()
|
||||
|
||||
for frame_idx in range(pred_tracks.shape[0]):
|
||||
motion_mat = motion_dict['motions'][frame_idx]
|
||||
points = modified_tracks[frame_idx, points_in_mask]
|
||||
points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1)
|
||||
transformed_points = torch.matmul(points_homo, motion_mat.T)
|
||||
modified_tracks[frame_idx, points_in_mask] = transformed_points[:, :3]
|
||||
|
||||
return modified_tracks
|
||||
Loading…
x
Reference in New Issue
Block a user