remove prints

This commit is contained in:
kijai 2024-10-24 17:16:26 +03:00
parent f4c13b1ef4
commit 2c67025577
2 changed files with 11 additions and 6 deletions

View File

@ -22,6 +22,10 @@ except:
from .dit.joint_model.asymm_models_joint import AsymmDiTJoint
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
MAX_T5_TOKEN_LENGTH = 256
def unnormalize_latents(
@ -100,7 +104,7 @@ class T2VSynthMochiModel:
self.device = device
self.offload_device = offload_device
print("Initializing model...")
logging.info("Initializing model...")
with (init_empty_weights() if is_accelerate_available else nullcontext()):
model = AsymmDiTJoint(
depth=48,
@ -124,23 +128,24 @@ class T2VSynthMochiModel:
)
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
print(f"Loading model state_dict from {dit_checkpoint_path}...")
logging.info(f"Loading model state_dict from {dit_checkpoint_path}...")
dit_sd = load_torch_file(dit_checkpoint_path)
if "gguf" in dit_checkpoint_path.lower():
logging.info("Loading GGUF model state_dict...")
from .. import mz_gguf_loader
import importlib
importlib.reload(mz_gguf_loader)
with mz_gguf_loader.quantize_lazy_load():
model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu")
elif is_accelerate_available:
print("Using accelerate to load and assign model weights to device...")
logging.info("Using accelerate to load and assign model weights to device...")
for name, param in model.named_parameters():
if not any(keyword in name for keyword in params_to_keep):
set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name])
else:
set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name])
else:
print("Loading state_dict without accelerate...")
logging.info("Loading state_dict without accelerate...")
model.load_state_dict(dit_sd)
for name, param in model.named_parameters():
if not any(keyword in name for keyword in params_to_keep):
@ -349,5 +354,5 @@ class T2VSynthMochiModel:
self.dit.to(self.offload_device)
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
print("samples: ", samples.shape, samples.dtype, samples.device)
logging.info(f"samples shape: {samples.shape}")
return samples

View File

@ -26,7 +26,7 @@ def quantize_load_state_dict(model, state_dict, device="cpu"):
for name, module in model.named_modules():
if name in Q4_0_qkey:
print(name)
#print(name)
q_linear = WQLinear_GGUF.from_linear(
linear=module,
device=device,