mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-23 10:24:29 +08:00
Update CheckpointLoaderKJ for latest Comfy changes
This commit is contained in:
parent
3160b36474
commit
409eda5e29
@ -109,42 +109,6 @@ def get_sage_func(sage_attention, allow_compile=False):
|
||||
return out
|
||||
return attention_sage
|
||||
|
||||
class BaseLoaderKJ:
|
||||
original_linear = None
|
||||
cublas_patched = False
|
||||
|
||||
def _patch_modules(self, patch_cublaslinear, sage_attention):
|
||||
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight
|
||||
|
||||
if patch_cublaslinear:
|
||||
if not BaseLoaderKJ.cublas_patched:
|
||||
BaseLoaderKJ.original_linear = disable_weight_init.Linear
|
||||
try:
|
||||
from cublas_ops import CublasLinear
|
||||
except ImportError:
|
||||
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm")
|
||||
|
||||
class PatchedLinear(CublasLinear, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
pass
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
disable_weight_init.Linear = PatchedLinear
|
||||
BaseLoaderKJ.cublas_patched = True
|
||||
else:
|
||||
if BaseLoaderKJ.cublas_patched:
|
||||
disable_weight_init.Linear = BaseLoaderKJ.original_linear
|
||||
BaseLoaderKJ.cublas_patched = False
|
||||
|
||||
|
||||
from comfy.patcher_extension import CallbacksMP
|
||||
class PathchSageAttentionKJ():
|
||||
@ -179,26 +143,27 @@ class PathchSageAttentionKJ():
|
||||
model_clone.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage
|
||||
|
||||
return model_clone,
|
||||
|
||||
class CheckpointLoaderKJ(BaseLoaderKJ):
|
||||
|
||||
|
||||
class CheckpointLoaderKJ():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],),
|
||||
"compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "default", "tooltip": "The compute dtype to use for the model."}),
|
||||
"patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
||||
"patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the cublas_ops arg"}),
|
||||
"sage_attention": (sageattn_modes, {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
||||
"enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, required minimum pytorch version 2.7.1"}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "patch"
|
||||
FUNCTION = "load"
|
||||
DESCRIPTION = "Experimental node for patching torch.nn.Linear with CublasLinear."
|
||||
EXPERIMENTAL = True
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def patch(self, ckpt_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation):
|
||||
def load(self, ckpt_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation):
|
||||
DTYPE_MAP = {
|
||||
"fp8_e4m3fn": torch.float8_e4m3fn,
|
||||
"fp8_e5m2": torch.float8_e5m2,
|
||||
@ -215,13 +180,18 @@ class CheckpointLoaderKJ(BaseLoaderKJ):
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
model_options["fp8_optimizations"] = True
|
||||
|
||||
if patch_cublaslinear:
|
||||
args.fast.add("cublas_ops")
|
||||
else:
|
||||
args.fast.discard("cublas_ops")
|
||||
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
|
||||
model, clip, vae = self.load_state_dict_guess_config(
|
||||
|
||||
model, clip, vae, _ = comfy.sd.load_state_dict_guess_config(
|
||||
sd,
|
||||
output_vae=True,
|
||||
output_clip=True,
|
||||
output_vae=True,
|
||||
output_clip=True,
|
||||
embedding_directory=folder_paths.get_folder_paths("embeddings"),
|
||||
metadata=metadata,
|
||||
model_options=model_options)
|
||||
@ -249,82 +219,7 @@ class CheckpointLoaderKJ(BaseLoaderKJ):
|
||||
model.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage
|
||||
|
||||
return model, clip, vae
|
||||
|
||||
def load_state_dict_guess_config(self, sd, output_vae=True, output_clip=True, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
||||
from comfy.sd import load_diffusion_model_state_dict, model_detection, VAE, CLIP
|
||||
clip = None
|
||||
vae = None
|
||||
model = None
|
||||
model_patcher = None
|
||||
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||
load_device = mm.get_torch_device()
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||
if model_config is None:
|
||||
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
|
||||
if diffusion_model is None:
|
||||
return None
|
||||
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.scaled_fp8 is not None:
|
||||
weight_dtype = None
|
||||
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
||||
|
||||
if unet_dtype is None:
|
||||
unet_dtype = mm.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
||||
|
||||
manual_cast_dtype = mm.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
if output_model:
|
||||
inital_load_device = mm.unet_inital_load_device(parameters, unet_dtype)
|
||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||
model.load_model_weights(sd, diffusion_model_prefix)
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||
|
||||
if output_clip:
|
||||
clip_target = model_config.clip_target(state_dict=sd)
|
||||
if clip_target is not None:
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
if len(m_filter) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
else:
|
||||
logging.debug("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected {}:".format(u))
|
||||
else:
|
||||
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
logging.debug("left over keys: {}".format(left_over))
|
||||
|
||||
if output_model:
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=mm.unet_offload_device())
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
logging.info("loaded diffusion model directly to GPU")
|
||||
mm.load_models_gpu([model_patcher], force_full_load=True)
|
||||
|
||||
return (model_patcher, clip, vae)
|
||||
|
||||
class DiffusionModelSelector():
|
||||
@classmethod
|
||||
@ -341,18 +236,18 @@ class DiffusionModelSelector():
|
||||
EXPERIMENTAL = True
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def get_path(self, model_name):
|
||||
def get_path(self, model_name):
|
||||
model_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name)
|
||||
return (model_path,)
|
||||
|
||||
class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
class DiffusionModelLoaderKJ():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}),
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],),
|
||||
"compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "default", "tooltip": "The compute dtype to use for the model."}),
|
||||
"patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}),
|
||||
"patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the cublas_ops arg"}),
|
||||
"sage_attention": (sageattn_modes, {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
|
||||
"enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}),
|
||||
},
|
||||
@ -367,7 +262,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
EXPERIMENTAL = True
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def patch_and_load(self, model_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation, extra_state_dict=None):
|
||||
def patch_and_load(self, model_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation, extra_state_dict=None):
|
||||
DTYPE_MAP = {
|
||||
"fp8_e4m3fn": torch.float8_e4m3fn,
|
||||
"fp8_e5m2": torch.float8_e5m2,
|
||||
@ -379,11 +274,11 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
if dtype := DTYPE_MAP.get(weight_dtype):
|
||||
model_options["dtype"] = dtype
|
||||
logging.info(f"Setting {model_name} weight dtype to {dtype}")
|
||||
|
||||
|
||||
if weight_dtype == "fp8_e4m3fn_fast":
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
model_options["fp8_optimizations"] = True
|
||||
|
||||
|
||||
if enable_fp16_accumulation:
|
||||
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||
@ -393,8 +288,13 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = False
|
||||
|
||||
if patch_cublaslinear:
|
||||
args.fast.add("cublas_ops")
|
||||
else:
|
||||
args.fast.discard("cublas_ops")
|
||||
|
||||
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name)
|
||||
|
||||
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
if extra_state_dict is not None:
|
||||
# If the model is a checkpoint, strip additional non-diffusion model entries before adding extra state dict
|
||||
@ -404,7 +304,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
||||
if len(temp_sd) > 0:
|
||||
sd = temp_sd
|
||||
|
||||
|
||||
extra_sd = comfy.utils.load_torch_file(extra_state_dict)
|
||||
sd.update(extra_sd)
|
||||
del extra_sd
|
||||
@ -422,7 +322,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
|
||||
# attention override
|
||||
model.model_options["transformer_options"]["optimized_attention_override"] = attention_override_sage
|
||||
|
||||
|
||||
return (model,)
|
||||
|
||||
class ModelPatchTorchSettings:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user