mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 04:39:09 +08:00
Disable graph-mode only on ROCm platform
Signed-off-by: qli88 <qiang.li2@amd.com>
This commit is contained in:
parent
a1033a9542
commit
23b4fd1715
@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.logprobs import FlatLogprobs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODELS = ["distilbert/distilgpt2"]
|
||||
MAX_TOKENS = 5
|
||||
@ -25,8 +26,11 @@ def test_ranks(
|
||||
flat_logprobs,
|
||||
example_prompts,
|
||||
):
|
||||
# TODO: Remove once graph mode is fixed for distilbert/distilgpt2 on ROCm.
|
||||
eager_mode = current_platform.is_rocm()
|
||||
|
||||
with vllm_runner(
|
||||
model, dtype=dtype, max_logprobs=MAX_LOGPROBS, enforce_eager=True
|
||||
model, dtype=dtype, max_logprobs=MAX_LOGPROBS, enforce_eager=eager_mode
|
||||
) as vllm_model:
|
||||
tokenizer = vllm_model.llm.get_tokenizer()
|
||||
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
|
||||
|
||||
@ -9,6 +9,7 @@ Run `pytest tests/samplers/test_no_bad_words.py`.
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def _generate(
|
||||
@ -94,7 +95,10 @@ class TestTwoTokenBadWord:
|
||||
)[0]
|
||||
|
||||
def test_two_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL, dtype="half", enforce_eager=True) as llm:
|
||||
# TODO: Remove once graph mode is fixed for distilbert/distilgpt2 on ROCm.
|
||||
eager_mode = current_platform.is_rocm()
|
||||
|
||||
with vllm_runner(self.MODEL, dtype="half", enforce_eager=eager_mode) as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[:2] == [
|
||||
self.target_token_id1,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user