This commit is contained in:
kijai 2025-02-17 02:48:15 +02:00
parent ee7d04d342
commit 1124c77d56
2 changed files with 637 additions and 4 deletions

View File

@ -1,9 +1,14 @@
import torch
import comfy.model_management as mm
from comfy.utils import ProgressBar, common_upscale
from ..utils import log
import os
import numpy as np
import folder_paths
from tqdm import tqdm
from PIL import Image, ImageDraw
from .motion import CameraMotionGenerator, ObjectMotionGenerator
class CogVideoDASTrackingEncode:
@classmethod
@ -139,32 +144,278 @@ class DAS_SpaTracker:
progressive_tracking=False
)
# cam_motion = CameraMotionGenerator(
# motion_type="trans",
# frame_num=49,
# W=720,
# H=480,
# fx=None,
# fy=None,
# fov=55,
# device=device,
# )
# poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
# pred_tracks = cam_motion.apply_motion_on_pts(pred_tracks, poses)
# print("Camera motion applied")
spatracker.to(offload_device)
from .spatracker.utils.visualizer import Visualizer
vis = Visualizer(grayscale=False, fps=24, pad_value=0)
vis = Visualizer(
grayscale=False,
fps=24,
pad_value=0,
#tracks_leave_trace=-1
)
msk_query = (T_Firsts == 0)
pred_tracks = pred_tracks[:,:,msk_query.squeeze()]
pred_visibility = pred_visibility[:,:,msk_query.squeeze()]
tracking_video = vis.visualize(video=video, tracks=pred_tracks,
visibility=pred_visibility, save_video=False,
filename="temp")
tracking_video = vis.visualize(
video=video,
tracks=pred_tracks,
visibility=pred_visibility,
save_video=False,
)
tracking_video = tracking_video.squeeze(0).permute(0, 2, 3, 1) # [T, H, W, C]
tracking_video = (tracking_video / 255.0).float()
return (tracking_video,)
class DAS_MoGeTracker:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MOGEMODEL",),
"image": ("IMAGE", ),
"num_frames": ("INT", {"default": 49, "min": 1, "max": 100, "step": 1}),
"width": ("INT", {"default": 720, "min": 1, "max": 10000, "step": 1}),
"height": ("INT", {"default": 480, "min": 1, "max": 10000, "step": 1}),
"fov": ("FLOAT", {"default": 55.0, "min": 1.0, "max": 180.0, "step": 1.0}),
"object_motion_type": (["none", "up", "down", "left", "right", "front", "back"],),
"object_motion_distance": ("INT", {"default": 50, "min": 1, "max": 1000, "step": 1}),
"camera_motion_type": (["none","translation", "rotation", "spiral"],),
},
"optional": {
"mask": ("MASK", ),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("tracking_video",)
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, model, image, num_frames, width, height, fov, object_motion_type, object_motion_distance, camera_motion_type, mask=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
B, H, W, C = image.shape
image_resized = common_upscale(image.movedim(-1,1), width, height, "lanczos", "disabled").movedim(1,-1)
# Use the first frame from previously loaded video_tensor
infer_result = model.infer(image_resized.permute(0, 3, 1, 2).to(device)[0].to(device)) # [C, H, W] in range [0,1]
H, W = infer_result["points"].shape[0:2]
motion_generator = ObjectMotionGenerator(num_frames, device=device)
if mask is not None:
mask = mask[0].bool()
mask = torch.nn.functional.interpolate(
mask[None, None].float(),
size=(H, W),
mode='nearest'
)[0, 0].bool()
else:
mask = torch.ones(H, W, dtype=torch.bool)
# Generate motion dictionary
motion_dict = motion_generator.generate_motion(
mask=mask,
motion_type=object_motion_type,
distance=object_motion_distance,
num_frames=num_frames,
)
pred_tracks = motion_generator.apply_motion(
infer_result["points"],
motion_dict,
tracking_method="moge"
)
print("pred_tracks shape: ", pred_tracks.shape)
print("Object motion applied")
camera_motion_type_mapping = {
"none": "none",
"translation": "trans",
"rotation": "rot",
"spiral": "spiral"
}
cam_motion = CameraMotionGenerator(
motion_type=camera_motion_type_mapping[camera_motion_type],
frame_num=num_frames,
W=width,
H=height,
fx=None,
fy=None,
fov=fov,
device=device,
)
# Apply camera motion if specified
cam_motion.set_intr(infer_result["intrinsics"])
poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
pred_tracks_flatten = pred_tracks.reshape(num_frames, H*W, 3)
pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([num_frames, H, W, 3]) # [T, H, W, 3]
print("Camera motion applied")
points = pred_tracks.cpu().numpy()
mask = infer_result["mask"].cpu().numpy()
# Create color array
T, H, W, _ = pred_tracks.shape
print("points shape: ", points.shape)
print("mask shape: ", mask.shape)
colors = np.zeros((H, W, 3), dtype=np.uint8)
# Set R channel - based on x coordinates (smaller on the left)
colors[:, :, 0] = np.tile(np.linspace(0, 255, W), (H, 1))
# Set G channel - based on y coordinates (smaller on the top)
colors[:, :, 1] = np.tile(np.linspace(0, 255, H), (W, 1)).T
# Set B channel - based on depth
z_values = points[0, :, :, 2] # get z values
inv_z = 1 / z_values # calculate 1/z
# Calculate 2% and 98% percentiles
p2 = np.percentile(inv_z, 2)
p98 = np.percentile(inv_z, 98)
# Normalize to [0,1] range
normalized_z = np.clip((inv_z - p2) / (p98 - p2), 0, 1)
colors[:, :, 2] = (normalized_z * 255).astype(np.uint8)
colors = colors.astype(np.uint8)
# First reshape points and colors
points = points.reshape(T, -1, 3) # (T, H*W, 3)
colors = colors.reshape(-1, 3) # (H*W, 3)
# Create mask for each frame
mask = mask.reshape(-1) # Flatten mask to (H*W,)
# Apply mask
points = points[:, mask, :] # (T, masked_points, 3)
colors = colors[mask] # (masked_points, 3)
# Repeat colors for each frame
colors = colors.reshape(1, -1, 3).repeat(T, axis=0) # (T, masked_points, 3)
# Initialize list to store frames
frames = []
pbar = ProgressBar(len(points))
for i, pts_i in enumerate(tqdm(points)):
pixels, depths = pts_i[..., :2], pts_i[..., 2]
pixels[..., 0] = pixels[..., 0] * W
pixels[..., 1] = pixels[..., 1] * H
pixels = pixels.astype(int)
valid = self.valid_mask(pixels, W, H)
frame_rgb = colors[i][valid]
pixels = pixels[valid]
depths = depths[valid]
img = Image.fromarray(np.uint8(np.zeros([H, W, 3])), mode="RGB")
sorted_pixels, _, sort_index = self.sort_points_by_depth(pixels, depths)
step = 1
sorted_pixels = sorted_pixels[::step]
sorted_rgb = frame_rgb[sort_index][::step]
for j in range(sorted_pixels.shape[0]):
self.draw_rectangle(
img,
coord=(sorted_pixels[j, 0], sorted_pixels[j, 1]),
side_length=2,
color=sorted_rgb[j],
)
frames.append(np.array(img))
pbar.update(1)
# Convert frames to video tensor in range [0,1]
tracking_video = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2).float() / 255.0
tracking_video = tracking_video.permute(0, 2, 3, 1) # [B, H, W, C]
print("tracking_video shape: ", tracking_video.shape)
return (tracking_video,)
def valid_mask(self, pixels, W, H):
"""Check if pixels are within valid image bounds
Args:
pixels (numpy.ndarray): Pixel coordinates of shape [N, 2]
W (int): Image width
H (int): Image height
Returns:
numpy.ndarray: Boolean mask of valid pixels
"""
return ((pixels[:, 0] >= 0) & (pixels[:, 0] < W) & (pixels[:, 1] > 0) & \
(pixels[:, 1] < H))
def sort_points_by_depth(self, points, depths):
"""Sort points by depth values
Args:
points (numpy.ndarray): Points array of shape [N, 2]
depths (numpy.ndarray): Depth values of shape [N]
Returns:
tuple: (sorted_points, sorted_depths, sort_index)
"""
# Combine points and depths into a single array for sorting
combined = np.hstack((points, depths[:, None])) # Nx3 (points + depth)
# Sort by depth (last column) in descending order
sort_index = combined[:, -1].argsort()[::-1]
sorted_combined = combined[sort_index]
# Split back into points and depths
sorted_points = sorted_combined[:, :-1]
sorted_depths = sorted_combined[:, -1]
return sorted_points, sorted_depths, sort_index
def draw_rectangle(self, rgb, coord, side_length, color=(255, 0, 0)):
"""Draw a rectangle on the image
Args:
rgb (PIL.Image): Image to draw on
coord (tuple): Center coordinates (x, y)
side_length (int): Length of rectangle sides
color (tuple): RGB color tuple
"""
draw = ImageDraw.Draw(rgb)
# Calculate the bounding box of the rectangle
left_up_point = (coord[0] - side_length//2, coord[1] - side_length//2)
right_down_point = (coord[0] + side_length//2, coord[1] + side_length//2)
color = tuple(list(color))
draw.rectangle(
[left_up_point, right_down_point],
fill=tuple(color),
outline=tuple(color),
)
NODE_CLASS_MAPPINGS = {
"CogVideoDASTrackingEncode": CogVideoDASTrackingEncode,
"DAS_SpaTracker": DAS_SpaTracker,
"DAS_SpaTrackerModelLoader": DAS_SpaTrackerModelLoader,
"DAS_MoGeTracker": DAS_MoGeTracker,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoDASTrackingEncode": "CogVideo DAS Tracking Encode",
"DAS_SpaTracker": "DAS SpaTracker",
"DAS_SpaTrackerModelLoader": "DAS SpaTracker Model Loader",
"DAS_MoGeTracker": "DAS MoGe Tracker",
}

382
das/motion.py Normal file
View File

@ -0,0 +1,382 @@
import torch
import numpy as np
import math
class CameraMotionGenerator:
def __init__(self, motion_type, frame_num=49, H=480, W=720, fx=None, fy=None, fov=55, device='cuda'):
self.motion_type = motion_type
self.frame_num = frame_num
self.fov = fov
self.device = device
self.W = W
self.H = H
self.intr = torch.tensor([
[0, 0, W / 2],
[0, 0, H / 2],
[0, 0, 1]
], dtype=torch.float32, device=device)
# if fx, fy not provided
if not fx or not fy:
fov_rad = math.radians(fov)
fx = fy = (W / 2) / math.tan(fov_rad / 2)
self.intr[0, 0] = fx
self.intr[1, 1] = fy
def _apply_poses(self, pts, poses):
"""
Args:
pts (torch.Tensor): pointclouds coordinates [T, N, 3]
intr (torch.Tensor): camera intrinsics [T, 3, 3]
poses (numpy.ndarray): camera poses [T, 4, 4]
"""
if isinstance(poses, np.ndarray):
poses = torch.from_numpy(poses)
intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1).to(torch.float)
T, N, _ = pts.shape
ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float)
pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3)
pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3)
pts_cam[:,:, :3] *= pts[:, :, 2:3]
# to homogeneous
pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4)
if poses.shape[0] == 1:
poses = poses.repeat(T, 1, 1)
elif poses.shape[0] != T:
raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})")
poses = poses.to(torch.float).to(self.device)
pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3)
pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3)
pts_proj[:, :, :2] /= pts_proj[:, :, 2:3]
return pts_proj
def w2s(self, pts, poses):
if isinstance(poses, np.ndarray):
poses = torch.from_numpy(poses)
assert poses.shape[0] == self.frame_num
poses = poses.to(torch.float32).to(self.device)
T, N, _ = pts.shape # (T, N, 3)
intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1)
# Step 1: 扩展点的维度,使其变成 (T, N, 4)最后一维填充1 (齐次坐标)
ones = torch.ones((T, N, 1), device=self.device, dtype=pts.dtype)
points_world_h = torch.cat([pts, ones], dim=-1)
points_camera_h = torch.bmm(poses, points_world_h.permute(0, 2, 1))
points_camera = points_camera_h[:, :3, :].permute(0, 2, 1)
points_image_h = torch.bmm(points_camera, intr.permute(0, 2, 1))
uv = points_image_h[:, :, :2] / points_image_h[:, :, 2:3]
# Step 5: 提取深度 (Z) 并拼接
depth = points_camera[:, :, 2:3] # (T, N, 1)
uvd = torch.cat([uv, depth], dim=-1) # (T, N, 3)
return uvd # 屏幕坐标 + 深度 (T, N, 3)
def apply_motion_on_pts(self, pts, camera_motion):
tracking_pts = self._apply_poses(pts.squeeze(), camera_motion).unsqueeze(0)
return tracking_pts
def set_intr(self, K):
if isinstance(K, np.ndarray):
K = torch.from_numpy(K)
self.intr = K.to(self.device)
def rot_poses(self, angle, axis='y'):
"""
pts (torch.Tensor): [T, N, 3]
angle (int): angle of rotation (degree)
"""
angle_rad = math.radians(angle)
angles = torch.linspace(0, angle_rad, self.frame_num)
rot_mats = torch.zeros(self.frame_num, 4, 4)
for i, theta in enumerate(angles):
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
if axis == 'x':
rot_mats[i] = torch.tensor([
[1, 0, 0, 0],
[0, cos_theta, -sin_theta, 0],
[0, sin_theta, cos_theta, 0],
[0, 0, 0, 1]
], dtype=torch.float32)
elif axis == 'y':
rot_mats[i] = torch.tensor([
[cos_theta, 0, sin_theta, 0],
[0, 1, 0, 0],
[-sin_theta, 0, cos_theta, 0],
[0, 0, 0, 1]
], dtype=torch.float32)
elif axis == 'z':
rot_mats[i] = torch.tensor([
[cos_theta, -sin_theta, 0, 0],
[sin_theta, cos_theta, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]
], dtype=torch.float32)
else:
raise ValueError("Invalid axis value. Choose 'x', 'y', or 'z'.")
return rot_mats.to(self.device)
def trans_poses(self, dx, dy, dz):
"""
params:
- dx: float, displacement along x axis
- dy: float, displacement along y axis
- dz: float, displacement along z axis
ret:
- matrices: torch.Tensor
"""
trans_mats = torch.eye(4).unsqueeze(0).repeat(self.frame_num, 1, 1) # (n, 4, 4)
delta_x = dx / (self.frame_num - 1)
delta_y = dy / (self.frame_num - 1)
delta_z = dz / (self.frame_num - 1)
for i in range(self.frame_num):
trans_mats[i, 0, 3] = i * delta_x
trans_mats[i, 1, 3] = i * delta_y
trans_mats[i, 2, 3] = i * delta_z
return trans_mats.to(self.device)
def _look_at(self, camera_position, target_position):
# look at direction
# import ipdb;ipdb.set_trace()
direction = target_position - camera_position
direction /= np.linalg.norm(direction)
# calculate rotation matrix
up = np.array([0, 1, 0])
right = np.cross(up, direction)
right /= np.linalg.norm(right)
up = np.cross(direction, right)
rotation_matrix = np.vstack([right, up, direction])
rotation_matrix = np.linalg.inv(rotation_matrix)
return rotation_matrix
def spiral_poses(self, radius, forward_ratio = 0.5, backward_ratio = 0.5, rotation_times = 0.1, look_at_times = 0.5):
"""Generate spiral camera poses
Args:
radius (float): Base radius of the spiral
forward_ratio (float): Scale factor for forward motion
backward_ratio (float): Scale factor for backward motion
rotation_times (float): Number of rotations to complete
look_at_times (float): Scale factor for look-at point distance
Returns:
torch.Tensor: Camera poses of shape [num_frames, 4, 4]
"""
# Generate spiral trajectory
t = np.linspace(0, 1, self.frame_num)
r = np.sin(np.pi * t) * radius * rotation_times
theta = 2 * np.pi * t
# Calculate camera positions
# Limit y motion for better floor/sky view
y = r * np.cos(theta) * 0.3
x = r * np.sin(theta)
z = -r
z[z < 0] *= forward_ratio
z[z > 0] *= backward_ratio
# Set look-at target
target_pos = np.array([0, 0, radius * look_at_times])
cam_pos = np.vstack([x, y, z]).T
cam_poses = []
for pos in cam_pos:
rot_mat = self._look_at(pos, target_pos)
trans_mat = np.eye(4)
trans_mat[:3, :3] = rot_mat
trans_mat[:3, 3] = pos
cam_poses.append(trans_mat[None])
camera_poses = np.concatenate(cam_poses, axis=0)
return torch.from_numpy(camera_poses).to(self.device)
def rot(self, pts, angle, axis):
"""
pts: torch.Tensor, (T, N, 2)
"""
rot_mats = self.rot_poses(angle, axis)
pts = self.apply_motion_on_pts(pts, rot_mats)
return pts
def trans(self, pts, dx, dy, dz):
if pts.shape[-1] != 3:
raise ValueError("points should be in the 3d coordinate.")
trans_mats = self.trans_poses(dx, dy, dz)
pts = self.apply_motion_on_pts(pts, trans_mats)
return pts
def spiral(self, pts, radius):
spiral_poses = self.spiral_poses(radius)
pts = self.apply_motion_on_pts(pts, spiral_poses)
return pts
def get_default_motion(self):
if self.motion_type == 'none':
motion = torch.eye(4).unsqueeze(0).repeat(self.frame_num, 1, 1).to(self.device)
elif self.motion_type == 'trans':
motion = self.trans_poses(0.02, 0, 0)
elif self.motion_type == 'spiral':
motion = self.spiral_poses(1)
elif self.motion_type == 'rot':
motion = self.rot_poses(-25, 'y')
else:
raise ValueError(f'camera_motion must be in [trans, spiral, rot], but get {self.motion_type}.')
return motion
class ObjectMotionGenerator:
def __init__(self, num_frames=49, device="cuda:0"):
"""Initialize ObjectMotionGenerator
Args:
device (str): Device to run on
"""
self.device = device
self.num_frames = num_frames
def _get_points_in_mask(self, pred_tracks, mask):
"""Get points that fall within the mask in first frame
Args:
pred_tracks (torch.Tensor): [num_frames, num_points, 3]
mask (torch.Tensor): [H, W] binary mask
Returns:
torch.Tensor: Boolean mask of selected points [num_points]
"""
first_frame_points = pred_tracks[0] # [num_points, 3]
xy_points = first_frame_points[:, :2] # [num_points, 2]
# Convert xy coordinates to pixel indices
xy_pixels = xy_points.round().long() # Convert to integer pixel coordinates
# Clamp coordinates to valid range
xy_pixels[:, 0].clamp_(0, mask.shape[1] - 1) # x coordinates
xy_pixels[:, 1].clamp_(0, mask.shape[0] - 1) # y coordinates
# Get mask values at point locations
points_in_mask = mask[xy_pixels[:, 1], xy_pixels[:, 0]] # Index using y, x order
return points_in_mask
def generate_motion(self, mask, motion_type, distance, num_frames=49):
"""Generate motion dictionary for the given parameters
Args:
mask (torch.Tensor): [H, W] binary mask
motion_type (str): Motion direction ('up', 'down', 'left', 'right')
distance (float): Total distance to move
num_frames (int): Number of frames
Returns:
dict: Motion dictionary containing:
- mask (torch.Tensor): Binary mask
- motions (torch.Tensor): Per-frame motion vectors [num_frames, 4, 4]
"""
self.num_frames = num_frames
# Define motion template vectors
template = {
"none": torch.tensor([0, 0, 0]),
'up': torch.tensor([0, -1, 0]),
'down': torch.tensor([0, 1, 0]),
'left': torch.tensor([-1, 0, 0]),
'right': torch.tensor([1, 0, 0]),
'front': torch.tensor([0, 0, 1]),
'back': torch.tensor([0, 0, -1])
}
if motion_type not in template:
raise ValueError(f"Unknown motion type: {motion_type}")
# Move mask to device
mask = mask.to(self.device)
# Generate per-frame motion matrices
motions = []
base_vec = template[motion_type].to(self.device) * distance
for frame_idx in range(num_frames):
# Calculate interpolation factor (0 to 1)
t = frame_idx / (num_frames - 1)
# Create motion matrix for current frame
current_motion = torch.eye(4, device=self.device)
current_motion[:3, 3] = base_vec * t
motions.append(current_motion)
motions = torch.stack(motions) # [num_frames, 4, 4]
return {
'mask': mask,
'motions': motions
}
def apply_motion(self, pred_tracks, motion_dict, tracking_method="spatracker"):
"""Apply motion to selected points
Args:
pred_tracks (torch.Tensor): [num_frames, num_points, 3] for spatracker
or [T, H, W, 3] for moge
motion_dict (dict): Motion dictionary containing mask and motions
tracking_method (str): "spatracker" or "moge"
Returns:
torch.Tensor: Modified pred_tracks with same shape as input
"""
pred_tracks = pred_tracks.to(self.device).float()
if tracking_method == "moge":
H = pred_tracks.shape[0]
W = pred_tracks.shape[1]
initial_points = pred_tracks # [H, W, 3]
selected_mask = motion_dict['mask']
valid_selected = ~torch.any(torch.isnan(initial_points), dim=2) & selected_mask
valid_selected = valid_selected.reshape([-1])
modified_tracks = pred_tracks.clone().reshape(-1, 3).unsqueeze(0).repeat(self.num_frames, 1, 1)
# import ipdb;ipdb.set_trace()
for frame_idx in range(self.num_frames):
# Get current frame motion
motion_mat = motion_dict['motions'][frame_idx]
# Moge's pointcloud is scale-invairant
motion_mat[0, 3] /= W
motion_mat[1, 3] /= H
# Apply motion to selected points
points = modified_tracks[frame_idx, valid_selected]
# Convert to homogeneous coordinates
points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1)
# Apply transformation
transformed_points = torch.matmul(points_homo, motion_mat.T)
# Convert back to 3D coordinates
modified_tracks[frame_idx, valid_selected] = transformed_points[:, :3]
return modified_tracks
else:
points_in_mask = self._get_points_in_mask(pred_tracks, motion_dict['mask'])
modified_tracks = pred_tracks.clone()
for frame_idx in range(pred_tracks.shape[0]):
motion_mat = motion_dict['motions'][frame_idx]
points = modified_tracks[frame_idx, points_in_mask]
points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1)
transformed_points = torch.matmul(points_homo, motion_mat.T)
modified_tracks[frame_idx, points_in_mask] = transformed_points[:, :3]
return modified_tracks