mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 21:17:13 +08:00
[compile] Fix CI for test_gpt2_cache_hit (#30902)
Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
parent
4a8412f773
commit
5f2f3fba1d
@ -9,6 +9,7 @@ from contextlib import contextmanager
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.activation
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CompilationConfig,
|
CompilationConfig,
|
||||||
@ -16,9 +17,12 @@ from vllm.config import (
|
|||||||
VllmConfig,
|
VllmConfig,
|
||||||
set_current_vllm_config,
|
set_current_vllm_config,
|
||||||
)
|
)
|
||||||
|
from vllm.envs import disable_envs_cache
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
|
from ..utils import create_new_process_for_each_test
|
||||||
|
|
||||||
|
|
||||||
def reference_fn(x: torch.Tensor):
|
def reference_fn(x: torch.Tensor):
|
||||||
assert x.shape[0] <= 42
|
assert x.shape[0] <= 42
|
||||||
@ -66,6 +70,7 @@ def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
|
|||||||
torch.compiler.set_stance("fail_on_recompile"),
|
torch.compiler.set_stance("fail_on_recompile"),
|
||||||
):
|
):
|
||||||
CompiledMod(vllm_config=vllm_config)(*args)
|
CompiledMod(vllm_config=vllm_config)(*args)
|
||||||
|
disable_envs_cache()
|
||||||
|
|
||||||
m.setenv("VLLM_USE_AOT_COMPILE", "1")
|
m.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
@ -101,6 +106,7 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
|
|||||||
vllm_config = make_vllm_config()
|
vllm_config = make_vllm_config()
|
||||||
with use_vllm_config(vllm_config):
|
with use_vllm_config(vllm_config):
|
||||||
expected = CompiledMod(vllm_config=vllm_config)(*args)
|
expected = CompiledMod(vllm_config=vllm_config)(*args)
|
||||||
|
disable_envs_cache()
|
||||||
|
|
||||||
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||||
vllm_config = make_vllm_config()
|
vllm_config = make_vllm_config()
|
||||||
@ -130,6 +136,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
|
|||||||
artifacts = compiled_mod.aot_compiled_fn._artifacts
|
artifacts = compiled_mod.aot_compiled_fn._artifacts
|
||||||
guards_string = artifacts.compiled_fn.shape_env.format_guards()
|
guards_string = artifacts.compiled_fn.shape_env.format_guards()
|
||||||
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
|
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
|
||||||
|
disable_envs_cache()
|
||||||
|
|
||||||
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||||
vllm_config = make_vllm_config()
|
vllm_config = make_vllm_config()
|
||||||
@ -144,7 +151,7 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
|
|||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||||
)
|
)
|
||||||
@use_vllm_config(make_vllm_config())
|
@create_new_process_for_each_test("spawn")
|
||||||
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
|
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
|
||||||
"""
|
"""
|
||||||
Test that compiling gpt2 twice results in a cache hit and
|
Test that compiling gpt2 twice results in a cache hit and
|
||||||
@ -186,6 +193,8 @@ def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
|
|||||||
|
|
||||||
# Clean up first model
|
# Clean up first model
|
||||||
del llm_model
|
del llm_model
|
||||||
|
disable_envs_cache()
|
||||||
|
vllm.model_executor.layers.activation._ACTIVATION_REGISTRY._dict.clear()
|
||||||
|
|
||||||
# Second compilation - should hit cache
|
# Second compilation - should hit cache
|
||||||
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user