mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-10 03:29:09 +08:00
[BugFix] Patch inductor memory plan logic (#26878)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
5d598680e3
commit
f57438338d
@ -22,6 +22,11 @@ sys.modules["vllm._C"] = MagicMock()
|
|||||||
class PydanticMagicMock(MagicMock):
|
class PydanticMagicMock(MagicMock):
|
||||||
"""`MagicMock` that's able to generate pydantic-core schemas."""
|
"""`MagicMock` that's able to generate pydantic-core schemas."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
name = kwargs.pop("name", None)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.__spec__ = importlib.machinery.ModuleSpec(name, None)
|
||||||
|
|
||||||
def __get_pydantic_core_schema__(self, source_type, handler):
|
def __get_pydantic_core_schema__(self, source_type, handler):
|
||||||
return core_schema.any_schema()
|
return core_schema.any_schema()
|
||||||
|
|
||||||
@ -42,7 +47,9 @@ def auto_mock(module, attr, max_mocks=50):
|
|||||||
raise e
|
raise e
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
logger.info("Mocking %s for argparse doc generation", e.name)
|
logger.info("Mocking %s for argparse doc generation", e.name)
|
||||||
sys.modules[e.name] = PydanticMagicMock()
|
sys.modules[e.name] = PydanticMagicMock(name=e.name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to import %s.%s: %s", module, attr, e)
|
||||||
|
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Failed to import {module}.{attr} after mocking {max_mocks} imports"
|
f"Failed to import {module}.{attr} after mocking {max_mocks} imports"
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from vllm.config import (
|
|||||||
set_current_vllm_config,
|
set_current_vllm_config,
|
||||||
)
|
)
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
# This import automatically registers `torch.ops.silly.attention`
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
from .. import silly_attention # noqa: F401
|
from .. import silly_attention # noqa: F401
|
||||||
@ -193,9 +194,8 @@ def run_model(
|
|||||||
|
|
||||||
@pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
|
@pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
|
||||||
def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
|
def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
|
||||||
if use_inductor_graph_partition:
|
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
# FIXME(luka/boyuan): this currently fails
|
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||||
pytest.skip("Inductor graph partition not supported with multi-graph")
|
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
|
|||||||
@ -3,9 +3,9 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import is_torch_equal
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -23,6 +23,72 @@ os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
|||||||
# see https://github.com/vllm-project/vllm/issues/10619
|
# see https://github.com/vllm-project/vllm/issues/10619
|
||||||
torch._inductor.config.compile_threads = 1
|
torch._inductor.config.compile_threads = 1
|
||||||
|
|
||||||
|
# ===================================================
|
||||||
|
# torch 2.9 Inductor PythonWrapperCodegen monkeypatch
|
||||||
|
# ===================================================
|
||||||
|
# This change monkeypatches memory_plan_reuse in pytorch 2.9.0 to work around
|
||||||
|
# a test failure for test_multi_graph_piecewise_compile_outputs_equal.
|
||||||
|
# For more context, see https://github.com/pytorch/pytorch/pull/165514.
|
||||||
|
|
||||||
|
|
||||||
|
def memory_plan_reuse_patched(self):
|
||||||
|
import torch._inductor.ir as ir
|
||||||
|
from torch._inductor.codegen.wrapper import (
|
||||||
|
EnterSubgraphLine,
|
||||||
|
ExitSubgraphLine,
|
||||||
|
MemoryPlanningLine,
|
||||||
|
MemoryPlanningState,
|
||||||
|
SubgraphPythonWrapperCodegen,
|
||||||
|
)
|
||||||
|
from torch._inductor.virtualized import V
|
||||||
|
|
||||||
|
def get_output_names(graph_outputs) -> list[str]:
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
names = []
|
||||||
|
shape_counter = itertools.count(0)
|
||||||
|
none_counter = itertools.count(0)
|
||||||
|
for node in graph_outputs:
|
||||||
|
if isinstance(node, ir.NoneAsConstantBuffer):
|
||||||
|
names.append(f"{V.graph.name}_none{next(none_counter)}")
|
||||||
|
elif isinstance(node, ir.ShapeAsConstantBuffer):
|
||||||
|
names.append(f"{V.graph.name}_shape{next(shape_counter)}")
|
||||||
|
else:
|
||||||
|
names.append(node.get_name())
|
||||||
|
return names
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(V.graph.wrapper_code, SubgraphPythonWrapperCodegen)
|
||||||
|
and V.graph.wrapper_code.partition_signatures is not None
|
||||||
|
):
|
||||||
|
out_names = get_output_names(
|
||||||
|
V.graph.wrapper_code.partition_signatures.output_nodes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out_names = V.graph.get_output_names()
|
||||||
|
|
||||||
|
while (
|
||||||
|
self.lines
|
||||||
|
and isinstance(self.lines[-1], MemoryPlanningLine)
|
||||||
|
and self.lines[-1].node.name not in out_names # type: ignore[attr-defined]
|
||||||
|
):
|
||||||
|
# these lines will be pointless
|
||||||
|
self.lines.pop()
|
||||||
|
|
||||||
|
# codegen allocations in two passes
|
||||||
|
planning_states = [MemoryPlanningState()]
|
||||||
|
past_planning_states = []
|
||||||
|
for i in range(len(self.lines)):
|
||||||
|
line = self.lines[i]
|
||||||
|
if isinstance(line, MemoryPlanningLine):
|
||||||
|
self.lines[i] = line.plan(planning_states[-1])
|
||||||
|
elif isinstance(line, EnterSubgraphLine):
|
||||||
|
planning_states.append(MemoryPlanningState())
|
||||||
|
elif isinstance(line, ExitSubgraphLine):
|
||||||
|
past_planning_states.append(planning_states.pop())
|
||||||
|
past_planning_states.append(planning_states.pop())
|
||||||
|
assert len(planning_states) == 0
|
||||||
|
|
||||||
|
|
||||||
# ========================================
|
# ========================================
|
||||||
# torch 2.9 Inductor Scheduler monkeypatch
|
# torch 2.9 Inductor Scheduler monkeypatch
|
||||||
@ -135,7 +201,9 @@ def _update_scheduler_patched(self) -> None:
|
|||||||
self.scheduler = Scheduler(self.operations)
|
self.scheduler = Scheduler(self.operations)
|
||||||
|
|
||||||
|
|
||||||
if version.parse(str(torch.__version__)) == version.parse("2.9.0"):
|
if is_torch_equal("2.9.0"):
|
||||||
|
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
|
||||||
from torch._inductor.graph import GraphLowering
|
from torch._inductor.graph import GraphLowering
|
||||||
|
|
||||||
|
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
|
||||||
GraphLowering._update_scheduler = _update_scheduler_patched
|
GraphLowering._update_scheduler = _update_scheduler_patched
|
||||||
|
|||||||
@ -3263,6 +3263,33 @@ def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
|
|||||||
return torch_version >= version.parse(target)
|
return torch_version >= version.parse(target)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_torch_equal(target: str) -> bool:
|
||||||
|
assert target.count(".") == 2
|
||||||
|
torch_version = str(torch.__version__)
|
||||||
|
torch_version = version.parse(torch_version)
|
||||||
|
# torch version is like "2.6.0.dev20240101" or "2.6.0.dev20240101+cpu"
|
||||||
|
# or "2.6.0+cu128" but never "2.6.0.1"
|
||||||
|
return (
|
||||||
|
torch_version >= version.parse(target)
|
||||||
|
and version.parse(target + ".1") > torch_version
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_equal(target: str) -> bool:
|
||||||
|
"""Check if the installed torch version is == the target version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: a version string, like "2.6.0".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the condition meets.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return _is_torch_equal(target)
|
||||||
|
except Exception:
|
||||||
|
return Version(importlib.metadata.version("torch")) == Version(target)
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def _has_module(module_name: str) -> bool:
|
def _has_module(module_name: str) -> bool:
|
||||||
"""Return True if *module_name* can be found in the current environment.
|
"""Return True if *module_name* can be found in the current environment.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user