This commit is contained in:
kijai 2024-10-24 02:59:52 +03:00
parent 00a550e81c
commit c673508188
2 changed files with 7 additions and 4 deletions

View File

@ -39,7 +39,10 @@ def fp8_linear_forward(cls, original_dtype, input):
def convert_fp8_linear(module, original_dtype):
setattr(module, "fp8_matmul_enabled", True)
for name, module in module.named_modules():
if isinstance(module, nn.Linear):
original_forward = module.forward
setattr(module, "original_forward", original_forward)
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
if "blocks" in name:
#print(module, name)
original_forward = module.forward
setattr(module, "original_forward", original_forward)
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))

View File

@ -127,7 +127,7 @@ class T2VSynthMochiModel:
print("Initializing model...")
with (init_empty_weights() if is_accelerate_available else nullcontext()):
model: nn.Module = AsymmDiTJoint(
model = AsymmDiTJoint(
depth=48,
patch_size=2,
num_heads=24,