diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 3065ea4..b97142a 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -166,10 +166,12 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): @classmethod def INPUT_TYPES(s): return {"required": { - "ckpt_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}), - "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],), + "model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "The name of the checkpoint (model) to load."}), + "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "fp16_fast", "bf16", "fp32"],), + "compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "fp16", "tooltip": "The compute dtype to use for the model."}), "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), "sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), + "enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}), }} RETURN_TYPES = ("MODEL",) @@ -179,10 +181,36 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): EXPERIMENTAL = True CATEGORY = "KJNodes/experimental" - def patch_and_load(self, ckpt_name, weight_dtype, patch_cublaslinear, sage_attention): - from nodes import UNETLoader - model, = UNETLoader.load_unet(self, ckpt_name, weight_dtype) + def patch_and_load(self, model_name, weight_dtype, compute_dtype, patch_cublaslinear, sage_attention, enable_fp16_accumulation): + DTYPE_MAP = { + "fp8_e4m3fn": torch.float8_e4m3fn, + "fp8_e5m2": torch.float8_e5m2, + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32 + } + model_options = {} + if dtype := DTYPE_MAP.get(weight_dtype): + model_options["dtype"] = dtype + + if weight_dtype == "fp8_e4m3fn_fast": + model_options["dtype"] = torch.float8_e4m3fn + model_options["fp8_optimizations"] = True + + try: + if enable_fp16_accumulation: + torch.backends.cuda.matmul.allow_fp16_accumulation = True + else: + torch.backends.cuda.matmul.allow_fp16_accumulation = False + except: + raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.0 nightly currently") + + unet_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name) + model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) + if dtype := DTYPE_MAP.get(compute_dtype): + model.set_model_compute_dtype(dtype) self._patch_modules(patch_cublaslinear, sage_attention) + return (model,) def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):