mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:45:00 +08:00
[Bugfix] Set SamplingParams.max_tokens for OpenAI requests if not provided by user (#6954)
This commit is contained in:
parent
0437492ea9
commit
3c10591ef2
@ -1,7 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
MODEL_NAME = "openai-community/gpt2"
|
MODEL_NAME = "openai-community/gpt2"
|
||||||
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
||||||
@ -42,3 +47,37 @@ async def _async_serving_chat_init():
|
|||||||
def test_async_serving_chat_init():
|
def test_async_serving_chat_init():
|
||||||
serving_completion = asyncio.run(_async_serving_chat_init())
|
serving_completion = asyncio.run(_async_serving_chat_init())
|
||||||
assert serving_completion.chat_template == CHAT_TEMPLATE
|
assert serving_completion.chat_template == CHAT_TEMPLATE
|
||||||
|
|
||||||
|
|
||||||
|
def test_serving_chat_should_set_correct_max_tokens():
|
||||||
|
mock_engine = MagicMock(spec=AsyncLLMEngine)
|
||||||
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
|
|
||||||
|
serving_chat = OpenAIServingChat(mock_engine,
|
||||||
|
MockModelConfig(),
|
||||||
|
served_model_names=[MODEL_NAME],
|
||||||
|
response_role="assistant",
|
||||||
|
chat_template=CHAT_TEMPLATE,
|
||||||
|
lora_modules=None,
|
||||||
|
prompt_adapters=None,
|
||||||
|
request_logger=None)
|
||||||
|
req = ChatCompletionRequest(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is 1+1?"
|
||||||
|
}],
|
||||||
|
guided_decoding_backend="outlines",
|
||||||
|
)
|
||||||
|
|
||||||
|
with suppress(Exception):
|
||||||
|
asyncio.run(serving_chat.create_chat_completion(req))
|
||||||
|
|
||||||
|
# AsyncLLMEngine.generate(inputs, sampling_params, ...)
|
||||||
|
assert mock_engine.generate.call_args.args[1].max_tokens == 93
|
||||||
|
|
||||||
|
req.max_tokens = 10
|
||||||
|
with suppress(Exception):
|
||||||
|
asyncio.run(serving_chat.create_chat_completion(req))
|
||||||
|
|
||||||
|
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from typing_extensions import Annotated
|
|||||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
|
||||||
@ -215,15 +215,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
# doc: end-chat-completion-extra-params
|
# doc: end-chat-completion-extra-params
|
||||||
|
|
||||||
def to_sampling_params(self,
|
def to_sampling_params(
|
||||||
tokenizer: PreTrainedTokenizer) -> SamplingParams:
|
self, tokenizer: PreTrainedTokenizer,
|
||||||
# We now allow logprobs being true without top_logrobs.
|
guided_decode_logits_processor: Optional[LogitsProcessor],
|
||||||
|
default_max_tokens: int) -> SamplingParams:
|
||||||
|
max_tokens = self.max_tokens
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
# We now allow logprobs being true without top_logrobs.
|
||||||
logits_processors = get_logits_processors(
|
logits_processors = get_logits_processors(
|
||||||
logit_bias=self.logit_bias,
|
logit_bias=self.logit_bias,
|
||||||
allowed_token_ids=None,
|
allowed_token_ids=None,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
if guided_decode_logits_processor:
|
||||||
|
logits_processors.append(guided_decode_logits_processor)
|
||||||
|
|
||||||
return SamplingParams(
|
return SamplingParams(
|
||||||
n=self.n,
|
n=self.n,
|
||||||
@ -241,7 +248,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
logprobs=self.top_logprobs if self.logprobs else None,
|
logprobs=self.top_logprobs if self.logprobs else None,
|
||||||
prompt_logprobs=self.top_logprobs if self.echo else None,
|
prompt_logprobs=self.top_logprobs if self.echo else None,
|
||||||
ignore_eos=self.ignore_eos,
|
ignore_eos=self.ignore_eos,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=max_tokens,
|
||||||
min_tokens=self.min_tokens,
|
min_tokens=self.min_tokens,
|
||||||
use_beam_search=self.use_beam_search,
|
use_beam_search=self.use_beam_search,
|
||||||
early_stopping=self.early_stopping,
|
early_stopping=self.early_stopping,
|
||||||
@ -395,7 +402,14 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
# doc: end-completion-extra-params
|
# doc: end-completion-extra-params
|
||||||
|
|
||||||
def to_sampling_params(self, tokenizer: PreTrainedTokenizer):
|
def to_sampling_params(
|
||||||
|
self, tokenizer: PreTrainedTokenizer,
|
||||||
|
guided_decode_logits_processor: Optional[LogitsProcessor],
|
||||||
|
default_max_tokens: int) -> SamplingParams:
|
||||||
|
max_tokens = self.max_tokens
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
echo_without_generation = self.echo and self.max_tokens == 0
|
echo_without_generation = self.echo and self.max_tokens == 0
|
||||||
|
|
||||||
logits_processors = get_logits_processors(
|
logits_processors = get_logits_processors(
|
||||||
@ -403,6 +417,8 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
allowed_token_ids=self.allowed_token_ids,
|
allowed_token_ids=self.allowed_token_ids,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
if guided_decode_logits_processor:
|
||||||
|
logits_processors.append(guided_decode_logits_processor)
|
||||||
|
|
||||||
return SamplingParams(
|
return SamplingParams(
|
||||||
n=self.n,
|
n=self.n,
|
||||||
@ -419,7 +435,7 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
stop_token_ids=self.stop_token_ids,
|
stop_token_ids=self.stop_token_ids,
|
||||||
logprobs=self.logprobs,
|
logprobs=self.logprobs,
|
||||||
ignore_eos=self.ignore_eos,
|
ignore_eos=self.ignore_eos,
|
||||||
max_tokens=self.max_tokens if not echo_without_generation else 1,
|
max_tokens=max_tokens if not echo_without_generation else 1,
|
||||||
min_tokens=self.min_tokens,
|
min_tokens=self.min_tokens,
|
||||||
use_beam_search=self.use_beam_search,
|
use_beam_search=self.use_beam_search,
|
||||||
early_stopping=self.early_stopping,
|
early_stopping=self.early_stopping,
|
||||||
|
|||||||
@ -25,8 +25,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
|||||||
PromptAdapterPath)
|
PromptAdapterPath)
|
||||||
from vllm.inputs import PromptInputs
|
from vllm.inputs import PromptInputs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.guided_decoding import (
|
|
||||||
get_guided_decoding_logits_processor)
|
|
||||||
from vllm.multimodal import MultiModalDataDict
|
from vllm.multimodal import MultiModalDataDict
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
@ -134,28 +132,23 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
request_id = f"chat-{random_uuid()}"
|
request_id = f"chat-{random_uuid()}"
|
||||||
try:
|
try:
|
||||||
sampling_params = request.to_sampling_params(tokenizer)
|
|
||||||
decoding_config = await self.engine.get_decoding_config()
|
|
||||||
guided_decoding_backend = request.guided_decoding_backend \
|
|
||||||
or decoding_config.guided_decoding_backend
|
|
||||||
guided_decode_logits_processor = (
|
guided_decode_logits_processor = (
|
||||||
await
|
await self._guided_decode_logits_processor(request, tokenizer))
|
||||||
get_guided_decoding_logits_processor(guided_decoding_backend,
|
|
||||||
request, tokenizer))
|
|
||||||
if guided_decode_logits_processor:
|
|
||||||
if sampling_params.logits_processors is None:
|
|
||||||
sampling_params.logits_processors = []
|
|
||||||
sampling_params.logits_processors.append(
|
|
||||||
guided_decode_logits_processor)
|
|
||||||
|
|
||||||
prompt_inputs = self._tokenize_prompt_input(
|
prompt_inputs = self._tokenize_prompt_input(
|
||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
truncate_prompt_tokens=sampling_params.truncate_prompt_tokens,
|
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||||
add_special_tokens=request.add_special_tokens,
|
add_special_tokens=request.add_special_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sampling_params = request.to_sampling_params(
|
||||||
|
tokenizer,
|
||||||
|
guided_decode_logits_processor,
|
||||||
|
default_max_tokens=self.max_model_len -
|
||||||
|
len(prompt_inputs["prompt_token_ids"]))
|
||||||
|
|
||||||
self._log_inputs(request_id,
|
self._log_inputs(request_id,
|
||||||
prompt_inputs,
|
prompt_inputs,
|
||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
|
|||||||
@ -24,8 +24,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
|||||||
OpenAIServing,
|
OpenAIServing,
|
||||||
PromptAdapterPath)
|
PromptAdapterPath)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.guided_decoding import (
|
|
||||||
get_guided_decoding_logits_processor)
|
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||||
@ -95,31 +93,24 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
|
|
||||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||||
|
|
||||||
sampling_params = request.to_sampling_params(tokenizer)
|
guided_decode_logits_processor = (
|
||||||
decoding_config = await self.engine.get_decoding_config()
|
await self._guided_decode_logits_processor(request, tokenizer))
|
||||||
guided_decoding_backend = request.guided_decoding_backend \
|
|
||||||
or decoding_config.guided_decoding_backend
|
|
||||||
guided_decode_logit_processor = (
|
|
||||||
await
|
|
||||||
get_guided_decoding_logits_processor(guided_decoding_backend,
|
|
||||||
request, tokenizer))
|
|
||||||
if guided_decode_logit_processor is not None:
|
|
||||||
if sampling_params.logits_processors is None:
|
|
||||||
sampling_params.logits_processors = []
|
|
||||||
sampling_params.logits_processors.append(
|
|
||||||
guided_decode_logit_processor)
|
|
||||||
|
|
||||||
prompts = list(
|
prompts = list(
|
||||||
self._tokenize_prompt_input_or_inputs(
|
self._tokenize_prompt_input_or_inputs(
|
||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
request.prompt,
|
request.prompt,
|
||||||
truncate_prompt_tokens=sampling_params.
|
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||||
truncate_prompt_tokens,
|
|
||||||
add_special_tokens=request.add_special_tokens,
|
add_special_tokens=request.add_special_tokens,
|
||||||
))
|
))
|
||||||
|
|
||||||
for i, prompt_inputs in enumerate(prompts):
|
for i, prompt_inputs in enumerate(prompts):
|
||||||
|
sampling_params = request.to_sampling_params(
|
||||||
|
tokenizer,
|
||||||
|
guided_decode_logits_processor,
|
||||||
|
default_max_tokens=self.max_model_len -
|
||||||
|
len(prompt_inputs["prompt_token_ids"]))
|
||||||
|
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
|
||||||
self._log_inputs(request_id_item,
|
self._log_inputs(request_id_item,
|
||||||
|
|||||||
@ -25,9 +25,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
from vllm.inputs import parse_and_batch_prompt
|
from vllm.inputs import parse_and_batch_prompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.model_executor.guided_decoding import (
|
||||||
|
get_guided_decoding_logits_processor)
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.transformers_utils.tokenizer_group import AnyTokenizer
|
from vllm.transformers_utils.tokenizer_group import AnyTokenizer
|
||||||
|
|
||||||
@ -150,6 +152,15 @@ class OpenAIServing:
|
|||||||
})
|
})
|
||||||
return json_str
|
return json_str
|
||||||
|
|
||||||
|
async def _guided_decode_logits_processor(
|
||||||
|
self, request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
|
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
|
||||||
|
decoding_config = await self.engine.get_decoding_config()
|
||||||
|
guided_decoding_backend = request.guided_decoding_backend \
|
||||||
|
or decoding_config.guided_decoding_backend
|
||||||
|
return await get_guided_decoding_logits_processor(
|
||||||
|
guided_decoding_backend, request, tokenizer)
|
||||||
|
|
||||||
async def _check_model(
|
async def _check_model(
|
||||||
self,
|
self,
|
||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
@ -254,9 +265,7 @@ class OpenAIServing:
|
|||||||
f"{self.max_model_len} tokens. However, you requested "
|
f"{self.max_model_len} tokens. However, you requested "
|
||||||
f"{token_num} tokens in the messages, "
|
f"{token_num} tokens in the messages, "
|
||||||
f"Please reduce the length of the messages.")
|
f"Please reduce the length of the messages.")
|
||||||
request.max_tokens = self.max_model_len - token_num
|
elif token_num + request.max_tokens > self.max_model_len:
|
||||||
|
|
||||||
if token_num + request.max_tokens > self.max_model_len:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"This model's maximum context length is "
|
f"This model's maximum context length is "
|
||||||
f"{self.max_model_len} tokens. However, you requested "
|
f"{self.max_model_len} tokens. However, you requested "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user