mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:24:56 +08:00
Add option to use unbacked, and backed size obl dynamic shapes for more sounds compilation. (#26199)
Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
parent
f716a15372
commit
7a228b5305
@ -151,6 +151,76 @@ To avoid this, please either:
|
||||
2. wrap the branching logic into a custom operator. TorchDynamo does not
|
||||
trace into custom operators.
|
||||
|
||||
## Debugging constraint violations and dynamic shapes guards issues
|
||||
|
||||
Dynamic-shape guards are a specific category of Dynamo guards. They are constraints that `torch.compile`
|
||||
attaches to dynamic dimensions (e.g., `seq_len`) to ensure the compiled artifact remains valid.
|
||||
These guards typically appear when framework code, custom passes, or user code branches based on
|
||||
dynamic shape values.
|
||||
|
||||
**Example:**
|
||||
|
||||
```python
|
||||
if x > 10:
|
||||
# path A
|
||||
else:
|
||||
# path B
|
||||
```
|
||||
|
||||
This creates a guard `x > 10` or `x <= 10` depending on which path was traced.
|
||||
|
||||
**vLLM's Assumption:**
|
||||
vLLM assumes that all guards added by torch.compile are safe to drop and will not
|
||||
constrain the compiled graph to specific input shapes. When this assumption is violated,
|
||||
it can cause issues that users need to debug.
|
||||
Some side effects that indicates this assumption is violated are runtime errors
|
||||
or `ConstraintViolationErrors`.
|
||||
|
||||
A `ConstraintViolationErrors` will be thrown if a dynamic shape gets constrained to
|
||||
a single value. If you encounter a constraint violation error or suspect that a dynamic
|
||||
shapes guard is being added incorrectly, you can use stricter dynamic shape modes to
|
||||
help debug the issue:
|
||||
|
||||
```sh
|
||||
# Online - using unbacked mode
|
||||
vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=unbacked
|
||||
|
||||
# Online - using backed_size_oblivious mode
|
||||
vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=backed_size_oblivious
|
||||
```
|
||||
|
||||
```py
|
||||
# Offline - using unbacked mode
|
||||
from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
|
||||
LLM(model, compilation_config=CompilationConfig(
|
||||
dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.UNBACKED)
|
||||
))
|
||||
|
||||
# Offline - using backed_size_oblivious mode
|
||||
from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
|
||||
LLM(model, compilation_config=CompilationConfig(
|
||||
dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS)
|
||||
))
|
||||
```
|
||||
|
||||
These modes are stricter and reduce or eliminate the need of dynamic shapes guarding, which can help isolate issues:
|
||||
|
||||
- `unbacked`: Uses unbacked symints which don't allow guards, making it easier to identify where guards are being incorrectly added
|
||||
- `backed_size_oblivious`: Uses a mode that is more strict about guarding.
|
||||
|
||||
For more details on dynamic shapes modes, see [Dynamic shapes and vLLM guard dropping](torch_compile.md#dynamic-shapes-and-vllm-guard-dropping).
|
||||
|
||||
### Printing guards
|
||||
|
||||
To see all guards that are being added during compilation, you can use `TORCH_LOGS=+dynamic`:
|
||||
|
||||
```sh
|
||||
TORCH_LOGS=+dynamic vllm serve meta-llama/Llama-3.2-1B
|
||||
```
|
||||
|
||||
Look for `[guard added]` in the logs to see where guards are being added. This can help you identify which operations are
|
||||
causing guards to be added incorrectly.
|
||||
|
||||
## Debugging TorchInductor
|
||||
|
||||
TorchInductor takes a captured graph and then compiles it down to some Python code
|
||||
|
||||
@ -29,6 +29,109 @@ A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all
|
||||
|
||||
By default, the cache saves compiled artifacts as binary files. If you would like to interact with the generated code for debugging purposes, set the field `compile_cache_save_format=unpacked` in the compilation config, or omit this and set the env variable `VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked`.
|
||||
|
||||
## Dynamic shapes and vllm guard dropping
|
||||
|
||||
`torch.compile` is designed to guard on dynamic shapes with no hesitation
|
||||
when needed. This contradicts with vLLM's `torch.compile` approach of
|
||||
dropping the guards since many of those guards could be material.
|
||||
|
||||
`torch.compile` provides two kinds of dynamic shapes: `backed` and `unbacked`.
|
||||
`torch.compile` guards on `backed` dynamic shapes and does not provide a
|
||||
guarantee that no guards will be added to them. User code, dynamo,
|
||||
inductor, and autograd all can add guards. Moreover, for 0/1
|
||||
specializations, backed symbols are specialized unconditionally to 0, 1,
|
||||
or >=2 even without encountering a branching on those ranges.
|
||||
|
||||
On the contrary, `unbacked` dynamic shapes are guaranteed not to be guarded
|
||||
on and are not 0/1 specialized. However, there is a possibility of
|
||||
throwing a data dependent error when a branch that requires their value is
|
||||
encountered and no explicit unbacked handling is defined. The framework is
|
||||
converging to a state where it won't throw DDE but rather pick general
|
||||
paths. One downside of using unbacked is missed optimization opportunities
|
||||
due to either perf bugs or picking general paths, also using a fixed
|
||||
non-example input-based hint (this will be fixed soon with override_hint
|
||||
API). An example of picking general paths is assuming input not contiguous
|
||||
in functions call contiguous() and reshape() when can't be symbolically proven
|
||||
with a change of introducing a clone.
|
||||
|
||||
`backed_size_oblivious` is a flag that enables treating backed symbols as
|
||||
unbacked wherever explicit handling for unbacked is defined. With this
|
||||
mode, 0/1 specializations are mostly avoided in framework code and the
|
||||
default 0/1 specialization does not happen. However, there is still no
|
||||
guarantee that torch.compile won't guard, especially due to user code or
|
||||
custom passes. `backed_size_oblivious` is experimental in PyTorch compile
|
||||
and could be deprecated. That said, it's a safer option to use than
|
||||
`backed` and the probability of reducing performance is lower than
|
||||
`unbacked`.
|
||||
|
||||
### Configuring Dynamic Shapes
|
||||
|
||||
The `DynamicShapesConfig` allows you to control the dynamic shapes behavior by
|
||||
setting the `type` field. You can choose between three modes:
|
||||
`BACKED`(default), `UNBACKED` , and `BACKED_SIZE_OBLIVIOUS`.
|
||||
|
||||
#### Offline Inference Example (Using LLM class)
|
||||
|
||||
When using the `LLM` class for offline inference, you can configure dynamic
|
||||
shapes through the `compilation_config` parameter:
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
|
||||
|
||||
# Example: Using backed_size_oblivious (experimental, safer than backed)
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B",
|
||||
compilation_config=CompilationConfig(
|
||||
dynamic_shapes_config=DynamicShapesConfig(
|
||||
type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Example: Using unbacked (strongest guarantee against guards)
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B",
|
||||
compilation_config=CompilationConfig(
|
||||
dynamic_shapes_config=DynamicShapesConfig(
|
||||
type=DynamicShapesType.UNBACKED
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Generate outputs
|
||||
prompts = ["Hello, my name is", "The future of AI is"]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
```
|
||||
|
||||
#### Online Serving Example (Using vllm serve)
|
||||
|
||||
When using `vllm serve` for online serving, you can configure dynamic shapes
|
||||
through the `--compilation-config` flag:
|
||||
|
||||
```bash
|
||||
# Example: Using unbacked
|
||||
vllm serve meta-llama/Llama-3.2-1B \
|
||||
--compilation-config '{"dynamic_shapes_config": {"type": "unbacked"}}'
|
||||
|
||||
|
||||
# Alternative: Using dot notation (simpler for single values)
|
||||
vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=unbacked
|
||||
```
|
||||
|
||||
#### Choosing the Right Mode
|
||||
|
||||
- **BACKED** (default): Use when you're willing to accept potential unsafe dropping of guards
|
||||
for maximal performance. Guard could be unsoundly added and then ignored.
|
||||
|
||||
- **UNBACKED** Use when you need the strongest guarantee against guards.
|
||||
This is the most conservative option but may miss some optimization opportunities.
|
||||
|
||||
- **BACKED_SIZE_OBLIVIOUS**: Use when you want a balance between avoiding guards
|
||||
and performance. This experimental mode is safer than BACKED but still not as
|
||||
conservative as UNBACKED.
|
||||
|
||||
## Python Code Compilation
|
||||
|
||||
In the very verbose logs, we can see:
|
||||
@ -122,7 +225,7 @@ When all the shapes are known, `torch.compile` can compare different configs, an
|
||||
triton_mm_4 0.0130 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
|
||||
triton_mm_8 0.0134 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
|
||||
triton_mm_12 0.0148 ms 87.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
|
||||
mm 0.0160 ms 81.6%
|
||||
mm 0.0160 ms 81.6%
|
||||
triton_mm_16 0.0165 ms 78.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
|
||||
triton_mm_3 0.0199 ms 65.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
|
||||
triton_mm_1 0.0203 ms 64.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=2
|
||||
|
||||
88
tests/compile/test_dynamic_shapes_compilation.py
Normal file
88
tests/compile/test_dynamic_shapes_compilation.py
Normal file
@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config.compilation import CompilationMode, DynamicShapesType
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
|
||||
def get_test_models():
|
||||
"""Get list of models to test based on PyTorch version"""
|
||||
# TODO "Qwen/Qwen3-4B-Instruct-2507" fails Fix issue and support it.
|
||||
return ["gpt2", "Qwen/Qwen2-7B-Instruct", "meta-llama/Llama-3.1-8B"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", get_test_models())
|
||||
@pytest.mark.parametrize(
|
||||
"shapes_type",
|
||||
[
|
||||
DynamicShapesType.BACKED,
|
||||
DynamicShapesType.UNBACKED,
|
||||
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("use_aot_compile", ["0"])
|
||||
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_dynamic_shapes_compilation(
|
||||
monkeypatch, model_name, shapes_type, use_aot_compile, use_bytecode_hook
|
||||
):
|
||||
"""Test that all dynamic shapes types compile successfully"""
|
||||
print(
|
||||
f"\nTesting model: {model_name} with {shapes_type.name}, "
|
||||
f"AOT compile: {use_aot_compile}, "
|
||||
f"Bytecode hook: {use_bytecode_hook}"
|
||||
)
|
||||
if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
|
||||
pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
|
||||
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
|
||||
|
||||
prompt = "Hello, my name is"
|
||||
|
||||
print(f"Testing {shapes_type.name} dynamic shapes...")
|
||||
|
||||
# Initialize the model with specific dynamic shapes configuration
|
||||
model = LLM(
|
||||
model=model_name,
|
||||
compilation_config={
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"dynamic_shapes_config": {
|
||||
"type": shapes_type.value,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
output = model.generate(prompt)
|
||||
result = output[0].outputs[0].text
|
||||
# Example of setting the sampling parameters
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
yes_tokens = tokenizer.encode("yes", add_special_tokens=False)
|
||||
no_tokens = tokenizer.encode("no", add_special_tokens=False)
|
||||
allowed_ids = list(set(yes_tokens + no_tokens))
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=1, temperature=0, allowed_token_ids=allowed_ids
|
||||
)
|
||||
|
||||
output = model.generate(
|
||||
"answer with yes or no is " + result + " rubbish for prompt " + prompt + "?",
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
result = output[0].outputs[0].text
|
||||
assert result == "yes"
|
||||
|
||||
# Clean up GPU memory
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
print("GPU memory cleared")
|
||||
@ -24,6 +24,7 @@ from vllm.config import (
|
||||
get_current_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
@ -104,6 +105,7 @@ def support_torch_compile(
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
|
||||
) -> Callable[[_T], _T] | _T:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
@ -161,6 +163,14 @@ def support_torch_compile(
|
||||
dim to be decorated with `mark_unbacked`. This is useful if we would like to
|
||||
enforce that dynamo does not specialize on 0/1 values in the case of dummy input
|
||||
such as for vision model compilation
|
||||
|
||||
`shape_invariants` is a function that gets compiled right before forward.
|
||||
The function should have the torch._check calls that are needed to set
|
||||
the relationships between different input sizes. For example:
|
||||
torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
|
||||
This enforces constraints on the symbolic shapes without hardcoding
|
||||
specific values. It is needed for some models to avoid data dependent
|
||||
errors.
|
||||
"""
|
||||
|
||||
def cls_decorator_helper(cls: _T) -> _T:
|
||||
@ -199,7 +209,11 @@ def support_torch_compile(
|
||||
f"Argument {k} not found in the forward method of {cls}"
|
||||
)
|
||||
return _support_torch_compile(
|
||||
cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
|
||||
cls,
|
||||
inferred_dynamic_arg_dims,
|
||||
mark_unbacked_dims,
|
||||
enable_if,
|
||||
shape_invariants,
|
||||
)
|
||||
|
||||
if cls is not None:
|
||||
@ -242,6 +256,7 @@ def _support_torch_compile(
|
||||
dynamic_arg_dims: dict[str, int | list[int]],
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
|
||||
) -> _T:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
@ -276,11 +291,12 @@ def _support_torch_compile(
|
||||
old_init(self, **kwargs)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = self.vllm_config.compilation_config
|
||||
enable_compile = enable_if is None or enable_if(vllm_config)
|
||||
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
|
||||
# will handle the compilation, so we don't need to do anything here.
|
||||
self.do_not_compile = (
|
||||
vllm_config.compilation_config.mode
|
||||
self.compilation_config.mode
|
||||
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
|
||||
or not supports_dynamo()
|
||||
or _should_ignore_torch_compile(self.__class__)
|
||||
@ -289,29 +305,38 @@ def _support_torch_compile(
|
||||
if self.do_not_compile:
|
||||
return
|
||||
|
||||
self._check_shape_invariants = shape_invariants
|
||||
|
||||
compilation_counter.num_models_seen += 1
|
||||
self.compiled = False
|
||||
TorchCompileWithNoGuardsWrapper.__init__(self)
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
||||
def _mark_dynamic_inputs(mod, *args, **kwargs):
|
||||
def _mark_dynamic_inputs(mod, type, *args, **kwargs):
|
||||
def mark_dynamic(arg, dims):
|
||||
if type == DynamicShapesType.UNBACKED:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
else:
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
|
||||
sig = inspect.signature(mod.__class__.forward)
|
||||
bound_args = sig.bind(mod, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.mark_dynamic(tensor, dims)
|
||||
mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported dynamic dimensions"
|
||||
@ -338,6 +363,7 @@ def _support_torch_compile(
|
||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
ds_type = self.compilation_config.dynamic_shapes_config.type
|
||||
cache_dir = None
|
||||
aot_compilation_path = None
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
@ -352,6 +378,14 @@ def _support_torch_compile(
|
||||
serialized backend artifacts), then we need to generate a new AOT
|
||||
compile artifact from scratch.
|
||||
"""
|
||||
# Validate that AOT compile is not used with unbacked dynamic
|
||||
# shapes. aot_compile re-allocates backed symbols post dynamo!
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
raise ValueError(
|
||||
"AOT compilation is not compatible with UNBACKED dynamic shapes. "
|
||||
"Please use BACKED or BACKED_SIZE_OBLIVIOUS dynamic shapes type "
|
||||
"when VLLM_USE_AOT_COMPILE is enabled."
|
||||
)
|
||||
from .caching import compilation_config_hash_factors
|
||||
|
||||
factors: list[str] = compilation_config_hash_factors(self.vllm_config)
|
||||
@ -401,7 +435,12 @@ def _support_torch_compile(
|
||||
# This is the path for the first compilation.
|
||||
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
_mark_dynamic_inputs(self, *args, **kwargs)
|
||||
_mark_dynamic_inputs(
|
||||
self,
|
||||
ds_type,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
@ -417,9 +456,7 @@ def _support_torch_compile(
|
||||
# properly when any of these files change.
|
||||
|
||||
# 1. the file containing the top-level forward function
|
||||
self.vllm_config.compilation_config.traced_files.add(
|
||||
original_code_object.co_filename
|
||||
)
|
||||
self.compilation_config.traced_files.add(original_code_object.co_filename)
|
||||
|
||||
# 2. every time Dynamo sees a function call, it will inline
|
||||
# the function by calling InliningInstructionTranslator.inline_call_
|
||||
@ -429,7 +466,7 @@ def _support_torch_compile(
|
||||
|
||||
def patched_inline_call(self_):
|
||||
code = self_.f_code
|
||||
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
|
||||
self.compilation_config.traced_files.add(code.co_filename)
|
||||
return inline_call(self_)
|
||||
|
||||
# Disable the C++ compilation of symbolic shape guards. C++-fication
|
||||
@ -445,12 +482,18 @@ def _support_torch_compile(
|
||||
# if the config doesn't exist
|
||||
logger.debug("enable_cpp_symbolic_shape_guards config not available")
|
||||
|
||||
# Prepare backed_size_oblivious config patch if needed
|
||||
fx_config_patches = {}
|
||||
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
|
||||
fx_config_patches["backed_size_oblivious"] = True
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
InliningInstructionTranslator, "inline_call_", patched_inline_call
|
||||
),
|
||||
torch._dynamo.config.patch(**dynamo_config_patches),
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
torch.fx.experimental._config.patch(**fx_config_patches),
|
||||
_torch27_patch_tensor_subclasses(),
|
||||
):
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
|
||||
@ -6,6 +6,7 @@ import sys
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from types import CodeType
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch._C._dynamo.guards
|
||||
@ -85,6 +86,12 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
since we drop all guards.
|
||||
"""
|
||||
|
||||
def check_invariants_and_forward(self, *args, **kwargs):
|
||||
assert hasattr(self, "_check_shape_invariants")
|
||||
self._check_shape_invariants(*args, **kwargs)
|
||||
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def __init__(self):
|
||||
self.compiled = False
|
||||
|
||||
@ -104,6 +111,21 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
# Drop all the guards.
|
||||
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
||||
|
||||
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
|
||||
from vllm.compilation.decorators import DynamicShapesType
|
||||
|
||||
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
|
||||
compiled_ptr: Any = self.forward
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
if envs.VLLM_USE_BYTECODE_HOOK:
|
||||
# reason is that bytecode does this hack torch._dynamo.eval_frame.
|
||||
# remove_from_cache(self.original_code_object()) to force a new
|
||||
# re-compilation.
|
||||
raise ValueError(
|
||||
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
|
||||
)
|
||||
compiled_ptr = self.check_invariants_and_forward
|
||||
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||
torch._dynamo.config.enable_aot_compile = True
|
||||
@ -114,7 +136,7 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
logger.warning(msg)
|
||||
|
||||
self._compiled_callable = torch.compile(
|
||||
self.forward,
|
||||
compiled_ptr,
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
|
||||
@ -192,6 +192,54 @@ class PassConfig:
|
||||
self.enable_qk_norm_rope_fusion = False
|
||||
|
||||
|
||||
class DynamicShapesType(str, enum.Enum):
|
||||
"""Types of dynamic shapes handling in torch.compile().
|
||||
see Dynamic shapes and vllm guard dropping in torch_compile.md
|
||||
for more details."""
|
||||
|
||||
BACKED = "backed"
|
||||
"""Use backed dynamic shapes. torch.compile() guards on backed dynamic
|
||||
shapes and may add guards. Symbols are specialized to 0, 1, or >=2 even
|
||||
without encountering branching on those ranges."""
|
||||
|
||||
UNBACKED = "unbacked"
|
||||
"""Use unbacked dynamic shapes. Guaranteed not to be guarded on and not
|
||||
0/1 specialized, but may throw data dependent errors when branches require
|
||||
their value without explicit unbacked handling."""
|
||||
|
||||
BACKED_SIZE_OBLIVIOUS = "backed_size_oblivious"
|
||||
"""Experimental flag that treats backed symbols as unbacked when explicit
|
||||
unbacked handling is defined."""
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class DynamicShapesConfig:
|
||||
"""Configuration to control/debug torch compile dynamic shapes."""
|
||||
|
||||
type: DynamicShapesType = DynamicShapesType.BACKED
|
||||
"""Controls the type of dynamic shapes handling to use with torch.compile().
|
||||
|
||||
- BACKED: Default PyTorch behavior with potential guards ignored.
|
||||
- UNBACKED: No guards guaranteed (most sound) but may throw
|
||||
data dependent errors.
|
||||
- BACKED_SIZE_OBLIVIOUS: Experimental safer alternative to
|
||||
backed/unbacked.
|
||||
"""
|
||||
|
||||
# TODO add a debug mode to fail
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
Provide a hash for DynamicShapesConfig
|
||||
"""
|
||||
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
factors = get_hash_factors(self, {})
|
||||
return hash_factors(factors)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class CompilationConfig:
|
||||
@ -322,7 +370,7 @@ class CompilationConfig:
|
||||
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
|
||||
compile_mm_encoder: bool = False
|
||||
"""Whether or not to compile the multimodal encoder.
|
||||
Currently, this only works for `Qwen2_5_vl` on selected platforms.
|
||||
Currently, this only works for `Qwen2_5_vl` on selected platforms.
|
||||
Disabled by default until more models are supported/tested to work."""
|
||||
|
||||
# Inductor capture
|
||||
@ -348,9 +396,11 @@ class CompilationConfig:
|
||||
"""Sizes to compile for inductor. In addition
|
||||
to integers, it also supports "cudagraph_capture_sizes" to
|
||||
specify the sizes for cudagraph capture."""
|
||||
|
||||
inductor_compile_config: dict = field(default_factory=dict)
|
||||
"""Additional configurations for inductor.
|
||||
- None: use default configurations."""
|
||||
|
||||
inductor_passes: dict[str, str] = field(default_factory=dict)
|
||||
"""Additional passes for inductor. It is a dictionary
|
||||
from pass name to pass function qualified name. We use function
|
||||
@ -460,8 +510,15 @@ class CompilationConfig:
|
||||
max_num_seqs, and prevents capture of many large graphs (>512) that would
|
||||
greatly increase startup time with limited performance benefit.
|
||||
"""
|
||||
|
||||
dynamic_shapes_config: DynamicShapesConfig = field(
|
||||
default_factory=DynamicShapesConfig
|
||||
)
|
||||
"""Configuration for dynamic shapes options"""
|
||||
|
||||
local_cache_dir: str = field(default=None, init=False) # type: ignore
|
||||
"""local cache dir for each rank"""
|
||||
|
||||
bs_to_padded_graph_size: list[int] = field(
|
||||
default=None, # type: ignore
|
||||
init=False,
|
||||
@ -530,6 +587,7 @@ class CompilationConfig:
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
|
||||
factors["pass_config"] = self.pass_config.compute_hash()
|
||||
return hash_factors(factors)
|
||||
|
||||
|
||||
@ -354,7 +354,17 @@ class LlamaDecoderLayer(nn.Module):
|
||||
return vllm_config.quant_config
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
def llama_model_invariants(
|
||||
input_ids, positions, intermediate_tensors=None, inputs_embeds=None
|
||||
):
|
||||
"""Shape invariants for Llama model compilation, those are translated to
|
||||
runtime assertions for unbacked dynamic shapes and are compiled away for
|
||||
backed"""
|
||||
if input_ids is not None:
|
||||
torch._check(positions.size()[0] == input_ids.size()[0])
|
||||
|
||||
|
||||
@support_torch_compile(shape_invariants=llama_model_invariants)
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -274,6 +274,38 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
def qwen_2_model_invariants(
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
):
|
||||
"""Shape invariants for Qwen2Model Model, those are translated to
|
||||
runtime assertions for unbacked dynamic shapes and are compiled away for
|
||||
backed"""
|
||||
# All these should be equal.
|
||||
# input_ids.size()[0]
|
||||
# positions.size()[-1]
|
||||
# intermediate_tensors["hidden_states"].size()[0]
|
||||
# inputs_embeds.size()[0]
|
||||
torch._check(input_ids.size()[0] == positions.size()[-1])
|
||||
if intermediate_tensors is not None:
|
||||
torch._check(
|
||||
input_ids.size()[0] == intermediate_tensors["hidden_states"].size()[0]
|
||||
)
|
||||
|
||||
if inputs_embeds is not None:
|
||||
torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
|
||||
|
||||
# Hidden dimensions should match (hidden_size)
|
||||
# intermediate_tensors["hidden_states"].size()[1]
|
||||
# inputs_embeds.size()[1]
|
||||
if inputs_embeds is not None and intermediate_tensors is not None:
|
||||
torch._check(
|
||||
inputs_embeds.size()[1] == intermediate_tensors["hidden_states"].size()[1]
|
||||
)
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
@ -282,7 +314,8 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
}
|
||||
},
|
||||
shape_invariants=qwen_2_model_invariants,
|
||||
)
|
||||
class Qwen2Model(nn.Module):
|
||||
def __init__(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user