PAB for GGUF too

This commit is contained in:
kijai 2024-09-23 00:08:37 +03:00
parent 2e22d4fa0c
commit d9abc00d3b

View File

@ -93,9 +93,12 @@ class CogVideoXPABConfig(PABConfig):
spatial_broadcast: bool = True,
spatial_threshold: list = [100, 850],
spatial_range: int = 2,
temporal_broadcast: bool = True,
temporal_broadcast: bool = False,
temporal_threshold: list = [100, 850],
temporal_range: int = 2,
temporal_range: int = 4,
cross_broadcast: bool = False,
cross_threshold: list = [100, 850],
cross_range: int = 6,
):
super().__init__(
steps=steps,
@ -104,7 +107,10 @@ class CogVideoXPABConfig(PABConfig):
spatial_range=spatial_range,
temporal_broadcast=temporal_broadcast,
temporal_threshold=temporal_threshold,
temporal_range=temporal_range
temporal_range=temporal_range,
cross_broadcast=cross_broadcast,
cross_threshold=cross_threshold,
cross_range=cross_range
)
@ -121,7 +127,12 @@ class CogVideoPABConfig:
"temporal_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Temporal PAB"}),
"temporal_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
"temporal_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
"temporal_range": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
"temporal_range": ("INT", {"default": 4, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
"cross_broadcast": ("BOOLEAN", {"default": False, "tooltip": "Enable Cross Attention PAB"}),
"cross_threshold_start": ("INT", {"default": 850, "min": 0, "max": 1000, "tooltip": "PAB Start Timestep"} ),
"cross_threshold_end": ("INT", {"default": 100, "min": 0, "max": 1000, "tooltip": "PAB End Timestep"} ),
"cross_range": ("INT", {"default": 6, "min": 0, "max": 10, "tooltip": "Broadcast timesteps range"} ),
"steps": ("INT", {"default": 50, "min": 0, "max": 1000, "tooltip": "Steps"} ),
}
}
@ -133,8 +144,10 @@ class CogVideoPABConfig:
DESCRIPTION = "EXPERIMENTAL:Pyramid Attention Broadcast (PAB) speeds up inference by mitigating redundant attention computation"
def config(self, spatial_broadcast, spatial_threshold_start, spatial_threshold_end, spatial_range,
temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range, steps):
temporal_broadcast, temporal_threshold_start, temporal_threshold_end, temporal_range,
cross_broadcast, cross_threshold_start, cross_threshold_end, cross_range, steps):
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
pab_config = CogVideoXPABConfig(
steps=steps,
spatial_broadcast=spatial_broadcast,
@ -142,7 +155,10 @@ class CogVideoPABConfig:
spatial_range=spatial_range,
temporal_broadcast=temporal_broadcast,
temporal_threshold=[temporal_threshold_end, temporal_threshold_start],
temporal_range=temporal_range
temporal_range=temporal_range,
cross_broadcast=cross_broadcast,
cross_threshold=[cross_threshold_end, cross_threshold_start],
cross_range=cross_range
)
return (pab_config, )
@ -311,6 +327,9 @@ class DownloadAndLoadCogVideoGGUFModel:
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
"enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}),
},
"optional": {
"pab_config": ("PAB_CONFIG", {"default": None}),
}
}
RETURN_TYPES = ("COGVIDEOPIPE",)
@ -318,7 +337,7 @@ class DownloadAndLoadCogVideoGGUFModel:
FUNCTION = "loadmodel"
CATEGORY = "CogVideoWrapper"
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload):
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, pab_config=None):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
@ -370,7 +389,10 @@ class DownloadAndLoadCogVideoGGUFModel:
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
else:
transformer_config["in_channels"] = 16
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
if pab_config is not None:
transformer = CogVideoXTransformer3DModelPAB.from_config(transformer_config)
else:
transformer = CogVideoXTransformer3DModel.from_config(transformer_config)
if "2b" in model:
for name, param in transformer.named_parameters():
@ -420,7 +442,7 @@ class DownloadAndLoadCogVideoGGUFModel:
else:
vae = AutoencoderKLCogVideoX.from_config(vae_config).to(vae_dtype).to(offload_device)
vae.load_state_dict(vae_sd)
pipe = CogVideoXPipeline(vae, transformer, scheduler)
pipe = CogVideoXPipeline(vae, transformer, scheduler, pab_config=pab_config)
# compilation
# if compile == "torch":