mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 21:44:39 +08:00
[Frontend] Optimize beam search loop by sorting and then splicing (#19347)
Signed-off-by: zhangguozhu <zhangguozhu@360.cn> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: zhangguozhu <zhangguozhu@360.cn> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
82b05b15e6
commit
56f45eddaf
@ -10,6 +10,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from http import HTTPStatus
|
||||
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
@ -389,8 +390,9 @@ class OpenAIServing:
|
||||
|
||||
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
|
||||
|
||||
logprobs_num = 2 * beam_width
|
||||
beam_search_params = SamplingParams(
|
||||
logprobs=2 * beam_width,
|
||||
logprobs=logprobs_num,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
)
|
||||
@ -443,40 +445,75 @@ class OpenAIServing:
|
||||
output = [x[0] for x in await asyncio.gather(*tasks)]
|
||||
|
||||
new_beams = []
|
||||
for i, current_beam in enumerate(all_beams):
|
||||
result = output[i]
|
||||
|
||||
# Store all new tokens generated by beam
|
||||
all_beams_token_id = []
|
||||
# Store the cumulative probability of all tokens
|
||||
# generated by beam search
|
||||
all_beams_logprob = []
|
||||
# Iterate through all beam inference results
|
||||
for i, result in enumerate(output):
|
||||
current_beam = all_beams[i]
|
||||
if result.outputs[0].logprobs is not None:
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
if token_id == eos_token_id and not ignore_eos:
|
||||
completed.append(
|
||||
BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id]
|
||||
if include_stop_str_in_output
|
||||
else current_beam.tokens,
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
cum_logprob=current_beam.cum_logprob
|
||||
+ logprob_obj.logprob,
|
||||
finish_reason="stop",
|
||||
stop_reason=eos_token_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
new_beams.append(
|
||||
BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=current_beam.cum_logprob
|
||||
+ logprob_obj.logprob,
|
||||
multi_modal_data=current_beam.multi_modal_data,
|
||||
mm_processor_kwargs=current_beam.mm_processor_kwargs,
|
||||
)
|
||||
)
|
||||
all_beams_token_id.extend(list(logprobs.keys()))
|
||||
all_beams_logprob.extend(
|
||||
[
|
||||
current_beam.cum_logprob + obj.logprob
|
||||
for obj in logprobs.values()
|
||||
]
|
||||
)
|
||||
|
||||
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
|
||||
all_beams = sorted_beams[:beam_width]
|
||||
# Handle the token for the end of sentence (EOS)
|
||||
all_beams_token_id = np.array(all_beams_token_id)
|
||||
all_beams_logprob = np.array(all_beams_logprob)
|
||||
|
||||
if not ignore_eos:
|
||||
# Get the index position of eos token in all generated results
|
||||
eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
|
||||
for idx in eos_idx:
|
||||
current_beam = all_beams[idx // logprobs_num]
|
||||
result = output[idx // logprobs_num]
|
||||
assert result.outputs[0].logprobs is not None
|
||||
logprobs_entry = result.outputs[0].logprobs[0]
|
||||
completed.append(
|
||||
BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [eos_token_id]
|
||||
if include_stop_str_in_output
|
||||
else current_beam.tokens,
|
||||
logprobs=current_beam.logprobs + [logprobs_entry],
|
||||
cum_logprob=float(all_beams_logprob[idx]),
|
||||
finish_reason="stop",
|
||||
stop_reason=eos_token_id,
|
||||
)
|
||||
)
|
||||
# After processing, set the log probability of the eos condition
|
||||
# to negative infinity.
|
||||
all_beams_logprob[eos_idx] = -np.inf
|
||||
|
||||
# Processing non-EOS tokens
|
||||
# Get indices of the top beam_width probabilities
|
||||
topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
|
||||
:beam_width
|
||||
]
|
||||
|
||||
for idx in topn_idx:
|
||||
current_beam = all_beams[idx // logprobs_num]
|
||||
result = output[idx // logprobs_num]
|
||||
token_id = int(all_beams_token_id[idx])
|
||||
assert result.outputs[0].logprobs is not None
|
||||
logprobs_entry = result.outputs[0].logprobs[0]
|
||||
new_beams.append(
|
||||
BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs_entry],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=float(all_beams_logprob[idx]),
|
||||
multi_modal_data=current_beam.multi_modal_data,
|
||||
mm_processor_kwargs=current_beam.mm_processor_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
all_beams = new_beams
|
||||
|
||||
completed.extend(all_beams)
|
||||
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user