diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index c50b0c4a23e17..127b8e6dcb87c 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -10,6 +10,7 @@ from concurrent.futures import ThreadPoolExecutor from http import HTTPStatus from typing import Any, ClassVar, Generic, TypeAlias, TypeVar +import numpy as np import torch from fastapi import Request from pydantic import BaseModel, ConfigDict, Field, TypeAdapter @@ -389,8 +390,9 @@ class OpenAIServing: sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) + logprobs_num = 2 * beam_width beam_search_params = SamplingParams( - logprobs=2 * beam_width, + logprobs=logprobs_num, max_tokens=1, temperature=temperature, ) @@ -443,40 +445,75 @@ class OpenAIServing: output = [x[0] for x in await asyncio.gather(*tasks)] new_beams = [] - for i, current_beam in enumerate(all_beams): - result = output[i] - + # Store all new tokens generated by beam + all_beams_token_id = [] + # Store the cumulative probability of all tokens + # generated by beam search + all_beams_logprob = [] + # Iterate through all beam inference results + for i, result in enumerate(output): + current_beam = all_beams[i] if result.outputs[0].logprobs is not None: logprobs = result.outputs[0].logprobs[0] - for token_id, logprob_obj in logprobs.items(): - if token_id == eos_token_id and not ignore_eos: - completed.append( - BeamSearchSequence( - tokens=current_beam.tokens + [token_id] - if include_stop_str_in_output - else current_beam.tokens, - logprobs=current_beam.logprobs + [logprobs], - cum_logprob=current_beam.cum_logprob - + logprob_obj.logprob, - finish_reason="stop", - stop_reason=eos_token_id, - ) - ) - else: - new_beams.append( - BeamSearchSequence( - tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + [logprobs], - lora_request=current_beam.lora_request, - cum_logprob=current_beam.cum_logprob - + logprob_obj.logprob, - multi_modal_data=current_beam.multi_modal_data, - mm_processor_kwargs=current_beam.mm_processor_kwargs, - ) - ) + all_beams_token_id.extend(list(logprobs.keys())) + all_beams_logprob.extend( + [ + current_beam.cum_logprob + obj.logprob + for obj in logprobs.values() + ] + ) - sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) - all_beams = sorted_beams[:beam_width] + # Handle the token for the end of sentence (EOS) + all_beams_token_id = np.array(all_beams_token_id) + all_beams_logprob = np.array(all_beams_logprob) + + if not ignore_eos: + # Get the index position of eos token in all generated results + eos_idx = np.where(all_beams_token_id == eos_token_id)[0] + for idx in eos_idx: + current_beam = all_beams[idx // logprobs_num] + result = output[idx // logprobs_num] + assert result.outputs[0].logprobs is not None + logprobs_entry = result.outputs[0].logprobs[0] + completed.append( + BeamSearchSequence( + tokens=current_beam.tokens + [eos_token_id] + if include_stop_str_in_output + else current_beam.tokens, + logprobs=current_beam.logprobs + [logprobs_entry], + cum_logprob=float(all_beams_logprob[idx]), + finish_reason="stop", + stop_reason=eos_token_id, + ) + ) + # After processing, set the log probability of the eos condition + # to negative infinity. + all_beams_logprob[eos_idx] = -np.inf + + # Processing non-EOS tokens + # Get indices of the top beam_width probabilities + topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[ + :beam_width + ] + + for idx in topn_idx: + current_beam = all_beams[idx // logprobs_num] + result = output[idx // logprobs_num] + token_id = int(all_beams_token_id[idx]) + assert result.outputs[0].logprobs is not None + logprobs_entry = result.outputs[0].logprobs[0] + new_beams.append( + BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs_entry], + lora_request=current_beam.lora_request, + cum_logprob=float(all_beams_logprob[idx]), + multi_modal_data=current_beam.multi_modal_data, + mm_processor_kwargs=current_beam.mm_processor_kwargs, + ) + ) + + all_beams = new_beams completed.extend(all_beams) sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)