diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 543da07..a87b8e8 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -868,7 +868,10 @@ class TorchCompileModelAdvanced: "backend": (["inductor","cudagraphs"], {"default": "inductor"}), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), - "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), + "dynamic": ( + ["auto", "true", "false"], + {"default": "auto", "tooltip": "Use dynamic shape tracing."}, + ), "compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only transformer blocks, faster compile and less error prone"}), "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), "debug_compile_keys": ("BOOLEAN", {"default": False, "tooltip": "Print the compile keys used for torch.compile"}), @@ -903,6 +906,12 @@ class TorchCompileModelAdvanced: logging.info(f" - {key}") if not compile_key_list: compile_key_list =["diffusion_model"] + + dynamic_kv = {"true": True, "false": False, "auto": None} + try: + dynamic = dynamic_kv[dynamic] + except KeyError: + raise ValueError(f"Invalid dynamic arg {dynamic}") set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph) except: