diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index e75e52a..cf957d3 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -22,6 +22,10 @@ except: from .dit.joint_model.asymm_models_joint import AsymmDiTJoint +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +log = logging.getLogger(__name__) + MAX_T5_TOKEN_LENGTH = 256 def unnormalize_latents( @@ -100,7 +104,7 @@ class T2VSynthMochiModel: self.device = device self.offload_device = offload_device - print("Initializing model...") + logging.info("Initializing model...") with (init_empty_weights() if is_accelerate_available else nullcontext()): model = AsymmDiTJoint( depth=48, @@ -124,23 +128,24 @@ class T2VSynthMochiModel: ) params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} - print(f"Loading model state_dict from {dit_checkpoint_path}...") + logging.info(f"Loading model state_dict from {dit_checkpoint_path}...") dit_sd = load_torch_file(dit_checkpoint_path) if "gguf" in dit_checkpoint_path.lower(): + logging.info("Loading GGUF model state_dict...") from .. import mz_gguf_loader import importlib importlib.reload(mz_gguf_loader) with mz_gguf_loader.quantize_lazy_load(): model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu") elif is_accelerate_available: - print("Using accelerate to load and assign model weights to device...") + logging.info("Using accelerate to load and assign model weights to device...") for name, param in model.named_parameters(): if not any(keyword in name for keyword in params_to_keep): set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name]) else: set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name]) else: - print("Loading state_dict without accelerate...") + logging.info("Loading state_dict without accelerate...") model.load_state_dict(dit_sd) for name, param in model.named_parameters(): if not any(keyword in name for keyword in params_to_keep): @@ -349,5 +354,5 @@ class T2VSynthMochiModel: self.dit.to(self.offload_device) samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) - print("samples: ", samples.shape, samples.dtype, samples.device) + logging.info(f"samples shape: {samples.shape}") return samples diff --git a/mz_gguf_loader.py b/mz_gguf_loader.py index 085e9c4..262f7df 100644 --- a/mz_gguf_loader.py +++ b/mz_gguf_loader.py @@ -26,7 +26,7 @@ def quantize_load_state_dict(model, state_dict, device="cpu"): for name, module in model.named_modules(): if name in Q4_0_qkey: - print(name) + #print(name) q_linear = WQLinear_GGUF.from_linear( linear=module, device=device,