mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-15 07:44:30 +08:00
Add ModelPatchTorchSettings
This commit is contained in:
parent
cc043fcac7
commit
ca07b9dadc
@ -193,6 +193,7 @@ NODE_CONFIG = {
|
|||||||
"TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"},
|
"TimerNodeKJ": {"class": TimerNodeKJ, "name": "Timer Node KJ"},
|
||||||
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},
|
"HunyuanVideoEncodeKeyframesToCond": {"class": HunyuanVideoEncodeKeyframesToCond, "name": "HunyuanVideo Encode Keyframes To Cond"},
|
||||||
"CFGZeroStarAndInit": {"class": CFGZeroStarAndInit, "name": "CFG Zero Star/Init"},
|
"CFGZeroStarAndInit": {"class": CFGZeroStarAndInit, "name": "CFG Zero Star/Init"},
|
||||||
|
"ModelPatchTorchSettings": {"class": ModelPatchTorchSettings, "name": "Model Patch Torch Settings"},
|
||||||
|
|
||||||
#instance diffusion
|
#instance diffusion
|
||||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||||
|
|||||||
@ -351,6 +351,45 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ):
|
|||||||
|
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
|
class ModelPatchTorchSettings:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"model": ("MODEL",),
|
||||||
|
"enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
DESCRIPTION = "Adds callbacks to model to set torch settings before and after running the model."
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
CATEGORY = "KJNodes/experimental"
|
||||||
|
|
||||||
|
def patch(self, model, enable_fp16_accumulation):
|
||||||
|
model_clone = model.clone()
|
||||||
|
|
||||||
|
def patch_enable_fp16_accum(model):
|
||||||
|
print("Patching torch settings: torch.backends.cuda.matmul.allow_fp16_accumulation = True")
|
||||||
|
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||||
|
def patch_disable_fp16_accum(model):
|
||||||
|
print("Patching torch settings: torch.backends.cuda.matmul.allow_fp16_accumulation = False")
|
||||||
|
torch.backends.cuda.matmul.allow_fp16_accumulation = False
|
||||||
|
|
||||||
|
if enable_fp16_accumulation:
|
||||||
|
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
|
||||||
|
model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_enable_fp16_accum)
|
||||||
|
model_clone.add_callback(CallbacksMP.ON_CLEANUP, patch_disable_fp16_accum)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.0 nightly currently")
|
||||||
|
else:
|
||||||
|
if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"):
|
||||||
|
model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_disable_fp16_accum)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Failed to set fp16 accumulation, this requires pytorch 2.7.0 nightly currently")
|
||||||
|
|
||||||
|
return (model_clone,)
|
||||||
|
|
||||||
def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
def patched_patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user