[compile] Fix CI for test_gpt2_cache_hit (#30902)

Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
Zhengxu Chen 2025-12-17 23:22:23 -05:00 committed by GitHub
parent 4a8412f773
commit 5f2f3fba1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 6 deletions

View File

@ -9,6 +9,7 @@ from contextlib import contextmanager
import pytest
import torch
import vllm.model_executor.layers.activation
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CompilationConfig,
@ -16,9 +17,12 @@ from vllm.config import (
VllmConfig,
set_current_vllm_config,
)
from vllm.envs import disable_envs_cache
from vllm.forward_context import set_forward_context
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import create_new_process_for_each_test
def reference_fn(x: torch.Tensor):
assert x.shape[0] <= 42
@ -66,6 +70,7 @@ def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
torch.compiler.set_stance("fail_on_recompile"),
):
CompiledMod(vllm_config=vllm_config)(*args)
disable_envs_cache()
m.setenv("VLLM_USE_AOT_COMPILE", "1")
torch._dynamo.reset()
@ -101,6 +106,7 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
expected = CompiledMod(vllm_config=vllm_config)(*args)
disable_envs_cache()
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config()
@ -130,6 +136,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
disable_envs_cache()
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config()
@ -144,7 +151,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
@use_vllm_config(make_vllm_config())
@create_new_process_for_each_test("spawn")
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
"""
Test that compiling gpt2 twice results in a cache hit and
@ -186,6 +193,8 @@ def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
# Clean up first model
del llm_model
disable_envs_cache()
vllm.model_executor.layers.activation._ACTIVATION_REGISTRY._dict.clear()
# Second compilation - should hit cache
m.setenv("VLLM_FORCE_AOT_LOAD", "1")

View File

@ -437,14 +437,14 @@ class CompilationConfig:
compile_ranges_split_points: list[int] | None = None
"""Split points that represent compile ranges for inductor.
The compile ranges are
[1, split_points[0]],
[split_points[0] + 1, split_points[1]], ...,
The compile ranges are
[1, split_points[0]],
[split_points[0] + 1, split_points[1]], ...,
[split_points[-1] + 1, max_num_batched_tokens].
Compile sizes are also used single element ranges,
the range is represented as [compile_sizes[i], compile_sizes[i]].
If a range overlaps with the compile size, graph for compile size
If a range overlaps with the compile size, graph for compile size
will be prioritized, i.e. if we have a range [1, 8] and a compile size 4,
graph for compile size 4 will be compiled and used instead of the graph
for range [1, 8].