mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-23 20:44:33 +08:00
Add option to load extra state dict with a diffusion model
Also add node DiffusionModelSelector to easily select the path. This can be used to add standalone VACE module to any WanModel
This commit is contained in:
parent
331260d908
commit
dafbcae4e6
@ -133,6 +133,7 @@ NODE_CONFIG = {
|
||||
"ModelSaveKJ": {"class": ModelSaveKJ, "name": "Model Save KJ"},
|
||||
"SetShakkerLabsUnionControlNetType": {"class": SetShakkerLabsUnionControlNetType, "name": "Set Shakker Labs Union ControlNet Type"},
|
||||
"StyleModelApplyAdvanced": {"class": StyleModelApplyAdvanced, "name": "Style Model Apply Advanced"},
|
||||
"DiffusionModelSelector": {"class": DiffusionModelSelector, "name": "Diffusion Model Selector"},
|
||||
#audioscheduler stuff
|
||||
"NormalizedAmplitudeToMask": {"class": NormalizedAmplitudeToMask},
|
||||
"NormalizedAmplitudeToFloatList": {"class": NormalizedAmplitudeToFloatList},
|
||||
|
||||
@ -370,6 +370,25 @@ class CheckpointLoaderKJ(BaseLoaderKJ):
|
||||
|
||||
return (model_patcher, clip, vae)
|
||||
|
||||
class DiffusionModelSelector():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("model_path",)
|
||||
FUNCTION = "get_path"
|
||||
DESCRIPTION = "Returns the path to the model as a string."
|
||||
EXPERIMENTAL = True
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def get_path(self, model_name):
|
||||
model_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name)
|
||||
return (model_path,)
|
||||
|
||||
class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -380,7 +399,11 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
"patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
||||
"sage_attention": (sageattn_modes, {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
||||
"enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}),
|
||||
}}
|
||||
},
|
||||
"optional": {
|
||||
"extra_state_dict": ("STRING", {"forceInput": True, "tooltip": "The full path to an additional state dict to load, this will be merged with the main state dict. Useful for example to add VACE module to a WanVideoModel. You can use DiffusionModelSelector to easily get the path."}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch_and_load"
|
||||
@ -388,7 +411,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
EXPERIMENTAL = True
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def patch_and_load(self, model_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation):
|
||||
def patch_and_load(self, model_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation, extra_state_dict=None):
|
||||
DTYPE_MAP = {
|
||||
"fp8_e4m3fn": torch.float8_e4m3fn,
|
||||
"fp8_e5m2": torch.float8_e5m2,
|
||||
@ -415,7 +438,14 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = False
|
||||
|
||||
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name)
|
||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
if extra_state_dict is not None:
|
||||
extra_sd = comfy.utils.load_torch_file(extra_state_dict)
|
||||
sd.update(extra_sd)
|
||||
del extra_sd
|
||||
|
||||
model = comfy.sd.load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
if dtype := DTYPE_MAP.get(compute_dtype):
|
||||
model.set_model_compute_dtype(dtype)
|
||||
model.force_cast_weights = False
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user