From 2eb9b81d277c07738a7bf147176381f8cffcb5e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sat, 9 Nov 2024 04:41:07 +0200 Subject: [PATCH] fp8 --- fp8_optimization.py | 14 +++++++++----- model_loading.py | 7 ++++--- nodes.py | 3 +-- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/fp8_optimization.py b/fp8_optimization.py index b01ac91..05b0146 100644 --- a/fp8_optimization.py +++ b/fp8_optimization.py @@ -36,10 +36,14 @@ def fp8_linear_forward(cls, original_dtype, input): else: return cls.original_forward(input) -def convert_fp8_linear(module, original_dtype): +def convert_fp8_linear(module, original_dtype, params_to_keep={}): setattr(module, "fp8_matmul_enabled", True) + + for name, module in module.named_modules(): - if isinstance(module, nn.Linear): - original_forward = module.forward - setattr(module, "original_forward", original_forward) - setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) + if not any(keyword in name for keyword in params_to_keep): + if isinstance(module, nn.Linear): + print(name) + original_forward = module.forward + setattr(module, "original_forward", original_forward) + setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) diff --git a/model_loading.py b/model_loading.py index 6df4c55..7adc9d6 100644 --- a/model_loading.py +++ b/model_loading.py @@ -153,7 +153,6 @@ class DownloadAndLoadCogVideoModel: base_path = os.path.join(download_path, (model.split("/")[-1])) download_path = base_path repo_id = model - if "2b" in model: scheduler_path = os.path.join(script_directory, 'configs', 'scheduler_config_2b.json') @@ -193,13 +192,15 @@ class DownloadAndLoadCogVideoModel: #fp8 if fp8_transformer == "enabled" or fp8_transformer == "fastmode": for name, param in transformer.named_parameters(): - params_to_keep = {"patch_embed", "lora", "pos_embedding"} + params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding"} if not any(keyword in name for keyword in params_to_keep): param.data = param.data.to(torch.float8_e4m3fn) if fp8_transformer == "fastmode": from .fp8_optimization import convert_fp8_linear - convert_fp8_linear(transformer, dtype) + if "1.5" in model: + params_to_keep = {"norm","ff"} + convert_fp8_linear(transformer, dtype, params_to_keep=params_to_keep) with open(scheduler_path) as f: scheduler_config = json.load(f) diff --git a/nodes.py b/nodes.py index fe4d367..fa5e3ad 100644 --- a/nodes.py +++ b/nodes.py @@ -861,8 +861,7 @@ class CogVideoSampler: pipe.transformer.fastercache_counter = 0 autocastcondition = not pipeline["onediff"] or not dtype == torch.float32 - autocastcondition = False ##todo - autocast_context = torch.autocast(mm.get_autocast_device(device)) if autocastcondition else nullcontext() + autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext() with autocast_context: latents = pipeline["pipe"]( num_inference_steps=steps,