mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 13:36:12 +08:00
Add evaluate_guards option to DynamicShapesConfig (#27432)
Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
parent
184076c3fe
commit
87aee9ed2b
@ -2,12 +2,21 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
import tempfile
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config.compilation import CompilationMode, DynamicShapesType
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.config.compilation import (
|
||||||
|
CompilationMode,
|
||||||
|
DynamicShapesConfig,
|
||||||
|
DynamicShapesType,
|
||||||
|
)
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.tokenizers import get_tokenizer
|
from vllm.tokenizers import get_tokenizer
|
||||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
@ -29,18 +38,19 @@ def get_test_models():
|
|||||||
)
|
)
|
||||||
@pytest.mark.parametrize("use_aot_compile", ["0"])
|
@pytest.mark.parametrize("use_aot_compile", ["0"])
|
||||||
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
|
||||||
|
@pytest.mark.parametrize("evaluate_guards", [False, True])
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||||
)
|
)
|
||||||
def test_dynamic_shapes_compilation(
|
def test_dynamic_shapes_compilation(
|
||||||
monkeypatch, model_name, shapes_type, use_aot_compile, use_bytecode_hook
|
monkeypatch,
|
||||||
|
model_name,
|
||||||
|
shapes_type,
|
||||||
|
use_aot_compile,
|
||||||
|
use_bytecode_hook,
|
||||||
|
evaluate_guards,
|
||||||
):
|
):
|
||||||
"""Test that all dynamic shapes types compile successfully"""
|
"""Test that all dynamic shapes types compile successfully"""
|
||||||
print(
|
|
||||||
f"\nTesting model: {model_name} with {shapes_type.name}, "
|
|
||||||
f"AOT compile: {use_aot_compile}, "
|
|
||||||
f"Bytecode hook: {use_bytecode_hook}"
|
|
||||||
)
|
|
||||||
if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
|
if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
|
||||||
pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
|
pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
|
||||||
|
|
||||||
@ -58,6 +68,7 @@ def test_dynamic_shapes_compilation(
|
|||||||
"mode": CompilationMode.VLLM_COMPILE,
|
"mode": CompilationMode.VLLM_COMPILE,
|
||||||
"dynamic_shapes_config": {
|
"dynamic_shapes_config": {
|
||||||
"type": shapes_type.value,
|
"type": shapes_type.value,
|
||||||
|
"evaluate_guards": evaluate_guards,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -86,3 +97,117 @@ def test_dynamic_shapes_compilation(
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
print("GPU memory cleared")
|
print("GPU memory cleared")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_aot_compile", ["0", "1"])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dynamic_shapes_type",
|
||||||
|
[
|
||||||
|
DynamicShapesType.BACKED,
|
||||||
|
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("evaluate_guards", [False, True])
|
||||||
|
def test_model_specialization_with_evaluate_guards(
|
||||||
|
monkeypatch, use_aot_compile, dynamic_shapes_type, evaluate_guards
|
||||||
|
):
|
||||||
|
"""Test that evaluate_guards correctly detects shape specialization
|
||||||
|
violations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if (
|
||||||
|
use_aot_compile == "1"
|
||||||
|
and dynamic_shapes_type == DynamicShapesType.BACKED
|
||||||
|
and evaluate_guards
|
||||||
|
):
|
||||||
|
pytest.skip("evaluate_guards for backed does not work with aot_compile =1")
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class ModelWithSizeCheck(torch.nn.Module):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
# This will cause specialization - torch.compile will guard on
|
||||||
|
# sx.shape[0]
|
||||||
|
if x.shape[0] >= 10:
|
||||||
|
return x * 10
|
||||||
|
else:
|
||||||
|
return x * 10
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class ModelWithOneSizeCheck(torch.nn.Module):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
# This will cause 0/1 specializations.
|
||||||
|
if x.shape[0] == 0:
|
||||||
|
return x * 10
|
||||||
|
if x.shape[0] == 1:
|
||||||
|
return x * 10
|
||||||
|
else:
|
||||||
|
return x * 10
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def use_vllm_config(vllm_config: VllmConfig):
|
||||||
|
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
|
||||||
|
yield
|
||||||
|
|
||||||
|
monkeypatch.setenv("TOKENIZERS_PARALLELISM", "true")
|
||||||
|
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
|
||||||
|
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "0")
|
||||||
|
|
||||||
|
# Create vllm config with the desired settings
|
||||||
|
from vllm.config import CompilationMode
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
|
dynamic_shapes_config=DynamicShapesConfig(
|
||||||
|
type=dynamic_shapes_type,
|
||||||
|
evaluate_guards=evaluate_guards,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test(model_class, input1, input2, is_01_specialization=False):
|
||||||
|
with (
|
||||||
|
torch.no_grad(),
|
||||||
|
use_vllm_config(vllm_config),
|
||||||
|
tempfile.TemporaryDirectory() as tmpdirname,
|
||||||
|
):
|
||||||
|
monkeypatch.setenv("VLLM_CACHE_ROOT", tmpdirname)
|
||||||
|
|
||||||
|
model = model_class(vllm_config=vllm_config).cuda()
|
||||||
|
|
||||||
|
model(input1)
|
||||||
|
|
||||||
|
if evaluate_guards and (
|
||||||
|
not (
|
||||||
|
is_01_specialization
|
||||||
|
and dynamic_shapes_type == DynamicShapesType.BACKED
|
||||||
|
)
|
||||||
|
):
|
||||||
|
# This should fail because guards were added.
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
model(input2)
|
||||||
|
|
||||||
|
# Expected failure - guard was violated
|
||||||
|
error_msg = str(excinfo.value)
|
||||||
|
assert (
|
||||||
|
"GuardManager check failed" in error_msg
|
||||||
|
or "Detected recompile when torch.compile stance" in error_msg
|
||||||
|
), error_msg
|
||||||
|
|
||||||
|
else:
|
||||||
|
model(input2)
|
||||||
|
|
||||||
|
test(ModelWithSizeCheck, torch.randn(20, 10).cuda(), torch.randn(5, 10).cuda())
|
||||||
|
test(ModelWithSizeCheck, torch.randn(5, 10).cuda(), torch.randn(20, 10).cuda())
|
||||||
|
test(
|
||||||
|
ModelWithOneSizeCheck,
|
||||||
|
torch.randn(20, 10).cuda(),
|
||||||
|
torch.randn(1, 10).cuda(),
|
||||||
|
is_01_specialization=True,
|
||||||
|
)
|
||||||
|
|||||||
@ -26,6 +26,7 @@ from vllm.compilation.partition_rules import (
|
|||||||
should_split,
|
should_split,
|
||||||
)
|
)
|
||||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||||
|
from vllm.config.compilation import DynamicShapesType
|
||||||
from vllm.config.utils import Range, hash_factors
|
from vllm.config.utils import Range, hash_factors
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logging_utils import lazy
|
from vllm.logging_utils import lazy
|
||||||
@ -722,6 +723,29 @@ class VllmBackend:
|
|||||||
self.split_gm, submod_names_to_compile, self.vllm_config, self
|
self.split_gm, submod_names_to_compile, self.vllm_config, self
|
||||||
).run(*fake_args)
|
).run(*fake_args)
|
||||||
|
|
||||||
|
from torch._guards import detect_fake_mode
|
||||||
|
|
||||||
|
fake_mode = detect_fake_mode()
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.compilation_config.dynamic_shapes_config.evaluate_guards
|
||||||
|
and self.compilation_config.dynamic_shapes_config.type
|
||||||
|
== DynamicShapesType.BACKED
|
||||||
|
):
|
||||||
|
from torch.utils._sympy.value_ranges import ValueRanges
|
||||||
|
|
||||||
|
# Drop counter-0/1 specializations guards; for backed dynamic shapes,
|
||||||
|
# torch.compile will specialize for 0/1 inputs or otherwise guards that
|
||||||
|
# shape is >= 2. This is because it's really hard not to hit a check
|
||||||
|
# against 0/1. When we evaluate shape guards, we exclude checking those
|
||||||
|
# guards (We would fail always otherwise).
|
||||||
|
|
||||||
|
# We avoid that by updating the ranges of backed sizes when the min is
|
||||||
|
# 2 for any, we assume it's 0.
|
||||||
|
for s, r in fake_mode.shape_env.var_to_range.items():
|
||||||
|
if r.lower == 2:
|
||||||
|
fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper)
|
||||||
|
|
||||||
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
||||||
if not os.path.exists(graph_path):
|
if not os.path.exists(graph_path):
|
||||||
# code adapted from
|
# code adapted from
|
||||||
@ -749,8 +773,6 @@ class VllmBackend:
|
|||||||
graph, example_inputs, self.prefix, self.split_gm
|
graph, example_inputs, self.prefix, self.split_gm
|
||||||
)
|
)
|
||||||
|
|
||||||
# if we need to copy input buffers for cudagraph
|
|
||||||
#
|
|
||||||
# index of tensors that have symbolic shapes (batch size)
|
# index of tensors that have symbolic shapes (batch size)
|
||||||
# for weights and static buffers, they will have concrete shapes.
|
# for weights and static buffers, they will have concrete shapes.
|
||||||
# symbolic shape only happens for input tensors.
|
# symbolic shape only happens for input tensors.
|
||||||
|
|||||||
@ -392,7 +392,6 @@ def _support_torch_compile(
|
|||||||
|
|
||||||
factors.append(_model_hash_key(self.forward))
|
factors.append(_model_hash_key(self.forward))
|
||||||
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
|
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
|
||||||
|
|
||||||
cache_dir = os.path.join(
|
cache_dir = os.path.join(
|
||||||
envs.VLLM_CACHE_ROOT,
|
envs.VLLM_CACHE_ROOT,
|
||||||
"torch_aot_compile",
|
"torch_aot_compile",
|
||||||
@ -413,6 +412,7 @@ def _support_torch_compile(
|
|||||||
f, f_globals=self.forward.__globals__
|
f, f_globals=self.forward.__globals__
|
||||||
)
|
)
|
||||||
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
|
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
|
||||||
|
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
|
||||||
loaded_fn.disable_guard_check()
|
loaded_fn.disable_guard_check()
|
||||||
self.aot_compiled_fn = loaded_fn
|
self.aot_compiled_fn = loaded_fn
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager, nullcontext
|
||||||
from types import CodeType
|
from types import CodeType
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -13,6 +13,7 @@ import torch._C._dynamo.guards
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
||||||
|
from vllm.config.compilation import DynamicShapesType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
|
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
|
||||||
|
|
||||||
@ -125,23 +126,49 @@ class TorchCompileWithNoGuardsWrapper:
|
|||||||
if isinstance(backend, str) and backend == "inductor":
|
if isinstance(backend, str) and backend == "inductor":
|
||||||
options = vllm_config.compilation_config.inductor_compile_config
|
options = vllm_config.compilation_config.inductor_compile_config
|
||||||
|
|
||||||
if mode != CompilationMode.STOCK_TORCH_COMPILE:
|
self.first_compile = True
|
||||||
# Drop all the guards.
|
self.evaluate_guards = (
|
||||||
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
|
||||||
|
)
|
||||||
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
|
|
||||||
from vllm.compilation.decorators import DynamicShapesType
|
|
||||||
|
|
||||||
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
|
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
|
||||||
compiled_ptr: Any = self.forward
|
|
||||||
if ds_type == DynamicShapesType.UNBACKED:
|
if mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||||
if envs.VLLM_USE_BYTECODE_HOOK:
|
# Drop all the guards.
|
||||||
# reason is that bytecode does this hack torch._dynamo.eval_frame.
|
if self.evaluate_guards:
|
||||||
# remove_from_cache(self.original_code_object()) to force a new
|
assert not envs.VLLM_USE_BYTECODE_HOOK, (
|
||||||
# re-compilation.
|
"compilation_config.dynamic_shapes_config.evaluate_guards "
|
||||||
raise ValueError(
|
"requires VLLM_USE_BYTECODE_HOOK=0. "
|
||||||
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if envs.VLLM_USE_AOT_COMPILE:
|
||||||
|
# disabled until https://github.com/pytorch/pytorch/pull/169239
|
||||||
|
# is picked up.
|
||||||
|
assert ds_type != DynamicShapesType.BACKED, (
|
||||||
|
"evaluate_guards for backed shapes requires "
|
||||||
|
"VLLM_USE_AOT_COMPILE=False. "
|
||||||
|
)
|
||||||
|
|
||||||
|
options["guard_filter_fn"] = lambda x: [
|
||||||
|
entry.guard_type == "SHAPE_ENV" for entry in x
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
||||||
|
|
||||||
|
compiled_ptr: Any = self.forward
|
||||||
|
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
|
||||||
|
|
||||||
|
if ds_type == DynamicShapesType.UNBACKED:
|
||||||
|
# reason is that bytecode does torch._dynamo.eval_frame.
|
||||||
|
# remove_from_cache(self.original_code_object()) to force a new
|
||||||
|
# re-compilation. And if we use
|
||||||
|
# compiled_ptr = self.check_invariants_and_forward
|
||||||
|
# it will reset all entries.
|
||||||
|
assert not envs.VLLM_USE_BYTECODE_HOOK, (
|
||||||
|
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
|
||||||
|
)
|
||||||
|
assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"
|
||||||
|
|
||||||
compiled_ptr = self.check_invariants_and_forward
|
compiled_ptr = self.check_invariants_and_forward
|
||||||
|
|
||||||
if envs.VLLM_USE_AOT_COMPILE:
|
if envs.VLLM_USE_AOT_COMPILE:
|
||||||
@ -195,7 +222,13 @@ class TorchCompileWithNoGuardsWrapper:
|
|||||||
self.forward, *args, **kwargs
|
self.forward, *args, **kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
with _compilation_context():
|
ctx = (
|
||||||
|
nullcontext()
|
||||||
|
if self.first_compile or not self.evaluate_guards
|
||||||
|
else torch.compiler.set_stance("fail_on_recompile")
|
||||||
|
)
|
||||||
|
self.first_compile = False
|
||||||
|
with _compilation_context(), ctx:
|
||||||
return self._call_with_optional_nvtx_range(
|
return self._call_with_optional_nvtx_range(
|
||||||
self._compiled_callable, *args, **kwargs
|
self._compiled_callable, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@ -344,7 +344,18 @@ class DynamicShapesConfig:
|
|||||||
backed/unbacked.
|
backed/unbacked.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO add a debug mode to fail
|
evaluate_guards: bool = False
|
||||||
|
"""
|
||||||
|
A debug mode to detect and fail if Dynamo ever specializes a dynamic shape by
|
||||||
|
guarding on it. When True, dynamic shape guards are not dropped from dynamo.
|
||||||
|
And a failure will be triggered if a recompilation ever happens due to that.
|
||||||
|
This mode requires VLLM_USE_BYTECODE_HOOK to be 0.
|
||||||
|
Enabling this allow observing the dynamic shapes guards in the tlparse
|
||||||
|
artifacts also.
|
||||||
|
When type is backed, aot_compile must be disabled for this mode to work.
|
||||||
|
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user