diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 956b0ea..ce3a3b6 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -337,7 +337,11 @@ class TorchCompileModelFluxAdvanced: "double_blocks": ("STRING", {"default": "0-18", "multiline": True}), "single_blocks": ("STRING", {"default": "0-37", "multiline": True}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), - }} + }, + "optional": { + "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), + } + } RETURN_TYPES = ("MODEL",) FUNCTION = "patch" @@ -355,11 +359,12 @@ class TorchCompileModelFluxAdvanced: blocks.append(int(part)) return blocks - def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks, dynamic): + def patch(self, model, backend, mode, fullgraph, single_blocks, double_blocks, dynamic, dynamo_cache_size_limit): single_block_list = self.parse_blocks(single_blocks) double_block_list = self.parse_blocks(double_blocks) m = model.clone() diffusion_model = m.get_model_object("diffusion_model") + torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit if not self._compiled: try: