[Misc] Improve type annotations for support_torch_compile (#10763)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-01 09:48:35 +08:00 committed by GitHub
parent 133707123e
commit f877a7d12a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,8 @@
import inspect
from typing import Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
import torch
import torch.nn as nn
from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
@ -12,10 +13,27 @@ from vllm.utils import supports_dynamo
logger = init_logger(__name__)
_T = TypeVar("_T", bound=type[nn.Module])
@overload
def support_torch_compile(
*,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
) -> Callable[[_T], _T]:
...
@overload
def support_torch_compile(cls: _T) -> _T:
...
def support_torch_compile(
cls: Optional[type] = None,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None):
cls: Optional[_T] = None,
*,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
) -> Union[Callable[[_T], _T], _T]:
"""
A decorator to add support for compiling the forward method of a class.
@ -66,7 +84,7 @@ def support_torch_compile(
computation graph.
"""
def cls_decorator_helper(cls: type):
def cls_decorator_helper(cls: _T) -> _T:
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
# to avoid too much indentation for `_support_torch_compile``
if not hasattr(cls, 'forward'):
@ -105,8 +123,10 @@ def support_torch_compile(
return cls_decorator_helper
def _support_torch_compile(cls: type,
dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
def _support_torch_compile(
cls: _T,
dynamic_arg_dims: Dict[str, Union[int, List[int]]],
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
"""
@ -119,7 +139,7 @@ def _support_torch_compile(cls: type,
# other than TorchCompileWrapperWithCustomDispatcher
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
old_init = cls.__init__ # type: ignore
old_init = cls.__init__
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
@ -135,7 +155,7 @@ def _support_torch_compile(cls: type,
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)
cls.__init__ = __init__ # type: ignore
cls.__init__ = __init__
def __call__(self, *args, **kwargs):
# torch.compiler.is_compiling() means we are inside the compilation
@ -180,5 +200,5 @@ def _support_torch_compile(cls: type,
model_output = self.forward(*args, **kwargs)
return model_output
cls.__call__ = __call__ # type: ignore
cls.__call__ = __call__
return cls