[Bugfix]: Fix final_res_batch list index out of range error (#21055)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey 2025-07-17 15:29:09 +08:00 committed by GitHub
parent c5b8b5953a
commit fdc5b43d20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 40 deletions

View File

@ -7,6 +7,7 @@ import openai # use the official client for correctness check
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import regex as re import regex as re
import requests
from openai import BadRequestError from openai import BadRequestError
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
@ -26,7 +27,8 @@ def default_server_args():
"2048", "2048",
"--max-num-seqs", "--max-num-seqs",
"128", "128",
"--enforce-eager" "--enforce-eager",
"--enable-prompt-tokens-details",
] ]
@ -679,3 +681,17 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
prompt=prompt, prompt=prompt,
extra_body={"guided_grammar": invalid_simplified_sql_grammar}, extra_body={"guided_grammar": invalid_simplified_sql_grammar},
) )
@pytest.mark.asyncio
async def test_completion_with_empty_prompt_embeds(
client: openai.AsyncOpenAI) -> None:
"""Test completion with empty prompt embeds."""
payload: dict[str, list] = {"prompt_embeds": []}
headers: dict[str, str] = {"Content-Type": "application/json"}
# base_url = http://localhost:8000/v1/completions
response = requests.post(f"{client.base_url}completions",
headers=headers,
json=payload)
assert response.status_code == 200, (
f"Expected status code 200, got {response.status_code}. ")

View File

@ -60,20 +60,25 @@ class OpenAIServingCompletion(OpenAIServing):
enable_prompt_tokens_details: bool = False, enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False, enable_force_include_usage: bool = False,
): ):
super().__init__(engine_client=engine_client, super().__init__(
model_config=model_config, engine_client=engine_client,
models=models, model_config=model_config,
request_logger=request_logger, models=models,
return_tokens_as_token_ids=return_tokens_as_token_ids, request_logger=request_logger,
enable_force_include_usage=enable_force_include_usage) return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage,
)
self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = ( self.default_sampling_params = (
self.model_config.get_diff_sampling_param()) self.model_config.get_diff_sampling_param())
if self.default_sampling_params: if self.default_sampling_params:
source = self.model_config.generation_config source = self.model_config.generation_config
source = "model" if source == "auto" else source source = "model" if source == "auto" else source
logger.info("Using default completion sampling params from %s: %s", logger.info(
source, self.default_sampling_params) "Using default completion sampling params from %s: %s",
source,
self.default_sampling_params,
)
async def create_completion( async def create_completion(
self, self,
@ -172,23 +177,28 @@ class OpenAIServingCompletion(OpenAIServing):
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
request=request, request=request,
input_length=input_length, input_length=input_length,
default_sampling_params=self.default_sampling_params) default_sampling_params=self.default_sampling_params,
)
if request.use_beam_search: if request.use_beam_search:
sampling_params = request.to_beam_search_params( sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params) max_tokens, self.default_sampling_params)
else: else:
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
max_tokens, self.model_config.logits_processor_pattern, max_tokens,
self.default_sampling_params) self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item, self._log_inputs(
request_prompts[i], request_id_item,
params=sampling_params, request_prompts[i],
lora_request=lora_request, params=sampling_params,
prompt_adapter_request=prompt_adapter_request) lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
trace_headers = (None if raw_request is None else await trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers)) self._get_trace_headers(raw_request.headers))
@ -245,7 +255,8 @@ class OpenAIServingCompletion(OpenAIServing):
num_prompts=num_prompts, num_prompts=num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
request_metadata=request_metadata, request_metadata=request_metadata,
enable_force_include_usage=self.enable_force_include_usage) enable_force_include_usage=self.enable_force_include_usage,
)
# Non-streaming response # Non-streaming response
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
@ -321,10 +332,10 @@ class OpenAIServingCompletion(OpenAIServing):
stream_options = request.stream_options stream_options = request.stream_options
if stream_options: if stream_options:
include_usage = stream_options.include_usage or \ include_usage = (stream_options.include_usage
enable_force_include_usage or enable_force_include_usage)
include_continuous_usage = include_usage and \ include_continuous_usage = (include_usage and
stream_options.continuous_usage_stats stream_options.continuous_usage_stats)
else: else:
include_usage, include_continuous_usage = False, False include_usage, include_continuous_usage = False, False
@ -370,7 +381,8 @@ class OpenAIServingCompletion(OpenAIServing):
# echo the prompt and first token # echo the prompt and first token
delta_text = prompt_text + output.text delta_text = prompt_text + output.text
delta_token_ids = [ delta_token_ids = [
*prompt_token_ids, *output.token_ids *prompt_token_ids,
*output.token_ids,
] ]
out_logprobs = [ out_logprobs = [
*(prompt_logprobs or []), *(prompt_logprobs or []),
@ -383,8 +395,8 @@ class OpenAIServingCompletion(OpenAIServing):
delta_token_ids = output.token_ids delta_token_ids = output.token_ids
out_logprobs = output.logprobs out_logprobs = output.logprobs
if not delta_text and not delta_token_ids \ if (not delta_text and not delta_token_ids
and not previous_num_tokens[i]: and not previous_num_tokens[i]):
# Chunked prefill case, don't return empty chunks # Chunked prefill case, don't return empty chunks
continue continue
@ -420,7 +432,8 @@ class OpenAIServingCompletion(OpenAIServing):
finish_reason=finish_reason, finish_reason=finish_reason,
stop_reason=stop_reason, stop_reason=stop_reason,
) )
]) ],
)
if include_continuous_usage: if include_continuous_usage:
prompt_tokens = num_prompt_tokens[prompt_idx] prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i] completion_tokens = previous_num_tokens[i]
@ -438,7 +451,8 @@ class OpenAIServingCompletion(OpenAIServing):
final_usage_info = UsageInfo( final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens, prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens, completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens) total_tokens=total_prompt_tokens + total_completion_tokens,
)
if self.enable_prompt_tokens_details and num_cached_tokens: if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo( final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
@ -452,8 +466,8 @@ class OpenAIServingCompletion(OpenAIServing):
choices=[], choices=[],
usage=final_usage_info, usage=final_usage_info,
) )
final_usage_data = (final_usage_chunk.model_dump_json( final_usage_data = final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True)) exclude_unset=False, exclude_none=True)
yield f"data: {final_usage_data}\n\n" yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices # report to FastAPI middleware aggregate usage across all choices
@ -478,8 +492,10 @@ class OpenAIServingCompletion(OpenAIServing):
choices: list[CompletionResponseChoice] = [] choices: list[CompletionResponseChoice] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
num_generated_tokens = 0 num_generated_tokens = 0
kv_transfer_params = None
last_final_res = None
for final_res in final_res_batch: for final_res in final_res_batch:
last_final_res = final_res
prompt_token_ids = final_res.prompt_token_ids prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None assert prompt_token_ids is not None
prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs) prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs)
@ -548,19 +564,22 @@ class OpenAIServingCompletion(OpenAIServing):
total_tokens=num_prompt_tokens + num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens,
) )
if self.enable_prompt_tokens_details and final_res.num_cached_tokens: if (self.enable_prompt_tokens_details and last_final_res
and last_final_res.num_cached_tokens):
usage.prompt_tokens_details = PromptTokenUsageInfo( usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens) cached_tokens=last_final_res.num_cached_tokens)
request_metadata.final_usage_info = usage request_metadata.final_usage_info = usage
if final_res_batch:
kv_transfer_params = final_res_batch[0].kv_transfer_params
return CompletionResponse( return CompletionResponse(
id=request_id, id=request_id,
created=created_time, created=created_time,
model=model_name, model=model_name,
choices=choices, choices=choices,
usage=usage, usage=usage,
kv_transfer_params=final_res_batch[0].kv_transfer_params) kv_transfer_params=kv_transfer_params,
)
def _create_completion_logprobs( def _create_completion_logprobs(
self, self,
@ -579,8 +598,9 @@ class OpenAIServingCompletion(OpenAIServing):
last_token_len = 0 last_token_len = 0
should_return_as_token_id = return_as_token_id if \ should_return_as_token_id = (return_as_token_id
return_as_token_id is not None else self.return_tokens_as_token_ids if return_as_token_id is not None else
self.return_tokens_as_token_ids)
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None: if step_top_logprobs is None:
@ -612,10 +632,12 @@ class OpenAIServingCompletion(OpenAIServing):
out_top_logprobs.append({ out_top_logprobs.append({
# Convert float("-inf") to the # Convert float("-inf") to the
# JSON-serializable float that OpenAI uses # JSON-serializable float that OpenAI uses
self._get_decoded_token(top_lp[1], self._get_decoded_token(
top_lp[0], top_lp[1],
tokenizer, top_lp[0],
return_as_token_id=should_return_as_token_id): tokenizer,
return_as_token_id=should_return_as_token_id,
):
max(top_lp[1].logprob, -9999.0) max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items()) for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i if num_output_top_logprobs >= i