tweaks
This commit is contained in:
parent
00a550e81c
commit
c673508188
@ -39,7 +39,10 @@ def fp8_linear_forward(cls, original_dtype, input):
|
||||
def convert_fp8_linear(module, original_dtype):
|
||||
setattr(module, "fp8_matmul_enabled", True)
|
||||
for name, module in module.named_modules():
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
original_forward = module.forward
|
||||
setattr(module, "original_forward", original_forward)
|
||||
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
|
||||
if "blocks" in name:
|
||||
#print(module, name)
|
||||
original_forward = module.forward
|
||||
setattr(module, "original_forward", original_forward)
|
||||
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
|
||||
|
||||
@ -127,7 +127,7 @@ class T2VSynthMochiModel:
|
||||
|
||||
print("Initializing model...")
|
||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||
model: nn.Module = AsymmDiTJoint(
|
||||
model = AsymmDiTJoint(
|
||||
depth=48,
|
||||
patch_size=2,
|
||||
num_heads=24,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user