Add evaluate_guards option to DynamicShapesConfig (#27432)

Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
Laith Sakka 2025-12-08 07:46:15 -08:00 committed by GitHub
parent 184076c3fe
commit 87aee9ed2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 222 additions and 31 deletions

View File

@ -2,12 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import tempfile
from contextlib import contextmanager
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.config.compilation import CompilationMode, DynamicShapesType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.config.compilation import (
CompilationMode,
DynamicShapesConfig,
DynamicShapesType,
)
from vllm.forward_context import set_forward_context
from vllm.tokenizers import get_tokenizer
from vllm.utils.torch_utils import is_torch_equal_or_newer
@ -29,18 +38,19 @@ def get_test_models():
)
@pytest.mark.parametrize("use_aot_compile", ["0"])
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
@pytest.mark.parametrize("evaluate_guards", [False, True])
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_dynamic_shapes_compilation(
monkeypatch, model_name, shapes_type, use_aot_compile, use_bytecode_hook
monkeypatch,
model_name,
shapes_type,
use_aot_compile,
use_bytecode_hook,
evaluate_guards,
):
"""Test that all dynamic shapes types compile successfully"""
print(
f"\nTesting model: {model_name} with {shapes_type.name}, "
f"AOT compile: {use_aot_compile}, "
f"Bytecode hook: {use_bytecode_hook}"
)
if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
@ -58,6 +68,7 @@ def test_dynamic_shapes_compilation(
"mode": CompilationMode.VLLM_COMPILE,
"dynamic_shapes_config": {
"type": shapes_type.value,
"evaluate_guards": evaluate_guards,
},
},
)
@ -86,3 +97,117 @@ def test_dynamic_shapes_compilation(
torch.cuda.empty_cache()
torch.cuda.synchronize()
print("GPU memory cleared")
@pytest.mark.parametrize("use_aot_compile", ["0", "1"])
@pytest.mark.parametrize(
"dynamic_shapes_type",
[
DynamicShapesType.BACKED,
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
],
)
@pytest.mark.parametrize("evaluate_guards", [False, True])
def test_model_specialization_with_evaluate_guards(
monkeypatch, use_aot_compile, dynamic_shapes_type, evaluate_guards
):
"""Test that evaluate_guards correctly detects shape specialization
violations.
"""
if (
use_aot_compile == "1"
and dynamic_shapes_type == DynamicShapesType.BACKED
and evaluate_guards
):
pytest.skip("evaluate_guards for backed does not work with aot_compile =1")
@support_torch_compile
class ModelWithSizeCheck(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward(self, x: torch.Tensor):
# This will cause specialization - torch.compile will guard on
# sx.shape[0]
if x.shape[0] >= 10:
return x * 10
else:
return x * 10
@support_torch_compile
class ModelWithOneSizeCheck(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward(self, x: torch.Tensor):
# This will cause 0/1 specializations.
if x.shape[0] == 0:
return x * 10
if x.shape[0] == 1:
return x * 10
else:
return x * 10
@contextmanager
def use_vllm_config(vllm_config: VllmConfig):
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
yield
monkeypatch.setenv("TOKENIZERS_PARALLELISM", "true")
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "0")
# Create vllm config with the desired settings
from vllm.config import CompilationMode
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
dynamic_shapes_config=DynamicShapesConfig(
type=dynamic_shapes_type,
evaluate_guards=evaluate_guards,
),
)
)
def test(model_class, input1, input2, is_01_specialization=False):
with (
torch.no_grad(),
use_vllm_config(vllm_config),
tempfile.TemporaryDirectory() as tmpdirname,
):
monkeypatch.setenv("VLLM_CACHE_ROOT", tmpdirname)
model = model_class(vllm_config=vllm_config).cuda()
model(input1)
if evaluate_guards and (
not (
is_01_specialization
and dynamic_shapes_type == DynamicShapesType.BACKED
)
):
# This should fail because guards were added.
with pytest.raises(RuntimeError) as excinfo:
model(input2)
# Expected failure - guard was violated
error_msg = str(excinfo.value)
assert (
"GuardManager check failed" in error_msg
or "Detected recompile when torch.compile stance" in error_msg
), error_msg
else:
model(input2)
test(ModelWithSizeCheck, torch.randn(20, 10).cuda(), torch.randn(5, 10).cuda())
test(ModelWithSizeCheck, torch.randn(5, 10).cuda(), torch.randn(20, 10).cuda())
test(
ModelWithOneSizeCheck,
torch.randn(20, 10).cuda(),
torch.randn(1, 10).cuda(),
is_01_specialization=True,
)

View File

@ -26,6 +26,7 @@ from vllm.compilation.partition_rules import (
should_split,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import DynamicShapesType
from vllm.config.utils import Range, hash_factors
from vllm.logger import init_logger
from vllm.logging_utils import lazy
@ -722,6 +723,29 @@ class VllmBackend:
self.split_gm, submod_names_to_compile, self.vllm_config, self
).run(*fake_args)
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode()
if (
self.compilation_config.dynamic_shapes_config.evaluate_guards
and self.compilation_config.dynamic_shapes_config.type
== DynamicShapesType.BACKED
):
from torch.utils._sympy.value_ranges import ValueRanges
# Drop counter-0/1 specializations guards; for backed dynamic shapes,
# torch.compile will specialize for 0/1 inputs or otherwise guards that
# shape is >= 2. This is because it's really hard not to hit a check
# against 0/1. When we evaluate shape guards, we exclude checking those
# guards (We would fail always otherwise).
# We avoid that by updating the ranges of backed sizes when the min is
# 2 for any, we assume it's 0.
for s, r in fake_mode.shape_env.var_to_range.items():
if r.lower == 2:
fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper)
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
if not os.path.exists(graph_path):
# code adapted from
@ -749,8 +773,6 @@ class VllmBackend:
graph, example_inputs, self.prefix, self.split_gm
)
# if we need to copy input buffers for cudagraph
#
# index of tensors that have symbolic shapes (batch size)
# for weights and static buffers, they will have concrete shapes.
# symbolic shape only happens for input tensors.

View File

@ -392,7 +392,6 @@ def _support_torch_compile(
factors.append(_model_hash_key(self.forward))
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT,
"torch_aot_compile",
@ -413,7 +412,8 @@ def _support_torch_compile(
f, f_globals=self.forward.__globals__
)
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
loaded_fn.disable_guard_check()
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
loaded_fn.disable_guard_check()
self.aot_compiled_fn = loaded_fn
except Exception as e:
if os.path.exists(aot_compilation_path):

View File

@ -4,7 +4,7 @@
import os
import sys
from abc import abstractmethod
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from types import CodeType
from typing import Any
@ -13,6 +13,7 @@ import torch._C._dynamo.guards
import vllm.envs as envs
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
@ -125,23 +126,49 @@ class TorchCompileWithNoGuardsWrapper:
if isinstance(backend, str) and backend == "inductor":
options = vllm_config.compilation_config.inductor_compile_config
if mode != CompilationMode.STOCK_TORCH_COMPILE:
# Drop all the guards.
options["guard_filter_fn"] = lambda x: [False for _ in x]
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
from vllm.compilation.decorators import DynamicShapesType
self.first_compile = True
self.evaluate_guards = (
vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
)
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
compiled_ptr: Any = self.forward
if ds_type == DynamicShapesType.UNBACKED:
if envs.VLLM_USE_BYTECODE_HOOK:
# reason is that bytecode does this hack torch._dynamo.eval_frame.
# remove_from_cache(self.original_code_object()) to force a new
# re-compilation.
raise ValueError(
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
if mode != CompilationMode.STOCK_TORCH_COMPILE:
# Drop all the guards.
if self.evaluate_guards:
assert not envs.VLLM_USE_BYTECODE_HOOK, (
"compilation_config.dynamic_shapes_config.evaluate_guards "
"requires VLLM_USE_BYTECODE_HOOK=0. "
)
if envs.VLLM_USE_AOT_COMPILE:
# disabled until https://github.com/pytorch/pytorch/pull/169239
# is picked up.
assert ds_type != DynamicShapesType.BACKED, (
"evaluate_guards for backed shapes requires "
"VLLM_USE_AOT_COMPILE=False. "
)
options["guard_filter_fn"] = lambda x: [
entry.guard_type == "SHAPE_ENV" for entry in x
]
else:
options["guard_filter_fn"] = lambda x: [False for _ in x]
compiled_ptr: Any = self.forward
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
if ds_type == DynamicShapesType.UNBACKED:
# reason is that bytecode does torch._dynamo.eval_frame.
# remove_from_cache(self.original_code_object()) to force a new
# re-compilation. And if we use
# compiled_ptr = self.check_invariants_and_forward
# it will reset all entries.
assert not envs.VLLM_USE_BYTECODE_HOOK, (
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
)
assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"
compiled_ptr = self.check_invariants_and_forward
if envs.VLLM_USE_AOT_COMPILE:
@ -195,7 +222,13 @@ class TorchCompileWithNoGuardsWrapper:
self.forward, *args, **kwargs
)
else:
with _compilation_context():
ctx = (
nullcontext()
if self.first_compile or not self.evaluate_guards
else torch.compiler.set_stance("fail_on_recompile")
)
self.first_compile = False
with _compilation_context(), ctx:
return self._call_with_optional_nvtx_range(
self._compiled_callable, *args, **kwargs
)

View File

@ -344,7 +344,18 @@ class DynamicShapesConfig:
backed/unbacked.
"""
# TODO add a debug mode to fail
evaluate_guards: bool = False
"""
A debug mode to detect and fail if Dynamo ever specializes a dynamic shape by
guarding on it. When True, dynamic shape guards are not dropped from dynamo.
And a failure will be triggered if a recompilation ever happens due to that.
This mode requires VLLM_USE_BYTECODE_HOOK to be 0.
Enabling this allow observing the dynamic shapes guards in the tlparse
artifacts also.
When type is backed, aot_compile must be disabled for this mode to work.
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
"""
def compute_hash(self) -> str:
"""
@ -455,8 +466,8 @@ class CompilationConfig:
We use string to avoid serialization issues when using compilation in a
distributed setting. When the compilation mode is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
compilation mode is 3, the backend supports both whole graph and piecewise
compilation, available backends include eager, inductor, and custom backends,
compilation mode is 3, the backend supports both whole graph and piecewise
compilation, available backends include eager, inductor, and custom backends,
the latter of which can be defined via `get_compile_backend`. Furthermore,
compilation is only piecewise if splitting ops is set accordingly and
use_inductor_graph_partition is off. Note that the default options for

View File

@ -66,7 +66,7 @@ class OptimizationLevel(IntEnum):
"""O0 : No optimization. no compilation, no cudagraphs, no other
optimization, just starting up immediately"""
O1 = 1
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
cudagraphs"""
O2 = 2
"""O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""