diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 61941da..86cc463 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -391,6 +391,14 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): sd = comfy.utils.load_torch_file(unet_path) if extra_state_dict is not None: + # If the model is a checkpoint, strip additional non-diffusion model entries before adding extra state dict + from comfy import model_detection + diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) + if diffusion_model_prefix == "model.diffusion_model.": + temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True) + if len(temp_sd) > 0: + sd = temp_sd + extra_sd = comfy.utils.load_torch_file(extra_state_dict) sd.update(extra_sd) del extra_sd