[Fix] [torch.compile] Improve UUID system for custom passes (#15249)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič 2025-03-23 21:54:07 -04:00 committed by GitHub
parent dccf535f8e
commit f622dbcf39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 132 additions and 91 deletions

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import pickle
import copy
import pytest
import torch
@ -10,32 +9,63 @@ from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import CompilationConfig
# dummy custom pass that doesn't inherit
def simple_callable(graph: torch.fx.Graph):
pass
callable_uuid = CallableInductorPass(simple_callable,
InductorPass.hash_source(__file__))
@pytest.mark.parametrize(
"works, callable",
[
(False, simple_callable),
(True, callable_uuid),
(True, CallableInductorPass(simple_callable)),
],
)
def test_pass_manager(works: bool, callable):
# Should fail to add directly to the pass manager
def test_bad_callable():
config = CompilationConfig().pass_config
pass_manager = PostGradPassManager()
pass_manager.configure(config)
# Try to add the callable to the pass manager
if works:
pass_manager.add(callable)
pickle.dumps(pass_manager)
else:
with pytest.raises(AssertionError):
pass_manager.add(callable)
with pytest.raises(AssertionError):
pass_manager.add(simple_callable) # noqa, type wrong on purpose
# Pass that inherits from InductorPass
class ProperPass(InductorPass):
def __call__(self, graph: torch.fx.graph.Graph) -> None:
pass
@pytest.mark.parametrize(
"callable",
[
ProperPass(),
# Can also wrap callables in CallableInductorPass for compliance
CallableInductorPass(simple_callable),
CallableInductorPass(simple_callable,
InductorPass.hash_source(__file__))
],
)
def test_pass_manager_uuid(callable):
config = CompilationConfig().pass_config
pass_manager = PostGradPassManager()
pass_manager.configure(config)
# Check that UUID is different if the same pass is added 2x
pass_manager.add(callable)
uuid1 = pass_manager.uuid()
pass_manager.add(callable)
uuid2 = pass_manager.uuid()
assert uuid1 != uuid2
# UUID should be the same as the original one,
# as we constructed in the same way.
pass_manager2 = PostGradPassManager()
pass_manager2.configure(config)
pass_manager2.add(callable)
assert uuid1 == pass_manager2.uuid()
# UUID should be different due to config change
config2 = copy.deepcopy(config)
config2.enable_fusion = not config2.enable_fusion
pass_manager3 = PostGradPassManager()
pass_manager3.configure(config2)
pass_manager3.add(callable)
assert uuid1 != pass_manager3.uuid()

View File

@ -1,26 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
import hashlib
import importlib.metadata
import inspect
import json
import types
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Dict, Optional, Union
import torch
from packaging.version import Version
from torch import fx
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
from torch._inductor.custom_graph_pass import CustomGraphPass
else:
# CustomGraphPass is not present in 2.5 or lower, import our version
from .torch25_custom_graph_pass import ( # noqa: yapf
Torch25CustomGraphPass as CustomGraphPass)
class InductorPass(ABC):
"""
General custom inductor pass interface.
"""
@abstractmethod
def __call__(self, graph: torch.fx.Graph):
"""
Execute the pass on the given graph.
"""
raise NotImplementedError
class InductorPass(CustomGraphPass):
"""
A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
"""
def uuid(self) -> Any:
"""
@ -48,7 +51,16 @@ class InductorPass(ABC):
else:
src_str = inspect.getsource(src.__class__)
hasher.update(src_str.encode("utf-8"))
return hasher.digest()
return hasher.hexdigest()
@staticmethod
def hash_dict(dict_: Dict[Any, Any]):
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
"""
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
class CallableInductorPass(InductorPass):
@ -61,25 +73,10 @@ class CallableInductorPass(InductorPass):
callable: Callable[[fx.Graph], None],
uuid: Optional[Any] = None):
self.callable = callable
if uuid is None:
uuid = InductorPass.hash_source(callable)
self._uuid = uuid
self._uuid = self.hash_source(callable) if uuid is None else uuid
def __call__(self, graph: torch.fx.Graph):
self.callable(graph)
def uuid(self) -> Any:
return self._uuid
def __getstate__(self):
"""
Pickling occurs in the Inductor code cache if a pass is not given to
the pass manager but is instead directly added to config as a pass.
See PostGradPassManager for more.
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
"""
return self._uuid
def __setstate__(self, state):
raise ValueError("Cannot unpickle CallableInductorPass")

View File

@ -1,8 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List
from typing import List
import torch
from torch import fx as fx
from vllm.config import CompilationConfig
@ -10,29 +9,18 @@ from vllm.logger import init_logger
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .inductor_pass import InductorPass
from .inductor_pass import CustomGraphPass, InductorPass
from .noop_elimination import NoOpEliminationPass
logger = init_logger(__name__)
class PlaceHolder:
pass
if torch.__version__ < "2.6":
Parent = PlaceHolder # type: ignore
else:
Parent = torch._inductor.custom_graph_pass.CustomGraphPass # type: ignore
class PostGradPassManager(Parent):
class PostGradPassManager(CustomGraphPass):
"""
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
It also supports pickling, which is used by the Inductor code cache.
TODO(torch==2.6), use CustomGraphPass
(torch._inductor.custom_graph_pass.CustomGraphPass)
It supports uuid for the Inductor code cache. That includes torch<2.6
support using pickling (in .inductor_pass.CustomGraphPass).
The order of the post-grad post-passes is:
1. passes (constructor parameter)
@ -67,27 +55,13 @@ class PostGradPassManager(Parent):
self.passes.append(pass_)
def uuid(self):
return self.__getstate__()
def __getstate__(self) -> Dict[str, List[Any]]:
"""
Custom pickling for the pass manager, as some passes cannot be pickled.
Pickling occurs because the pass manager is set as the value of
`config["post_grad_custom_post_pass"]` in the Inductor config.
The config is pickled to act as a key in the Inductor code cache.
Any other passes in the config are pickled as well.
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
The PostGradPassManager is set as a custom pass in the Inductor and
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
state = {"pass_config": self.pass_config.uuid(), "passes": []}
for pass_ in self.passes:
state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid())
return state
def __setstate__(self, state):
"""
Do not allow unpickling of the pass manager.
If this is needed in the future, it should properly pickle the passes.
"""
raise ValueError("Cannot unpickle PostGradPassManager")
return InductorPass.hash_dict(state)

View File

@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any, Optional
import torch
class Torch25CustomGraphPass(ABC): # noqa (redefinition)
"""
This class replaces CustomGraphPass from torch==2.6 when using torch<2.6.
It conforms to the 2.6 interface but also supports pickling, as that's what
the inductor code cache uses to determine the cache key before 2.6.
(in 2.6 and above, uuid() is used.)
Subclasses can just "pretend" that uuid is used.
"""
@abstractmethod
def __call__(self, graph: torch.fx.graph.Graph) -> None:
"""
Implementation of the custom pass.
"""
@abstractmethod
def uuid(self) -> Optional[Any]:
"""
Return an ID to uniquely identify your custom pass implementation.
Return None to skip inductor code caching entirely.
"""
def __getstate__(self):
"""
Pickling is used instead of uuid() in torch<2.6. Just return uuid()
to enable subclasses to only have to implement uuid.
"""
return self.uuid()
def __setstate__(self, state):
raise ValueError("Cannot unpickle CustomGraphPass because pickling"
" is used for cache key uuid. Use torch>=2.6 with"
" native uuid support for custom passes.")

View File

@ -4,6 +4,7 @@ import ast
import copy
import enum
import hashlib
import importlib.metadata
import json
import sys
import warnings
@ -17,6 +18,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, Union)
import torch
from packaging.version import Version
from pydantic import BaseModel, Field, PrivateAttr
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
@ -52,8 +54,6 @@ if TYPE_CHECKING:
else:
QuantizationConfig = None
from packaging.version import Version
logger = init_logger(__name__)
# This value is chosen to have a balance between ITL and TTFT. Note it is
@ -3088,8 +3088,7 @@ class CompilationConfig(BaseModel):
compilation.
"""
dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).digest()
return InductorPass.hash_dict(dict_)
def model_post_init(self, __context: Any) -> None:
if not self.enable_noop and self.enable_fusion:
@ -3178,7 +3177,7 @@ class CompilationConfig(BaseModel):
# and it is not yet a priority. RFC here:
# https://github.com/vllm-project/vllm/issues/14703
if Version(torch.__version__) >= Version("2.6"):
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
KEY = 'enable_auto_functionalized_v2'
if KEY not in self.inductor_compile_config:
self.inductor_compile_config[KEY] = False