mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-01-23 16:34:22 +08:00
Update das_nodes.py
This commit is contained in:
parent
d3601e3fa3
commit
ee7d04d342
@ -71,11 +71,39 @@ class CogVideoDASTrackingEncode:
|
||||
"start_percent": start_percent,
|
||||
"end_percent": end_percent
|
||||
}, )
|
||||
|
||||
|
||||
class DAS_SpaTrackerModelLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model": (folder_paths.get_filename_list("CogVideo"), {"tooltip": "These models are loaded from the 'ComfyUI/models/CogVideo' -folder",}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("SPATRACKERMODEL",)
|
||||
RETURN_NAMES = ("spatracker_model",)
|
||||
FUNCTION = "load"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def load(self, model):
|
||||
device = mm.get_torch_device()
|
||||
|
||||
model_path = folder_paths.get_full_path("CogVideo", model)
|
||||
from .spatracker.predictor import SpaTrackerPredictor
|
||||
|
||||
spatracker = SpaTrackerPredictor(
|
||||
checkpoint=model_path,
|
||||
interp_shape=(384, 576),
|
||||
seq_length=12
|
||||
).to(device)
|
||||
|
||||
return (spatracker,)
|
||||
|
||||
class DAS_SpaTracker:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"spatracker": ("SPATRACKERMODEL",),
|
||||
"images": ("IMAGE", ),
|
||||
"depth_images": ("IMAGE", ),
|
||||
"density": ("INT", {"default": 70, "min": 1, "max": 100, "step": 1}),
|
||||
@ -87,29 +115,19 @@ class DAS_SpaTracker:
|
||||
FUNCTION = "encode"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def encode(self, images, depth_images, density):
|
||||
def encode(self, spatracker, images, depth_images, density):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
|
||||
model_path = os.path.join(folder_paths.models_dir, 'spatracker', 'spaT_final.pth')
|
||||
from .spatracker.predictor import SpaTrackerPredictor
|
||||
from .spatracker.utils.visualizer import Visualizer
|
||||
|
||||
if not hasattr(self, "tracker"):
|
||||
self.tracker = SpaTrackerPredictor(
|
||||
checkpoint=model_path,
|
||||
interp_shape=(384, 576),
|
||||
seq_length=12
|
||||
).to(device)
|
||||
|
||||
segm_mask = np.ones((480, 720), dtype=np.uint8)
|
||||
|
||||
video = images.permute(0, 3, 1, 2).to(device).unsqueeze(0)
|
||||
video_depth = depth_images.permute(0, 3, 1, 2).to(device)
|
||||
video_depth = video_depth[:, 0:1, :, :]
|
||||
|
||||
spatracker.to(device)
|
||||
|
||||
pred_tracks, pred_visibility, T_Firsts = self.tracker(
|
||||
pred_tracks, pred_visibility, T_Firsts = spatracker(
|
||||
video * 255,
|
||||
video_depth=video_depth,
|
||||
grid_size=density,
|
||||
@ -121,7 +139,11 @@ class DAS_SpaTracker:
|
||||
progressive_tracking=False
|
||||
)
|
||||
|
||||
spatracker.to(offload_device)
|
||||
|
||||
from .spatracker.utils.visualizer import Visualizer
|
||||
vis = Visualizer(grayscale=False, fps=24, pad_value=0)
|
||||
|
||||
msk_query = (T_Firsts == 0)
|
||||
pred_tracks = pred_tracks[:,:,msk_query.squeeze()]
|
||||
pred_visibility = pred_visibility[:,:,msk_query.squeeze()]
|
||||
@ -139,8 +161,10 @@ class DAS_SpaTracker:
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CogVideoDASTrackingEncode": CogVideoDASTrackingEncode,
|
||||
"DAS_SpaTracker": DAS_SpaTracker,
|
||||
"DAS_SpaTrackerModelLoader": DAS_SpaTrackerModelLoader,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CogVideoDASTrackingEncode": "CogVideo DAS Tracking Encode",
|
||||
"DAS_SpaTracker": "DAS SpaTracker",
|
||||
"DAS_SpaTrackerModelLoader": "DAS SpaTracker Model Loader",
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user