[CI][ROCm] Fix test_correctness_sliding_window (#29243)

Signed-off-by: Divakar Verma <divakar.verma@amd.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Divakar Verma 2025-12-01 22:53:27 -06:00 committed by GitHub
parent 81fe3f82af
commit a690fb5bd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,6 +5,7 @@ from dataclasses import dataclass
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
from ...utils import check_answers, prep_prompts from ...utils import check_answers, prep_prompts
@ -40,10 +41,17 @@ def test_sliding_window_retrieval(
If we tell it upfront which we are going to be looking for, then If we tell it upfront which we are going to be looking for, then
it answers correctly (mostly). it answers correctly (mostly).
""" """
# NOTE: For ROCm, we have to enforce eager mode to use custom kernel
# implementation of GELU with tanh approximation, as PyTorch's native
# implementation is currently unstable with torch.compile and produces garbage.
enforce_eager = current_platform.is_rocm()
test_config = model_config[model] test_config = model_config[model]
llm = LLM( llm = LLM(
model=model, disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager model=model,
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager,
enforce_eager=enforce_eager,
) )
sampling_params = SamplingParams(temperature=0.0, max_tokens=100) sampling_params = SamplingParams(temperature=0.0, max_tokens=100)