Disable graph-mode only on ROCm platform

Signed-off-by: qli88 <qiang.li2@amd.com>
This commit is contained in:
qli88 2025-12-16 16:40:19 +00:00
parent a1033a9542
commit 23b4fd1715
2 changed files with 10 additions and 2 deletions

View File

@ -5,6 +5,7 @@ import pytest
from vllm import SamplingParams
from vllm.logprobs import FlatLogprobs
from vllm.platforms import current_platform
MODELS = ["distilbert/distilgpt2"]
MAX_TOKENS = 5
@ -25,8 +26,11 @@ def test_ranks(
flat_logprobs,
example_prompts,
):
# TODO: Remove once graph mode is fixed for distilbert/distilgpt2 on ROCm.
eager_mode = current_platform.is_rocm()
with vllm_runner(
model, dtype=dtype, max_logprobs=MAX_LOGPROBS, enforce_eager=True
model, dtype=dtype, max_logprobs=MAX_LOGPROBS, enforce_eager=eager_mode
) as vllm_model:
tokenizer = vllm_model.llm.get_tokenizer()
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]

View File

@ -9,6 +9,7 @@ Run `pytest tests/samplers/test_no_bad_words.py`.
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
def _generate(
@ -94,7 +95,10 @@ class TestTwoTokenBadWord:
)[0]
def test_two_token_bad_word(self, vllm_runner):
with vllm_runner(self.MODEL, dtype="half", enforce_eager=True) as llm:
# TODO: Remove once graph mode is fixed for distilbert/distilgpt2 on ROCm.
eager_mode = current_platform.is_rocm()
with vllm_runner(self.MODEL, dtype="half", enforce_eager=eager_mode) as llm:
output_token_ids = self._generate(llm)
assert output_token_ids[:2] == [
self.target_token_id1,