diff --git a/fp8_optimization.py b/fp8_optimization.py index b01ac91..9d6d712 100644 --- a/fp8_optimization.py +++ b/fp8_optimization.py @@ -39,7 +39,10 @@ def fp8_linear_forward(cls, original_dtype, input): def convert_fp8_linear(module, original_dtype): setattr(module, "fp8_matmul_enabled", True) for name, module in module.named_modules(): + if isinstance(module, nn.Linear): - original_forward = module.forward - setattr(module, "original_forward", original_forward) - setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) + if "blocks" in name: + #print(module, name) + original_forward = module.forward + setattr(module, "original_forward", original_forward) + setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 3561859..45805a7 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -127,7 +127,7 @@ class T2VSynthMochiModel: print("Initializing model...") with (init_empty_weights() if is_accelerate_available else nullcontext()): - model: nn.Module = AsymmDiTJoint( + model = AsymmDiTJoint( depth=48, patch_size=2, num_heads=24,