mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-14 10:24:33 +08:00
73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import contextlib
|
|
|
|
import torch
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
|
|
"""
|
|
Check if a node should be split for dynamo graph partition.
|
|
It operates on dynamo graph, so the node.target can be anything.
|
|
We need to check and split only on OpOverload and OpOverloadPacket.
|
|
"""
|
|
|
|
if node.op != "call_function":
|
|
return False
|
|
|
|
target = node.target
|
|
|
|
if isinstance(target, torch._ops.OpOverloadPacket):
|
|
# Example: "aten::add"
|
|
return target._qualified_op_name in splitting_ops
|
|
|
|
if isinstance(target, torch._ops.OpOverload):
|
|
# Example: "aten::add"
|
|
packet_name = target.name()
|
|
|
|
# Example: "aten::add.default"
|
|
op_overload_name = f"{packet_name}.{target._overloadname}"
|
|
return op_overload_name in splitting_ops or packet_name in splitting_ops
|
|
|
|
return False
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def inductor_partition_rule_context(splitting_ops: list[str]):
|
|
"""Context manager to temporarily register Inductor partition rules.
|
|
|
|
Registers custom partition rules for specified operators, forcing the
|
|
Inductor scheduler to partition the graph at these operators. The rules
|
|
are automatically restored to their previous state on exit.
|
|
|
|
Args:
|
|
splitting_ops: List of operator names to partition on.
|
|
"""
|
|
if not splitting_ops:
|
|
logger.debug("No partition ops provided; skipping rule registration.")
|
|
yield
|
|
return
|
|
|
|
# Save current state before registering
|
|
|
|
saved_splitting_ops: list[str] = list(
|
|
torch._inductor.config.custom_should_partition_ops
|
|
)
|
|
torch._inductor.config.custom_should_partition_ops = splitting_ops
|
|
|
|
logger.debug(
|
|
"Registered inductor partition rules for %d operators", len(splitting_ops)
|
|
)
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
# Clear and restore previous state
|
|
torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
|
|
logger.debug("Restored previous partition rules state.")
|