Add option to use unbacked, and backed size obl dynamic shapes for more sounds compilation. (#26199)

Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
Laith Sakka 2025-11-24 07:12:41 -08:00 committed by GitHub
parent f716a15372
commit 7a228b5305
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 442 additions and 15 deletions

View File

@ -151,6 +151,76 @@ To avoid this, please either:
2. wrap the branching logic into a custom operator. TorchDynamo does not
trace into custom operators.
## Debugging constraint violations and dynamic shapes guards issues
Dynamic-shape guards are a specific category of Dynamo guards. They are constraints that `torch.compile`
attaches to dynamic dimensions (e.g., `seq_len`) to ensure the compiled artifact remains valid.
These guards typically appear when framework code, custom passes, or user code branches based on
dynamic shape values.
**Example:**
```python
if x > 10:
# path A
else:
# path B
```
This creates a guard `x > 10` or `x <= 10` depending on which path was traced.
**vLLM's Assumption:**
vLLM assumes that all guards added by torch.compile are safe to drop and will not
constrain the compiled graph to specific input shapes. When this assumption is violated,
it can cause issues that users need to debug.
Some side effects that indicates this assumption is violated are runtime errors
or `ConstraintViolationErrors`.
A `ConstraintViolationErrors` will be thrown if a dynamic shape gets constrained to
a single value. If you encounter a constraint violation error or suspect that a dynamic
shapes guard is being added incorrectly, you can use stricter dynamic shape modes to
help debug the issue:
```sh
# Online - using unbacked mode
vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=unbacked
# Online - using backed_size_oblivious mode
vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=backed_size_oblivious
```
```py
# Offline - using unbacked mode
from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
LLM(model, compilation_config=CompilationConfig(
dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.UNBACKED)
))
# Offline - using backed_size_oblivious mode
from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
LLM(model, compilation_config=CompilationConfig(
dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS)
))
```
These modes are stricter and reduce or eliminate the need of dynamic shapes guarding, which can help isolate issues:
- `unbacked`: Uses unbacked symints which don't allow guards, making it easier to identify where guards are being incorrectly added
- `backed_size_oblivious`: Uses a mode that is more strict about guarding.
For more details on dynamic shapes modes, see [Dynamic shapes and vLLM guard dropping](torch_compile.md#dynamic-shapes-and-vllm-guard-dropping).
### Printing guards
To see all guards that are being added during compilation, you can use `TORCH_LOGS=+dynamic`:
```sh
TORCH_LOGS=+dynamic vllm serve meta-llama/Llama-3.2-1B
```
Look for `[guard added]` in the logs to see where guards are being added. This can help you identify which operations are
causing guards to be added incorrectly.
## Debugging TorchInductor
TorchInductor takes a captured graph and then compiles it down to some Python code

View File

@ -29,6 +29,109 @@ A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all
By default, the cache saves compiled artifacts as binary files. If you would like to interact with the generated code for debugging purposes, set the field `compile_cache_save_format=unpacked` in the compilation config, or omit this and set the env variable `VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked`.
## Dynamic shapes and vllm guard dropping
`torch.compile` is designed to guard on dynamic shapes with no hesitation
when needed. This contradicts with vLLM's `torch.compile` approach of
dropping the guards since many of those guards could be material.
`torch.compile` provides two kinds of dynamic shapes: `backed` and `unbacked`.
`torch.compile` guards on `backed` dynamic shapes and does not provide a
guarantee that no guards will be added to them. User code, dynamo,
inductor, and autograd all can add guards. Moreover, for 0/1
specializations, backed symbols are specialized unconditionally to 0, 1,
or >=2 even without encountering a branching on those ranges.
On the contrary, `unbacked` dynamic shapes are guaranteed not to be guarded
on and are not 0/1 specialized. However, there is a possibility of
throwing a data dependent error when a branch that requires their value is
encountered and no explicit unbacked handling is defined. The framework is
converging to a state where it won't throw DDE but rather pick general
paths. One downside of using unbacked is missed optimization opportunities
due to either perf bugs or picking general paths, also using a fixed
non-example input-based hint (this will be fixed soon with override_hint
API). An example of picking general paths is assuming input not contiguous
in functions call contiguous() and reshape() when can't be symbolically proven
with a change of introducing a clone.
`backed_size_oblivious` is a flag that enables treating backed symbols as
unbacked wherever explicit handling for unbacked is defined. With this
mode, 0/1 specializations are mostly avoided in framework code and the
default 0/1 specialization does not happen. However, there is still no
guarantee that torch.compile won't guard, especially due to user code or
custom passes. `backed_size_oblivious` is experimental in PyTorch compile
and could be deprecated. That said, it's a safer option to use than
`backed` and the probability of reducing performance is lower than
`unbacked`.
### Configuring Dynamic Shapes
The `DynamicShapesConfig` allows you to control the dynamic shapes behavior by
setting the `type` field. You can choose between three modes:
`BACKED`(default), `UNBACKED` , and `BACKED_SIZE_OBLIVIOUS`.
#### Offline Inference Example (Using LLM class)
When using the `LLM` class for offline inference, you can configure dynamic
shapes through the `compilation_config` parameter:
```python
from vllm import LLM, SamplingParams
from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
# Example: Using backed_size_oblivious (experimental, safer than backed)
llm = LLM(
model="meta-llama/Llama-3.2-1B",
compilation_config=CompilationConfig(
dynamic_shapes_config=DynamicShapesConfig(
type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS
)
)
)
# Example: Using unbacked (strongest guarantee against guards)
llm = LLM(
model="meta-llama/Llama-3.2-1B",
compilation_config=CompilationConfig(
dynamic_shapes_config=DynamicShapesConfig(
type=DynamicShapesType.UNBACKED
)
)
)
# Generate outputs
prompts = ["Hello, my name is", "The future of AI is"]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(prompts, sampling_params)
```
#### Online Serving Example (Using vllm serve)
When using `vllm serve` for online serving, you can configure dynamic shapes
through the `--compilation-config` flag:
```bash
# Example: Using unbacked
vllm serve meta-llama/Llama-3.2-1B \
--compilation-config '{"dynamic_shapes_config": {"type": "unbacked"}}'
# Alternative: Using dot notation (simpler for single values)
vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=unbacked
```
#### Choosing the Right Mode
- **BACKED** (default): Use when you're willing to accept potential unsafe dropping of guards
for maximal performance. Guard could be unsoundly added and then ignored.
- **UNBACKED** Use when you need the strongest guarantee against guards.
This is the most conservative option but may miss some optimization opportunities.
- **BACKED_SIZE_OBLIVIOUS**: Use when you want a balance between avoiding guards
and performance. This experimental mode is safer than BACKED but still not as
conservative as UNBACKED.
## Python Code Compilation
In the very verbose logs, we can see:
@ -122,7 +225,7 @@ When all the shapes are known, `torch.compile` can compare different configs, an
triton_mm_4 0.0130 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
triton_mm_8 0.0134 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
triton_mm_12 0.0148 ms 87.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
mm 0.0160 ms 81.6%
mm 0.0160 ms 81.6%
triton_mm_16 0.0165 ms 78.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
triton_mm_3 0.0199 ms 65.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
triton_mm_1 0.0203 ms 64.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=2

View File

@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.config.compilation import CompilationMode, DynamicShapesType
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils.torch_utils import is_torch_equal_or_newer
def get_test_models():
"""Get list of models to test based on PyTorch version"""
# TODO "Qwen/Qwen3-4B-Instruct-2507" fails Fix issue and support it.
return ["gpt2", "Qwen/Qwen2-7B-Instruct", "meta-llama/Llama-3.1-8B"]
@pytest.mark.parametrize("model_name", get_test_models())
@pytest.mark.parametrize(
"shapes_type",
[
DynamicShapesType.BACKED,
DynamicShapesType.UNBACKED,
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
],
)
@pytest.mark.parametrize("use_aot_compile", ["0"])
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_dynamic_shapes_compilation(
monkeypatch, model_name, shapes_type, use_aot_compile, use_bytecode_hook
):
"""Test that all dynamic shapes types compile successfully"""
print(
f"\nTesting model: {model_name} with {shapes_type.name}, "
f"AOT compile: {use_aot_compile}, "
f"Bytecode hook: {use_bytecode_hook}"
)
if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
prompt = "Hello, my name is"
print(f"Testing {shapes_type.name} dynamic shapes...")
# Initialize the model with specific dynamic shapes configuration
model = LLM(
model=model_name,
compilation_config={
"mode": CompilationMode.VLLM_COMPILE,
"dynamic_shapes_config": {
"type": shapes_type.value,
},
},
)
output = model.generate(prompt)
result = output[0].outputs[0].text
# Example of setting the sampling parameters
tokenizer = get_tokenizer(model_name)
yes_tokens = tokenizer.encode("yes", add_special_tokens=False)
no_tokens = tokenizer.encode("no", add_special_tokens=False)
allowed_ids = list(set(yes_tokens + no_tokens))
sampling_params = SamplingParams(
max_tokens=1, temperature=0, allowed_token_ids=allowed_ids
)
output = model.generate(
"answer with yes or no is " + result + " rubbish for prompt " + prompt + "?",
sampling_params=sampling_params,
)
result = output[0].outputs[0].text
assert result == "yes"
# Clean up GPU memory
del model
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
print("GPU memory cleared")

View File

@ -24,6 +24,7 @@ from vllm.config import (
get_current_vllm_config,
set_current_vllm_config,
)
from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname
@ -104,6 +105,7 @@ def support_torch_compile(
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> Callable[[_T], _T] | _T:
"""
A decorator to add support for compiling the forward method of a class.
@ -161,6 +163,14 @@ def support_torch_compile(
dim to be decorated with `mark_unbacked`. This is useful if we would like to
enforce that dynamo does not specialize on 0/1 values in the case of dummy input
such as for vision model compilation
`shape_invariants` is a function that gets compiled right before forward.
The function should have the torch._check calls that are needed to set
the relationships between different input sizes. For example:
torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
This enforces constraints on the symbolic shapes without hardcoding
specific values. It is needed for some models to avoid data dependent
errors.
"""
def cls_decorator_helper(cls: _T) -> _T:
@ -199,7 +209,11 @@ def support_torch_compile(
f"Argument {k} not found in the forward method of {cls}"
)
return _support_torch_compile(
cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
cls,
inferred_dynamic_arg_dims,
mark_unbacked_dims,
enable_if,
shape_invariants,
)
if cls is not None:
@ -242,6 +256,7 @@ def _support_torch_compile(
dynamic_arg_dims: dict[str, int | list[int]],
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
@ -276,11 +291,12 @@ def _support_torch_compile(
old_init(self, **kwargs)
self.vllm_config = vllm_config
self.compilation_config = self.vllm_config.compilation_config
enable_compile = enable_if is None or enable_if(vllm_config)
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = (
vllm_config.compilation_config.mode
self.compilation_config.mode
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
or not supports_dynamo()
or _should_ignore_torch_compile(self.__class__)
@ -289,29 +305,38 @@ def _support_torch_compile(
if self.do_not_compile:
return
self._check_shape_invariants = shape_invariants
compilation_counter.num_models_seen += 1
self.compiled = False
TorchCompileWithNoGuardsWrapper.__init__(self)
cls.__init__ = __init__
def _mark_dynamic_inputs(mod, *args, **kwargs):
def _mark_dynamic_inputs(mod, type, *args, **kwargs):
def mark_dynamic(arg, dims):
if type == DynamicShapesType.UNBACKED:
torch._dynamo.decorators.mark_unbacked(arg, dims)
else:
torch._dynamo.mark_dynamic(arg, dims)
sig = inspect.signature(mod.__class__.forward)
bound_args = sig.bind(mod, *args, **kwargs)
bound_args.apply_defaults()
for k, dims in dynamic_arg_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
torch._dynamo.mark_dynamic(arg, dims)
mark_dynamic(arg, dims)
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
# In case dims is specified with negative indexing
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
torch._dynamo.mark_dynamic(tensor, dims)
mark_dynamic(tensor, dims)
else:
raise ValueError(
"Unsupported dynamic dimensions"
@ -338,6 +363,7 @@ def _support_torch_compile(
if getattr(self, "aot_compiled_fn", None) is not None:
return self.aot_compiled_fn(self, *args, **kwargs)
ds_type = self.compilation_config.dynamic_shapes_config.type
cache_dir = None
aot_compilation_path = None
if envs.VLLM_USE_AOT_COMPILE:
@ -352,6 +378,14 @@ def _support_torch_compile(
serialized backend artifacts), then we need to generate a new AOT
compile artifact from scratch.
"""
# Validate that AOT compile is not used with unbacked dynamic
# shapes. aot_compile re-allocates backed symbols post dynamo!
if ds_type == DynamicShapesType.UNBACKED:
raise ValueError(
"AOT compilation is not compatible with UNBACKED dynamic shapes. "
"Please use BACKED or BACKED_SIZE_OBLIVIOUS dynamic shapes type "
"when VLLM_USE_AOT_COMPILE is enabled."
)
from .caching import compilation_config_hash_factors
factors: list[str] = compilation_config_hash_factors(self.vllm_config)
@ -401,7 +435,12 @@ def _support_torch_compile(
# This is the path for the first compilation.
# the first compilation needs to have dynamic shapes marked
_mark_dynamic_inputs(self, *args, **kwargs)
_mark_dynamic_inputs(
self,
ds_type,
*args,
**kwargs,
)
# here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config)
@ -417,9 +456,7 @@ def _support_torch_compile(
# properly when any of these files change.
# 1. the file containing the top-level forward function
self.vllm_config.compilation_config.traced_files.add(
original_code_object.co_filename
)
self.compilation_config.traced_files.add(original_code_object.co_filename)
# 2. every time Dynamo sees a function call, it will inline
# the function by calling InliningInstructionTranslator.inline_call_
@ -429,7 +466,7 @@ def _support_torch_compile(
def patched_inline_call(self_):
code = self_.f_code
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
self.compilation_config.traced_files.add(code.co_filename)
return inline_call(self_)
# Disable the C++ compilation of symbolic shape guards. C++-fication
@ -445,12 +482,18 @@ def _support_torch_compile(
# if the config doesn't exist
logger.debug("enable_cpp_symbolic_shape_guards config not available")
# Prepare backed_size_oblivious config patch if needed
fx_config_patches = {}
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
fx_config_patches["backed_size_oblivious"] = True
with (
patch.object(
InliningInstructionTranslator, "inline_call_", patched_inline_call
),
torch._dynamo.config.patch(**dynamo_config_patches),
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
torch.fx.experimental._config.patch(**fx_config_patches),
_torch27_patch_tensor_subclasses(),
):
if envs.VLLM_USE_AOT_COMPILE:

View File

@ -6,6 +6,7 @@ import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
from typing import Any
import torch
import torch._C._dynamo.guards
@ -85,6 +86,12 @@ class TorchCompileWithNoGuardsWrapper:
since we drop all guards.
"""
def check_invariants_and_forward(self, *args, **kwargs):
assert hasattr(self, "_check_shape_invariants")
self._check_shape_invariants(*args, **kwargs)
return self.forward(*args, **kwargs)
def __init__(self):
self.compiled = False
@ -104,6 +111,21 @@ class TorchCompileWithNoGuardsWrapper:
# Drop all the guards.
options["guard_filter_fn"] = lambda x: [False for _ in x]
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
from vllm.compilation.decorators import DynamicShapesType
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
compiled_ptr: Any = self.forward
if ds_type == DynamicShapesType.UNBACKED:
if envs.VLLM_USE_BYTECODE_HOOK:
# reason is that bytecode does this hack torch._dynamo.eval_frame.
# remove_from_cache(self.original_code_object()) to force a new
# re-compilation.
raise ValueError(
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
)
compiled_ptr = self.check_invariants_and_forward
if envs.VLLM_USE_AOT_COMPILE:
if hasattr(torch._dynamo.config, "enable_aot_compile"):
torch._dynamo.config.enable_aot_compile = True
@ -114,7 +136,7 @@ class TorchCompileWithNoGuardsWrapper:
logger.warning(msg)
self._compiled_callable = torch.compile(
self.forward,
compiled_ptr,
fullgraph=True,
dynamic=False,
backend=backend,

View File

@ -192,6 +192,54 @@ class PassConfig:
self.enable_qk_norm_rope_fusion = False
class DynamicShapesType(str, enum.Enum):
"""Types of dynamic shapes handling in torch.compile().
see Dynamic shapes and vllm guard dropping in torch_compile.md
for more details."""
BACKED = "backed"
"""Use backed dynamic shapes. torch.compile() guards on backed dynamic
shapes and may add guards. Symbols are specialized to 0, 1, or >=2 even
without encountering branching on those ranges."""
UNBACKED = "unbacked"
"""Use unbacked dynamic shapes. Guaranteed not to be guarded on and not
0/1 specialized, but may throw data dependent errors when branches require
their value without explicit unbacked handling."""
BACKED_SIZE_OBLIVIOUS = "backed_size_oblivious"
"""Experimental flag that treats backed symbols as unbacked when explicit
unbacked handling is defined."""
@config
@dataclass
class DynamicShapesConfig:
"""Configuration to control/debug torch compile dynamic shapes."""
type: DynamicShapesType = DynamicShapesType.BACKED
"""Controls the type of dynamic shapes handling to use with torch.compile().
- BACKED: Default PyTorch behavior with potential guards ignored.
- UNBACKED: No guards guaranteed (most sound) but may throw
data dependent errors.
- BACKED_SIZE_OBLIVIOUS: Experimental safer alternative to
backed/unbacked.
"""
# TODO add a debug mode to fail
def compute_hash(self) -> str:
"""
Provide a hash for DynamicShapesConfig
"""
from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, {})
return hash_factors(factors)
@config
@dataclass
class CompilationConfig:
@ -322,7 +370,7 @@ class CompilationConfig:
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
compile_mm_encoder: bool = False
"""Whether or not to compile the multimodal encoder.
Currently, this only works for `Qwen2_5_vl` on selected platforms.
Currently, this only works for `Qwen2_5_vl` on selected platforms.
Disabled by default until more models are supported/tested to work."""
# Inductor capture
@ -348,9 +396,11 @@ class CompilationConfig:
"""Sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture."""
inductor_compile_config: dict = field(default_factory=dict)
"""Additional configurations for inductor.
- None: use default configurations."""
inductor_passes: dict[str, str] = field(default_factory=dict)
"""Additional passes for inductor. It is a dictionary
from pass name to pass function qualified name. We use function
@ -460,8 +510,15 @@ class CompilationConfig:
max_num_seqs, and prevents capture of many large graphs (>512) that would
greatly increase startup time with limited performance benefit.
"""
dynamic_shapes_config: DynamicShapesConfig = field(
default_factory=DynamicShapesConfig
)
"""Configuration for dynamic shapes options"""
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""
bs_to_padded_graph_size: list[int] = field(
default=None, # type: ignore
init=False,
@ -530,6 +587,7 @@ class CompilationConfig:
from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, ignored_factors)
factors["pass_config"] = self.pass_config.compute_hash()
return hash_factors(factors)

View File

@ -354,7 +354,17 @@ class LlamaDecoderLayer(nn.Module):
return vllm_config.quant_config
@support_torch_compile
def llama_model_invariants(
input_ids, positions, intermediate_tensors=None, inputs_embeds=None
):
"""Shape invariants for Llama model compilation, those are translated to
runtime assertions for unbacked dynamic shapes and are compiled away for
backed"""
if input_ids is not None:
torch._check(positions.size()[0] == input_ids.size()[0])
@support_torch_compile(shape_invariants=llama_model_invariants)
class LlamaModel(nn.Module):
def __init__(
self,

View File

@ -274,6 +274,38 @@ class Qwen2DecoderLayer(nn.Module):
return hidden_states, residual
def qwen_2_model_invariants(
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
):
"""Shape invariants for Qwen2Model Model, those are translated to
runtime assertions for unbacked dynamic shapes and are compiled away for
backed"""
# All these should be equal.
# input_ids.size()[0]
# positions.size()[-1]
# intermediate_tensors["hidden_states"].size()[0]
# inputs_embeds.size()[0]
torch._check(input_ids.size()[0] == positions.size()[-1])
if intermediate_tensors is not None:
torch._check(
input_ids.size()[0] == intermediate_tensors["hidden_states"].size()[0]
)
if inputs_embeds is not None:
torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
# Hidden dimensions should match (hidden_size)
# intermediate_tensors["hidden_states"].size()[1]
# inputs_embeds.size()[1]
if inputs_embeds is not None and intermediate_tensors is not None:
torch._check(
inputs_embeds.size()[1] == intermediate_tensors["hidden_states"].size()[1]
)
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
@ -282,7 +314,8 @@ class Qwen2DecoderLayer(nn.Module):
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}
},
shape_invariants=qwen_2_model_invariants,
)
class Qwen2Model(nn.Module):
def __init__(