update CheckpointLoaderKJ

This commit is contained in:
kijai 2025-05-07 17:30:38 +03:00
parent b7e5b6f1e2
commit bfb6d973fe

View File

@ -156,8 +156,11 @@ class CheckpointLoaderKJ(BaseLoaderKJ):
def INPUT_TYPES(s):
return {"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],),
"compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "default", "tooltip": "The compute dtype to use for the model."}),
"patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
"sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
"enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}),
}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
@ -167,13 +170,127 @@ class CheckpointLoaderKJ(BaseLoaderKJ):
EXPERIMENTAL = True
CATEGORY = "KJNodes/experimental"
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
from nodes import CheckpointLoaderSimple
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
def patch(self, ckpt_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation):
DTYPE_MAP = {
"fp8_e4m3fn": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32
}
model_options = {}
if dtype := DTYPE_MAP.get(weight_dtype):
model_options["dtype"] = dtype
print(f"Setting {ckpt_name} weight dtype to {dtype}")
if weight_dtype == "fp8_e4m3fn_fast":
model_options["dtype"] = torch.float8_e4m3fn
model_options["fp8_optimizations"] = True
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
model, clip, vae = self.load_state_dict_guess_config(
sd,
output_vae=True,
output_clip=True,
embedding_directory=folder_paths.get_folder_paths("embeddings"),
metadata=metadata,
model_options=model_options)
if dtype := DTYPE_MAP.get(compute_dtype):
model.set_model_compute_dtype(dtype)
model.force_cast_weights = False
print(f"Setting {ckpt_name} compute dtype to {dtype}")
if enable_fp16_accumulation:
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
torch.backends.cuda.matmul.allow_fp16_accumulation = True
else:
raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.0 nightly currently")
else:
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
torch.backends.cuda.matmul.allow_fp16_accumulation = False
def patch_attention(model):
self._patch_modules(patch_cublaslinear, sage_attention)
model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
return model, clip, vae
def load_state_dict_guess_config(self, sd, output_vae=True, output_clip=True, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
from comfy.sd import load_diffusion_model_state_dict, model_detection, VAE, CLIP
clip = None
vae = None
model = None
model_patcher = None
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
load_device = mm.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None:
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
if diffusion_model is None:
return None
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None:
weight_dtype = None
model_config.custom_operations = model_options.get("custom_operations", None)
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
if unet_dtype is None:
unet_dtype = mm.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
manual_cast_dtype = mm.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if output_model:
inital_load_device = mm.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model.load_model_weights(sd, diffusion_model_prefix)
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd, metadata=metadata)
if output_clip:
clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
parameters = comfy.utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
if len(m_filter) > 0:
logging.warning("clip missing: {}".format(m))
else:
logging.debug("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected {}:".format(u))
else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
left_over = sd.keys()
if len(left_over) > 0:
logging.debug("left over keys: {}".format(left_over))
if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=mm.unet_offload_device())
if inital_load_device != torch.device("cpu"):
logging.info("loaded diffusion model directly to GPU")
mm.load_models_gpu([model_patcher], force_full_load=True)
return (model_patcher, clip, vae)
class DiffusionModelLoaderKJ(BaseLoaderKJ):
@classmethod