mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:35:32 +08:00
[V1][Frontend] Add Testing For V1 Runtime Parameters (#14159)
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
parent
47d4a7e004
commit
257e200a25
150
tests/v1/sample/test_sampling_params_e2e.py
Normal file
150
tests/v1/sample/test_sampling_params_e2e.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
if os.getenv("VLLM_USE_V1", "0") != "1":
|
||||||
|
pytest.skip("Test package requires V1", allow_module_level=True)
|
||||||
|
|
||||||
|
MODEL = "meta-llama/Llama-3.2-1B"
|
||||||
|
PROMPT = "Hello my name is Robert and I"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def model() -> LLM:
|
||||||
|
return LLM(MODEL, enforce_eager=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_n_gt_1(model):
|
||||||
|
"""ParallelSampling is supported."""
|
||||||
|
|
||||||
|
params = SamplingParams(n=3)
|
||||||
|
outputs = model.generate(PROMPT, params)
|
||||||
|
assert len(outputs[0].outputs) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_best_of(model):
|
||||||
|
"""Raise a ValueError since best_of is deprecated."""
|
||||||
|
|
||||||
|
params = SamplingParams(n=2, best_of=3)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = model.generate(PROMPT, params)
|
||||||
|
|
||||||
|
|
||||||
|
def test_penalties(model):
|
||||||
|
"""Check that we do not get errors if applied."""
|
||||||
|
|
||||||
|
params = SamplingParams(
|
||||||
|
temperature=1.2,
|
||||||
|
presence_penalty=1.2,
|
||||||
|
frequency_penalty=1.2,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
min_p=0.5,
|
||||||
|
top_p=0.5,
|
||||||
|
top_k=3,
|
||||||
|
)
|
||||||
|
_ = model.generate(PROMPT, params)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop(model):
|
||||||
|
"""Check that we respect the stop words."""
|
||||||
|
|
||||||
|
output = model.generate(PROMPT, SamplingParams(temperature=0))
|
||||||
|
split_text = output[0].outputs[0].text.split()
|
||||||
|
|
||||||
|
STOP_IDX = 5
|
||||||
|
params = SamplingParams(temperature=0, stop=split_text[STOP_IDX])
|
||||||
|
output = model.generate(PROMPT, params)
|
||||||
|
new_split_text = output[0].outputs[0].text.split()
|
||||||
|
|
||||||
|
# Output should not contain the stop word.
|
||||||
|
assert len(new_split_text) == STOP_IDX
|
||||||
|
|
||||||
|
params = SamplingParams(temperature=0,
|
||||||
|
stop=split_text[STOP_IDX],
|
||||||
|
include_stop_str_in_output=True)
|
||||||
|
output = model.generate(PROMPT, params)
|
||||||
|
new_split_text = output[0].outputs[0].text.split()
|
||||||
|
|
||||||
|
# Output should contain the stop word.
|
||||||
|
assert len(new_split_text) == STOP_IDX + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_token_ids(model):
|
||||||
|
"""Check that we respect the stop token ids."""
|
||||||
|
|
||||||
|
output = model.generate(PROMPT, SamplingParams(temperature=0))
|
||||||
|
|
||||||
|
stop_token_id_0 = output[0].outputs[0].token_ids[5]
|
||||||
|
stop_token_id_1 = output[0].outputs[0].token_ids[6]
|
||||||
|
|
||||||
|
stop_token_ids = [stop_token_id_1, stop_token_id_0]
|
||||||
|
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
|
||||||
|
output = model.generate(PROMPT, params)
|
||||||
|
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0
|
||||||
|
|
||||||
|
stop_token_ids = [stop_token_id_0, stop_token_id_1]
|
||||||
|
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
|
||||||
|
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0
|
||||||
|
|
||||||
|
|
||||||
|
def test_bad_words(model):
|
||||||
|
"""Check that we respect bad words."""
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_logits_processor(model):
|
||||||
|
"""Check that we reject logits processor."""
|
||||||
|
|
||||||
|
# This sample logits processor gives infinite score to the i-th token,
|
||||||
|
# where i is the length of the input sequence.
|
||||||
|
# We therefore expect the output token sequence to be [0, 1, 2, ...]
|
||||||
|
def pick_ith(token_ids, logits):
|
||||||
|
logits[len(token_ids)] = float("inf")
|
||||||
|
return logits
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = model.generate(PROMPT,
|
||||||
|
SamplingParams(logits_processors=[pick_ith]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowed_token_ids(model):
|
||||||
|
"""Check that we can use allowed_token_ids."""
|
||||||
|
|
||||||
|
TOKEN_ID = 10
|
||||||
|
allowed_token_ids = [TOKEN_ID]
|
||||||
|
output = model.generate(
|
||||||
|
PROMPT, SamplingParams(allowed_token_ids=allowed_token_ids))
|
||||||
|
assert output[0].outputs[0].token_ids[-1] == TOKEN_ID
|
||||||
|
|
||||||
|
# Reject negative token id.
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[-1]))
|
||||||
|
|
||||||
|
# Reject out of vocabulary.
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = model.generate(PROMPT,
|
||||||
|
SamplingParams(allowed_token_ids=[10000000]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority(model):
|
||||||
|
"""Check that we reject requests with priority."""
|
||||||
|
|
||||||
|
# Reject all allowed token ids
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = model.generate(PROMPT, priority=[1])
|
||||||
|
|
||||||
|
|
||||||
|
def test_seed(model):
|
||||||
|
"""Check that seed impacts randomness."""
|
||||||
|
|
||||||
|
out_1 = model.generate(PROMPT, SamplingParams(seed=42))
|
||||||
|
out_2 = model.generate(PROMPT, SamplingParams(seed=42))
|
||||||
|
out_3 = model.generate(PROMPT, SamplingParams(seed=43))
|
||||||
|
|
||||||
|
assert out_1[0].outputs[0].text == out_2[0].outputs[0].text
|
||||||
|
assert out_1[0].outputs[0].text != out_3[0].outputs[0].text
|
||||||
@ -55,11 +55,8 @@ class Processor:
|
|||||||
|
|
||||||
def _validate_logprobs(
|
def _validate_logprobs(
|
||||||
self,
|
self,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: SamplingParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not isinstance(params, SamplingParams):
|
|
||||||
return
|
|
||||||
|
|
||||||
max_logprobs = self.model_config.max_logprobs
|
max_logprobs = self.model_config.max_logprobs
|
||||||
# Validate sample logprobs.
|
# Validate sample logprobs.
|
||||||
if params.logprobs and params.logprobs > max_logprobs:
|
if params.logprobs and params.logprobs > max_logprobs:
|
||||||
@ -79,17 +76,10 @@ class Processor:
|
|||||||
raise ValueError("Prefix caching with prompt logprobs not yet "
|
raise ValueError("Prefix caching with prompt logprobs not yet "
|
||||||
"supported on VLLM V1.")
|
"supported on VLLM V1.")
|
||||||
|
|
||||||
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
|
def _validate_sampling_params(
|
||||||
if lora_request is not None and not self.lora_config:
|
|
||||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
|
||||||
"not enabled!")
|
|
||||||
|
|
||||||
def _validate_allowed_token_ids(
|
|
||||||
self,
|
self,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: SamplingParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not isinstance(params, SamplingParams):
|
|
||||||
return
|
|
||||||
if params.allowed_token_ids is None:
|
if params.allowed_token_ids is None:
|
||||||
return
|
return
|
||||||
if not params.allowed_token_ids:
|
if not params.allowed_token_ids:
|
||||||
@ -99,6 +89,42 @@ class Processor:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"allowed_token_ids contains out-of-vocab token id!")
|
"allowed_token_ids contains out-of-vocab token id!")
|
||||||
|
|
||||||
|
def _validate_supported_sampling_params(
|
||||||
|
self,
|
||||||
|
params: SamplingParams,
|
||||||
|
) -> None:
|
||||||
|
# Best of not yet supported.
|
||||||
|
if params.best_of:
|
||||||
|
raise ValueError("VLLM V1 does not yet support best_of.")
|
||||||
|
# Bad words not yet supported.
|
||||||
|
if params.bad_words:
|
||||||
|
raise ValueError("VLLM V1 does not yet support bad_words.")
|
||||||
|
# Logits processors not supported.
|
||||||
|
if params.logits_processors:
|
||||||
|
raise ValueError("VLLM V1 does not support per request "
|
||||||
|
"user provided logits processors.")
|
||||||
|
|
||||||
|
def _validate_params(
|
||||||
|
self,
|
||||||
|
params: Union[SamplingParams, PoolingParams],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Validate supported SamplingParam.
|
||||||
|
Should raise ValueError if unsupported for API Server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(params, SamplingParams):
|
||||||
|
raise ValueError("V1 does not yet support Pooling models.")
|
||||||
|
|
||||||
|
self._validate_logprobs(params)
|
||||||
|
self._validate_sampling_params(params)
|
||||||
|
self._validate_supported_sampling_params(params)
|
||||||
|
|
||||||
|
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
|
||||||
|
if lora_request is not None and not self.lora_config:
|
||||||
|
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||||
|
"not enabled!")
|
||||||
|
|
||||||
def process_inputs(
|
def process_inputs(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@ -114,14 +140,17 @@ class Processor:
|
|||||||
# TODO(woosuk): Support pooling models.
|
# TODO(woosuk): Support pooling models.
|
||||||
# TODO(woosuk): Support encoder-decoder models.
|
# TODO(woosuk): Support encoder-decoder models.
|
||||||
|
|
||||||
self._validate_logprobs(params)
|
|
||||||
self._validate_lora(lora_request)
|
self._validate_lora(lora_request)
|
||||||
self._validate_allowed_token_ids(params)
|
self._validate_params(params)
|
||||||
|
if priority != 0:
|
||||||
|
raise ValueError("V1 does not support priority yet.")
|
||||||
|
if trace_headers is not None:
|
||||||
|
raise ValueError("V1 does not support tracing yet.")
|
||||||
|
if prompt_adapter_request is not None:
|
||||||
|
raise ValueError("V1 does not support prompt_adapter_request.")
|
||||||
|
|
||||||
if arrival_time is None:
|
if arrival_time is None:
|
||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
assert priority == 0, "vLLM V1 does not support priority at the moment."
|
|
||||||
assert trace_headers is None, "vLLM V1 does not support tracing yet."
|
|
||||||
|
|
||||||
# Process inputs, which includes:
|
# Process inputs, which includes:
|
||||||
# 1. Tokenize text prompt, with LoRA request if one exists.
|
# 1. Tokenize text prompt, with LoRA request if one exists.
|
||||||
|
|||||||
@ -298,6 +298,11 @@ class InputBatch:
|
|||||||
if sampling_params.logit_bias is not None:
|
if sampling_params.logit_bias is not None:
|
||||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||||
|
|
||||||
|
# FIXME: this implementation is incorrect. We create this mask
|
||||||
|
# then apply -inf to these specific tokens, which means we never
|
||||||
|
# select the allowed tokens! We cannot do the reverse, since
|
||||||
|
# this will impact the requests that do not have allowed_token_ids.
|
||||||
|
# This feature is currently disabled on V1 (we reject in Processor).
|
||||||
if sampling_params.allowed_token_ids:
|
if sampling_params.allowed_token_ids:
|
||||||
self.has_allowed_token_ids.add(req_id)
|
self.has_allowed_token_ids.add(req_id)
|
||||||
if self.allowed_token_ids_mask_cpu_tensor is None:
|
if self.allowed_token_ids_mask_cpu_tensor is None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user