diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py index ecd71ee1f3f66..a4da5b933e159 100644 --- a/docs/mkdocs/hooks/generate_argparse.py +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -22,6 +22,11 @@ sys.modules["vllm._C"] = MagicMock() class PydanticMagicMock(MagicMock): """`MagicMock` that's able to generate pydantic-core schemas.""" + def __init__(self, *args, **kwargs): + name = kwargs.pop("name", None) + super().__init__(*args, **kwargs) + self.__spec__ = importlib.machinery.ModuleSpec(name, None) + def __get_pydantic_core_schema__(self, source_type, handler): return core_schema.any_schema() @@ -42,7 +47,9 @@ def auto_mock(module, attr, max_mocks=50): raise e except ModuleNotFoundError as e: logger.info("Mocking %s for argparse doc generation", e.name) - sys.modules[e.name] = PydanticMagicMock() + sys.modules[e.name] = PydanticMagicMock(name=e.name) + except Exception as e: + logger.warning("Failed to import %s.%s: %s", module, attr, e) raise ImportError( f"Failed to import {module}.{attr} after mocking {max_mocks} imports" diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index d1f741479acf4..246239b87d5fe 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -20,6 +20,7 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 @@ -193,9 +194,8 @@ def run_model( @pytest.mark.parametrize("use_inductor_graph_partition", [False, True]) def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool): - if use_inductor_graph_partition: - # FIXME(luka/boyuan): this currently fails - pytest.skip("Inductor graph partition not supported with multi-graph") + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") outputs = [] diff --git a/vllm/env_override.py b/vllm/env_override.py index eb51dee1cf033..f4ac48584cb7e 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -3,9 +3,9 @@ import os import torch -from packaging import version from vllm.logger import init_logger +from vllm.utils import is_torch_equal logger = init_logger(__name__) @@ -23,6 +23,72 @@ os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" # see https://github.com/vllm-project/vllm/issues/10619 torch._inductor.config.compile_threads = 1 +# =================================================== +# torch 2.9 Inductor PythonWrapperCodegen monkeypatch +# =================================================== +# This change monkeypatches memory_plan_reuse in pytorch 2.9.0 to work around +# a test failure for test_multi_graph_piecewise_compile_outputs_equal. +# For more context, see https://github.com/pytorch/pytorch/pull/165514. + + +def memory_plan_reuse_patched(self): + import torch._inductor.ir as ir + from torch._inductor.codegen.wrapper import ( + EnterSubgraphLine, + ExitSubgraphLine, + MemoryPlanningLine, + MemoryPlanningState, + SubgraphPythonWrapperCodegen, + ) + from torch._inductor.virtualized import V + + def get_output_names(graph_outputs) -> list[str]: + import itertools + + names = [] + shape_counter = itertools.count(0) + none_counter = itertools.count(0) + for node in graph_outputs: + if isinstance(node, ir.NoneAsConstantBuffer): + names.append(f"{V.graph.name}_none{next(none_counter)}") + elif isinstance(node, ir.ShapeAsConstantBuffer): + names.append(f"{V.graph.name}_shape{next(shape_counter)}") + else: + names.append(node.get_name()) + return names + + if ( + isinstance(V.graph.wrapper_code, SubgraphPythonWrapperCodegen) + and V.graph.wrapper_code.partition_signatures is not None + ): + out_names = get_output_names( + V.graph.wrapper_code.partition_signatures.output_nodes + ) + else: + out_names = V.graph.get_output_names() + + while ( + self.lines + and isinstance(self.lines[-1], MemoryPlanningLine) + and self.lines[-1].node.name not in out_names # type: ignore[attr-defined] + ): + # these lines will be pointless + self.lines.pop() + + # codegen allocations in two passes + planning_states = [MemoryPlanningState()] + past_planning_states = [] + for i in range(len(self.lines)): + line = self.lines[i] + if isinstance(line, MemoryPlanningLine): + self.lines[i] = line.plan(planning_states[-1]) + elif isinstance(line, EnterSubgraphLine): + planning_states.append(MemoryPlanningState()) + elif isinstance(line, ExitSubgraphLine): + past_planning_states.append(planning_states.pop()) + past_planning_states.append(planning_states.pop()) + assert len(planning_states) == 0 + # ======================================== # torch 2.9 Inductor Scheduler monkeypatch @@ -135,7 +201,9 @@ def _update_scheduler_patched(self) -> None: self.scheduler = Scheduler(self.operations) -if version.parse(str(torch.__version__)) == version.parse("2.9.0"): +if is_torch_equal("2.9.0"): + from torch._inductor.codegen.wrapper import PythonWrapperCodegen from torch._inductor.graph import GraphLowering + PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched GraphLowering._update_scheduler = _update_scheduler_patched diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index ad0918a6ed8d0..1f01cbeda9686 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3263,6 +3263,33 @@ def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: return torch_version >= version.parse(target) +def _is_torch_equal(target: str) -> bool: + assert target.count(".") == 2 + torch_version = str(torch.__version__) + torch_version = version.parse(torch_version) + # torch version is like "2.6.0.dev20240101" or "2.6.0.dev20240101+cpu" + # or "2.6.0+cu128" but never "2.6.0.1" + return ( + torch_version >= version.parse(target) + and version.parse(target + ".1") > torch_version + ) + + +def is_torch_equal(target: str) -> bool: + """Check if the installed torch version is == the target version. + + Args: + target: a version string, like "2.6.0". + + Returns: + Whether the condition meets. + """ + try: + return _is_torch_equal(target) + except Exception: + return Version(importlib.metadata.version("torch")) == Version(target) + + @cache def _has_module(module_name: str) -> bool: """Return True if *module_name* can be found in the current environment.