Add ModelPatchTorchSettings

This commit is contained in:
kijai 2025-05-07 19:00:45 +03:00
parent cc043fcac7
commit ca07b9dadc
2 changed files with 40 additions and 0 deletions

View File

@ -193,6 +193,7 @@ NODE_CONFIG = {
"TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"},
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},
"CFGZeroStarAndInit": {"class": CFGZeroStarAndInit, "name": "CFG Zero Star/Init"},
"ModelPatchTorchSettings": {"class": ModelPatchTorchSettings, "name": "Model Patch Torch Settings"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -351,6 +351,45 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
return (model,)
class ModelPatchTorchSettings:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
OUTPUT_NODE = True
DESCRIPTION = "Adds callbacks to model to set torch settings before and after running the model."
EXPERIMENTAL = True
CATEGORY = "KJNodes/experimental"
def patch(self, model, enable_fp16_accumulation):
model_clone = model.clone()
def patch_enable_fp16_accum(model):
print("Patching torch settings: torch.backends.cuda.matmul.allow_fp16_accumulation = True")
torch.backends.cuda.matmul.allow_fp16_accumulation = True
def patch_disable_fp16_accum(model):
print("Patching torch settings: torch.backends.cuda.matmul.allow_fp16_accumulation = False")
torch.backends.cuda.matmul.allow_fp16_accumulation = False
if enable_fp16_accumulation:
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_enable_fp16_accum)
model_clone.add_callback(CallbacksMP.ON_CLEANUP, patch_disable_fp16_accum)
else:
raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.0 nightly currently")
else:
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_disable_fp16_accum)
else:
raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.0 nightly currently")
return (model_clone,)
def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
with self.use_ejected():