remove prints
This commit is contained in:
parent
f4c13b1ef4
commit
2c67025577
@ -22,6 +22,10 @@ except:
|
||||
|
||||
from .dit.joint_model.asymm_models_joint import AsymmDiTJoint
|
||||
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
MAX_T5_TOKEN_LENGTH = 256
|
||||
|
||||
def unnormalize_latents(
|
||||
@ -100,7 +104,7 @@ class T2VSynthMochiModel:
|
||||
self.device = device
|
||||
self.offload_device = offload_device
|
||||
|
||||
print("Initializing model...")
|
||||
logging.info("Initializing model...")
|
||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||
model = AsymmDiTJoint(
|
||||
depth=48,
|
||||
@ -124,23 +128,24 @@ class T2VSynthMochiModel:
|
||||
)
|
||||
|
||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||
print(f"Loading model state_dict from {dit_checkpoint_path}...")
|
||||
logging.info(f"Loading model state_dict from {dit_checkpoint_path}...")
|
||||
dit_sd = load_torch_file(dit_checkpoint_path)
|
||||
if "gguf" in dit_checkpoint_path.lower():
|
||||
logging.info("Loading GGUF model state_dict...")
|
||||
from .. import mz_gguf_loader
|
||||
import importlib
|
||||
importlib.reload(mz_gguf_loader)
|
||||
with mz_gguf_loader.quantize_lazy_load():
|
||||
model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu")
|
||||
elif is_accelerate_available:
|
||||
print("Using accelerate to load and assign model weights to device...")
|
||||
logging.info("Using accelerate to load and assign model weights to device...")
|
||||
for name, param in model.named_parameters():
|
||||
if not any(keyword in name for keyword in params_to_keep):
|
||||
set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name])
|
||||
else:
|
||||
set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name])
|
||||
else:
|
||||
print("Loading state_dict without accelerate...")
|
||||
logging.info("Loading state_dict without accelerate...")
|
||||
model.load_state_dict(dit_sd)
|
||||
for name, param in model.named_parameters():
|
||||
if not any(keyword in name for keyword in params_to_keep):
|
||||
@ -349,5 +354,5 @@ class T2VSynthMochiModel:
|
||||
self.dit.to(self.offload_device)
|
||||
|
||||
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
||||
print("samples: ", samples.shape, samples.dtype, samples.device)
|
||||
logging.info(f"samples shape: {samples.shape}")
|
||||
return samples
|
||||
|
||||
@ -26,7 +26,7 @@ def quantize_load_state_dict(model, state_dict, device="cpu"):
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if name in Q4_0_qkey:
|
||||
print(name)
|
||||
#print(name)
|
||||
q_linear = WQLinear_GGUF.from_linear(
|
||||
linear=module,
|
||||
device=device,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user