ComfyUI/comfy_extras/nodes_torch_compile.py
2025-10-03 11:43:54 -07:00

40 lines
1.1 KiB
Python

from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_api.torch_helpers import set_torch_compile_wrapper
class TorchCompileModel(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="TorchCompileModel",
category="_for_testing",
inputs=[
io.Model.Input("model"),
io.Combo.Input(
"backend",
options=["inductor", "cudagraphs"],
),
],
outputs=[io.Model.Output()],
is_experimental=True,
)
@classmethod
def execute(cls, model, backend) -> io.NodeOutput:
m = model.clone()
set_torch_compile_wrapper(model=m, backend=backend)
return io.NodeOutput(m)
class TorchCompileExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TorchCompileModel,
]
async def comfy_entrypoint() -> TorchCompileExtension:
return TorchCompileExtension()