mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 12:54:40 +08:00
Add ModelPatchTorchSettings
This commit is contained in:
parent
cc043fcac7
commit
ca07b9dadc
@ -193,6 +193,7 @@ NODE_CONFIG = {
|
||||
"TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"},
|
||||
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},
|
||||
"CFGZeroStarAndInit": {"class": CFGZeroStarAndInit, "name": "CFG Zero Star/Init"},
|
||||
"ModelPatchTorchSettings": {"class": ModelPatchTorchSettings, "name": "Model Patch Torch Settings"},
|
||||
|
||||
#instance diffusion
|
||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||
|
||||
@ -351,6 +351,45 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
||||
|
||||
return (model,)
|
||||
|
||||
class ModelPatchTorchSettings:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model": ("MODEL",),
|
||||
"enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
OUTPUT_NODE = True
|
||||
DESCRIPTION = "Adds callbacks to model to set torch settings before and after running the model."
|
||||
EXPERIMENTAL = True
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def patch(self, model, enable_fp16_accumulation):
|
||||
model_clone = model.clone()
|
||||
|
||||
def patch_enable_fp16_accum(model):
|
||||
print("Patching torch settings: torch.backends.cuda.matmul.allow_fp16_accumulation = True")
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||
def patch_disable_fp16_accum(model):
|
||||
print("Patching torch settings: torch.backends.cuda.matmul.allow_fp16_accumulation = False")
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = False
|
||||
|
||||
if enable_fp16_accumulation:
|
||||
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
|
||||
model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_enable_fp16_accum)
|
||||
model_clone.add_callback(CallbacksMP.ON_CLEANUP, patch_disable_fp16_accum)
|
||||
else:
|
||||
raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.0 nightly currently")
|
||||
else:
|
||||
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
|
||||
model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_disable_fp16_accum)
|
||||
else:
|
||||
raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.0 nightly currently")
|
||||
|
||||
return (model_clone,)
|
||||
|
||||
def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||
with self.use_ejected():
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user