Update t2v_synth_mochi.py
This commit is contained in:
parent
1cd5409295
commit
508eaa22df
@ -23,6 +23,8 @@ except:
|
|||||||
is_accelerate_available = False
|
is_accelerate_available = False
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
from .dit.joint_model.asymm_models_joint import AsymmDiTJoint
|
||||||
|
|
||||||
MAX_T5_TOKEN_LENGTH = 256
|
MAX_T5_TOKEN_LENGTH = 256
|
||||||
|
|
||||||
def unnormalize_latents(
|
def unnormalize_latents(
|
||||||
@ -125,9 +127,7 @@ class T2VSynthMochiModel:
|
|||||||
self.offload_device = offload_device
|
self.offload_device = offload_device
|
||||||
|
|
||||||
with t("construct_dit"):
|
with t("construct_dit"):
|
||||||
from .dit.joint_model.asymm_models_joint import (
|
|
||||||
AsymmDiTJoint,
|
|
||||||
)
|
|
||||||
model: nn.Module = torch.nn.utils.skip_init(
|
model: nn.Module = torch.nn.utils.skip_init(
|
||||||
AsymmDiTJoint,
|
AsymmDiTJoint,
|
||||||
depth=48,
|
depth=48,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user