diff --git a/__init__.py b/__init__.py index 104aa60..8785bac 100644 --- a/__init__.py +++ b/__init__.py @@ -193,6 +193,7 @@ NODE_CONFIG = { "TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"}, "HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"}, "CFGZeroStarAndInit": {"class": CFGZeroStarAndInit, "name": "CFG Zero Star/Init"}, + "ModelPatchTorchSettings": {"class": ModelPatchTorchSettings, "name": "Model Patch Torch Settings"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 127c7f9..69b83d9 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -351,6 +351,45 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): return (model,) +class ModelPatchTorchSettings: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("MODEL",), + "enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + OUTPUT_NODE = True + DESCRIPTION = "Adds callbacks to model to set torch settings before and after running the model." + EXPERIMENTAL = True + CATEGORY = "KJNodes/experimental" + + def patch(self, model, enable_fp16_accumulation): + model_clone = model.clone() + + def patch_enable_fp16_accum(model): + print("Patching torch settings: torch.backends.cuda.matmul.allow_fp16_accumulation = True") + torch.backends.cuda.matmul.allow_fp16_accumulation = True + def patch_disable_fp16_accum(model): + print("Patching torch settings: torch.backends.cuda.matmul.allow_fp16_accumulation = False") + torch.backends.cuda.matmul.allow_fp16_accumulation = False + + if enable_fp16_accumulation: + if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): + model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_enable_fp16_accum) + model_clone.add_callback(CallbacksMP.ON_CLEANUP, patch_disable_fp16_accum) + else: + raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.0 nightly currently") + else: + if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): + model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_disable_fp16_accum) + else: + raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.0 nightly currently") + + return (model_clone,) + def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): with self.use_ejected():