Update nodes.py

This commit is contained in:
kijai 2024-09-22 21:34:17 +03:00
parent 81c5447ca3
commit 2e22d4fa0c

View File

@ -93,12 +93,19 @@ class CogVideoXPABConfig(PABConfig):
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 850],
spatial_range: int = 2,
temporal_broadcast: bool = True,
temporal_threshold: list = [100, 850],
temporal_range: int = 2,
):
super().__init__(
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=spatial_threshold,
spatial_range=spatial_range,
temporal_broadcast=temporal_broadcast,
temporal_threshold=temporal_threshold,
temporal_range=temporal_range
)
from .videosys.cogvideox_transformer_3d import CogVideoXTransformer3DModel as CogVideoXTransformer3DModelPAB
@ -108,9 +115,13 @@ class CogVideoPABConfig:
def INPUT_TYPES(s):
return {"required": {
"spatial_broadcast": ("BOOLEAN", {"default": True, "tooltip": "Enable Spatial PAB"}),
"pab_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
"pab_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
"pab_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
"spatial_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
"spatial_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
"spatial_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
"temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB"}),
"temporal_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
"temporal_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
"temporal_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
"steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Steps"} ),
}
}
@ -121,13 +132,17 @@ class CogVideoPABConfig:
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation"
def config(self, spatial_broadcast, pab_threshold_start, pab_threshold_end, pab_range, steps):
def config(self, spatial_broadcast, spatial_threshold_start, spatial_threshold_end, spatial_range,
temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range, steps):
pab_config = CogVideoXPABConfig(
steps=steps,
spatial_broadcast=spatial_broadcast,
spatial_threshold=[pab_threshold_end, pab_threshold_start],
spatial_range=pab_range
spatial_threshold=[spatial_threshold_end, spatial_threshold_start],
spatial_range=spatial_range,
temporal_broadcast=temporal_broadcast,
temporal_threshold=[temporal_threshold_end, temporal_threshold_start],
temporal_range=temporal_range
)
return (pab_config, )