From ee7d04d3420f2e8d2ccb986b1eaabf7f16a20457 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 15 Feb 2025 00:45:50 +0200 Subject: [PATCH] Update das_nodes.py --- das/das_nodes.py | 54 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/das/das_nodes.py b/das/das_nodes.py index 17b9b88..7fac893 100644 --- a/das/das_nodes.py +++ b/das/das_nodes.py @@ -71,11 +71,39 @@ class CogVideoDASTrackingEncode: "start_percent": start_percent, "end_percent": end_percent }, ) - + +class DAS_SpaTrackerModelLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": (folder_paths.get_filename_list("CogVideo"), {"tooltip": "These models are loaded from the 'ComfyUI/models/CogVideo' -folder",}), + }, + } + + RETURN_TYPES = ("SPATRACKERMODEL",) + RETURN_NAMES = ("spatracker_model",) + FUNCTION = "load" + CATEGORY = "CogVideoWrapper" + + def load(self, model): + device = mm.get_torch_device() + + model_path = folder_paths.get_full_path("CogVideo", model) + from .spatracker.predictor import SpaTrackerPredictor + + spatracker = SpaTrackerPredictor( + checkpoint=model_path, + interp_shape=(384, 576), + seq_length=12 + ).to(device) + + return (spatracker,) + class DAS_SpaTracker: @classmethod def INPUT_TYPES(s): return {"required": { + "spatracker": ("SPATRACKERMODEL",), "images": ("IMAGE", ), "depth_images": ("IMAGE", ), "density": ("INT", {"default": 70, "min": 1, "max": 100, "step": 1}), @@ -87,29 +115,19 @@ class DAS_SpaTracker: FUNCTION = "encode" CATEGORY = "CogVideoWrapper" - def encode(self, images, depth_images, density): + def encode(self, spatracker, images, depth_images, density): device = mm.get_torch_device() offload_device = mm.unet_offload_device() - generator = torch.Generator(device=device).manual_seed(0) - - model_path = os.path.join(folder_paths.models_dir, 'spatracker', 'spaT_final.pth') - from .spatracker.predictor import SpaTrackerPredictor - from .spatracker.utils.visualizer import Visualizer - - if not hasattr(self, "tracker"): - self.tracker = SpaTrackerPredictor( - checkpoint=model_path, - interp_shape=(384, 576), - seq_length=12 - ).to(device) segm_mask = np.ones((480, 720), dtype=np.uint8) video = images.permute(0, 3, 1, 2).to(device).unsqueeze(0) video_depth = depth_images.permute(0, 3, 1, 2).to(device) video_depth = video_depth[:, 0:1, :, :] + + spatracker.to(device) - pred_tracks, pred_visibility, T_Firsts = self.tracker( + pred_tracks, pred_visibility, T_Firsts = spatracker( video * 255, video_depth=video_depth, grid_size=density, @@ -121,7 +139,11 @@ class DAS_SpaTracker: progressive_tracking=False ) + spatracker.to(offload_device) + + from .spatracker.utils.visualizer import Visualizer vis = Visualizer(grayscale=False, fps=24, pad_value=0) + msk_query = (T_Firsts == 0) pred_tracks = pred_tracks[:,:,msk_query.squeeze()] pred_visibility = pred_visibility[:,:,msk_query.squeeze()] @@ -139,8 +161,10 @@ class DAS_SpaTracker: NODE_CLASS_MAPPINGS = { "CogVideoDASTrackingEncode": CogVideoDASTrackingEncode, "DAS_SpaTracker": DAS_SpaTracker, + "DAS_SpaTrackerModelLoader": DAS_SpaTrackerModelLoader, } NODE_DISPLAY_NAME_MAPPINGS = { "CogVideoDASTrackingEncode": "CogVideo DAS Tracking Encode", "DAS_SpaTracker": "DAS SpaTracker", + "DAS_SpaTrackerModelLoader": "DAS SpaTracker Model Loader", }