mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-10 05:15:05 +08:00
update CheckpointLoaderKJ
This commit is contained in:
parent
b7e5b6f1e2
commit
bfb6d973fe
@ -156,8 +156,11 @@ class CheckpointLoaderKJ(BaseLoaderKJ):
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],),
|
||||
"compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "default", "tooltip": "The compute dtype to use for the model."}),
|
||||
"patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
||||
"sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
||||
"enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
@ -167,14 +170,128 @@ class CheckpointLoaderKJ(BaseLoaderKJ):
|
||||
EXPERIMENTAL = True
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def patch(self, ckpt_name, patch_cublaslinear, sage_attention):
|
||||
from nodes import CheckpointLoaderSimple
|
||||
model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name)
|
||||
def patch(self, ckpt_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation):
|
||||
DTYPE_MAP = {
|
||||
"fp8_e4m3fn": torch.float8_e4m3fn,
|
||||
"fp8_e5m2": torch.float8_e5m2,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
"fp32": torch.float32
|
||||
}
|
||||
model_options = {}
|
||||
if dtype := DTYPE_MAP.get(weight_dtype):
|
||||
model_options["dtype"] = dtype
|
||||
print(f"Setting {ckpt_name} weight dtype to {dtype}")
|
||||
|
||||
if weight_dtype == "fp8_e4m3fn_fast":
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
model_options["fp8_optimizations"] = True
|
||||
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
|
||||
model, clip, vae = self.load_state_dict_guess_config(
|
||||
sd,
|
||||
output_vae=True,
|
||||
output_clip=True,
|
||||
embedding_directory=folder_paths.get_folder_paths("embeddings"),
|
||||
metadata=metadata,
|
||||
model_options=model_options)
|
||||
|
||||
if dtype := DTYPE_MAP.get(compute_dtype):
|
||||
model.set_model_compute_dtype(dtype)
|
||||
model.force_cast_weights = False
|
||||
print(f"Setting {ckpt_name} compute dtype to {dtype}")
|
||||
|
||||
if enable_fp16_accumulation:
|
||||
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||
else:
|
||||
raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.0 nightly currently")
|
||||
else:
|
||||
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = False
|
||||
|
||||
def patch_attention(model):
|
||||
self._patch_modules(patch_cublaslinear, sage_attention)
|
||||
model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
|
||||
return model, clip, vae
|
||||
|
||||
def load_state_dict_guess_config(self, sd, output_vae=True, output_clip=True, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
||||
from comfy.sd import load_diffusion_model_state_dict, model_detection, VAE, CLIP
|
||||
clip = None
|
||||
vae = None
|
||||
model = None
|
||||
model_patcher = None
|
||||
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||
load_device = mm.get_torch_device()
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||
if model_config is None:
|
||||
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
|
||||
if diffusion_model is None:
|
||||
return None
|
||||
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.scaled_fp8 is not None:
|
||||
weight_dtype = None
|
||||
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
||||
|
||||
if unet_dtype is None:
|
||||
unet_dtype = mm.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
||||
|
||||
manual_cast_dtype = mm.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
if output_model:
|
||||
inital_load_device = mm.unet_inital_load_device(parameters, unet_dtype)
|
||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||
model.load_model_weights(sd, diffusion_model_prefix)
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||
|
||||
if output_clip:
|
||||
clip_target = model_config.clip_target(state_dict=sd)
|
||||
if clip_target is not None:
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
if len(m_filter) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
else:
|
||||
logging.debug("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected {}:".format(u))
|
||||
else:
|
||||
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
logging.debug("left over keys: {}".format(left_over))
|
||||
|
||||
if output_model:
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=mm.unet_offload_device())
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
logging.info("loaded diffusion model directly to GPU")
|
||||
mm.load_models_gpu([model_patcher], force_full_load=True)
|
||||
|
||||
return (model_patcher, clip, vae)
|
||||
|
||||
class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user