diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 8bc43b1f2e03c..27188c230d810 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -5,6 +5,7 @@ import pytest from vllm import SamplingParams from vllm.logprobs import FlatLogprobs +from vllm.platforms import current_platform MODELS = ["distilbert/distilgpt2"] MAX_TOKENS = 5 @@ -25,8 +26,11 @@ def test_ranks( flat_logprobs, example_prompts, ): + # TODO: Remove once graph mode is fixed for distilbert/distilgpt2 on ROCm. + eager_mode = current_platform.is_rocm() + with vllm_runner( - model, dtype=dtype, max_logprobs=MAX_LOGPROBS, enforce_eager=True + model, dtype=dtype, max_logprobs=MAX_LOGPROBS, enforce_eager=eager_mode ) as vllm_model: tokenizer = vllm_model.llm.get_tokenizer() example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts] diff --git a/tests/samplers/test_no_bad_words.py b/tests/samplers/test_no_bad_words.py index 4171cb0d6cc48..9d58c8f1a613e 100644 --- a/tests/samplers/test_no_bad_words.py +++ b/tests/samplers/test_no_bad_words.py @@ -9,6 +9,7 @@ Run `pytest tests/samplers/test_no_bad_words.py`. from transformers import AutoTokenizer from vllm import LLM, SamplingParams +from vllm.platforms import current_platform def _generate( @@ -94,7 +95,10 @@ class TestTwoTokenBadWord: )[0] def test_two_token_bad_word(self, vllm_runner): - with vllm_runner(self.MODEL, dtype="half", enforce_eager=True) as llm: + # TODO: Remove once graph mode is fixed for distilbert/distilgpt2 on ROCm. + eager_mode = current_platform.is_rocm() + + with vllm_runner(self.MODEL, dtype="half", enforce_eager=eager_mode) as llm: output_token_ids = self._generate(llm) assert output_token_ids[:2] == [ self.target_token_id1,