mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 16:57:10 +08:00
Merge b62e11ea1e4cc8292bbc43f742d9519a562075d1 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
6ab253d1b3
@ -5,6 +5,7 @@ import pytest
|
|||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.logprobs import FlatLogprobs
|
from vllm.logprobs import FlatLogprobs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
MODELS = ["distilbert/distilgpt2"]
|
MODELS = ["distilbert/distilgpt2"]
|
||||||
MAX_TOKENS = 5
|
MAX_TOKENS = 5
|
||||||
@ -25,7 +26,12 @@ def test_ranks(
|
|||||||
flat_logprobs,
|
flat_logprobs,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
):
|
):
|
||||||
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
|
# TODO: Remove once graph mode is fixed for distilbert/distilgpt2 on ROCm.
|
||||||
|
eager_mode = current_platform.is_rocm()
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model, dtype=dtype, max_logprobs=MAX_LOGPROBS, enforce_eager=eager_mode
|
||||||
|
) as vllm_model:
|
||||||
tokenizer = vllm_model.llm.get_tokenizer()
|
tokenizer = vllm_model.llm.get_tokenizer()
|
||||||
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
|
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
|
|||||||
@ -9,6 +9,7 @@ Run `pytest tests/samplers/test_no_bad_words.py`.
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
@ -94,7 +95,11 @@ class TestTwoTokenBadWord:
|
|||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
def test_two_token_bad_word(self, vllm_runner):
|
def test_two_token_bad_word(self, vllm_runner):
|
||||||
with vllm_runner(self.MODEL, dtype="half") as llm:
|
# TODO: Remove once graph mode is fixed for distilbert/distilgpt2 on ROCm.
|
||||||
|
|
||||||
|
eager_mode = current_platform.is_rocm()
|
||||||
|
|
||||||
|
with vllm_runner(self.MODEL, dtype="half", enforce_eager=eager_mode) as llm:
|
||||||
output_token_ids = self._generate(llm)
|
output_token_ids = self._generate(llm)
|
||||||
assert output_token_ids[:2] == [
|
assert output_token_ids[:2] == [
|
||||||
self.target_token_id1,
|
self.target_token_id1,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user