From 1124c77d56e88b848fcd2cbb7fbf52022b55ce49 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 17 Feb 2025 02:48:15 +0200 Subject: [PATCH] add moge --- das/das_nodes.py | 259 +++++++++++++++++++++++++++++++- das/motion.py | 382 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 637 insertions(+), 4 deletions(-) create mode 100644 das/motion.py diff --git a/das/das_nodes.py b/das/das_nodes.py index 7fac893..529e16e 100644 --- a/das/das_nodes.py +++ b/das/das_nodes.py @@ -1,9 +1,14 @@ import torch import comfy.model_management as mm +from comfy.utils import ProgressBar, common_upscale from ..utils import log import os import numpy as np import folder_paths +from tqdm import tqdm +from PIL import Image, ImageDraw + +from .motion import CameraMotionGenerator, ObjectMotionGenerator class CogVideoDASTrackingEncode: @classmethod @@ -139,32 +144,278 @@ class DAS_SpaTracker: progressive_tracking=False ) + # cam_motion = CameraMotionGenerator( + # motion_type="trans", + # frame_num=49, + # W=720, + # H=480, + # fx=None, + # fy=None, + # fov=55, + # device=device, + # ) + # poses = cam_motion.get_default_motion() # shape: [49, 4, 4] + # pred_tracks = cam_motion.apply_motion_on_pts(pred_tracks, poses) + # print("Camera motion applied") + spatracker.to(offload_device) from .spatracker.utils.visualizer import Visualizer - vis = Visualizer(grayscale=False, fps=24, pad_value=0) + vis = Visualizer( + grayscale=False, + fps=24, + pad_value=0, + #tracks_leave_trace=-1 + ) msk_query = (T_Firsts == 0) pred_tracks = pred_tracks[:,:,msk_query.squeeze()] pred_visibility = pred_visibility[:,:,msk_query.squeeze()] - tracking_video = vis.visualize(video=video, tracks=pred_tracks, - visibility=pred_visibility, save_video=False, - filename="temp") + tracking_video = vis.visualize( + video=video, + tracks=pred_tracks, + visibility=pred_visibility, + save_video=False, + ) tracking_video = tracking_video.squeeze(0).permute(0, 2, 3, 1) # [T, H, W, C] tracking_video = (tracking_video / 255.0).float() return (tracking_video,) + +class DAS_MoGeTracker: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("MOGEMODEL",), + "image": ("IMAGE", ), + "num_frames": ("INT", {"default": 49, "min": 1, "max": 100, "step": 1}), + "width": ("INT", {"default": 720, "min": 1, "max": 10000, "step": 1}), + "height": ("INT", {"default": 480, "min": 1, "max": 10000, "step": 1}), + "fov": ("FLOAT", {"default": 55.0, "min": 1.0, "max": 180.0, "step": 1.0}), + "object_motion_type": (["none", "up", "down", "left", "right", "front", "back"],), + "object_motion_distance": ("INT", {"default": 50, "min": 1, "max": 1000, "step": 1}), + "camera_motion_type": (["none","translation", "rotation", "spiral"],), + }, + "optional": { + "mask": ("MASK", ), + }, + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("tracking_video",) + FUNCTION = "encode" + CATEGORY = "CogVideoWrapper" + + def encode(self, model, image, num_frames, width, height, fov, object_motion_type, object_motion_distance, camera_motion_type, mask=None): + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + B, H, W, C = image.shape + + image_resized = common_upscale(image.movedim(-1,1), width, height, "lanczos", "disabled").movedim(1,-1) + + # Use the first frame from previously loaded video_tensor + infer_result = model.infer(image_resized.permute(0, 3, 1, 2).to(device)[0].to(device)) # [C, H, W] in range [0,1] + H, W = infer_result["points"].shape[0:2] + + motion_generator = ObjectMotionGenerator(num_frames, device=device) + + if mask is not None: + mask = mask[0].bool() + mask = torch.nn.functional.interpolate( + mask[None, None].float(), + size=(H, W), + mode='nearest' + )[0, 0].bool() + else: + mask = torch.ones(H, W, dtype=torch.bool) + + # Generate motion dictionary + motion_dict = motion_generator.generate_motion( + mask=mask, + motion_type=object_motion_type, + distance=object_motion_distance, + num_frames=num_frames, + ) + + pred_tracks = motion_generator.apply_motion( + infer_result["points"], + motion_dict, + tracking_method="moge" + ) + print("pred_tracks shape: ", pred_tracks.shape) + print("Object motion applied") + + camera_motion_type_mapping = { + "none": "none", + "translation": "trans", + "rotation": "rot", + "spiral": "spiral" + } + cam_motion = CameraMotionGenerator( + motion_type=camera_motion_type_mapping[camera_motion_type], + frame_num=num_frames, + W=width, + H=height, + fx=None, + fy=None, + fov=fov, + device=device, + ) + # Apply camera motion if specified + cam_motion.set_intr(infer_result["intrinsics"]) + poses = cam_motion.get_default_motion() # shape: [49, 4, 4] + + pred_tracks_flatten = pred_tracks.reshape(num_frames, H*W, 3) + pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([num_frames, H, W, 3]) # [T, H, W, 3] + print("Camera motion applied") + + + points = pred_tracks.cpu().numpy() + mask = infer_result["mask"].cpu().numpy() + # Create color array + T, H, W, _ = pred_tracks.shape + + print("points shape: ", points.shape) + + print("mask shape: ", mask.shape) + colors = np.zeros((H, W, 3), dtype=np.uint8) + + # Set R channel - based on x coordinates (smaller on the left) + colors[:, :, 0] = np.tile(np.linspace(0, 255, W), (H, 1)) + + # Set G channel - based on y coordinates (smaller on the top) + colors[:, :, 1] = np.tile(np.linspace(0, 255, H), (W, 1)).T + + # Set B channel - based on depth + z_values = points[0, :, :, 2] # get z values + inv_z = 1 / z_values # calculate 1/z + # Calculate 2% and 98% percentiles + p2 = np.percentile(inv_z, 2) + p98 = np.percentile(inv_z, 98) + # Normalize to [0,1] range + normalized_z = np.clip((inv_z - p2) / (p98 - p2), 0, 1) + colors[:, :, 2] = (normalized_z * 255).astype(np.uint8) + colors = colors.astype(np.uint8) + + # First reshape points and colors + points = points.reshape(T, -1, 3) # (T, H*W, 3) + colors = colors.reshape(-1, 3) # (H*W, 3) + + # Create mask for each frame + mask = mask.reshape(-1) # Flatten mask to (H*W,) + + # Apply mask + points = points[:, mask, :] # (T, masked_points, 3) + colors = colors[mask] # (masked_points, 3) + + # Repeat colors for each frame + colors = colors.reshape(1, -1, 3).repeat(T, axis=0) # (T, masked_points, 3) + + # Initialize list to store frames + frames = [] + pbar = ProgressBar(len(points)) + + for i, pts_i in enumerate(tqdm(points)): + pixels, depths = pts_i[..., :2], pts_i[..., 2] + pixels[..., 0] = pixels[..., 0] * W + pixels[..., 1] = pixels[..., 1] * H + pixels = pixels.astype(int) + + valid = self.valid_mask(pixels, W, H) + + frame_rgb = colors[i][valid] + pixels = pixels[valid] + depths = depths[valid] + + img = Image.fromarray(np.uint8(np.zeros([H, W, 3])), mode="RGB") + sorted_pixels, _, sort_index = self.sort_points_by_depth(pixels, depths) + step = 1 + sorted_pixels = sorted_pixels[::step] + sorted_rgb = frame_rgb[sort_index][::step] + + for j in range(sorted_pixels.shape[0]): + self.draw_rectangle( + img, + coord=(sorted_pixels[j, 0], sorted_pixels[j, 1]), + side_length=2, + color=sorted_rgb[j], + ) + frames.append(np.array(img)) + pbar.update(1) + + # Convert frames to video tensor in range [0,1] + tracking_video = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2).float() / 255.0 + tracking_video = tracking_video.permute(0, 2, 3, 1) # [B, H, W, C] + print("tracking_video shape: ", tracking_video.shape) + return (tracking_video,) + + def valid_mask(self, pixels, W, H): + """Check if pixels are within valid image bounds + + Args: + pixels (numpy.ndarray): Pixel coordinates of shape [N, 2] + W (int): Image width + H (int): Image height + + Returns: + numpy.ndarray: Boolean mask of valid pixels + """ + return ((pixels[:, 0] >= 0) & (pixels[:, 0] < W) & (pixels[:, 1] > 0) & \ + (pixels[:, 1] < H)) + + def sort_points_by_depth(self, points, depths): + """Sort points by depth values + + Args: + points (numpy.ndarray): Points array of shape [N, 2] + depths (numpy.ndarray): Depth values of shape [N] + + Returns: + tuple: (sorted_points, sorted_depths, sort_index) + """ + # Combine points and depths into a single array for sorting + combined = np.hstack((points, depths[:, None])) # Nx3 (points + depth) + # Sort by depth (last column) in descending order + sort_index = combined[:, -1].argsort()[::-1] + sorted_combined = combined[sort_index] + # Split back into points and depths + sorted_points = sorted_combined[:, :-1] + sorted_depths = sorted_combined[:, -1] + return sorted_points, sorted_depths, sort_index + + def draw_rectangle(self, rgb, coord, side_length, color=(255, 0, 0)): + """Draw a rectangle on the image + + Args: + rgb (PIL.Image): Image to draw on + coord (tuple): Center coordinates (x, y) + side_length (int): Length of rectangle sides + color (tuple): RGB color tuple + """ + draw = ImageDraw.Draw(rgb) + # Calculate the bounding box of the rectangle + left_up_point = (coord[0] - side_length//2, coord[1] - side_length//2) + right_down_point = (coord[0] + side_length//2, coord[1] + side_length//2) + color = tuple(list(color)) + + draw.rectangle( + [left_up_point, right_down_point], + fill=tuple(color), + outline=tuple(color), + ) NODE_CLASS_MAPPINGS = { "CogVideoDASTrackingEncode": CogVideoDASTrackingEncode, "DAS_SpaTracker": DAS_SpaTracker, "DAS_SpaTrackerModelLoader": DAS_SpaTrackerModelLoader, + "DAS_MoGeTracker": DAS_MoGeTracker, } NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoDASTrackingEncode": "CogVideo DAS Tracking Encode", "DAS_SpaTracker": "DAS SpaTracker", "DAS_SpaTrackerModelLoader": "DAS SpaTracker Model Loader", + "DAS_MoGeTracker": "DAS MoGe Tracker", } diff --git a/das/motion.py b/das/motion.py new file mode 100644 index 0000000..5f44a3f --- /dev/null +++ b/das/motion.py @@ -0,0 +1,382 @@ +import torch +import numpy as np +import math + +class CameraMotionGenerator: + def __init__(self, motion_type, frame_num=49, H=480, W=720, fx=None, fy=None, fov=55, device='cuda'): + self.motion_type = motion_type + self.frame_num = frame_num + self.fov = fov + self.device = device + self.W = W + self.H = H + self.intr = torch.tensor([ + [0, 0, W / 2], + [0, 0, H / 2], + [0, 0, 1] + ], dtype=torch.float32, device=device) + # if fx, fy not provided + if not fx or not fy: + fov_rad = math.radians(fov) + fx = fy = (W / 2) / math.tan(fov_rad / 2) + + self.intr[0, 0] = fx + self.intr[1, 1] = fy + + def _apply_poses(self, pts, poses): + """ + Args: + pts (torch.Tensor): pointclouds coordinates [T, N, 3] + intr (torch.Tensor): camera intrinsics [T, 3, 3] + poses (numpy.ndarray): camera poses [T, 4, 4] + """ + if isinstance(poses, np.ndarray): + poses = torch.from_numpy(poses) + + intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1).to(torch.float) + T, N, _ = pts.shape + ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float) + pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3) + pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3) + pts_cam[:,:, :3] *= pts[:, :, 2:3] + + # to homogeneous + pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4) + + if poses.shape[0] == 1: + poses = poses.repeat(T, 1, 1) + elif poses.shape[0] != T: + raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})") + + poses = poses.to(torch.float).to(self.device) + pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3) + pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3) + pts_proj[:, :, :2] /= pts_proj[:, :, 2:3] + + return pts_proj + + def w2s(self, pts, poses): + if isinstance(poses, np.ndarray): + poses = torch.from_numpy(poses) + assert poses.shape[0] == self.frame_num + poses = poses.to(torch.float32).to(self.device) + T, N, _ = pts.shape # (T, N, 3) + intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1) + # Step 1: 扩展点的维度,使其变成 (T, N, 4),最后一维填充1 (齐次坐标) + ones = torch.ones((T, N, 1), device=self.device, dtype=pts.dtype) + points_world_h = torch.cat([pts, ones], dim=-1) + points_camera_h = torch.bmm(poses, points_world_h.permute(0, 2, 1)) + points_camera = points_camera_h[:, :3, :].permute(0, 2, 1) + + points_image_h = torch.bmm(points_camera, intr.permute(0, 2, 1)) + + uv = points_image_h[:, :, :2] / points_image_h[:, :, 2:3] + + # Step 5: 提取深度 (Z) 并拼接 + depth = points_camera[:, :, 2:3] # (T, N, 1) + uvd = torch.cat([uv, depth], dim=-1) # (T, N, 3) + + return uvd # 屏幕坐标 + 深度 (T, N, 3) + + def apply_motion_on_pts(self, pts, camera_motion): + tracking_pts = self._apply_poses(pts.squeeze(), camera_motion).unsqueeze(0) + return tracking_pts + + def set_intr(self, K): + if isinstance(K, np.ndarray): + K = torch.from_numpy(K) + self.intr = K.to(self.device) + + def rot_poses(self, angle, axis='y'): + """ + pts (torch.Tensor): [T, N, 3] + angle (int): angle of rotation (degree) + """ + angle_rad = math.radians(angle) + angles = torch.linspace(0, angle_rad, self.frame_num) + rot_mats = torch.zeros(self.frame_num, 4, 4) + + for i, theta in enumerate(angles): + cos_theta = torch.cos(theta) + sin_theta = torch.sin(theta) + if axis == 'x': + rot_mats[i] = torch.tensor([ + [1, 0, 0, 0], + [0, cos_theta, -sin_theta, 0], + [0, sin_theta, cos_theta, 0], + [0, 0, 0, 1] + ], dtype=torch.float32) + elif axis == 'y': + rot_mats[i] = torch.tensor([ + [cos_theta, 0, sin_theta, 0], + [0, 1, 0, 0], + [-sin_theta, 0, cos_theta, 0], + [0, 0, 0, 1] + ], dtype=torch.float32) + + elif axis == 'z': + rot_mats[i] = torch.tensor([ + [cos_theta, -sin_theta, 0, 0], + [sin_theta, cos_theta, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1] + ], dtype=torch.float32) + else: + raise ValueError("Invalid axis value. Choose 'x', 'y', or 'z'.") + + return rot_mats.to(self.device) + + def trans_poses(self, dx, dy, dz): + """ + params: + - dx: float, displacement along x axis。 + - dy: float, displacement along y axis。 + - dz: float, displacement along z axis。 + + ret: + - matrices: torch.Tensor + """ + trans_mats = torch.eye(4).unsqueeze(0).repeat(self.frame_num, 1, 1) # (n, 4, 4) + + delta_x = dx / (self.frame_num - 1) + delta_y = dy / (self.frame_num - 1) + delta_z = dz / (self.frame_num - 1) + + for i in range(self.frame_num): + trans_mats[i, 0, 3] = i * delta_x + trans_mats[i, 1, 3] = i * delta_y + trans_mats[i, 2, 3] = i * delta_z + + return trans_mats.to(self.device) + + + def _look_at(self, camera_position, target_position): + # look at direction + # import ipdb;ipdb.set_trace() + direction = target_position - camera_position + direction /= np.linalg.norm(direction) + # calculate rotation matrix + up = np.array([0, 1, 0]) + right = np.cross(up, direction) + right /= np.linalg.norm(right) + up = np.cross(direction, right) + rotation_matrix = np.vstack([right, up, direction]) + rotation_matrix = np.linalg.inv(rotation_matrix) + return rotation_matrix + + def spiral_poses(self, radius, forward_ratio = 0.5, backward_ratio = 0.5, rotation_times = 0.1, look_at_times = 0.5): + """Generate spiral camera poses + + Args: + radius (float): Base radius of the spiral + forward_ratio (float): Scale factor for forward motion + backward_ratio (float): Scale factor for backward motion + rotation_times (float): Number of rotations to complete + look_at_times (float): Scale factor for look-at point distance + + Returns: + torch.Tensor: Camera poses of shape [num_frames, 4, 4] + """ + # Generate spiral trajectory + t = np.linspace(0, 1, self.frame_num) + r = np.sin(np.pi * t) * radius * rotation_times + theta = 2 * np.pi * t + + # Calculate camera positions + # Limit y motion for better floor/sky view + y = r * np.cos(theta) * 0.3 + x = r * np.sin(theta) + z = -r + z[z < 0] *= forward_ratio + z[z > 0] *= backward_ratio + + # Set look-at target + target_pos = np.array([0, 0, radius * look_at_times]) + cam_pos = np.vstack([x, y, z]).T + cam_poses = [] + + for pos in cam_pos: + rot_mat = self._look_at(pos, target_pos) + trans_mat = np.eye(4) + trans_mat[:3, :3] = rot_mat + trans_mat[:3, 3] = pos + cam_poses.append(trans_mat[None]) + + camera_poses = np.concatenate(cam_poses, axis=0) + return torch.from_numpy(camera_poses).to(self.device) + + def rot(self, pts, angle, axis): + """ + pts: torch.Tensor, (T, N, 2) + """ + rot_mats = self.rot_poses(angle, axis) + pts = self.apply_motion_on_pts(pts, rot_mats) + return pts + + def trans(self, pts, dx, dy, dz): + if pts.shape[-1] != 3: + raise ValueError("points should be in the 3d coordinate.") + trans_mats = self.trans_poses(dx, dy, dz) + pts = self.apply_motion_on_pts(pts, trans_mats) + return pts + + def spiral(self, pts, radius): + spiral_poses = self.spiral_poses(radius) + pts = self.apply_motion_on_pts(pts, spiral_poses) + return pts + + def get_default_motion(self): + if self.motion_type == 'none': + motion = torch.eye(4).unsqueeze(0).repeat(self.frame_num, 1, 1).to(self.device) + elif self.motion_type == 'trans': + motion = self.trans_poses(0.02, 0, 0) + elif self.motion_type == 'spiral': + motion = self.spiral_poses(1) + elif self.motion_type == 'rot': + motion = self.rot_poses(-25, 'y') + else: + raise ValueError(f'camera_motion must be in [trans, spiral, rot], but get {self.motion_type}.') + + return motion + +class ObjectMotionGenerator: + def __init__(self, num_frames=49, device="cuda:0"): + """Initialize ObjectMotionGenerator + + Args: + device (str): Device to run on + """ + self.device = device + self.num_frames = num_frames + + def _get_points_in_mask(self, pred_tracks, mask): + """Get points that fall within the mask in first frame + + Args: + pred_tracks (torch.Tensor): [num_frames, num_points, 3] + mask (torch.Tensor): [H, W] binary mask + + Returns: + torch.Tensor: Boolean mask of selected points [num_points] + """ + first_frame_points = pred_tracks[0] # [num_points, 3] + xy_points = first_frame_points[:, :2] # [num_points, 2] + + # Convert xy coordinates to pixel indices + xy_pixels = xy_points.round().long() # Convert to integer pixel coordinates + + # Clamp coordinates to valid range + xy_pixels[:, 0].clamp_(0, mask.shape[1] - 1) # x coordinates + xy_pixels[:, 1].clamp_(0, mask.shape[0] - 1) # y coordinates + + # Get mask values at point locations + points_in_mask = mask[xy_pixels[:, 1], xy_pixels[:, 0]] # Index using y, x order + + return points_in_mask + + def generate_motion(self, mask, motion_type, distance, num_frames=49): + """Generate motion dictionary for the given parameters + + Args: + mask (torch.Tensor): [H, W] binary mask + motion_type (str): Motion direction ('up', 'down', 'left', 'right') + distance (float): Total distance to move + num_frames (int): Number of frames + + Returns: + dict: Motion dictionary containing: + - mask (torch.Tensor): Binary mask + - motions (torch.Tensor): Per-frame motion vectors [num_frames, 4, 4] + """ + + self.num_frames = num_frames + # Define motion template vectors + template = { + "none": torch.tensor([0, 0, 0]), + 'up': torch.tensor([0, -1, 0]), + 'down': torch.tensor([0, 1, 0]), + 'left': torch.tensor([-1, 0, 0]), + 'right': torch.tensor([1, 0, 0]), + 'front': torch.tensor([0, 0, 1]), + 'back': torch.tensor([0, 0, -1]) + } + + if motion_type not in template: + raise ValueError(f"Unknown motion type: {motion_type}") + + # Move mask to device + mask = mask.to(self.device) + + # Generate per-frame motion matrices + motions = [] + base_vec = template[motion_type].to(self.device) * distance + + for frame_idx in range(num_frames): + # Calculate interpolation factor (0 to 1) + t = frame_idx / (num_frames - 1) + + # Create motion matrix for current frame + current_motion = torch.eye(4, device=self.device) + current_motion[:3, 3] = base_vec * t + motions.append(current_motion) + + motions = torch.stack(motions) # [num_frames, 4, 4] + + return { + 'mask': mask, + 'motions': motions + } + + def apply_motion(self, pred_tracks, motion_dict, tracking_method="spatracker"): + """Apply motion to selected points + + Args: + pred_tracks (torch.Tensor): [num_frames, num_points, 3] for spatracker + or [T, H, W, 3] for moge + motion_dict (dict): Motion dictionary containing mask and motions + tracking_method (str): "spatracker" or "moge" + + Returns: + torch.Tensor: Modified pred_tracks with same shape as input + """ + pred_tracks = pred_tracks.to(self.device).float() + + if tracking_method == "moge": + + H = pred_tracks.shape[0] + W = pred_tracks.shape[1] + + initial_points = pred_tracks # [H, W, 3] + selected_mask = motion_dict['mask'] + valid_selected = ~torch.any(torch.isnan(initial_points), dim=2) & selected_mask + valid_selected = valid_selected.reshape([-1]) + modified_tracks = pred_tracks.clone().reshape(-1, 3).unsqueeze(0).repeat(self.num_frames, 1, 1) + # import ipdb;ipdb.set_trace() + for frame_idx in range(self.num_frames): + # Get current frame motion + motion_mat = motion_dict['motions'][frame_idx] + # Moge's pointcloud is scale-invairant + motion_mat[0, 3] /= W + motion_mat[1, 3] /= H + # Apply motion to selected points + points = modified_tracks[frame_idx, valid_selected] + # Convert to homogeneous coordinates + points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1) + # Apply transformation + transformed_points = torch.matmul(points_homo, motion_mat.T) + # Convert back to 3D coordinates + modified_tracks[frame_idx, valid_selected] = transformed_points[:, :3] + return modified_tracks + + else: + points_in_mask = self._get_points_in_mask(pred_tracks, motion_dict['mask']) + modified_tracks = pred_tracks.clone() + + for frame_idx in range(pred_tracks.shape[0]): + motion_mat = motion_dict['motions'][frame_idx] + points = modified_tracks[frame_idx, points_in_mask] + points_homo = torch.cat([points, torch.ones_like(points[:, :1])], dim=1) + transformed_points = torch.matmul(points_homo, motion_mat.T) + modified_tracks[frame_idx, points_in_mask] = transformed_points[:, :3] + + return modified_tracks \ No newline at end of file