mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-21 06:05:54 +08:00
Merge b62e11ea1e4cc8292bbc43f742d9519a562075d1 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
6ab253d1b3
@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.logprobs import FlatLogprobs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODELS = ["distilbert/distilgpt2"]
|
||||
MAX_TOKENS = 5
|
||||
@ -25,7 +26,12 @@ def test_ranks(
|
||||
flat_logprobs,
|
||||
example_prompts,
|
||||
):
|
||||
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
|
||||
# TODO: Remove once graph mode is fixed for distilbert/distilgpt2 on ROCm.
|
||||
eager_mode = current_platform.is_rocm()
|
||||
|
||||
with vllm_runner(
|
||||
model, dtype=dtype, max_logprobs=MAX_LOGPROBS, enforce_eager=eager_mode
|
||||
) as vllm_model:
|
||||
tokenizer = vllm_model.llm.get_tokenizer()
|
||||
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
|
||||
sampling_params = SamplingParams(
|
||||
|
||||
@ -9,6 +9,7 @@ Run `pytest tests/samplers/test_no_bad_words.py`.
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def _generate(
|
||||
@ -94,7 +95,11 @@ class TestTwoTokenBadWord:
|
||||
)[0]
|
||||
|
||||
def test_two_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL, dtype="half") as llm:
|
||||
# TODO: Remove once graph mode is fixed for distilbert/distilgpt2 on ROCm.
|
||||
|
||||
eager_mode = current_platform.is_rocm()
|
||||
|
||||
with vllm_runner(self.MODEL, dtype="half", enforce_eager=eager_mode) as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[:2] == [
|
||||
self.target_token_id1,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user