mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-10 02:20:13 +08:00
[Frontend] API support for beam search for MQLLMEngine (#9117)
This commit is contained in:
parent
e1faa2a598
commit
8c746226c9
@ -495,30 +495,25 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
assert len(batch.choices) == 2
|
assert len(batch.choices) == 2
|
||||||
assert batch.choices[0].text == batch.choices[1].text
|
assert batch.choices[0].text == batch.choices[1].text
|
||||||
|
|
||||||
try:
|
# test n = 2
|
||||||
# test n = 2
|
batch = await client.completions.create(
|
||||||
batch = await client.completions.create(
|
model=model_name,
|
||||||
model=model_name,
|
prompt=prompts,
|
||||||
prompt=prompts,
|
n=2,
|
||||||
n=2,
|
max_tokens=5,
|
||||||
max_tokens=5,
|
temperature=0.0,
|
||||||
temperature=0.0,
|
extra_body=dict(
|
||||||
extra_body=dict(
|
# NOTE: this has to be true for n > 1 in vLLM, but
|
||||||
# NOTE: this has to be true for n > 1 in vLLM, but
|
# not necessary for official client.
|
||||||
# not necessary for official client.
|
use_beam_search=True),
|
||||||
use_beam_search=True),
|
)
|
||||||
)
|
assert len(batch.choices) == 4
|
||||||
assert len(batch.choices) == 4
|
assert batch.choices[0].text != batch.choices[
|
||||||
assert batch.choices[0].text != batch.choices[
|
1].text, "beam search should be different"
|
||||||
1].text, "beam search should be different"
|
assert batch.choices[0].text == batch.choices[
|
||||||
assert batch.choices[0].text == batch.choices[
|
2].text, "two copies of the same prompt should be the same"
|
||||||
2].text, "two copies of the same prompt should be the same"
|
assert batch.choices[1].text == batch.choices[
|
||||||
assert batch.choices[1].text == batch.choices[
|
3].text, "two copies of the same prompt should be the same"
|
||||||
3].text, "two copies of the same prompt should be the same"
|
|
||||||
except BadRequestError as e:
|
|
||||||
# the only allowed exception is when beam search is not supported
|
|
||||||
# in the default mqllmengine
|
|
||||||
assert "--disable-frontend-multiprocessing" in str(e)
|
|
||||||
|
|
||||||
# test streaming
|
# test streaming
|
||||||
batch = await client.completions.create(
|
batch = await client.completions.create(
|
||||||
|
|||||||
61
vllm/beam_search.py
Normal file
61
vllm/beam_search.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BeamSearchSequence:
|
||||||
|
"""A sequence for beam search.
|
||||||
|
It keeps track of the tokens and the log probability of the sequence.
|
||||||
|
The text field is optional and will only be filled when the sequence is
|
||||||
|
about to be returned to the user.
|
||||||
|
"""
|
||||||
|
# The tokens includes the prompt.
|
||||||
|
tokens: List[int]
|
||||||
|
cum_logprob: float = 0.0
|
||||||
|
text: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BeamSearchOutput:
|
||||||
|
"""The output of beam search.
|
||||||
|
It contains the list of the best beam search sequences.
|
||||||
|
The length of the list is equal to the beam width.
|
||||||
|
"""
|
||||||
|
sequences: List[BeamSearchSequence]
|
||||||
|
|
||||||
|
|
||||||
|
class BeamSearchInstance:
|
||||||
|
|
||||||
|
def __init__(self, prompt_tokens: List[int]):
|
||||||
|
self.beams: List[BeamSearchSequence] = [
|
||||||
|
BeamSearchSequence(tokens=prompt_tokens)
|
||||||
|
]
|
||||||
|
self.completed: List[BeamSearchSequence] = []
|
||||||
|
|
||||||
|
|
||||||
|
def get_beam_search_score(
|
||||||
|
tokens: List[int],
|
||||||
|
cumulative_logprob: float,
|
||||||
|
eos_token_id: int,
|
||||||
|
length_penalty: float = 1.0,
|
||||||
|
) -> float:
|
||||||
|
"""Calculate the beam search score with length penalty.
|
||||||
|
|
||||||
|
Adapted from
|
||||||
|
|
||||||
|
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
||||||
|
"""
|
||||||
|
seq_len = len(tokens)
|
||||||
|
if tokens[-1] == eos_token_id:
|
||||||
|
seq_len -= 1
|
||||||
|
|
||||||
|
return cumulative_logprob / (seq_len**length_penalty)
|
||||||
|
|
||||||
|
|
||||||
|
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
|
||||||
|
|
||||||
|
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||||
|
return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id,
|
||||||
|
length_penalty)
|
||||||
|
|
||||||
|
return sort_beams_key
|
||||||
@ -7,6 +7,7 @@ from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
|
|||||||
from weakref import ReferenceType
|
from weakref import ReferenceType
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||||
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig)
|
ParallelConfig, SchedulerConfig)
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
@ -14,7 +15,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
|||||||
from vllm.engine.async_timeout import asyncio_timeout
|
from vllm.engine.async_timeout import asyncio_timeout
|
||||||
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
|
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
|
||||||
from vllm.engine.metrics_types import StatLoggerBase
|
from vllm.engine.metrics_types import StatLoggerBase
|
||||||
from vllm.entrypoints.llm import BeamSearchSequence
|
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
from vllm.executor.gpu_executor import GPUExecutorAsync
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
@ -33,7 +33,7 @@ from vllm.sequence import ExecuteModelRequest
|
|||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
|
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
|
||||||
get_beam_search_score, random_uuid, weak_bind)
|
random_uuid, weak_bind)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||||
@ -1052,16 +1052,14 @@ class AsyncLLMEngine:
|
|||||||
temperature = params.temperature
|
temperature = params.temperature
|
||||||
length_penalty = params.length_penalty
|
length_penalty = params.length_penalty
|
||||||
|
|
||||||
def sort_beams_key(x: BeamSearchSequence) -> float:
|
|
||||||
return get_beam_search_score(x.tokens, x.cum_logprob,
|
|
||||||
tokenizer.eos_token_id,
|
|
||||||
length_penalty)
|
|
||||||
|
|
||||||
tokenizer = await self.get_tokenizer()
|
tokenizer = await self.get_tokenizer()
|
||||||
tokenizedPrompt = prompt if isinstance(
|
tokenizedPrompt = prompt if isinstance(
|
||||||
prompt, list) else tokenizer.encode(prompt)
|
prompt, list) else tokenizer.encode(prompt)
|
||||||
tokenizedLength = len(tokenizedPrompt)
|
tokenizedLength = len(tokenizedPrompt)
|
||||||
|
|
||||||
|
sort_beams_key = create_sort_beams_key_function(
|
||||||
|
tokenizer.eos_token_id, length_penalty)
|
||||||
|
|
||||||
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
||||||
max_tokens=1,
|
max_tokens=1,
|
||||||
temperature=temperature)
|
temperature=temperature)
|
||||||
|
|||||||
@ -2,8 +2,8 @@ import asyncio
|
|||||||
import copy
|
import copy
|
||||||
import pickle
|
import pickle
|
||||||
from contextlib import contextmanager, suppress
|
from contextlib import contextmanager, suppress
|
||||||
from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional,
|
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
|
||||||
Union, overload)
|
Optional, Union, overload)
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import zmq
|
import zmq
|
||||||
@ -12,6 +12,7 @@ from zmq import Frame # type: ignore[attr-defined]
|
|||||||
from zmq.asyncio import Socket
|
from zmq.asyncio import Socket
|
||||||
|
|
||||||
from vllm import PoolingParams
|
from vllm import PoolingParams
|
||||||
|
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||||
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
|
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
@ -27,14 +28,16 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
|||||||
RPCUProfileRequest)
|
RPCUProfileRequest)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.envs import VLLM_RPC_TIMEOUT
|
from vllm.envs import VLLM_RPC_TIMEOUT
|
||||||
from vllm.inputs import PromptType
|
from vllm.inputs import PromptType, TokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
|
||||||
|
RequestOutput)
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||||
from vllm.utils import deprecate_kwargs
|
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
|
||||||
|
random_uuid)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -441,6 +444,104 @@ class MQLLMEngineClient:
|
|||||||
lora_request, trace_headers,
|
lora_request, trace_headers,
|
||||||
prompt_adapter_request, priority)
|
prompt_adapter_request, priority)
|
||||||
|
|
||||||
|
async def beam_search(
|
||||||
|
self,
|
||||||
|
prompt: Union[PromptType, List[int]],
|
||||||
|
request_id: str,
|
||||||
|
params: BeamSearchParams,
|
||||||
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
|
||||||
|
beam_width = params.beam_width
|
||||||
|
max_tokens = params.max_tokens
|
||||||
|
ignore_eos = params.ignore_eos
|
||||||
|
temperature = params.temperature
|
||||||
|
length_penalty = params.length_penalty
|
||||||
|
|
||||||
|
tokenizer = await self.get_tokenizer(lora_request=None)
|
||||||
|
tokenizedPrompt = prompt if isinstance(
|
||||||
|
prompt, list) else tokenizer.encode(prompt)
|
||||||
|
tokenizedLength = len(tokenizedPrompt)
|
||||||
|
|
||||||
|
sort_beams_key = create_sort_beams_key_function(
|
||||||
|
tokenizer.eos_token_id, length_penalty)
|
||||||
|
|
||||||
|
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
||||||
|
max_tokens=1,
|
||||||
|
temperature=temperature)
|
||||||
|
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
|
||||||
|
completed = []
|
||||||
|
|
||||||
|
for _ in range(max_tokens):
|
||||||
|
prompts_batch = [
|
||||||
|
TokensPrompt(prompt_token_ids=beam.tokens)
|
||||||
|
for beam in all_beams
|
||||||
|
]
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
request_id = f"beam_search-{random_uuid()}"
|
||||||
|
for i, individual_prompt in enumerate(prompts_batch):
|
||||||
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
task = asyncio.create_task(
|
||||||
|
collect_from_async_generator(
|
||||||
|
self.generate(individual_prompt, beam_search_params,
|
||||||
|
request_id_item)))
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
output = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
output = [x[0] for x in output]
|
||||||
|
|
||||||
|
logger.info(output)
|
||||||
|
|
||||||
|
new_beams = []
|
||||||
|
for i, current_beam in enumerate(all_beams):
|
||||||
|
result = output[i]
|
||||||
|
|
||||||
|
if result.outputs[0].logprobs is not None:
|
||||||
|
logprobs = result.outputs[0].logprobs[0]
|
||||||
|
for token_id, logprob_obj in logprobs.items():
|
||||||
|
new_beam = BeamSearchSequence(
|
||||||
|
tokens=current_beam.tokens + [token_id],
|
||||||
|
cum_logprob=current_beam.cum_logprob +
|
||||||
|
logprob_obj.logprob)
|
||||||
|
|
||||||
|
if token_id == tokenizer.eos_token_id and \
|
||||||
|
not ignore_eos:
|
||||||
|
completed.append(new_beam)
|
||||||
|
else:
|
||||||
|
new_beams.append(new_beam)
|
||||||
|
|
||||||
|
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
|
||||||
|
all_beams = sorted_beams[:beam_width]
|
||||||
|
|
||||||
|
completed.extend(all_beams)
|
||||||
|
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
||||||
|
best_beams = sorted_completed[:beam_width]
|
||||||
|
|
||||||
|
for beam in best_beams:
|
||||||
|
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
|
||||||
|
|
||||||
|
beam_search_output = RequestOutput(
|
||||||
|
request_id=request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
outputs=[
|
||||||
|
CompletionOutput(
|
||||||
|
text=beam.text,
|
||||||
|
cumulative_logprob=beam.cum_logprob,
|
||||||
|
token_ids=beam.tokens,
|
||||||
|
index=i,
|
||||||
|
logprobs=beam.cum_logprob,
|
||||||
|
) for (i, beam) in enumerate(best_beams)
|
||||||
|
],
|
||||||
|
finished=True,
|
||||||
|
prompt_token_ids=tokenizedPrompt,
|
||||||
|
prompt_logprobs=None)
|
||||||
|
|
||||||
|
logger.info(beam_search_output)
|
||||||
|
|
||||||
|
yield beam_search_output
|
||||||
|
|
||||||
@overload # DEPRECATED
|
@overload # DEPRECATED
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
|
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
|
||||||
Union, cast, overload)
|
Union, cast, overload)
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||||
|
BeamSearchSequence, get_beam_search_score)
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||||
@ -28,43 +29,11 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
|||||||
get_cached_tokenizer)
|
get_cached_tokenizer)
|
||||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score,
|
from vllm.utils import Counter, deprecate_kwargs, is_list_of
|
||||||
is_list_of)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BeamSearchSequence:
|
|
||||||
"""A sequence for beam search.
|
|
||||||
It keeps track of the tokens and the log probability of the sequence.
|
|
||||||
The text field is optional and will only be filled when the sequence is
|
|
||||||
about to be returned to the user.
|
|
||||||
"""
|
|
||||||
# The tokens includes the prompt.
|
|
||||||
tokens: List[int]
|
|
||||||
cum_logprob: float = 0.0
|
|
||||||
text: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BeamSearchOutput:
|
|
||||||
"""The output of beam search.
|
|
||||||
It contains the list of the best beam search sequences.
|
|
||||||
The length of the list is equal to the beam width.
|
|
||||||
"""
|
|
||||||
sequences: List[BeamSearchSequence]
|
|
||||||
|
|
||||||
|
|
||||||
class BeamSearchInstance:
|
|
||||||
|
|
||||||
def __init__(self, prompt_tokens: List[int]):
|
|
||||||
self.beams: List[BeamSearchSequence] = [
|
|
||||||
BeamSearchSequence(tokens=prompt_tokens)
|
|
||||||
]
|
|
||||||
self.completed: List[BeamSearchSequence] = []
|
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
"""An LLM for generating texts from given prompts and sampling parameters.
|
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from fastapi import Request
|
|||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||||
apply_hf_chat_template,
|
apply_hf_chat_template,
|
||||||
@ -236,15 +237,16 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
log_tracing_disabled_warning()
|
log_tracing_disabled_warning()
|
||||||
|
|
||||||
if isinstance(sampling_params, BeamSearchParams):
|
if isinstance(sampling_params, BeamSearchParams):
|
||||||
if not isinstance(self.engine_client, AsyncLLMEngine):
|
assert isinstance(self.engine_client,
|
||||||
raise ValueError(
|
(AsyncLLMEngine,
|
||||||
"Beam search in the API server is only supported with"
|
MQLLMEngineClient)), \
|
||||||
" AsyncLLMEngine. please add "
|
"Beam search is only supported with" \
|
||||||
"`--disable-frontend-multiprocessing` to "
|
"AsyncLLMEngine and MQLLMEngineClient."
|
||||||
"use beam search.")
|
|
||||||
result_generator = self.engine_client.beam_search(
|
result_generator = self.engine_client.beam_search(
|
||||||
engine_inputs['prompt_token_ids'], request_id,
|
engine_inputs['prompt_token_ids'],
|
||||||
sampling_params)
|
request_id,
|
||||||
|
sampling_params,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
result_generator = self.engine_client.generate(
|
result_generator = self.engine_client.generate(
|
||||||
engine_inputs,
|
engine_inputs,
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from fastapi import Request
|
|||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
@ -150,15 +151,16 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
log_tracing_disabled_warning()
|
log_tracing_disabled_warning()
|
||||||
|
|
||||||
if isinstance(sampling_params, BeamSearchParams):
|
if isinstance(sampling_params, BeamSearchParams):
|
||||||
if not isinstance(self.engine_client, AsyncLLMEngine):
|
assert isinstance(self.engine_client,
|
||||||
raise ValueError(
|
(AsyncLLMEngine,
|
||||||
"Beam search in the API server is only supported"
|
MQLLMEngineClient)), \
|
||||||
" with AsyncLLMEngine. please add "
|
"Beam search is only supported with" \
|
||||||
"`--disable-frontend-multiprocessing` to "
|
"AsyncLLMEngine and MQLLMEngineClient."
|
||||||
"use beam search.")
|
|
||||||
generator = self.engine_client.beam_search(
|
generator = self.engine_client.beam_search(
|
||||||
prompt_inputs["prompt_token_ids"], request_id_item,
|
prompt_inputs["prompt_token_ids"],
|
||||||
sampling_params)
|
request_id_item,
|
||||||
|
sampling_params,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
generator = self.engine_client.generate(
|
generator = self.engine_client.generate(
|
||||||
{
|
{
|
||||||
|
|||||||
@ -1370,22 +1370,3 @@ class AtomicCounter:
|
|||||||
@property
|
@property
|
||||||
def value(self):
|
def value(self):
|
||||||
return self._value
|
return self._value
|
||||||
|
|
||||||
|
|
||||||
def get_beam_search_score(
|
|
||||||
tokens: List[int],
|
|
||||||
cumulative_logprob: float,
|
|
||||||
eos_token_id: int,
|
|
||||||
length_penalty: float = 1.0,
|
|
||||||
) -> float:
|
|
||||||
"""Calculate the beam search score with length penalty.
|
|
||||||
|
|
||||||
Adapted from
|
|
||||||
|
|
||||||
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
|
||||||
"""
|
|
||||||
seq_len = len(tokens)
|
|
||||||
if tokens[-1] == eos_token_id:
|
|
||||||
seq_len -= 1
|
|
||||||
|
|
||||||
return cumulative_logprob / (seq_len**length_penalty)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user