tweaks
This commit is contained in:
parent
00a550e81c
commit
c673508188
@ -39,7 +39,10 @@ def fp8_linear_forward(cls, original_dtype, input):
|
|||||||
def convert_fp8_linear(module, original_dtype):
|
def convert_fp8_linear(module, original_dtype):
|
||||||
setattr(module, "fp8_matmul_enabled", True)
|
setattr(module, "fp8_matmul_enabled", True)
|
||||||
for name, module in module.named_modules():
|
for name, module in module.named_modules():
|
||||||
|
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
original_forward = module.forward
|
if "blocks" in name:
|
||||||
setattr(module, "original_forward", original_forward)
|
#print(module, name)
|
||||||
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
|
original_forward = module.forward
|
||||||
|
setattr(module, "original_forward", original_forward)
|
||||||
|
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
|
||||||
|
|||||||
@ -127,7 +127,7 @@ class T2VSynthMochiModel:
|
|||||||
|
|
||||||
print("Initializing model...")
|
print("Initializing model...")
|
||||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||||
model: nn.Module = AsymmDiTJoint(
|
model = AsymmDiTJoint(
|
||||||
depth=48,
|
depth=48,
|
||||||
patch_size=2,
|
patch_size=2,
|
||||||
num_heads=24,
|
num_heads=24,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user