mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 03:25:02 +08:00
[Misc] Improve type annotations for support_torch_compile (#10763)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
133707123e
commit
f877a7d12a
@ -1,7 +1,8 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||||
@ -12,10 +13,27 @@ from vllm.utils import supports_dynamo
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_T = TypeVar("_T", bound=type[nn.Module])
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def support_torch_compile(
|
||||||
|
*,
|
||||||
|
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
|
||||||
|
) -> Callable[[_T], _T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def support_torch_compile(cls: _T) -> _T:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
def support_torch_compile(
|
def support_torch_compile(
|
||||||
cls: Optional[type] = None,
|
cls: Optional[_T] = None,
|
||||||
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None):
|
*,
|
||||||
|
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
|
||||||
|
) -> Union[Callable[[_T], _T], _T]:
|
||||||
"""
|
"""
|
||||||
A decorator to add support for compiling the forward method of a class.
|
A decorator to add support for compiling the forward method of a class.
|
||||||
|
|
||||||
@ -66,7 +84,7 @@ def support_torch_compile(
|
|||||||
computation graph.
|
computation graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def cls_decorator_helper(cls: type):
|
def cls_decorator_helper(cls: _T) -> _T:
|
||||||
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
|
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
|
||||||
# to avoid too much indentation for `_support_torch_compile``
|
# to avoid too much indentation for `_support_torch_compile``
|
||||||
if not hasattr(cls, 'forward'):
|
if not hasattr(cls, 'forward'):
|
||||||
@ -105,8 +123,10 @@ def support_torch_compile(
|
|||||||
return cls_decorator_helper
|
return cls_decorator_helper
|
||||||
|
|
||||||
|
|
||||||
def _support_torch_compile(cls: type,
|
def _support_torch_compile(
|
||||||
dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
|
cls: _T,
|
||||||
|
dynamic_arg_dims: Dict[str, Union[int, List[int]]],
|
||||||
|
) -> _T:
|
||||||
"""
|
"""
|
||||||
A decorator to add support for compiling the forward method of a class.
|
A decorator to add support for compiling the forward method of a class.
|
||||||
"""
|
"""
|
||||||
@ -119,7 +139,7 @@ def _support_torch_compile(cls: type,
|
|||||||
# other than TorchCompileWrapperWithCustomDispatcher
|
# other than TorchCompileWrapperWithCustomDispatcher
|
||||||
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
|
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
|
||||||
|
|
||||||
old_init = cls.__init__ # type: ignore
|
old_init = cls.__init__
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
|
||||||
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||||
@ -135,7 +155,7 @@ def _support_torch_compile(cls: type,
|
|||||||
TorchCompileWrapperWithCustomDispatcher.__init__(
|
TorchCompileWrapperWithCustomDispatcher.__init__(
|
||||||
self, compilation_level=vllm_config.compilation_config.level)
|
self, compilation_level=vllm_config.compilation_config.level)
|
||||||
|
|
||||||
cls.__init__ = __init__ # type: ignore
|
cls.__init__ = __init__
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
# torch.compiler.is_compiling() means we are inside the compilation
|
# torch.compiler.is_compiling() means we are inside the compilation
|
||||||
@ -180,5 +200,5 @@ def _support_torch_compile(cls: type,
|
|||||||
model_output = self.forward(*args, **kwargs)
|
model_output = self.forward(*args, **kwargs)
|
||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
cls.__call__ = __call__ # type: ignore
|
cls.__call__ = __call__
|
||||||
return cls
|
return cls
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user