diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index 71b0e86c75c18..b6a78eaa09209 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import pytest from vllm import LLM, SamplingParams +from vllm.platforms import current_platform from ...utils import check_answers, prep_prompts @@ -40,10 +41,17 @@ def test_sliding_window_retrieval( If we tell it upfront which we are going to be looking for, then it answers correctly (mostly). """ + # NOTE: For ROCm, we have to enforce eager mode to use custom kernel + # implementation of GELU with tanh approximation, as PyTorch's native + # implementation is currently unstable with torch.compile and produces garbage. + enforce_eager = current_platform.is_rocm() + test_config = model_config[model] llm = LLM( - model=model, disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager + model=model, + disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager, + enforce_eager=enforce_eager, ) sampling_params = SamplingParams(temperature=0.0, max_tokens=100)