Update t2v_synth_mochi.py
This commit is contained in:
parent
1cd5409295
commit
508eaa22df
@ -23,6 +23,8 @@ except:
|
||||
is_accelerate_available = False
|
||||
pass
|
||||
|
||||
from .dit.joint_model.asymm_models_joint import AsymmDiTJoint
|
||||
|
||||
MAX_T5_TOKEN_LENGTH = 256
|
||||
|
||||
def unnormalize_latents(
|
||||
@ -125,9 +127,7 @@ class T2VSynthMochiModel:
|
||||
self.offload_device = offload_device
|
||||
|
||||
with t("construct_dit"):
|
||||
from .dit.joint_model.asymm_models_joint import (
|
||||
AsymmDiTJoint,
|
||||
)
|
||||
|
||||
model: nn.Module = torch.nn.utils.skip_init(
|
||||
AsymmDiTJoint,
|
||||
depth=48,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user