mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:55:01 +08:00
Add evaluate_guards option to DynamicShapesConfig (#27432)
Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
parent
184076c3fe
commit
87aee9ed2b
@ -2,12 +2,21 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config.compilation import CompilationMode, DynamicShapesType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.config.compilation import (
|
||||
CompilationMode,
|
||||
DynamicShapesConfig,
|
||||
DynamicShapesType,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
@ -29,18 +38,19 @@ def get_test_models():
|
||||
)
|
||||
@pytest.mark.parametrize("use_aot_compile", ["0"])
|
||||
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||
@pytest.mark.parametrize("evaluate_guards", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
def test_dynamic_shapes_compilation(
|
||||
monkeypatch, model_name, shapes_type, use_aot_compile, use_bytecode_hook
|
||||
monkeypatch,
|
||||
model_name,
|
||||
shapes_type,
|
||||
use_aot_compile,
|
||||
use_bytecode_hook,
|
||||
evaluate_guards,
|
||||
):
|
||||
"""Test that all dynamic shapes types compile successfully"""
|
||||
print(
|
||||
f"\nTesting model: {model_name} with {shapes_type.name}, "
|
||||
f"AOT compile: {use_aot_compile}, "
|
||||
f"Bytecode hook: {use_bytecode_hook}"
|
||||
)
|
||||
if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
|
||||
pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
|
||||
|
||||
@ -58,6 +68,7 @@ def test_dynamic_shapes_compilation(
|
||||
"mode": CompilationMode.VLLM_COMPILE,
|
||||
"dynamic_shapes_config": {
|
||||
"type": shapes_type.value,
|
||||
"evaluate_guards": evaluate_guards,
|
||||
},
|
||||
},
|
||||
)
|
||||
@ -86,3 +97,117 @@ def test_dynamic_shapes_compilation(
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
print("GPU memory cleared")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_aot_compile", ["0", "1"])
|
||||
@pytest.mark.parametrize(
|
||||
"dynamic_shapes_type",
|
||||
[
|
||||
DynamicShapesType.BACKED,
|
||||
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("evaluate_guards", [False, True])
|
||||
def test_model_specialization_with_evaluate_guards(
|
||||
monkeypatch, use_aot_compile, dynamic_shapes_type, evaluate_guards
|
||||
):
|
||||
"""Test that evaluate_guards correctly detects shape specialization
|
||||
violations.
|
||||
"""
|
||||
|
||||
if (
|
||||
use_aot_compile == "1"
|
||||
and dynamic_shapes_type == DynamicShapesType.BACKED
|
||||
and evaluate_guards
|
||||
):
|
||||
pytest.skip("evaluate_guards for backed does not work with aot_compile =1")
|
||||
|
||||
@support_torch_compile
|
||||
class ModelWithSizeCheck(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# This will cause specialization - torch.compile will guard on
|
||||
# sx.shape[0]
|
||||
if x.shape[0] >= 10:
|
||||
return x * 10
|
||||
else:
|
||||
return x * 10
|
||||
|
||||
@support_torch_compile
|
||||
class ModelWithOneSizeCheck(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# This will cause 0/1 specializations.
|
||||
if x.shape[0] == 0:
|
||||
return x * 10
|
||||
if x.shape[0] == 1:
|
||||
return x * 10
|
||||
else:
|
||||
return x * 10
|
||||
|
||||
@contextmanager
|
||||
def use_vllm_config(vllm_config: VllmConfig):
|
||||
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
|
||||
yield
|
||||
|
||||
monkeypatch.setenv("TOKENIZERS_PARALLELISM", "true")
|
||||
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
|
||||
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "0")
|
||||
|
||||
# Create vllm config with the desired settings
|
||||
from vllm.config import CompilationMode
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
dynamic_shapes_config=DynamicShapesConfig(
|
||||
type=dynamic_shapes_type,
|
||||
evaluate_guards=evaluate_guards,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def test(model_class, input1, input2, is_01_specialization=False):
|
||||
with (
|
||||
torch.no_grad(),
|
||||
use_vllm_config(vllm_config),
|
||||
tempfile.TemporaryDirectory() as tmpdirname,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_CACHE_ROOT", tmpdirname)
|
||||
|
||||
model = model_class(vllm_config=vllm_config).cuda()
|
||||
|
||||
model(input1)
|
||||
|
||||
if evaluate_guards and (
|
||||
not (
|
||||
is_01_specialization
|
||||
and dynamic_shapes_type == DynamicShapesType.BACKED
|
||||
)
|
||||
):
|
||||
# This should fail because guards were added.
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
model(input2)
|
||||
|
||||
# Expected failure - guard was violated
|
||||
error_msg = str(excinfo.value)
|
||||
assert (
|
||||
"GuardManager check failed" in error_msg
|
||||
or "Detected recompile when torch.compile stance" in error_msg
|
||||
), error_msg
|
||||
|
||||
else:
|
||||
model(input2)
|
||||
|
||||
test(ModelWithSizeCheck, torch.randn(20, 10).cuda(), torch.randn(5, 10).cuda())
|
||||
test(ModelWithSizeCheck, torch.randn(5, 10).cuda(), torch.randn(20, 10).cuda())
|
||||
test(
|
||||
ModelWithOneSizeCheck,
|
||||
torch.randn(20, 10).cuda(),
|
||||
torch.randn(1, 10).cuda(),
|
||||
is_01_specialization=True,
|
||||
)
|
||||
|
||||
@ -26,6 +26,7 @@ from vllm.compilation.partition_rules import (
|
||||
should_split,
|
||||
)
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.config.utils import Range, hash_factors
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils import lazy
|
||||
@ -722,6 +723,29 @@ class VllmBackend:
|
||||
self.split_gm, submod_names_to_compile, self.vllm_config, self
|
||||
).run(*fake_args)
|
||||
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
fake_mode = detect_fake_mode()
|
||||
|
||||
if (
|
||||
self.compilation_config.dynamic_shapes_config.evaluate_guards
|
||||
and self.compilation_config.dynamic_shapes_config.type
|
||||
== DynamicShapesType.BACKED
|
||||
):
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
# Drop counter-0/1 specializations guards; for backed dynamic shapes,
|
||||
# torch.compile will specialize for 0/1 inputs or otherwise guards that
|
||||
# shape is >= 2. This is because it's really hard not to hit a check
|
||||
# against 0/1. When we evaluate shape guards, we exclude checking those
|
||||
# guards (We would fail always otherwise).
|
||||
|
||||
# We avoid that by updating the ranges of backed sizes when the min is
|
||||
# 2 for any, we assume it's 0.
|
||||
for s, r in fake_mode.shape_env.var_to_range.items():
|
||||
if r.lower == 2:
|
||||
fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper)
|
||||
|
||||
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
||||
if not os.path.exists(graph_path):
|
||||
# code adapted from
|
||||
@ -749,8 +773,6 @@ class VllmBackend:
|
||||
graph, example_inputs, self.prefix, self.split_gm
|
||||
)
|
||||
|
||||
# if we need to copy input buffers for cudagraph
|
||||
#
|
||||
# index of tensors that have symbolic shapes (batch size)
|
||||
# for weights and static buffers, they will have concrete shapes.
|
||||
# symbolic shape only happens for input tensors.
|
||||
|
||||
@ -392,7 +392,6 @@ def _support_torch_compile(
|
||||
|
||||
factors.append(_model_hash_key(self.forward))
|
||||
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
"torch_aot_compile",
|
||||
@ -413,7 +412,8 @@ def _support_torch_compile(
|
||||
f, f_globals=self.forward.__globals__
|
||||
)
|
||||
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
|
||||
loaded_fn.disable_guard_check()
|
||||
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
|
||||
loaded_fn.disable_guard_check()
|
||||
self.aot_compiled_fn = loaded_fn
|
||||
except Exception as e:
|
||||
if os.path.exists(aot_compilation_path):
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from types import CodeType
|
||||
from typing import Any
|
||||
|
||||
@ -13,6 +13,7 @@ import torch._C._dynamo.guards
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
||||
from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
|
||||
|
||||
@ -125,23 +126,49 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
if isinstance(backend, str) and backend == "inductor":
|
||||
options = vllm_config.compilation_config.inductor_compile_config
|
||||
|
||||
if mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
# Drop all the guards.
|
||||
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
||||
|
||||
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
|
||||
from vllm.compilation.decorators import DynamicShapesType
|
||||
self.first_compile = True
|
||||
self.evaluate_guards = (
|
||||
vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
|
||||
)
|
||||
|
||||
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
|
||||
compiled_ptr: Any = self.forward
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
if envs.VLLM_USE_BYTECODE_HOOK:
|
||||
# reason is that bytecode does this hack torch._dynamo.eval_frame.
|
||||
# remove_from_cache(self.original_code_object()) to force a new
|
||||
# re-compilation.
|
||||
raise ValueError(
|
||||
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
|
||||
|
||||
if mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
# Drop all the guards.
|
||||
if self.evaluate_guards:
|
||||
assert not envs.VLLM_USE_BYTECODE_HOOK, (
|
||||
"compilation_config.dynamic_shapes_config.evaluate_guards "
|
||||
"requires VLLM_USE_BYTECODE_HOOK=0. "
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
# disabled until https://github.com/pytorch/pytorch/pull/169239
|
||||
# is picked up.
|
||||
assert ds_type != DynamicShapesType.BACKED, (
|
||||
"evaluate_guards for backed shapes requires "
|
||||
"VLLM_USE_AOT_COMPILE=False. "
|
||||
)
|
||||
|
||||
options["guard_filter_fn"] = lambda x: [
|
||||
entry.guard_type == "SHAPE_ENV" for entry in x
|
||||
]
|
||||
else:
|
||||
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
||||
|
||||
compiled_ptr: Any = self.forward
|
||||
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
|
||||
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
# reason is that bytecode does torch._dynamo.eval_frame.
|
||||
# remove_from_cache(self.original_code_object()) to force a new
|
||||
# re-compilation. And if we use
|
||||
# compiled_ptr = self.check_invariants_and_forward
|
||||
# it will reset all entries.
|
||||
assert not envs.VLLM_USE_BYTECODE_HOOK, (
|
||||
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
|
||||
)
|
||||
assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"
|
||||
|
||||
compiled_ptr = self.check_invariants_and_forward
|
||||
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
@ -195,7 +222,13 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
self.forward, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
with _compilation_context():
|
||||
ctx = (
|
||||
nullcontext()
|
||||
if self.first_compile or not self.evaluate_guards
|
||||
else torch.compiler.set_stance("fail_on_recompile")
|
||||
)
|
||||
self.first_compile = False
|
||||
with _compilation_context(), ctx:
|
||||
return self._call_with_optional_nvtx_range(
|
||||
self._compiled_callable, *args, **kwargs
|
||||
)
|
||||
|
||||
@ -344,7 +344,18 @@ class DynamicShapesConfig:
|
||||
backed/unbacked.
|
||||
"""
|
||||
|
||||
# TODO add a debug mode to fail
|
||||
evaluate_guards: bool = False
|
||||
"""
|
||||
A debug mode to detect and fail if Dynamo ever specializes a dynamic shape by
|
||||
guarding on it. When True, dynamic shape guards are not dropped from dynamo.
|
||||
And a failure will be triggered if a recompilation ever happens due to that.
|
||||
This mode requires VLLM_USE_BYTECODE_HOOK to be 0.
|
||||
Enabling this allow observing the dynamic shapes guards in the tlparse
|
||||
artifacts also.
|
||||
When type is backed, aot_compile must be disabled for this mode to work.
|
||||
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
|
||||
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@ -455,8 +466,8 @@ class CompilationConfig:
|
||||
We use string to avoid serialization issues when using compilation in a
|
||||
distributed setting. When the compilation mode is 1 or 2, the backend is
|
||||
used for the compilation directly (it sees the whole graph). When the
|
||||
compilation mode is 3, the backend supports both whole graph and piecewise
|
||||
compilation, available backends include eager, inductor, and custom backends,
|
||||
compilation mode is 3, the backend supports both whole graph and piecewise
|
||||
compilation, available backends include eager, inductor, and custom backends,
|
||||
the latter of which can be defined via `get_compile_backend`. Furthermore,
|
||||
compilation is only piecewise if splitting ops is set accordingly and
|
||||
use_inductor_graph_partition is off. Note that the default options for
|
||||
|
||||
@ -66,7 +66,7 @@ class OptimizationLevel(IntEnum):
|
||||
"""O0 : No optimization. no compilation, no cudagraphs, no other
|
||||
optimization, just starting up immediately"""
|
||||
O1 = 1
|
||||
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
|
||||
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
|
||||
cudagraphs"""
|
||||
O2 = 2
|
||||
"""O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user