diff --git a/nodes.py b/nodes.py index 2d81d76..bd4c79c 100644 --- a/nodes.py +++ b/nodes.py @@ -93,9 +93,12 @@ class CogVideoXPABConfig(PABConfig): spatial_broadcast: bool = True, spatial_threshold: list = [100, 850], spatial_range: int = 2, - temporal_broadcast: bool = True, + temporal_broadcast: bool = False, temporal_threshold: list = [100, 850], - temporal_range: int = 2, + temporal_range: int = 4, + cross_broadcast: bool = False, + cross_threshold: list = [100, 850], + cross_range: int = 6, ): super().__init__( steps=steps, @@ -104,7 +107,10 @@ class CogVideoXPABConfig(PABConfig): spatial_range=spatial_range, temporal_broadcast=temporal_broadcast, temporal_threshold=temporal_threshold, - temporal_range=temporal_range + temporal_range=temporal_range, + cross_broadcast=cross_broadcast, + cross_threshold=cross_threshold, + cross_range=cross_range ) @@ -121,7 +127,12 @@ class CogVideoPABConfig: "temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB"}), "temporal_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ), "temporal_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ), - "temporal_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ), + "temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ), + "cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB"}), + "cross_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ), + "cross_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ), + "cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ), + "steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Steps"} ), } } @@ -133,8 +144,10 @@ class CogVideoPABConfig: DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation" def config(self, spatial_broadcast, spatial_threshold_start, spatial_threshold_end, spatial_range, - temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range, steps): + temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range, + cross_broadcast, cross_threshold_start, cross_threshold_end, cross_range, steps): + #os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" pab_config = CogVideoXPABConfig( steps=steps, spatial_broadcast=spatial_broadcast, @@ -142,7 +155,10 @@ class CogVideoPABConfig: spatial_range=spatial_range, temporal_broadcast=temporal_broadcast, temporal_threshold=[temporal_threshold_end, temporal_threshold_start], - temporal_range=temporal_range + temporal_range=temporal_range, + cross_broadcast=cross_broadcast, + cross_threshold=[cross_threshold_end, cross_threshold_start], + cross_range=cross_range ) return (pab_config, ) @@ -311,6 +327,9 @@ class DownloadAndLoadCogVideoGGUFModel: "load_device": (["main_device", "offload_device"], {"default": "main_device"}), "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}), }, + "optional": { + "pab_config": ("PAB_CONFIG", {"default": None}), + } } RETURN_TYPES = ("COGVIDEOPIPE",) @@ -318,7 +337,7 @@ class DownloadAndLoadCogVideoGGUFModel: FUNCTION = "loadmodel" CATEGORY = "CogVideoWrapper" - def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload): + def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None): device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.soft_empty_cache() @@ -370,7 +389,10 @@ class DownloadAndLoadCogVideoGGUFModel: transformer = CogVideoXTransformer3DModel.from_config(transformer_config) else: transformer_config["in_channels"] = 16 - transformer = CogVideoXTransformer3DModel.from_config(transformer_config) + if pab_config is not None: + transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config) + else: + transformer = CogVideoXTransformer3DModel.from_config(transformer_config) if "2b" in model: for name, param in transformer.named_parameters(): @@ -420,7 +442,7 @@ class DownloadAndLoadCogVideoGGUFModel: else: vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device) vae.load_state_dict(vae_sd) - pipe = CogVideoXPipeline(vae, transformer, scheduler) + pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config) # compilation # if compile == "torch":