Update das_nodes.py

This commit is contained in:
kijai 2025-02-15 00:45:50 +02:00
parent d3601e3fa3
commit ee7d04d342

View File

@ -71,11 +71,39 @@ class CogVideoDASTrackingEncode:
"start_percent": start_percent,
"end_percent": end_percent
}, )
class DAS_SpaTrackerModelLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": (folder_paths.get_filename_list("CogVideo"), {"tooltip": "These models are loaded from the 'ComfyUI/models/CogVideo' -folder",}),
},
}
RETURN_TYPES = ("SPATRACKERMODEL",)
RETURN_NAMES = ("spatracker_model",)
FUNCTION = "load"
CATEGORY = "CogVideoWrapper"
def load(self, model):
device = mm.get_torch_device()
model_path = folder_paths.get_full_path("CogVideo", model)
from .spatracker.predictor import SpaTrackerPredictor
spatracker = SpaTrackerPredictor(
checkpoint=model_path,
interp_shape=(384, 576),
seq_length=12
).to(device)
return (spatracker,)
class DAS_SpaTracker:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"spatracker": ("SPATRACKERMODEL",),
"images": ("IMAGE", ),
"depth_images": ("IMAGE", ),
"density": ("INT", {"default": 70, "min": 1, "max": 100, "step": 1}),
@ -87,29 +115,19 @@ class DAS_SpaTracker:
FUNCTION = "encode"
CATEGORY = "CogVideoWrapper"
def encode(self, images, depth_images, density):
def encode(self, spatracker, images, depth_images, density):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
generator = torch.Generator(device=device).manual_seed(0)
model_path = os.path.join(folder_paths.models_dir, 'spatracker', 'spaT_final.pth')
from .spatracker.predictor import SpaTrackerPredictor
from .spatracker.utils.visualizer import Visualizer
if not hasattr(self, "tracker"):
self.tracker = SpaTrackerPredictor(
checkpoint=model_path,
interp_shape=(384, 576),
seq_length=12
).to(device)
segm_mask = np.ones((480, 720), dtype=np.uint8)
video = images.permute(0, 3, 1, 2).to(device).unsqueeze(0)
video_depth = depth_images.permute(0, 3, 1, 2).to(device)
video_depth = video_depth[:, 0:1, :, :]
spatracker.to(device)
pred_tracks, pred_visibility, T_Firsts = self.tracker(
pred_tracks, pred_visibility, T_Firsts = spatracker(
video * 255,
video_depth=video_depth,
grid_size=density,
@ -121,7 +139,11 @@ class DAS_SpaTracker:
progressive_tracking=False
)
spatracker.to(offload_device)
from .spatracker.utils.visualizer import Visualizer
vis = Visualizer(grayscale=False, fps=24, pad_value=0)
msk_query = (T_Firsts == 0)
pred_tracks = pred_tracks[:,:,msk_query.squeeze()]
pred_visibility = pred_visibility[:,:,msk_query.squeeze()]
@ -139,8 +161,10 @@ class DAS_SpaTracker:
NODE_CLASS_MAPPINGS = {
"CogVideoDASTrackingEncode": CogVideoDASTrackingEncode,
"DAS_SpaTracker": DAS_SpaTracker,
"DAS_SpaTrackerModelLoader": DAS_SpaTrackerModelLoader,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CogVideoDASTrackingEncode": "CogVideo DAS Tracking Encode",
"DAS_SpaTracker": "DAS SpaTracker",
"DAS_SpaTrackerModelLoader": "DAS SpaTracker Model Loader",
}