From dafbcae4e6ad1dfbc6bb33d01dc9559e630ecc67 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 6 Aug 2025 20:09:14 +0300 Subject: [PATCH] Add option to load extra state dict with a diffusion model Also add node DiffusionModelSelector to easily select the path. This can be used to add standalone VACE module to any WanModel --- __init__.py | 1 + nodes/model_optimization_nodes.py | 36 ++++++++++++++++++++++++++++--- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/__init__.py b/__init__.py index 590ed7c..2d379d4 100644 --- a/__init__.py +++ b/__init__.py @@ -133,6 +133,7 @@ NODE_CONFIG = { "ModelSaveKJ": {"class": ModelSaveKJ, "name": "Model Save KJ"}, "SetShakkerLabsUnionControlNetType": {"class": SetShakkerLabsUnionControlNetType, "name": "Set Shakker Labs Union ControlNet Type"}, "StyleModelApplyAdvanced": {"class": StyleModelApplyAdvanced, "name": "Style Model Apply Advanced"}, + "DiffusionModelSelector": {"class": DiffusionModelSelector, "name": "Diffusion Model Selector"}, #audioscheduler stuff "NormalizedAmplitudeToMask": {"class": NormalizedAmplitudeToMask}, "NormalizedAmplitudeToFloatList": {"class": NormalizedAmplitudeToFloatList}, diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index dcf828d..c1f2e88 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -370,6 +370,25 @@ class CheckpointLoaderKJ(BaseLoaderKJ): return (model_patcher, clip, vae) +class DiffusionModelSelector(): + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}), + }, + } + + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("model_path",) + FUNCTION = "get_path" + DESCRIPTION = "Returns the path to the model as a string." + EXPERIMENTAL = True + CATEGORY = "KJNodes/experimental" + + def get_path(self, model_name): + model_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name) + return (model_path,) + class DiffusionModelLoaderKJ(BaseLoaderKJ): @classmethod def INPUT_TYPES(s): @@ -380,7 +399,11 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), "sage_attention": (sageattn_modes, {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), "enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}), - }} + }, + "optional": { + "extra_state_dict": ("STRING", {"forceInput": True, "tooltip": "The full path to an additional state dict to load, this will be merged with the main state dict. Useful for example to add VACE module to a WanVideoModel. You can use DiffusionModelSelector to easily get the path."}), + } + } RETURN_TYPES = ("MODEL",) FUNCTION = "patch_and_load" @@ -388,7 +411,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): EXPERIMENTAL = True CATEGORY = "KJNodes/experimental" - def patch_and_load(self, model_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation): + def patch_and_load(self, model_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation, extra_state_dict=None): DTYPE_MAP = { "fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e5m2": torch.float8_e5m2, @@ -415,7 +438,14 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): torch.backends.cuda.matmul.allow_fp16_accumulation = False unet_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name) - model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) + + sd = comfy.utils.load_torch_file(unet_path) + if extra_state_dict is not None: + extra_sd = comfy.utils.load_torch_file(extra_state_dict) + sd.update(extra_sd) + del extra_sd + + model = comfy.sd.load_diffusion_model_state_dict(sd, model_options=model_options) if dtype := DTYPE_MAP.get(compute_dtype): model.set_model_compute_dtype(dtype) model.force_cast_weights = False