mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 01:05:35 +08:00
[Misc] Improve type annotations for support_torch_compile (#10763)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
133707123e
commit
f877a7d12a
@ -1,7 +1,8 @@
|
||||
import inspect
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
@ -12,10 +13,27 @@ from vllm.utils import supports_dynamo
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
|
||||
) -> Callable[[_T], _T]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(cls: _T) -> _T:
|
||||
...
|
||||
|
||||
|
||||
def support_torch_compile(
|
||||
cls: Optional[type] = None,
|
||||
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None):
|
||||
cls: Optional[_T] = None,
|
||||
*,
|
||||
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
|
||||
) -> Union[Callable[[_T], _T], _T]:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
|
||||
@ -66,7 +84,7 @@ def support_torch_compile(
|
||||
computation graph.
|
||||
"""
|
||||
|
||||
def cls_decorator_helper(cls: type):
|
||||
def cls_decorator_helper(cls: _T) -> _T:
|
||||
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
|
||||
# to avoid too much indentation for `_support_torch_compile``
|
||||
if not hasattr(cls, 'forward'):
|
||||
@ -105,8 +123,10 @@ def support_torch_compile(
|
||||
return cls_decorator_helper
|
||||
|
||||
|
||||
def _support_torch_compile(cls: type,
|
||||
dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
|
||||
def _support_torch_compile(
|
||||
cls: _T,
|
||||
dynamic_arg_dims: Dict[str, Union[int, List[int]]],
|
||||
) -> _T:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
"""
|
||||
@ -119,7 +139,7 @@ def _support_torch_compile(cls: type,
|
||||
# other than TorchCompileWrapperWithCustomDispatcher
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
|
||||
|
||||
old_init = cls.__init__ # type: ignore
|
||||
old_init = cls.__init__
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
|
||||
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
@ -135,7 +155,7 @@ def _support_torch_compile(cls: type,
|
||||
TorchCompileWrapperWithCustomDispatcher.__init__(
|
||||
self, compilation_level=vllm_config.compilation_config.level)
|
||||
|
||||
cls.__init__ = __init__ # type: ignore
|
||||
cls.__init__ = __init__
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
@ -180,5 +200,5 @@ def _support_torch_compile(cls: type,
|
||||
model_output = self.forward(*args, **kwargs)
|
||||
return model_output
|
||||
|
||||
cls.__call__ = __call__ # type: ignore
|
||||
cls.__call__ = __call__
|
||||
return cls
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user