mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 07:01:50 +08:00
117 lines
3.4 KiB
Python
117 lines
3.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import hashlib
|
|
import inspect
|
|
import json
|
|
import types
|
|
from contextlib import contextmanager
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
import torch
|
|
from torch import fx
|
|
|
|
from vllm.utils import is_torch_equal_or_newer
|
|
|
|
if is_torch_equal_or_newer("2.6"):
|
|
from torch._inductor.custom_graph_pass import CustomGraphPass
|
|
else:
|
|
# CustomGraphPass is not present in 2.5 or lower, import our version
|
|
from .torch25_custom_graph_pass import ( # noqa: E501
|
|
Torch25CustomGraphPass as CustomGraphPass)
|
|
|
|
_pass_context = None
|
|
|
|
|
|
class PassContext:
|
|
|
|
def __init__(self, runtime_shape: Optional[int]):
|
|
self.runtime_shape = runtime_shape
|
|
|
|
|
|
def get_pass_context() -> PassContext:
|
|
"""Get the current pass context."""
|
|
assert _pass_context is not None
|
|
return _pass_context
|
|
|
|
|
|
@contextmanager
|
|
def pass_context(runtime_shape: Optional[int]):
|
|
"""A context manager that stores the current pass context,
|
|
usually it is a list of sizes to specialize.
|
|
"""
|
|
global _pass_context
|
|
prev_context = _pass_context
|
|
_pass_context = PassContext(runtime_shape)
|
|
try:
|
|
yield
|
|
finally:
|
|
_pass_context = prev_context
|
|
|
|
|
|
class InductorPass(CustomGraphPass):
|
|
"""
|
|
A custom graph pass that uses a hash of its source as the UUID.
|
|
This is defined as a convenience and should work in most cases.
|
|
"""
|
|
|
|
def uuid(self) -> Any:
|
|
"""
|
|
Provide a unique identifier for the pass, used in Inductor code cache.
|
|
This should depend on the pass implementation, so that changes to the
|
|
pass result in recompilation.
|
|
By default, the object source is hashed.
|
|
"""
|
|
return InductorPass.hash_source(self)
|
|
|
|
@staticmethod
|
|
def hash_source(*srcs: Union[str, Any]):
|
|
"""
|
|
Utility method to hash the sources of functions or objects.
|
|
:param srcs: strings or objects to add to the hash.
|
|
Objects and functions have their source inspected.
|
|
:return:
|
|
"""
|
|
hasher = hashlib.sha256()
|
|
for src in srcs:
|
|
if isinstance(src, str):
|
|
src_str = src
|
|
elif isinstance(src, (types.FunctionType, type)):
|
|
src_str = inspect.getsource(src)
|
|
else:
|
|
# object instance
|
|
src_str = inspect.getsource(src.__class__)
|
|
hasher.update(src_str.encode("utf-8"))
|
|
return hasher.hexdigest()
|
|
|
|
@staticmethod
|
|
def hash_dict(dict_: dict[Any, Any]):
|
|
"""
|
|
Utility method to hash a dictionary, can alternatively be used for uuid.
|
|
:return: A sha256 hash of the json rep of the dictionary.
|
|
"""
|
|
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
|
return hashlib.sha256(encoded).hexdigest()
|
|
|
|
def is_applicable_for_shape(self, shape: Optional[int]):
|
|
return True
|
|
|
|
|
|
class CallableInductorPass(InductorPass):
|
|
"""
|
|
This class is a wrapper for a callable that automatically provides an
|
|
implementation of the UUID.
|
|
"""
|
|
|
|
def __init__(self,
|
|
callable: Callable[[fx.Graph], None],
|
|
uuid: Optional[Any] = None):
|
|
self.callable = callable
|
|
self._uuid = self.hash_source(callable) if uuid is None else uuid
|
|
|
|
def __call__(self, graph: torch.fx.Graph):
|
|
self.callable(graph)
|
|
|
|
def uuid(self) -> Any:
|
|
return self._uuid
|