This commit is contained in:
kijai 2024-10-24 02:59:52 +03:00
parent 00a550e81c
commit c673508188
2 changed files with 7 additions and 4 deletions

View File

@ -39,7 +39,10 @@ def fp8_linear_forward(cls, original_dtype, input):
def convert_fp8_linear(module, original_dtype): def convert_fp8_linear(module, original_dtype):
setattr(module, "fp8_matmul_enabled", True) setattr(module, "fp8_matmul_enabled", True)
for name, module in module.named_modules(): for name, module in module.named_modules():
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
original_forward = module.forward if "blocks" in name:
setattr(module, "original_forward", original_forward) #print(module, name)
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) original_forward = module.forward
setattr(module, "original_forward", original_forward)
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))

View File

@ -127,7 +127,7 @@ class T2VSynthMochiModel:
print("Initializing model...") print("Initializing model...")
with (init_empty_weights() if is_accelerate_available else nullcontext()): with (init_empty_weights() if is_accelerate_available else nullcontext()):
model: nn.Module = AsymmDiTJoint( model = AsymmDiTJoint(
depth=48, depth=48,
patch_size=2, patch_size=2,
num_heads=24, num_heads=24,