[compile] Fix CI for test_gpt2_cache_hit (#30902)

Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
Zhengxu Chen 2025-12-17 23:22:23 -05:00 committed by GitHub
parent 4a8412f773
commit 5f2f3fba1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 6 deletions

View File

@ -9,6 +9,7 @@ from contextlib import contextmanager
import pytest import pytest
import torch import torch
import vllm.model_executor.layers.activation
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ( from vllm.config import (
CompilationConfig, CompilationConfig,
@ -16,9 +17,12 @@ from vllm.config import (
VllmConfig, VllmConfig,
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.envs import disable_envs_cache
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import create_new_process_for_each_test
def reference_fn(x: torch.Tensor): def reference_fn(x: torch.Tensor):
assert x.shape[0] <= 42 assert x.shape[0] <= 42
@ -66,6 +70,7 @@ def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
torch.compiler.set_stance("fail_on_recompile"), torch.compiler.set_stance("fail_on_recompile"),
): ):
CompiledMod(vllm_config=vllm_config)(*args) CompiledMod(vllm_config=vllm_config)(*args)
disable_envs_cache()
m.setenv("VLLM_USE_AOT_COMPILE", "1") m.setenv("VLLM_USE_AOT_COMPILE", "1")
torch._dynamo.reset() torch._dynamo.reset()
@ -101,6 +106,7 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
vllm_config = make_vllm_config() vllm_config = make_vllm_config()
with use_vllm_config(vllm_config): with use_vllm_config(vllm_config):
expected = CompiledMod(vllm_config=vllm_config)(*args) expected = CompiledMod(vllm_config=vllm_config)(*args)
disable_envs_cache()
m.setenv("VLLM_FORCE_AOT_LOAD", "1") m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config() vllm_config = make_vllm_config()
@ -130,6 +136,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
artifacts = compiled_mod.aot_compiled_fn._artifacts artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards() guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
disable_envs_cache()
m.setenv("VLLM_FORCE_AOT_LOAD", "1") m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config() vllm_config = make_vllm_config()
@ -144,7 +151,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
@pytest.mark.skipif( @pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
) )
@use_vllm_config(make_vllm_config()) @create_new_process_for_each_test("spawn")
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch): def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
""" """
Test that compiling gpt2 twice results in a cache hit and Test that compiling gpt2 twice results in a cache hit and
@ -186,6 +193,8 @@ def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
# Clean up first model # Clean up first model
del llm_model del llm_model
disable_envs_cache()
vllm.model_executor.layers.activation._ACTIVATION_REGISTRY._dict.clear()
# Second compilation - should hit cache # Second compilation - should hit cache
m.setenv("VLLM_FORCE_AOT_LOAD", "1") m.setenv("VLLM_FORCE_AOT_LOAD", "1")