mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-01-23 20:24:25 +08:00
PAB for GGUF too
This commit is contained in:
parent
2e22d4fa0c
commit
d9abc00d3b
40
nodes.py
40
nodes.py
@ -93,9 +93,12 @@ class CogVideoXPABConfig(PABConfig):
|
||||
spatial_broadcast: bool = True,
|
||||
spatial_threshold: list = [100, 850],
|
||||
spatial_range: int = 2,
|
||||
temporal_broadcast: bool = True,
|
||||
temporal_broadcast: bool = False,
|
||||
temporal_threshold: list = [100, 850],
|
||||
temporal_range: int = 2,
|
||||
temporal_range: int = 4,
|
||||
cross_broadcast: bool = False,
|
||||
cross_threshold: list = [100, 850],
|
||||
cross_range: int = 6,
|
||||
):
|
||||
super().__init__(
|
||||
steps=steps,
|
||||
@ -104,7 +107,10 @@ class CogVideoXPABConfig(PABConfig):
|
||||
spatial_range=spatial_range,
|
||||
temporal_broadcast=temporal_broadcast,
|
||||
temporal_threshold=temporal_threshold,
|
||||
temporal_range=temporal_range
|
||||
temporal_range=temporal_range,
|
||||
cross_broadcast=cross_broadcast,
|
||||
cross_threshold=cross_threshold,
|
||||
cross_range=cross_range
|
||||
|
||||
)
|
||||
|
||||
@ -121,7 +127,12 @@ class CogVideoPABConfig:
|
||||
"temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB"}),
|
||||
"temporal_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
||||
"temporal_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
||||
"temporal_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
|
||||
"temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
|
||||
"cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB"}),
|
||||
"cross_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
|
||||
"cross_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
|
||||
"cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
|
||||
|
||||
"steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Steps"} ),
|
||||
}
|
||||
}
|
||||
@ -133,8 +144,10 @@ class CogVideoPABConfig:
|
||||
DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation"
|
||||
|
||||
def config(self, spatial_broadcast, spatial_threshold_start, spatial_threshold_end, spatial_range,
|
||||
temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range, steps):
|
||||
temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range,
|
||||
cross_broadcast, cross_threshold_start, cross_threshold_end, cross_range, steps):
|
||||
|
||||
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
pab_config = CogVideoXPABConfig(
|
||||
steps=steps,
|
||||
spatial_broadcast=spatial_broadcast,
|
||||
@ -142,7 +155,10 @@ class CogVideoPABConfig:
|
||||
spatial_range=spatial_range,
|
||||
temporal_broadcast=temporal_broadcast,
|
||||
temporal_threshold=[temporal_threshold_end, temporal_threshold_start],
|
||||
temporal_range=temporal_range
|
||||
temporal_range=temporal_range,
|
||||
cross_broadcast=cross_broadcast,
|
||||
cross_threshold=[cross_threshold_end, cross_threshold_start],
|
||||
cross_range=cross_range
|
||||
)
|
||||
|
||||
return (pab_config, )
|
||||
@ -311,6 +327,9 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
||||
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
|
||||
},
|
||||
"optional": {
|
||||
"pab_config": ("PAB_CONFIG", {"default": None}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("COGVIDEOPIPE",)
|
||||
@ -318,7 +337,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
|
||||
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload):
|
||||
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
mm.soft_empty_cache()
|
||||
@ -370,7 +389,10 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
||||
else:
|
||||
transformer_config["in_channels"] = 16
|
||||
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
||||
if pab_config is not None:
|
||||
transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config)
|
||||
else:
|
||||
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
|
||||
|
||||
if "2b" in model:
|
||||
for name, param in transformer.named_parameters():
|
||||
@ -420,7 +442,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
else:
|
||||
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device)
|
||||
vae.load_state_dict(vae_sd)
|
||||
pipe = CogVideoXPipeline(vae, transformer, scheduler)
|
||||
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
|
||||
|
||||
# compilation
|
||||
# if compile == "torch":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user