mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 19:25:01 +08:00
[torch.compile] auto infer dynamic_arg_dims from type annotation (#9589)
This commit is contained in:
parent
cd5601ac37
commit
17c79f3c36
@ -1,24 +1,58 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.levels import CompilationLevel
|
from vllm.compilation.levels import CompilationLevel
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import supports_dynamo
|
from vllm.utils import supports_dynamo
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
|
|
||||||
|
def support_torch_compile(
|
||||||
|
cls: Optional[type] = None,
|
||||||
|
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None):
|
||||||
"""
|
"""
|
||||||
A decorator to add support for compiling the forward method of a class.
|
A decorator to add support for compiling the forward method of a class.
|
||||||
|
|
||||||
|
Usage 1: use directly as a decorator without arguments:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@support_torch_compile
|
||||||
|
class MyModel(nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
Usage 2: use as a decorator with arguments:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
|
||||||
|
class MyModel(nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
|
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
|
||||||
dimensions of the argument. The dynamic dimensions can be either a single
|
dimensions of the argument. The dynamic dimensions can be either a single
|
||||||
integer or a list of integers.
|
integer or a list of integers.
|
||||||
|
|
||||||
Depending on the value of arguments:
|
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
|
||||||
|
of the `forward` method, based on the following default rules:
|
||||||
|
|
||||||
|
- if the argument is annotated as `torch.Tensor` or
|
||||||
|
`Optional[torch.Tensor]`, the first dimension will be
|
||||||
|
marked as dynamic.
|
||||||
|
- if the argument is annotated as `IntermediateTensors`, the first
|
||||||
|
dimension of all the tensors in the intermediate tensors
|
||||||
|
will be marked as dynamic.
|
||||||
|
|
||||||
|
During runtime, when we actually mark dimensions of tensors,
|
||||||
|
it depends on the value of arguments:
|
||||||
|
|
||||||
- if it is a single integer, the corresponding dimension of the argument
|
- if it is a single integer, the corresponding dimension of the argument
|
||||||
will be marked as dynamic.
|
will be marked as dynamic.
|
||||||
@ -38,11 +72,35 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
|
|||||||
if not hasattr(cls, 'forward'):
|
if not hasattr(cls, 'forward'):
|
||||||
raise TypeError("decorated class should have a forward method.")
|
raise TypeError("decorated class should have a forward method.")
|
||||||
sig = inspect.signature(cls.forward)
|
sig = inspect.signature(cls.forward)
|
||||||
for k in dynamic_arg_dims:
|
inferred_dynamic_arg_dims = dynamic_arg_dims
|
||||||
|
if inferred_dynamic_arg_dims is None:
|
||||||
|
inferred_dynamic_arg_dims = {}
|
||||||
|
for k, v in sig.parameters.items():
|
||||||
|
if v.annotation in [
|
||||||
|
torch.Tensor, Optional[torch.Tensor],
|
||||||
|
IntermediateTensors, Optional[IntermediateTensors]
|
||||||
|
]:
|
||||||
|
inferred_dynamic_arg_dims[k] = 0
|
||||||
|
|
||||||
|
logger.debug(("Inferred dynamic dimensions for "
|
||||||
|
"forward method of %s: %s"), cls,
|
||||||
|
list(inferred_dynamic_arg_dims.keys()))
|
||||||
|
|
||||||
|
if len(inferred_dynamic_arg_dims) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"No dynamic dimensions found in the forward method of "
|
||||||
|
f"{cls}. Please provide dynamic_arg_dims explicitly.")
|
||||||
|
|
||||||
|
for k in inferred_dynamic_arg_dims:
|
||||||
if k not in sig.parameters:
|
if k not in sig.parameters:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Argument {k} not found in the forward method of {cls}")
|
f"Argument {k} not found in the forward method of {cls}")
|
||||||
return _support_torch_compile(cls, dynamic_arg_dims)
|
return _support_torch_compile(cls, inferred_dynamic_arg_dims)
|
||||||
|
|
||||||
|
if cls is not None:
|
||||||
|
# use `support_torch_compile` as a decorator without arguments
|
||||||
|
assert isinstance(cls, type)
|
||||||
|
return cls_decorator_helper(cls)
|
||||||
|
|
||||||
return cls_decorator_helper
|
return cls_decorator_helper
|
||||||
|
|
||||||
|
|||||||
@ -241,13 +241,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile(
|
@support_torch_compile
|
||||||
dynamic_arg_dims={
|
|
||||||
"input_ids": 0,
|
|
||||||
"positions": 0,
|
|
||||||
"inputs_embeds": 0,
|
|
||||||
"intermediate_tensors": 0,
|
|
||||||
})
|
|
||||||
class Gemma2Model(nn.Module):
|
class Gemma2Model(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -268,13 +268,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile(
|
@support_torch_compile
|
||||||
dynamic_arg_dims={
|
|
||||||
"input_ids": 0,
|
|
||||||
"positions": 0,
|
|
||||||
"inputs_embeds": 0,
|
|
||||||
"intermediate_tensors": 0,
|
|
||||||
})
|
|
||||||
class LlamaModel(nn.Module):
|
class LlamaModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user