[Graph Partition][Cache] Use inductor partition ops config (#27702)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
Boyuan Feng 2025-11-05 05:04:48 -08:00 committed by GitHub
parent 6b7a81185d
commit 6ab183813c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 37 additions and 63 deletions

View File

@ -97,10 +97,9 @@ class CompilerManager:
compilation (e.g. partition rules, pass context)."""
with pass_context(runtime_shape):
if self.compilation_config.use_inductor_graph_partition:
inductor_partition_ops = resolve_defined_ops(
with inductor_partition_rule_context(
self.compilation_config.splitting_ops
)
with inductor_partition_rule_context(inductor_partition_ops):
):
yield
else:
yield

View File

@ -3,15 +3,12 @@
import contextlib
import logging
from typing import TYPE_CHECKING
import torch
from torch._library.utils import lookup_op
from vllm.logger import init_logger
if TYPE_CHECKING:
import torch
logger = init_logger(__name__)
@ -56,47 +53,35 @@ def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
@contextlib.contextmanager
def inductor_partition_rule_context(overloads: list["torch._ops.OpOverload"]):
def inductor_partition_rule_context(splitting_ops: list[str]):
"""Context manager to temporarily register Inductor partition rules.
Registers custom partition rules for specified operators, forcing the
Inductor scheduler to partition the graph at these operators. The rules
are automatically restored to their previous state on exit.
Note: Callers should use resolve_defined_ops() to convert operator names
to OpOverload objects before calling this function.
Args:
overloads: List of resolved operator overload objects.
splitting_ops: List of operator names to partition on.
"""
if not overloads:
if not splitting_ops:
logger.debug("No partition ops provided; skipping rule registration.")
yield
return
from torch._inductor.scheduler import ( # type: ignore
_custom_should_partition_fns,
register_should_partition_rule,
)
def _always_partition(*_args, **_kwargs):
return True
# Save current state before registering
saved_rules = _custom_should_partition_fns.copy()
for overload in overloads:
register_should_partition_rule(
overload,
_always_partition,
)
saved_splitting_ops: list[str] = list(
torch._inductor.config.custom_should_partition_ops
)
torch._inductor.config.custom_should_partition_ops = splitting_ops
logger.debug("Registered inductor partition rules for %d operators", len(overloads))
logger.debug(
"Registered inductor partition rules for %d operators", len(splitting_ops)
)
try:
yield
finally:
# Clear and restore previous state
_custom_should_partition_fns.clear()
_custom_should_partition_fns.update(saved_rules)
torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
logger.debug("Restored previous partition rules state.")

View File

@ -113,27 +113,6 @@ class PostGradPassManager(CustomGraphPass):
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
# [HACK: Bug with Inductor graph partition and torch.compile cache]
# In PyTorch 2.9, torch.compile has a bug where the graph
# partition is not taken into account during caching.
# Because vLLM's Mode.VLLM_COMPILE is the only mode that uses
# Inductor graph partition, and VLLM_COMPILE implies there
# is a PostGradPassManager, we put the list of operators to graph
# partition into the PostGradPassManager's uuid (which
# then gets incorporated into Inductor's FX graph cache key).
# Remove this hack whenever torch.compile fixes it.
# This is the list of operators that vLLM asks Inductor to split.
self.inductor_splitting_ops = []
if (
config.compilation_config.use_inductor_graph_partition
and config.compilation_config.splitting_ops is not None
):
# Sort them so we're not dependent on the ordering.
self.inductor_splitting_ops = sorted(
config.compilation_config.splitting_ops
)
def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)
@ -144,16 +123,9 @@ class PostGradPassManager(CustomGraphPass):
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
state = {
"pass_config": self.pass_config.uuid(),
"passes": [],
"inductor_splitting_ops": [],
}
state = {"pass_config": self.pass_config.uuid(), "passes": []}
for pass_ in self.passes:
state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid())
# See [HACK: Bug with Inductor graph partition and torch.compile cache]
state["inductor_splitting_ops"].extend(self.inductor_splitting_ops)
return InductorPass.hash_dict(state)

View File

@ -272,7 +272,6 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool:
from torch._inductor.scheduler import (
BaseSchedulerNode,
FusedSchedulerNode,
_custom_should_partition_fns,
)
from torch._inductor.utils import (
_unstable_customized_partition_wrapper,
@ -283,9 +282,21 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool:
# Allow users to manually specify if a node should be partitioned
# Can only do this for FallbackKernels
ir_node = node.node
if isinstance(ir_node, ir.FallbackKernel):
operator = ir_node.op_overload
if operator is not None and operator in _custom_should_partition_fns:
if isinstance(ir_node, torch._inductor.ir.FallbackKernel) and (
op := ir_node.op_overload
):
op_overload_packet_name = op.name()
op_overload_name = (
f"{op_overload_packet_name}.{op._overloadname}"
if isinstance(op, torch._ops.OpOverload)
else op_overload_packet_name
)
if (
op_overload_packet_name
in torch._inductor.config.custom_should_partition_ops
or op_overload_name in torch._inductor.config.custom_should_partition_ops
):
assert isinstance(op, torch._ops.OpOverload)
return True
# When not using cudagraphs, keep all kernels in the `call` function
@ -355,6 +366,13 @@ def _update_scheduler_patched(self) -> None:
if is_torch_equal("2.9.0"):
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
from torch._inductor.graph import GraphLowering
from torch.utils._config_module import _Config, _ConfigEntry
# `custom_should_partition_ops` is a new config after 2.9.0. So this would
# not overwrite any user configs.
torch._inductor.config._config["custom_should_partition_ops"] = _ConfigEntry(
_Config(default=[])
)
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
GraphLowering._update_scheduler = _update_scheduler_patched