diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index bdbd104f3b23..2c1ee4dc7480 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 - -import pickle +import copy import pytest import torch @@ -10,32 +9,63 @@ from vllm.compilation.pass_manager import PostGradPassManager from vllm.config import CompilationConfig +# dummy custom pass that doesn't inherit def simple_callable(graph: torch.fx.Graph): pass -callable_uuid = CallableInductorPass(simple_callable, - InductorPass.hash_source(__file__)) - - -@pytest.mark.parametrize( - "works, callable", - [ - (False, simple_callable), - (True, callable_uuid), - (True, CallableInductorPass(simple_callable)), - ], -) -def test_pass_manager(works: bool, callable): +# Should fail to add directly to the pass manager +def test_bad_callable(): config = CompilationConfig().pass_config pass_manager = PostGradPassManager() pass_manager.configure(config) - # Try to add the callable to the pass manager - if works: - pass_manager.add(callable) - pickle.dumps(pass_manager) - else: - with pytest.raises(AssertionError): - pass_manager.add(callable) + with pytest.raises(AssertionError): + pass_manager.add(simple_callable) # noqa, type wrong on purpose + + +# Pass that inherits from InductorPass +class ProperPass(InductorPass): + + def __call__(self, graph: torch.fx.graph.Graph) -> None: + pass + + +@pytest.mark.parametrize( + "callable", + [ + ProperPass(), + # Can also wrap callables in CallableInductorPass for compliance + CallableInductorPass(simple_callable), + CallableInductorPass(simple_callable, + InductorPass.hash_source(__file__)) + ], +) +def test_pass_manager_uuid(callable): + config = CompilationConfig().pass_config + + pass_manager = PostGradPassManager() + pass_manager.configure(config) + + # Check that UUID is different if the same pass is added 2x + pass_manager.add(callable) + uuid1 = pass_manager.uuid() + pass_manager.add(callable) + uuid2 = pass_manager.uuid() + assert uuid1 != uuid2 + + # UUID should be the same as the original one, + # as we constructed in the same way. + pass_manager2 = PostGradPassManager() + pass_manager2.configure(config) + pass_manager2.add(callable) + assert uuid1 == pass_manager2.uuid() + + # UUID should be different due to config change + config2 = copy.deepcopy(config) + config2.enable_fusion = not config2.enable_fusion + pass_manager3 = PostGradPassManager() + pass_manager3.configure(config2) + pass_manager3.add(callable) + assert uuid1 != pass_manager3.uuid() diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 1fea927aac31..08dd8c8e1ea2 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -1,26 +1,29 @@ # SPDX-License-Identifier: Apache-2.0 import hashlib +import importlib.metadata import inspect +import json import types -from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import torch +from packaging.version import Version from torch import fx +if Version(importlib.metadata.version('torch')) >= Version("2.6"): + from torch._inductor.custom_graph_pass import CustomGraphPass +else: + # CustomGraphPass is not present in 2.5 or lower, import our version + from .torch25_custom_graph_pass import ( # noqa: yapf + Torch25CustomGraphPass as CustomGraphPass) -class InductorPass(ABC): - """ - General custom inductor pass interface. - """ - @abstractmethod - def __call__(self, graph: torch.fx.Graph): - """ - Execute the pass on the given graph. - """ - raise NotImplementedError +class InductorPass(CustomGraphPass): + """ + A custom graph pass that uses a hash of its source as the UUID. + This is defined as a convenience and should work in most cases. + """ def uuid(self) -> Any: """ @@ -48,7 +51,16 @@ class InductorPass(ABC): else: src_str = inspect.getsource(src.__class__) hasher.update(src_str.encode("utf-8")) - return hasher.digest() + return hasher.hexdigest() + + @staticmethod + def hash_dict(dict_: Dict[Any, Any]): + """ + Utility method to hash a dictionary, can alternatively be used for uuid. + :return: A sha256 hash of the json rep of the dictionary. + """ + encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() class CallableInductorPass(InductorPass): @@ -61,25 +73,10 @@ class CallableInductorPass(InductorPass): callable: Callable[[fx.Graph], None], uuid: Optional[Any] = None): self.callable = callable - if uuid is None: - uuid = InductorPass.hash_source(callable) - self._uuid = uuid + self._uuid = self.hash_source(callable) if uuid is None else uuid def __call__(self, graph: torch.fx.Graph): self.callable(graph) def uuid(self) -> Any: return self._uuid - - def __getstate__(self): - """ - Pickling occurs in the Inductor code cache if a pass is not given to - the pass manager but is instead directly added to config as a pass. - See PostGradPassManager for more. - - TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead. - """ - return self._uuid - - def __setstate__(self, state): - raise ValueError("Cannot unpickle CallableInductorPass") diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index b012346c353e..530a88b2b09a 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List +from typing import List -import torch from torch import fx as fx from vllm.config import CompilationConfig @@ -10,29 +9,18 @@ from vllm.logger import init_logger from .fix_functionalization import FixFunctionalizationPass from .fusion import FusionPass -from .inductor_pass import InductorPass +from .inductor_pass import CustomGraphPass, InductorPass from .noop_elimination import NoOpEliminationPass logger = init_logger(__name__) -class PlaceHolder: - pass - - -if torch.__version__ < "2.6": - Parent = PlaceHolder # type: ignore -else: - Parent = torch._inductor.custom_graph_pass.CustomGraphPass # type: ignore - - -class PostGradPassManager(Parent): +class PostGradPassManager(CustomGraphPass): """ The pass manager for post-grad passes. It handles configuration, adding custom passes, and running passes. - It also supports pickling, which is used by the Inductor code cache. - TODO(torch==2.6), use CustomGraphPass - (torch._inductor.custom_graph_pass.CustomGraphPass) + It supports uuid for the Inductor code cache. That includes torch<2.6 + support using pickling (in .inductor_pass.CustomGraphPass). The order of the post-grad post-passes is: 1. passes (constructor parameter) @@ -67,27 +55,13 @@ class PostGradPassManager(Parent): self.passes.append(pass_) def uuid(self): - return self.__getstate__() - - def __getstate__(self) -> Dict[str, List[Any]]: """ - Custom pickling for the pass manager, as some passes cannot be pickled. - Pickling occurs because the pass manager is set as the value of - `config["post_grad_custom_post_pass"]` in the Inductor config. - The config is pickled to act as a key in the Inductor code cache. - Any other passes in the config are pickled as well. - - TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead. + The PostGradPassManager is set as a custom pass in the Inductor and + affects compilation caching. Its uuid depends on the UUIDs of all + dependent passes and the pass config. See InductorPass for more info. """ state = {"pass_config": self.pass_config.uuid(), "passes": []} for pass_ in self.passes: state["passes"].append(pass_.uuid()) state["passes"].append(self.fix_functionalization.uuid()) - return state - - def __setstate__(self, state): - """ - Do not allow unpickling of the pass manager. - If this is needed in the future, it should properly pickle the passes. - """ - raise ValueError("Cannot unpickle PostGradPassManager") + return InductorPass.hash_dict(state) diff --git a/vllm/compilation/torch25_custom_graph_pass.py b/vllm/compilation/torch25_custom_graph_pass.py new file mode 100644 index 000000000000..4b881d0b6f2d --- /dev/null +++ b/vllm/compilation/torch25_custom_graph_pass.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from typing import Any, Optional + +import torch + + +class Torch25CustomGraphPass(ABC): # noqa (redefinition) + """ + This class replaces CustomGraphPass from torch==2.6 when using torch<2.6. + It conforms to the 2.6 interface but also supports pickling, as that's what + the inductor code cache uses to determine the cache key before 2.6. + (in 2.6 and above, uuid() is used.) + + Subclasses can just "pretend" that uuid is used. + """ + + @abstractmethod + def __call__(self, graph: torch.fx.graph.Graph) -> None: + """ + Implementation of the custom pass. + """ + + @abstractmethod + def uuid(self) -> Optional[Any]: + """ + Return an ID to uniquely identify your custom pass implementation. + Return None to skip inductor code caching entirely. + """ + + def __getstate__(self): + """ + Pickling is used instead of uuid() in torch<2.6. Just return uuid() + to enable subclasses to only have to implement uuid. + """ + return self.uuid() + + def __setstate__(self, state): + raise ValueError("Cannot unpickle CustomGraphPass because pickling" + " is used for cache key uuid. Use torch>=2.6 with" + " native uuid support for custom passes.") diff --git a/vllm/config.py b/vllm/config.py index e486889b5855..2fd0db4ee942 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4,6 +4,7 @@ import ast import copy import enum import hashlib +import importlib.metadata import json import sys import warnings @@ -17,6 +18,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, Optional, Protocol, Union) import torch +from packaging.version import Version from pydantic import BaseModel, Field, PrivateAttr from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig @@ -52,8 +54,6 @@ if TYPE_CHECKING: else: QuantizationConfig = None -from packaging.version import Version - logger = init_logger(__name__) # This value is chosen to have a balance between ITL and TTFT. Note it is @@ -3088,8 +3088,7 @@ class CompilationConfig(BaseModel): compilation. """ dict_ = self.model_dump(include={"enable_fusion", "enable_noop"}) - encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") - return hashlib.sha256(encoded).digest() + return InductorPass.hash_dict(dict_) def model_post_init(self, __context: Any) -> None: if not self.enable_noop and self.enable_fusion: @@ -3178,7 +3177,7 @@ class CompilationConfig(BaseModel): # and it is not yet a priority. RFC here: # https://github.com/vllm-project/vllm/issues/14703 - if Version(torch.__version__) >= Version("2.6"): + if Version(importlib.metadata.version('torch')) >= Version("2.6"): KEY = 'enable_auto_functionalized_v2' if KEY not in self.inductor_compile_config: self.inductor_compile_config[KEY] = False