mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-30 19:17:02 +08:00
[aot_compile]change VLLM backend to read fake args from example_value (#29104)
Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
parent
c8ab988b15
commit
1f0d184590
@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import multiprocessing
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
|
||||
@ -137,3 +139,67 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
|
||||
artifacts = compiled_mod.aot_compiled_fn._artifacts
|
||||
guards_string = artifacts.compiled_fn.shape_env.format_guards()
|
||||
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||
)
|
||||
@use_vllm_config(make_vllm_config())
|
||||
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test that compiling gpt2 twice results in a cache hit and
|
||||
capture torch dynamic symbol creations to ensure make_symbol
|
||||
not called on cache hit.
|
||||
"""
|
||||
|
||||
import torch.fx.experimental.symbolic_shapes as symbolic_shapes_module
|
||||
from torch.utils._sympy.symbol import make_symbol
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
create_symbol_counter = multiprocessing.Value("i", 0)
|
||||
original_make_symbol = make_symbol
|
||||
|
||||
@functools.wraps(original_make_symbol)
|
||||
def counting_make_symbol(prefix, idx, **kwargs):
|
||||
with create_symbol_counter.get_lock():
|
||||
create_symbol_counter.value += 1
|
||||
return original_make_symbol(prefix, idx, **kwargs)
|
||||
|
||||
symbolic_shapes_module.make_symbol = counting_make_symbol
|
||||
try:
|
||||
with monkeypatch.context() as m, tempfile.TemporaryDirectory() as tmpdirname:
|
||||
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
|
||||
m.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||
# First compilation - initialize model and generate
|
||||
llm_model = LLM(
|
||||
model="gpt2",
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
),
|
||||
max_model_len=256,
|
||||
)
|
||||
|
||||
llm_model.generate("Hello, my name is")
|
||||
assert create_symbol_counter.value == 2
|
||||
create_symbol_counter.value = 0
|
||||
|
||||
# Clean up first model
|
||||
del llm_model
|
||||
|
||||
# Second compilation - should hit cache
|
||||
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||
llm_model = LLM(
|
||||
model="gpt2",
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
),
|
||||
max_model_len=256,
|
||||
)
|
||||
llm_model.generate("Hello, my name is")
|
||||
|
||||
assert create_symbol_counter.value == 0
|
||||
|
||||
finally:
|
||||
# Restore original method
|
||||
symbolic_shapes_module.make_symbol = original_make_symbol
|
||||
|
||||
@ -402,6 +402,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
self.extra_traceback = False
|
||||
|
||||
def run(self, *args):
|
||||
# maybe instead just assert inputs are fake?
|
||||
fake_args = [
|
||||
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in args
|
||||
@ -416,11 +417,13 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
assert isinstance(target, str)
|
||||
|
||||
output = super().call_module(target, args, kwargs)
|
||||
|
||||
if target in self.compile_submod_names:
|
||||
index = self.compile_submod_names.index(target)
|
||||
submod = self.fetch_attr(target)
|
||||
|
||||
sym_shape_indices = [
|
||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||
]
|
||||
@ -746,11 +749,21 @@ class VllmBackend:
|
||||
if not item.is_splitting_graph
|
||||
]
|
||||
|
||||
# Extract fake values from the graph to use them when needed.
|
||||
all_fake_values = []
|
||||
for i in graph.graph.find_nodes(op="placeholder"):
|
||||
all_fake_values.append(i.meta["example_value"])
|
||||
|
||||
fake_args = [
|
||||
all_fake_values[i] if isinstance(t, torch.Tensor) else t
|
||||
for i, t in enumerate(example_inputs)
|
||||
]
|
||||
|
||||
# propagate the split graph to the piecewise backend,
|
||||
# compile submodules with symbolic shapes
|
||||
PiecewiseCompileInterpreter(
|
||||
self.split_gm, submod_names_to_compile, self.vllm_config, self
|
||||
).run(*example_inputs)
|
||||
).run(*fake_args)
|
||||
|
||||
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
||||
if not os.path.exists(graph_path):
|
||||
@ -780,14 +793,7 @@ class VllmBackend:
|
||||
)
|
||||
|
||||
# if we need to copy input buffers for cudagraph
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
fake_mode = detect_fake_mode()
|
||||
fake_args = [
|
||||
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in example_inputs
|
||||
]
|
||||
|
||||
#
|
||||
# index of tensors that have symbolic shapes (batch size)
|
||||
# for weights and static buffers, they will have concrete shapes.
|
||||
# symbolic shape only happens for input tensors.
|
||||
|
||||
@ -433,7 +433,6 @@ def _support_torch_compile(
|
||||
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
|
||||
|
||||
# This is the path for the first compilation.
|
||||
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
_mark_dynamic_inputs(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user