vllm/tests/compile/test_aot_compile.py
Laith Sakka 1f0d184590
[aot_compile]change VLLM backend to read fake args from example_value (#29104)
Signed-off-by: Laith Sakka <lsakka@meta.com>
2025-12-04 17:33:45 -05:00

206 lines
7.0 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import multiprocessing
import tempfile
from contextlib import contextmanager
import pytest
import torch
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CompilationConfig,
CompilationMode,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import set_forward_context
from vllm.utils.torch_utils import is_torch_equal_or_newer
def reference_fn(x: torch.Tensor):
assert x.shape[0] <= 42
assert x.shape[0] % 2 == 0
for _ in range(3000):
x = x + x.shape[0]
return x
@support_torch_compile
class CompiledMod(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward(self, x: torch.Tensor):
return reference_fn(x)
def make_vllm_config() -> VllmConfig:
return VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
)
)
@contextmanager
def use_vllm_config(vllm_config: VllmConfig):
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
yield
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
vllm_config = make_vllm_config()
args = (torch.randn(10, 10),)
expected = reference_fn(*args)
with use_vllm_config(vllm_config):
m.setenv("VLLM_USE_AOT_COMPILE", "0")
with (
pytest.raises(RuntimeError, match="Detected recompile"),
torch.compiler.set_stance("fail_on_recompile"),
):
CompiledMod(vllm_config=vllm_config)(*args)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
actual = CompiledMod(vllm_config=vllm_config)(*args)
assert torch.allclose(actual, expected)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m:
args = (torch.randn(10, 10),)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError):
CompiledMod(vllm_config=vllm_config)(*args)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
args = (torch.randn(10, 10),)
with tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
expected = CompiledMod(vllm_config=vllm_config)(*args)
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
ret = CompiledMod(vllm_config=vllm_config)(*args)
assert torch.allclose(ret, expected)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_shape_env(monkeypatch: pytest.MonkeyPatch):
"""
Test that the shape environment is correctly serialized and preserved
when loading from cache.
"""
with monkeypatch.context() as m:
args = (torch.randn(10, 10),)
with tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
compiled_mod(*args)
artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
compiled_mod(*args)
artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
@use_vllm_config(make_vllm_config())
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
"""
Test that compiling gpt2 twice results in a cache hit and
capture torch dynamic symbol creations to ensure make_symbol
not called on cache hit.
"""
import torch.fx.experimental.symbolic_shapes as symbolic_shapes_module
from torch.utils._sympy.symbol import make_symbol
from vllm import LLM
create_symbol_counter = multiprocessing.Value("i", 0)
original_make_symbol = make_symbol
@functools.wraps(original_make_symbol)
def counting_make_symbol(prefix, idx, **kwargs):
with create_symbol_counter.get_lock():
create_symbol_counter.value += 1
return original_make_symbol(prefix, idx, **kwargs)
symbolic_shapes_module.make_symbol = counting_make_symbol
try:
with monkeypatch.context() as m, tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
# First compilation - initialize model and generate
llm_model = LLM(
model="gpt2",
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
),
max_model_len=256,
)
llm_model.generate("Hello, my name is")
assert create_symbol_counter.value == 2
create_symbol_counter.value = 0
# Clean up first model
del llm_model
# Second compilation - should hit cache
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
llm_model = LLM(
model="gpt2",
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
),
max_model_len=256,
)
llm_model.generate("Hello, my name is")
assert create_symbol_counter.value == 0
finally:
# Restore original method
symbolic_shapes_module.make_symbol = original_make_symbol