diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 776fd42bbc35..2462f8f9f10c 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -7,6 +7,7 @@ import openai # use the official client for correctness check import pytest import pytest_asyncio import regex as re +import requests from openai import BadRequestError from tests.utils import RemoteOpenAIServer @@ -26,7 +27,8 @@ def default_server_args(): "2048", "--max-num-seqs", "128", - "--enforce-eager" + "--enforce-eager", + "--enable-prompt-tokens-details", ] @@ -679,3 +681,17 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): prompt=prompt, extra_body={"guided_grammar": invalid_simplified_sql_grammar}, ) + + +@pytest.mark.asyncio +async def test_completion_with_empty_prompt_embeds( + client: openai.AsyncOpenAI) -> None: + """Test completion with empty prompt embeds.""" + payload: dict[str, list] = {"prompt_embeds": []} + headers: dict[str, str] = {"Content-Type": "application/json"} + # base_url = http://localhost:8000/v1/completions + response = requests.post(f"{client.base_url}completions", + headers=headers, + json=payload) + assert response.status_code == 200, ( + f"Expected status code 200, got {response.status_code}. ") diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index eb9a35a7a37d..1e1f655022f0 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -60,20 +60,25 @@ class OpenAIServingCompletion(OpenAIServing): enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, ): - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage) + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + enable_force_include_usage=enable_force_include_usage, + ) self.enable_prompt_tokens_details = enable_prompt_tokens_details self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source - logger.info("Using default completion sampling params from %s: %s", - source, self.default_sampling_params) + logger.info( + "Using default completion sampling params from %s: %s", + source, + self.default_sampling_params, + ) async def create_completion( self, @@ -172,23 +177,28 @@ class OpenAIServingCompletion(OpenAIServing): max_model_len=self.max_model_len, request=request, input_length=input_length, - default_sampling_params=self.default_sampling_params) + default_sampling_params=self.default_sampling_params, + ) if request.use_beam_search: sampling_params = request.to_beam_search_params( max_tokens, self.default_sampling_params) else: sampling_params = request.to_sampling_params( - max_tokens, self.model_config.logits_processor_pattern, - self.default_sampling_params) + max_tokens, + self.model_config.logits_processor_pattern, + self.default_sampling_params, + ) request_id_item = f"{request_id}-{i}" - self._log_inputs(request_id_item, - request_prompts[i], - params=sampling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + self._log_inputs( + request_id_item, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) @@ -245,7 +255,8 @@ class OpenAIServingCompletion(OpenAIServing): num_prompts=num_prompts, tokenizer=tokenizer, request_metadata=request_metadata, - enable_force_include_usage=self.enable_force_include_usage) + enable_force_include_usage=self.enable_force_include_usage, + ) # Non-streaming response final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts @@ -321,10 +332,10 @@ class OpenAIServingCompletion(OpenAIServing): stream_options = request.stream_options if stream_options: - include_usage = stream_options.include_usage or \ - enable_force_include_usage - include_continuous_usage = include_usage and \ - stream_options.continuous_usage_stats + include_usage = (stream_options.include_usage + or enable_force_include_usage) + include_continuous_usage = (include_usage and + stream_options.continuous_usage_stats) else: include_usage, include_continuous_usage = False, False @@ -370,7 +381,8 @@ class OpenAIServingCompletion(OpenAIServing): # echo the prompt and first token delta_text = prompt_text + output.text delta_token_ids = [ - *prompt_token_ids, *output.token_ids + *prompt_token_ids, + *output.token_ids, ] out_logprobs = [ *(prompt_logprobs or []), @@ -383,8 +395,8 @@ class OpenAIServingCompletion(OpenAIServing): delta_token_ids = output.token_ids out_logprobs = output.logprobs - if not delta_text and not delta_token_ids \ - and not previous_num_tokens[i]: + if (not delta_text and not delta_token_ids + and not previous_num_tokens[i]): # Chunked prefill case, don't return empty chunks continue @@ -420,7 +432,8 @@ class OpenAIServingCompletion(OpenAIServing): finish_reason=finish_reason, stop_reason=stop_reason, ) - ]) + ], + ) if include_continuous_usage: prompt_tokens = num_prompt_tokens[prompt_idx] completion_tokens = previous_num_tokens[i] @@ -438,7 +451,8 @@ class OpenAIServingCompletion(OpenAIServing): final_usage_info = UsageInfo( prompt_tokens=total_prompt_tokens, completion_tokens=total_completion_tokens, - total_tokens=total_prompt_tokens + total_completion_tokens) + total_tokens=total_prompt_tokens + total_completion_tokens, + ) if self.enable_prompt_tokens_details and num_cached_tokens: final_usage_info.prompt_tokens_details = PromptTokenUsageInfo( @@ -452,8 +466,8 @@ class OpenAIServingCompletion(OpenAIServing): choices=[], usage=final_usage_info, ) - final_usage_data = (final_usage_chunk.model_dump_json( - exclude_unset=False, exclude_none=True)) + final_usage_data = final_usage_chunk.model_dump_json( + exclude_unset=False, exclude_none=True) yield f"data: {final_usage_data}\n\n" # report to FastAPI middleware aggregate usage across all choices @@ -478,8 +492,10 @@ class OpenAIServingCompletion(OpenAIServing): choices: list[CompletionResponseChoice] = [] num_prompt_tokens = 0 num_generated_tokens = 0 - + kv_transfer_params = None + last_final_res = None for final_res in final_res_batch: + last_final_res = final_res prompt_token_ids = final_res.prompt_token_ids assert prompt_token_ids is not None prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs) @@ -548,19 +564,22 @@ class OpenAIServingCompletion(OpenAIServing): total_tokens=num_prompt_tokens + num_generated_tokens, ) - if self.enable_prompt_tokens_details and final_res.num_cached_tokens: + if (self.enable_prompt_tokens_details and last_final_res + and last_final_res.num_cached_tokens): usage.prompt_tokens_details = PromptTokenUsageInfo( - cached_tokens=final_res.num_cached_tokens) + cached_tokens=last_final_res.num_cached_tokens) request_metadata.final_usage_info = usage - + if final_res_batch: + kv_transfer_params = final_res_batch[0].kv_transfer_params return CompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, - kv_transfer_params=final_res_batch[0].kv_transfer_params) + kv_transfer_params=kv_transfer_params, + ) def _create_completion_logprobs( self, @@ -579,8 +598,9 @@ class OpenAIServingCompletion(OpenAIServing): last_token_len = 0 - should_return_as_token_id = return_as_token_id if \ - return_as_token_id is not None else self.return_tokens_as_token_ids + should_return_as_token_id = (return_as_token_id + if return_as_token_id is not None else + self.return_tokens_as_token_ids) for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] if step_top_logprobs is None: @@ -612,10 +632,12 @@ class OpenAIServingCompletion(OpenAIServing): out_top_logprobs.append({ # Convert float("-inf") to the # JSON-serializable float that OpenAI uses - self._get_decoded_token(top_lp[1], - top_lp[0], - tokenizer, - return_as_token_id=should_return_as_token_id): + self._get_decoded_token( + top_lp[1], + top_lp[0], + tokenizer, + return_as_token_id=should_return_as_token_id, + ): max(top_lp[1].logprob, -9999.0) for i, top_lp in enumerate(step_top_logprobs.items()) if num_output_top_logprobs >= i