fix 2b fp8, for real this time

This commit is contained in:
Jukka Seppänen 2024-09-03 20:52:53 +03:00
parent 3d536025e0
commit 05413b6cfe

View File

@ -65,7 +65,14 @@ class DownloadAndLoadCogVideoModel:
local_dir_use_symlinks=False,
)
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(torch.float8_e4m3fn).to(offload_device)
transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder="transformer").to(offload_device)
if "2b" in model:
for name, param in transformer.named_parameters():
if name != "pos_embedding":
param.data = param.data.to(torch.float8_e4m3fn)
else:
transformer.to(torch.float8_e4m3fn)
if fp8_transformer == "fastmode":
from .fp8_optimization import convert_fp8_linear
convert_fp8_linear(transformer, dtype)
@ -84,12 +91,13 @@ class DownloadAndLoadCogVideoModel:
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
elif compile == "onediff":
from onediffx import compile_pipe
options = None
from onediffx import compile_pipe, quantize_pipe
os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1'
pipe = compile_pipe(
pipe,
backend="nexfort",
options=options,
options= {"mode": "max-optimize:max-autotune:max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": False, "triton.fuse_attention_allow_fp16_reduction": False}},
ignores=["vae"],
fuse_qkv_projections=True,
)