mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 14:30:09 +08:00
[Frontend] Optimize beam search loop by sorting and then splicing (#19347)
Signed-off-by: zhangguozhu <zhangguozhu@360.cn> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: zhangguozhu <zhangguozhu@360.cn> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
82b05b15e6
commit
56f45eddaf
@ -10,6 +10,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
|
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||||
@ -389,8 +390,9 @@ class OpenAIServing:
|
|||||||
|
|
||||||
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
|
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
|
||||||
|
|
||||||
|
logprobs_num = 2 * beam_width
|
||||||
beam_search_params = SamplingParams(
|
beam_search_params = SamplingParams(
|
||||||
logprobs=2 * beam_width,
|
logprobs=logprobs_num,
|
||||||
max_tokens=1,
|
max_tokens=1,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
@ -443,40 +445,75 @@ class OpenAIServing:
|
|||||||
output = [x[0] for x in await asyncio.gather(*tasks)]
|
output = [x[0] for x in await asyncio.gather(*tasks)]
|
||||||
|
|
||||||
new_beams = []
|
new_beams = []
|
||||||
for i, current_beam in enumerate(all_beams):
|
# Store all new tokens generated by beam
|
||||||
result = output[i]
|
all_beams_token_id = []
|
||||||
|
# Store the cumulative probability of all tokens
|
||||||
|
# generated by beam search
|
||||||
|
all_beams_logprob = []
|
||||||
|
# Iterate through all beam inference results
|
||||||
|
for i, result in enumerate(output):
|
||||||
|
current_beam = all_beams[i]
|
||||||
if result.outputs[0].logprobs is not None:
|
if result.outputs[0].logprobs is not None:
|
||||||
logprobs = result.outputs[0].logprobs[0]
|
logprobs = result.outputs[0].logprobs[0]
|
||||||
for token_id, logprob_obj in logprobs.items():
|
all_beams_token_id.extend(list(logprobs.keys()))
|
||||||
if token_id == eos_token_id and not ignore_eos:
|
all_beams_logprob.extend(
|
||||||
completed.append(
|
[
|
||||||
BeamSearchSequence(
|
current_beam.cum_logprob + obj.logprob
|
||||||
tokens=current_beam.tokens + [token_id]
|
for obj in logprobs.values()
|
||||||
if include_stop_str_in_output
|
]
|
||||||
else current_beam.tokens,
|
)
|
||||||
logprobs=current_beam.logprobs + [logprobs],
|
|
||||||
cum_logprob=current_beam.cum_logprob
|
|
||||||
+ logprob_obj.logprob,
|
|
||||||
finish_reason="stop",
|
|
||||||
stop_reason=eos_token_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
new_beams.append(
|
|
||||||
BeamSearchSequence(
|
|
||||||
tokens=current_beam.tokens + [token_id],
|
|
||||||
logprobs=current_beam.logprobs + [logprobs],
|
|
||||||
lora_request=current_beam.lora_request,
|
|
||||||
cum_logprob=current_beam.cum_logprob
|
|
||||||
+ logprob_obj.logprob,
|
|
||||||
multi_modal_data=current_beam.multi_modal_data,
|
|
||||||
mm_processor_kwargs=current_beam.mm_processor_kwargs,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
|
# Handle the token for the end of sentence (EOS)
|
||||||
all_beams = sorted_beams[:beam_width]
|
all_beams_token_id = np.array(all_beams_token_id)
|
||||||
|
all_beams_logprob = np.array(all_beams_logprob)
|
||||||
|
|
||||||
|
if not ignore_eos:
|
||||||
|
# Get the index position of eos token in all generated results
|
||||||
|
eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
|
||||||
|
for idx in eos_idx:
|
||||||
|
current_beam = all_beams[idx // logprobs_num]
|
||||||
|
result = output[idx // logprobs_num]
|
||||||
|
assert result.outputs[0].logprobs is not None
|
||||||
|
logprobs_entry = result.outputs[0].logprobs[0]
|
||||||
|
completed.append(
|
||||||
|
BeamSearchSequence(
|
||||||
|
tokens=current_beam.tokens + [eos_token_id]
|
||||||
|
if include_stop_str_in_output
|
||||||
|
else current_beam.tokens,
|
||||||
|
logprobs=current_beam.logprobs + [logprobs_entry],
|
||||||
|
cum_logprob=float(all_beams_logprob[idx]),
|
||||||
|
finish_reason="stop",
|
||||||
|
stop_reason=eos_token_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# After processing, set the log probability of the eos condition
|
||||||
|
# to negative infinity.
|
||||||
|
all_beams_logprob[eos_idx] = -np.inf
|
||||||
|
|
||||||
|
# Processing non-EOS tokens
|
||||||
|
# Get indices of the top beam_width probabilities
|
||||||
|
topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
|
||||||
|
:beam_width
|
||||||
|
]
|
||||||
|
|
||||||
|
for idx in topn_idx:
|
||||||
|
current_beam = all_beams[idx // logprobs_num]
|
||||||
|
result = output[idx // logprobs_num]
|
||||||
|
token_id = int(all_beams_token_id[idx])
|
||||||
|
assert result.outputs[0].logprobs is not None
|
||||||
|
logprobs_entry = result.outputs[0].logprobs[0]
|
||||||
|
new_beams.append(
|
||||||
|
BeamSearchSequence(
|
||||||
|
tokens=current_beam.tokens + [token_id],
|
||||||
|
logprobs=current_beam.logprobs + [logprobs_entry],
|
||||||
|
lora_request=current_beam.lora_request,
|
||||||
|
cum_logprob=float(all_beams_logprob[idx]),
|
||||||
|
multi_modal_data=current_beam.multi_modal_data,
|
||||||
|
mm_processor_kwargs=current_beam.mm_processor_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
all_beams = new_beams
|
||||||
|
|
||||||
completed.extend(all_beams)
|
completed.extend(all_beams)
|
||||||
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user