diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index b8c3aad..93d4caf 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -156,8 +156,11 @@ class CheckpointLoaderKJ(BaseLoaderKJ): def INPUT_TYPES(s): return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}), + "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],), + "compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "default", "tooltip": "The compute dtype to use for the model."}), "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), "sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), + "enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") @@ -167,13 +170,127 @@ class CheckpointLoaderKJ(BaseLoaderKJ): EXPERIMENTAL = True CATEGORY = "KJNodes/experimental" - def patch(self, ckpt_name, patch_cublaslinear, sage_attention): - from nodes import CheckpointLoaderSimple - model, clip, vae = CheckpointLoaderSimple.load_checkpoint(self, ckpt_name) + def patch(self, ckpt_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation): + DTYPE_MAP = { + "fp8_e4m3fn": torch.float8_e4m3fn, + "fp8_e5m2": torch.float8_e5m2, + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32 + } + model_options = {} + if dtype := DTYPE_MAP.get(weight_dtype): + model_options["dtype"] = dtype + print(f"Setting {ckpt_name} weight dtype to {dtype}") + + if weight_dtype == "fp8_e4m3fn_fast": + model_options["dtype"] = torch.float8_e4m3fn + model_options["fp8_optimizations"] = True + + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) + + model, clip, vae = self.load_state_dict_guess_config( + sd, + output_vae=True, + output_clip=True, + embedding_directory=folder_paths.get_folder_paths("embeddings"), + metadata=metadata, + model_options=model_options) + + if dtype := DTYPE_MAP.get(compute_dtype): + model.set_model_compute_dtype(dtype) + model.force_cast_weights = False + print(f"Setting {ckpt_name} compute dtype to {dtype}") + + if enable_fp16_accumulation: + if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): + torch.backends.cuda.matmul.allow_fp16_accumulation = True + else: + raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.0 nightly currently") + else: + if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): + torch.backends.cuda.matmul.allow_fp16_accumulation = False + def patch_attention(model): self._patch_modules(patch_cublaslinear, sage_attention) model.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention) return model, clip, vae + + def load_state_dict_guess_config(self, sd, output_vae=True, output_clip=True, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): + from comfy.sd import load_diffusion_model_state_dict, model_detection, VAE, CLIP + clip = None + vae = None + model = None + model_patcher = None + + diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) + parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) + load_device = mm.get_torch_device() + + model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) + if model_config is None: + logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.") + diffusion_model = load_diffusion_model_state_dict(sd, model_options={}) + if diffusion_model is None: + return None + return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used' + + + unet_weight_dtype = list(model_config.supported_inference_dtypes) + if model_config.scaled_fp8 is not None: + weight_dtype = None + + model_config.custom_operations = model_options.get("custom_operations", None) + unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None)) + + if unet_dtype is None: + unet_dtype = mm.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype) + + manual_cast_dtype = mm.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) + + if output_model: + inital_load_device = mm.unet_inital_load_device(parameters, unet_dtype) + model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) + model.load_model_weights(sd, diffusion_model_prefix) + + if output_vae: + vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) + vae_sd = model_config.process_vae_state_dict(vae_sd) + vae = VAE(sd=vae_sd, metadata=metadata) + + if output_clip: + clip_target = model_config.clip_target(state_dict=sd) + if clip_target is not None: + clip_sd = model_config.process_clip_state_dict(sd) + if len(clip_sd) > 0: + parameters = comfy.utils.calculate_parameters(clip_sd) + clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options) + m, u = clip.load_sd(clip_sd, full_model=True) + if len(m) > 0: + m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) + if len(m_filter) > 0: + logging.warning("clip missing: {}".format(m)) + else: + logging.debug("clip missing: {}".format(m)) + + if len(u) > 0: + logging.debug("clip unexpected {}:".format(u)) + else: + logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") + + left_over = sd.keys() + if len(left_over) > 0: + logging.debug("left over keys: {}".format(left_over)) + + if output_model: + model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=mm.unet_offload_device()) + if inital_load_device != torch.device("cpu"): + logging.info("loaded diffusion model directly to GPU") + mm.load_models_gpu([model_patcher], force_full_load=True) + + return (model_patcher, clip, vae) class DiffusionModelLoaderKJ(BaseLoaderKJ): @classmethod