diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 77381ae..dc40c33 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -109,42 +109,6 @@ def get_sage_func(sage_attention, allow_compile=False): return out return attention_sage -class BaseLoaderKJ: - original_linear = None - cublas_patched = False - - def _patch_modules(self, patch_cublaslinear, sage_attention): - from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight - - if patch_cublaslinear: - if not BaseLoaderKJ.cublas_patched: - BaseLoaderKJ.original_linear = disable_weight_init.Linear - try: - from cublas_ops import CublasLinear - except ImportError: - raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm") - - class PatchedLinear(CublasLinear, CastWeightBiasOp): - def reset_parameters(self): - pass - - def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) - - def forward(self, *args, **kwargs): - if self.comfy_cast_weights: - return self.forward_comfy_cast_weights(*args, **kwargs) - else: - return super().forward(*args, **kwargs) - - disable_weight_init.Linear = PatchedLinear - BaseLoaderKJ.cublas_patched = True - else: - if BaseLoaderKJ.cublas_patched: - disable_weight_init.Linear = BaseLoaderKJ.original_linear - BaseLoaderKJ.cublas_patched = False - from comfy.patcher_extension import CallbacksMP class PathchSageAttentionKJ(): @@ -179,26 +143,27 @@ class PathchSageAttentionKJ(): model_clone.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage return model_clone, - -class CheckpointLoaderKJ(BaseLoaderKJ): + + +class CheckpointLoaderKJ(): @classmethod def INPUT_TYPES(s): return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}), "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],), "compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "default", "tooltip": "The compute dtype to use for the model."}), - "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), + "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the cublas_ops arg"}), "sage_attention": (sageattn_modes, {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), "enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, required minimum pytorch version 2.7.1"}), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") - FUNCTION = "patch" + FUNCTION = "load" DESCRIPTION = "Experimental node for patching torch.nn.Linear with CublasLinear." EXPERIMENTAL = True CATEGORY = "KJNodes/experimental" - def patch(self, ckpt_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation): + def load(self, ckpt_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation): DTYPE_MAP = { "fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e5m2": torch.float8_e5m2, @@ -215,13 +180,18 @@ class CheckpointLoaderKJ(BaseLoaderKJ): model_options["dtype"] = torch.float8_e4m3fn model_options["fp8_optimizations"] = True + if patch_cublaslinear: + args.fast.add("cublas_ops") + else: + args.fast.discard("cublas_ops") + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) - - model, clip, vae = self.load_state_dict_guess_config( + + model, clip, vae, _ = comfy.sd.load_state_dict_guess_config( sd, - output_vae=True, - output_clip=True, + output_vae=True, + output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), metadata=metadata, model_options=model_options) @@ -249,82 +219,7 @@ class CheckpointLoaderKJ(BaseLoaderKJ): model.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage return model, clip, vae - - def load_state_dict_guess_config(self, sd, output_vae=True, output_clip=True, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): - from comfy.sd import load_diffusion_model_state_dict, model_detection, VAE, CLIP - clip = None - vae = None - model = None - model_patcher = None - diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) - parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) - weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) - load_device = mm.get_torch_device() - - model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) - if model_config is None: - logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.") - diffusion_model = load_diffusion_model_state_dict(sd, model_options={}) - if diffusion_model is None: - return None - return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used' - - - unet_weight_dtype = list(model_config.supported_inference_dtypes) - if model_config.scaled_fp8 is not None: - weight_dtype = None - - model_config.custom_operations = model_options.get("custom_operations", None) - unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None)) - - if unet_dtype is None: - unet_dtype = mm.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype) - - manual_cast_dtype = mm.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) - - if output_model: - inital_load_device = mm.unet_inital_load_device(parameters, unet_dtype) - model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) - model.load_model_weights(sd, diffusion_model_prefix) - - if output_vae: - vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) - vae_sd = model_config.process_vae_state_dict(vae_sd) - vae = VAE(sd=vae_sd, metadata=metadata) - - if output_clip: - clip_target = model_config.clip_target(state_dict=sd) - if clip_target is not None: - clip_sd = model_config.process_clip_state_dict(sd) - if len(clip_sd) > 0: - parameters = comfy.utils.calculate_parameters(clip_sd) - clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options) - m, u = clip.load_sd(clip_sd, full_model=True) - if len(m) > 0: - m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) - if len(m_filter) > 0: - logging.warning("clip missing: {}".format(m)) - else: - logging.debug("clip missing: {}".format(m)) - - if len(u) > 0: - logging.debug("clip unexpected {}:".format(u)) - else: - logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") - - left_over = sd.keys() - if len(left_over) > 0: - logging.debug("left over keys: {}".format(left_over)) - - if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=mm.unet_offload_device()) - if inital_load_device != torch.device("cpu"): - logging.info("loaded diffusion model directly to GPU") - mm.load_models_gpu([model_patcher], force_full_load=True) - - return (model_patcher, clip, vae) class DiffusionModelSelector(): @classmethod @@ -341,18 +236,18 @@ class DiffusionModelSelector(): EXPERIMENTAL = True CATEGORY = "KJNodes/experimental" - def get_path(self, model_name): + def get_path(self, model_name): model_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name) return (model_path,) -class DiffusionModelLoaderKJ(BaseLoaderKJ): +class DiffusionModelLoaderKJ(): @classmethod def INPUT_TYPES(s): return {"required": { "model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}), "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],), "compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "default", "tooltip": "The compute dtype to use for the model."}), - "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), + "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the cublas_ops arg"}), "sage_attention": (sageattn_modes, {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), "enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}), }, @@ -367,7 +262,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): EXPERIMENTAL = True CATEGORY = "KJNodes/experimental" - def patch_and_load(self, model_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation, extra_state_dict=None): + def patch_and_load(self, model_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation, extra_state_dict=None): DTYPE_MAP = { "fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e5m2": torch.float8_e5m2, @@ -379,11 +274,11 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): if dtype := DTYPE_MAP.get(weight_dtype): model_options["dtype"] = dtype logging.info(f"Setting {model_name} weight dtype to {dtype}") - + if weight_dtype == "fp8_e4m3fn_fast": model_options["dtype"] = torch.float8_e4m3fn model_options["fp8_optimizations"] = True - + if enable_fp16_accumulation: if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): torch.backends.cuda.matmul.allow_fp16_accumulation = True @@ -393,8 +288,13 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): torch.backends.cuda.matmul.allow_fp16_accumulation = False + if patch_cublaslinear: + args.fast.add("cublas_ops") + else: + args.fast.discard("cublas_ops") + unet_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name) - + sd = comfy.utils.load_torch_file(unet_path) if extra_state_dict is not None: # If the model is a checkpoint, strip additional non-diffusion model entries before adding extra state dict @@ -404,7 +304,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True) if len(temp_sd) > 0: sd = temp_sd - + extra_sd = comfy.utils.load_torch_file(extra_state_dict) sd.update(extra_sd) del extra_sd @@ -422,7 +322,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): # attention override model.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage - + return (model,) class ModelPatchTorchSettings: