From 05413b6cfe606c236e529d84aa4405e73a768267 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Tue, 3 Sep 2024 20:52:53 +0300 Subject: [PATCH] fix 2b fp8, for real this time --- nodes.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index 227545e..9053168 100644 --- a/nodes.py +++ b/nodes.py @@ -65,7 +65,14 @@ class DownloadAndLoadCogVideoModel: local_dir_use_symlinks=False, ) if fp8_transformer == "enabled" or fp8_transformer == "fastmode": - transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(torch.float8_e4m3fn).to(offload_device) + transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(offload_device) + if "2b" in model: + for name, param in transformer.named_parameters(): + if name != "pos_embedding": + param.data = param.data.to(torch.float8_e4m3fn) + else: + transformer.to(torch.float8_e4m3fn) + if fp8_transformer == "fastmode": from .fp8_optimization import convert_fp8_linear convert_fp8_linear(transformer, dtype) @@ -84,12 +91,13 @@ class DownloadAndLoadCogVideoModel: pipe.transformer.to(memory_format=torch.channels_last) pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) elif compile == "onediff": - from onediffx import compile_pipe - options = None + from onediffx import compile_pipe, quantize_pipe + os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1' + pipe = compile_pipe( pipe, backend="nexfort", - options=options, + options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}}, ignores=["vae"], fuse_qkv_projections=True, )