Update nodes.py

This commit is contained in:
kijai 2024-09-22 21:34:17 +03:00
parent 81c5447ca3
commit 2e22d4fa0c

View File

@ -93,12 +93,19 @@ class CogVideoXPABConfig(PABConfig):
spatial_broadcast: bool = True, spatial_broadcast: bool = True,
spatial_threshold: list = [100, 850], spatial_threshold: list = [100, 850],
spatial_range: int = 2, spatial_range: int = 2,
temporal_broadcast: bool = True,
temporal_threshold: list = [100, 850],
temporal_range: int = 2,
): ):
super().__init__( super().__init__(
steps=steps, steps=steps,
spatial_broadcast=spatial_broadcast, spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold, spatial_threshold=spatial_threshold,
spatial_range=spatial_range, spatial_range=spatial_range,
temporal_broadcast=temporal_broadcast,
temporal_threshold=temporal_threshold,
temporal_range=temporal_range
) )
from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
@ -108,9 +115,13 @@ class CogVideoPABConfig:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { return {"required": {
"spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB"}), "spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB"}),
"pab_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ), "spatial_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
"pab_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ), "spatial_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
"pab_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ), "spatial_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
"temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB"}),
"temporal_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
"temporal_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
"temporal_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
"steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Steps"} ), "steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Steps"} ),
} }
} }
@ -121,13 +132,17 @@ class CogVideoPABConfig:
CATEGORY = "CogVideoWrapper" CATEGORY = "CogVideoWrapper"
DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation" DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation"
def config(self, spatial_broadcast, pab_threshold_start, pab_threshold_end, pab_range, steps): def config(self, spatial_broadcast, spatial_threshold_start, spatial_threshold_end, spatial_range,
temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range, steps):
pab_config = CogVideoXPABConfig( pab_config = CogVideoXPABConfig(
steps=steps, steps=steps,
spatial_broadcast=spatial_broadcast, spatial_broadcast=spatial_broadcast,
spatial_threshold=[pab_threshold_end, pab_threshold_start], spatial_threshold=[spatial_threshold_end, spatial_threshold_start],
spatial_range=pab_range spatial_range=spatial_range,
temporal_broadcast=temporal_broadcast,
temporal_threshold=[temporal_threshold_end, temporal_threshold_start],
temporal_range=temporal_range
) )
return (pab_config, ) return (pab_config, )