Merge b62e11ea1e4cc8292bbc43f742d9519a562075d1 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
qli88 2025-12-25 00:06:51 +00:00 committed by GitHub
commit 6ab253d1b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 2 deletions

View File

@ -5,6 +5,7 @@ import pytest
from vllm import SamplingParams
from vllm.logprobs import FlatLogprobs
from vllm.platforms import current_platform
MODELS = ["distilbert/distilgpt2"]
MAX_TOKENS = 5
@ -25,7 +26,12 @@ def test_ranks(
flat_logprobs,
example_prompts,
):
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
# TODO: Remove once graph mode is fixed for distilbert/distilgpt2 on ROCm.
eager_mode = current_platform.is_rocm()
with vllm_runner(
model, dtype=dtype, max_logprobs=MAX_LOGPROBS, enforce_eager=eager_mode
) as vllm_model:
tokenizer = vllm_model.llm.get_tokenizer()
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
sampling_params = SamplingParams(

View File

@ -9,6 +9,7 @@ Run `pytest tests/samplers/test_no_bad_words.py`.
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
def _generate(
@ -94,7 +95,11 @@ class TestTwoTokenBadWord:
)[0]
def test_two_token_bad_word(self, vllm_runner):
with vllm_runner(self.MODEL, dtype="half") as llm:
# TODO: Remove once graph mode is fixed for distilbert/distilgpt2 on ROCm.
eager_mode = current_platform.is_rocm()
with vllm_runner(self.MODEL, dtype="half", enforce_eager=eager_mode) as llm:
output_token_ids = self._generate(llm)
assert output_token_ids[:2] == [
self.target_token_id1,