diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index ea40c48027205..27188c230d810 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -5,6 +5,7 @@ import pytest from vllm import SamplingParams from vllm.logprobs import FlatLogprobs +from vllm.platforms import current_platform MODELS = ["distilbert/distilgpt2"] MAX_TOKENS = 5 @@ -25,7 +26,12 @@ def test_ranks( flat_logprobs, example_prompts, ): - with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model: + # TODO: Remove once graph mode is fixed for distilbert/distilgpt2 on ROCm. + eager_mode = current_platform.is_rocm() + + with vllm_runner( + model, dtype=dtype, max_logprobs=MAX_LOGPROBS, enforce_eager=eager_mode + ) as vllm_model: tokenizer = vllm_model.llm.get_tokenizer() example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts] sampling_params = SamplingParams( diff --git a/tests/samplers/test_no_bad_words.py b/tests/samplers/test_no_bad_words.py index 74047d2f03558..5721efcdeaf7e 100644 --- a/tests/samplers/test_no_bad_words.py +++ b/tests/samplers/test_no_bad_words.py @@ -9,6 +9,7 @@ Run `pytest tests/samplers/test_no_bad_words.py`. from transformers import AutoTokenizer from vllm import LLM, SamplingParams +from vllm.platforms import current_platform def _generate( @@ -94,7 +95,11 @@ class TestTwoTokenBadWord: )[0] def test_two_token_bad_word(self, vllm_runner): - with vllm_runner(self.MODEL, dtype="half") as llm: + # TODO: Remove once graph mode is fixed for distilbert/distilgpt2 on ROCm. + + eager_mode = current_platform.is_rocm() + + with vllm_runner(self.MODEL, dtype="half", enforce_eager=eager_mode) as llm: output_token_ids = self._generate(llm) assert output_token_ids[:2] == [ self.target_token_id1,