[Misc] Improve type annotations for support_torch_compile (#10763)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-01 09:48:35 +08:00 committed by GitHub
parent 133707123e
commit f877a7d12a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,8 @@
import inspect import inspect
from typing import Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
import torch import torch
import torch.nn as nn
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
@ -12,10 +13,27 @@ from vllm.utils import supports_dynamo
logger = init_logger(__name__) logger = init_logger(__name__)
_T = TypeVar("_T", bound=type[nn.Module])
@overload
def support_torch_compile(
*,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
) -> Callable[[_T], _T]:
...
@overload
def support_torch_compile(cls: _T) -> _T:
...
def support_torch_compile( def support_torch_compile(
cls: Optional[type] = None, cls: Optional[_T] = None,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None): *,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
) -> Union[Callable[[_T], _T], _T]:
""" """
A decorator to add support for compiling the forward method of a class. A decorator to add support for compiling the forward method of a class.
@ -66,7 +84,7 @@ def support_torch_compile(
computation graph. computation graph.
""" """
def cls_decorator_helper(cls: type): def cls_decorator_helper(cls: _T) -> _T:
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile`` # helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
# to avoid too much indentation for `_support_torch_compile`` # to avoid too much indentation for `_support_torch_compile``
if not hasattr(cls, 'forward'): if not hasattr(cls, 'forward'):
@ -105,8 +123,10 @@ def support_torch_compile(
return cls_decorator_helper return cls_decorator_helper
def _support_torch_compile(cls: type, def _support_torch_compile(
dynamic_arg_dims: Dict[str, Union[int, List[int]]]): cls: _T,
dynamic_arg_dims: Dict[str, Union[int, List[int]]],
) -> _T:
""" """
A decorator to add support for compiling the forward method of a class. A decorator to add support for compiling the forward method of a class.
""" """
@ -119,7 +139,7 @@ def _support_torch_compile(cls: type,
# other than TorchCompileWrapperWithCustomDispatcher # other than TorchCompileWrapperWithCustomDispatcher
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
old_init = cls.__init__ # type: ignore old_init = cls.__init__
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
@ -135,7 +155,7 @@ def _support_torch_compile(cls: type,
TorchCompileWrapperWithCustomDispatcher.__init__( TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level) self, compilation_level=vllm_config.compilation_config.level)
cls.__init__ = __init__ # type: ignore cls.__init__ = __init__
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# torch.compiler.is_compiling() means we are inside the compilation # torch.compiler.is_compiling() means we are inside the compilation
@ -180,5 +200,5 @@ def _support_torch_compile(cls: type,
model_output = self.forward(*args, **kwargs) model_output = self.forward(*args, **kwargs)
return model_output return model_output
cls.__call__ = __call__ # type: ignore cls.__call__ = __call__
return cls return cls