remove prints
This commit is contained in:
parent
f4c13b1ef4
commit
2c67025577
@ -22,6 +22,10 @@ except:
|
|||||||
|
|
||||||
from .dit.joint_model.asymm_models_joint import AsymmDiTJoint
|
from .dit.joint_model.asymm_models_joint import AsymmDiTJoint
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
MAX_T5_TOKEN_LENGTH = 256
|
MAX_T5_TOKEN_LENGTH = 256
|
||||||
|
|
||||||
def unnormalize_latents(
|
def unnormalize_latents(
|
||||||
@ -100,7 +104,7 @@ class T2VSynthMochiModel:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.offload_device = offload_device
|
self.offload_device = offload_device
|
||||||
|
|
||||||
print("Initializing model...")
|
logging.info("Initializing model...")
|
||||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||||
model = AsymmDiTJoint(
|
model = AsymmDiTJoint(
|
||||||
depth=48,
|
depth=48,
|
||||||
@ -124,23 +128,24 @@ class T2VSynthMochiModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||||
print(f"Loading model state_dict from {dit_checkpoint_path}...")
|
logging.info(f"Loading model state_dict from {dit_checkpoint_path}...")
|
||||||
dit_sd = load_torch_file(dit_checkpoint_path)
|
dit_sd = load_torch_file(dit_checkpoint_path)
|
||||||
if "gguf" in dit_checkpoint_path.lower():
|
if "gguf" in dit_checkpoint_path.lower():
|
||||||
|
logging.info("Loading GGUF model state_dict...")
|
||||||
from .. import mz_gguf_loader
|
from .. import mz_gguf_loader
|
||||||
import importlib
|
import importlib
|
||||||
importlib.reload(mz_gguf_loader)
|
importlib.reload(mz_gguf_loader)
|
||||||
with mz_gguf_loader.quantize_lazy_load():
|
with mz_gguf_loader.quantize_lazy_load():
|
||||||
model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu")
|
model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu")
|
||||||
elif is_accelerate_available:
|
elif is_accelerate_available:
|
||||||
print("Using accelerate to load and assign model weights to device...")
|
logging.info("Using accelerate to load and assign model weights to device...")
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if not any(keyword in name for keyword in params_to_keep):
|
if not any(keyword in name for keyword in params_to_keep):
|
||||||
set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name])
|
set_module_tensor_to_device(model, name, dtype=weight_dtype, device=self.device, value=dit_sd[name])
|
||||||
else:
|
else:
|
||||||
set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name])
|
set_module_tensor_to_device(model, name, dtype=torch.bfloat16, device=self.device, value=dit_sd[name])
|
||||||
else:
|
else:
|
||||||
print("Loading state_dict without accelerate...")
|
logging.info("Loading state_dict without accelerate...")
|
||||||
model.load_state_dict(dit_sd)
|
model.load_state_dict(dit_sd)
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if not any(keyword in name for keyword in params_to_keep):
|
if not any(keyword in name for keyword in params_to_keep):
|
||||||
@ -349,5 +354,5 @@ class T2VSynthMochiModel:
|
|||||||
self.dit.to(self.offload_device)
|
self.dit.to(self.offload_device)
|
||||||
|
|
||||||
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std)
|
||||||
print("samples: ", samples.shape, samples.dtype, samples.device)
|
logging.info(f"samples shape: {samples.shape}")
|
||||||
return samples
|
return samples
|
||||||
|
|||||||
@ -26,7 +26,7 @@ def quantize_load_state_dict(model, state_dict, device="cpu"):
|
|||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if name in Q4_0_qkey:
|
if name in Q4_0_qkey:
|
||||||
print(name)
|
#print(name)
|
||||||
q_linear = WQLinear_GGUF.from_linear(
|
q_linear = WQLinear_GGUF.from_linear(
|
||||||
linear=module,
|
linear=module,
|
||||||
device=device,
|
device=device,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user