DiffusionModelLoaderKJ: Allow model checkpoints to be used with extra state dict input

This commit is contained in:
ozbayb 2025-11-08 11:48:31 -07:00
parent c661baadd9
commit 3a8786c206

View File

@ -391,6 +391,14 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
sd = comfy.utils.load_torch_file(unet_path)
if extra_state_dict is not None:
# If the model is a checkpoint, strip additional non-diffusion model entries before adding extra state dict
from comfy import model_detection
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
if diffusion_model_prefix == "model.diffusion_model.":
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
if len(temp_sd) > 0:
sd = temp_sd
extra_sd = comfy.utils.load_torch_file(extra_state_dict)
sd.update(extra_sd)
del extra_sd