vllm/vllm/compilation/partition_rules.py
Luka Govedič 2dcd12d357
[torch.compile] Fix tests for torch==2.9 inductor partition (#26116)
Signed-off-by: ProExpertProg <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
2025-10-14 19:55:02 -04:00

103 lines
3.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import logging
from typing import TYPE_CHECKING
from torch._library.utils import lookup_op
from vllm.logger import init_logger
if TYPE_CHECKING:
import torch
logger = init_logger(__name__)
def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
"""Resolve operator names to OpOverload objects.
Skips operators that fail to resolve (e.g., operators not registered or
model-specific operators not present in the current model).
Note: Users should inspect the operator graph before lowering and ensure
the specified operators are present in the final graph. Built-in PyTorch
operators (aten::*, torch::*) may be decomposed, fused, or transformed
during Inductor's compilation passes, so use them with caution.
Args:
op_names: List of operator names in PyTorch format
(e.g., "vllm::unified_attention")
Returns:
List of successfully resolved operator overloads
"""
resolved = []
for op_name in op_names:
try:
resolved.append(lookup_op(op_name))
except Exception:
# Skip operators that don't exist (e.g., model-specific ops)
# Do not warn for attention ops, warn for others
# (most likely manually specified)
from vllm.config import CompilationConfig
logger.log(
logging.DEBUG
if op_name in CompilationConfig._attention_ops
else logging.WARNING,
"Failed to resolve operator for CUDAGraph partition: %s",
op_name,
)
continue
return resolved
@contextlib.contextmanager
def inductor_partition_rule_context(overloads: list["torch._ops.OpOverload"]):
"""Context manager to temporarily register Inductor partition rules.
Registers custom partition rules for specified operators, forcing the
Inductor scheduler to partition the graph at these operators. The rules
are automatically restored to their previous state on exit.
Note: Callers should use resolve_defined_ops() to convert operator names
to OpOverload objects before calling this function.
Args:
overloads: List of resolved operator overload objects.
"""
if not overloads:
logger.debug("No partition ops provided; skipping rule registration.")
yield
return
from torch._inductor.scheduler import ( # type: ignore
_custom_should_partition_fns,
register_should_partition_rule,
)
def _always_partition(*_args, **_kwargs):
return True
# Save current state before registering
saved_rules = _custom_should_partition_fns.copy()
for overload in overloads:
register_should_partition_rule(
overload,
_always_partition,
)
logger.debug("Registered inductor partition rules for %d operators", len(overloads))
try:
yield
finally:
# Clear and restore previous state
_custom_should_partition_fns.clear()
_custom_should_partition_fns.update(saved_rules)
logger.debug("Restored previous partition rules state.")