mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-04 05:23:32 +08:00
[Frontend] merge beam search implementations (#9296)
This commit is contained in:
parent
473e7b3606
commit
4d31cd424b
@ -7,7 +7,6 @@ from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
|
|||||||
from weakref import ReferenceType
|
from weakref import ReferenceType
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
|
||||||
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig)
|
ParallelConfig, SchedulerConfig)
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
@ -15,25 +14,24 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
|||||||
from vllm.engine.async_timeout import asyncio_timeout
|
from vllm.engine.async_timeout import asyncio_timeout
|
||||||
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
|
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
|
||||||
from vllm.engine.metrics_types import StatLoggerBase
|
from vllm.engine.metrics_types import StatLoggerBase
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
from vllm.executor.gpu_executor import GPUExecutorAsync
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
from vllm.inputs import PromptType, TokensPrompt
|
from vllm.inputs import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.guided_decoding import (
|
from vllm.model_executor.guided_decoding import (
|
||||||
get_guided_decoding_logits_processor)
|
get_guided_decoding_logits_processor)
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
|
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||||
RequestOutput)
|
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
|
from vllm.utils import deprecate_kwargs, weak_bind
|
||||||
random_uuid, weak_bind)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||||
@ -541,7 +539,7 @@ async def build_guided_decoding_logits_processor_async(
|
|||||||
return sampling_params
|
return sampling_params
|
||||||
|
|
||||||
|
|
||||||
class AsyncLLMEngine:
|
class AsyncLLMEngine(EngineClient):
|
||||||
"""An asynchronous wrapper for :class:`LLMEngine`.
|
"""An asynchronous wrapper for :class:`LLMEngine`.
|
||||||
|
|
||||||
This class is used to wrap the :class:`LLMEngine` class to make it
|
This class is used to wrap the :class:`LLMEngine` class to make it
|
||||||
@ -1039,102 +1037,6 @@ class AsyncLLMEngine:
|
|||||||
):
|
):
|
||||||
yield LLMEngine.validate_output(output, RequestOutput)
|
yield LLMEngine.validate_output(output, RequestOutput)
|
||||||
|
|
||||||
async def beam_search(
|
|
||||||
self,
|
|
||||||
prompt: Union[PromptType, List[int]],
|
|
||||||
request_id: str,
|
|
||||||
params: BeamSearchParams,
|
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
|
||||||
|
|
||||||
beam_width = params.beam_width
|
|
||||||
max_tokens = params.max_tokens
|
|
||||||
ignore_eos = params.ignore_eos
|
|
||||||
temperature = params.temperature
|
|
||||||
length_penalty = params.length_penalty
|
|
||||||
|
|
||||||
tokenizer = await self.get_tokenizer()
|
|
||||||
tokenizedPrompt = prompt if isinstance(
|
|
||||||
prompt, list) else tokenizer.encode(prompt)
|
|
||||||
tokenizedLength = len(tokenizedPrompt)
|
|
||||||
|
|
||||||
sort_beams_key = create_sort_beams_key_function(
|
|
||||||
tokenizer.eos_token_id, length_penalty)
|
|
||||||
|
|
||||||
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
|
||||||
max_tokens=1,
|
|
||||||
temperature=temperature)
|
|
||||||
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
|
|
||||||
completed = []
|
|
||||||
|
|
||||||
for _ in range(max_tokens):
|
|
||||||
prompts_batch = [
|
|
||||||
TokensPrompt(prompt_token_ids=beam.tokens)
|
|
||||||
for beam in all_beams
|
|
||||||
]
|
|
||||||
|
|
||||||
tasks = []
|
|
||||||
|
|
||||||
request_id = f"beam_search-{random_uuid()}"
|
|
||||||
for i, individual_prompt in enumerate(prompts_batch):
|
|
||||||
request_id_item = f"{request_id}-{i}"
|
|
||||||
task = asyncio.create_task(
|
|
||||||
collect_from_async_generator(
|
|
||||||
self.generate(individual_prompt, beam_search_params,
|
|
||||||
request_id_item)))
|
|
||||||
tasks.append(task)
|
|
||||||
|
|
||||||
output = await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
output = [x[0] for x in output]
|
|
||||||
|
|
||||||
logger.info(output)
|
|
||||||
|
|
||||||
new_beams = []
|
|
||||||
for i, current_beam in enumerate(all_beams):
|
|
||||||
result = output[i]
|
|
||||||
|
|
||||||
if result.outputs[0].logprobs is not None:
|
|
||||||
logprobs = result.outputs[0].logprobs[0]
|
|
||||||
for token_id, logprob_obj in logprobs.items():
|
|
||||||
new_beam = BeamSearchSequence(
|
|
||||||
tokens=current_beam.tokens + [token_id],
|
|
||||||
cum_logprob=current_beam.cum_logprob +
|
|
||||||
logprob_obj.logprob)
|
|
||||||
|
|
||||||
if token_id == tokenizer.eos_token_id and \
|
|
||||||
not ignore_eos:
|
|
||||||
completed.append(new_beam)
|
|
||||||
else:
|
|
||||||
new_beams.append(new_beam)
|
|
||||||
|
|
||||||
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
|
|
||||||
all_beams = sorted_beams[:beam_width]
|
|
||||||
|
|
||||||
completed.extend(all_beams)
|
|
||||||
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
|
||||||
best_beams = sorted_completed[:beam_width]
|
|
||||||
|
|
||||||
for beam in best_beams:
|
|
||||||
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
|
|
||||||
|
|
||||||
beam_search_output = RequestOutput(
|
|
||||||
request_id=request_id,
|
|
||||||
prompt=prompt,
|
|
||||||
outputs=[
|
|
||||||
CompletionOutput(
|
|
||||||
text=beam.text,
|
|
||||||
cumulative_logprob=beam.cum_logprob,
|
|
||||||
token_ids=beam.tokens,
|
|
||||||
index=i,
|
|
||||||
logprobs=beam.cum_logprob,
|
|
||||||
) for (i, beam) in enumerate(best_beams)
|
|
||||||
],
|
|
||||||
finished=True,
|
|
||||||
prompt_token_ids=tokenizedPrompt,
|
|
||||||
prompt_logprobs=None)
|
|
||||||
|
|
||||||
yield LLMEngine.validate_output(beam_search_output, RequestOutput)
|
|
||||||
|
|
||||||
async def encode(
|
async def encode(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
prompt: PromptType,
|
||||||
|
|||||||
@ -12,8 +12,8 @@ from zmq import Frame # type: ignore[attr-defined]
|
|||||||
from zmq.asyncio import Socket
|
from zmq.asyncio import Socket
|
||||||
|
|
||||||
from vllm import PoolingParams
|
from vllm import PoolingParams
|
||||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
|
||||||
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
|
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
|
||||||
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -26,18 +26,18 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
|||||||
RPCError, RPCProcessRequest,
|
RPCError, RPCProcessRequest,
|
||||||
RPCStartupRequest, RPCStartupResponse,
|
RPCStartupRequest, RPCStartupResponse,
|
||||||
RPCUProfileRequest)
|
RPCUProfileRequest)
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.envs import VLLM_RPC_TIMEOUT
|
from vllm.envs import VLLM_RPC_TIMEOUT
|
||||||
from vllm.inputs import PromptType, TokensPrompt
|
from vllm.inputs import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
RequestOutput)
|
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||||
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
|
from vllm.utils import deprecate_kwargs
|
||||||
random_uuid)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ class MQClientClosedError(Exception):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class MQLLMEngineClient:
|
class MQLLMEngineClient(EngineClient):
|
||||||
"""A client wrapper for MQLLMEngine that conforms to the
|
"""A client wrapper for MQLLMEngine that conforms to the
|
||||||
EngineClient protocol.
|
EngineClient protocol.
|
||||||
|
|
||||||
@ -316,7 +316,7 @@ class MQLLMEngineClient:
|
|||||||
or response != VLLM_RPC_SUCCESS_STR):
|
or response != VLLM_RPC_SUCCESS_STR):
|
||||||
raise ValueError(error_message)
|
raise ValueError(error_message)
|
||||||
|
|
||||||
async def get_tokenizer(self, lora_request: LoRARequest):
|
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
|
||||||
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
|
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
|
||||||
|
|
||||||
async def get_decoding_config(self) -> DecodingConfig:
|
async def get_decoding_config(self) -> DecodingConfig:
|
||||||
@ -344,8 +344,14 @@ class MQLLMEngineClient:
|
|||||||
await self._send_one_way_rpc_request(
|
await self._send_one_way_rpc_request(
|
||||||
request=RPCAbortRequest(request_id), socket=self.input_socket)
|
request=RPCAbortRequest(request_id), socket=self.input_socket)
|
||||||
|
|
||||||
async def do_log_stats(self):
|
async def do_log_stats(
|
||||||
"""Ignore do_log_stats (handled on MQLLMEngine polling)"""
|
self,
|
||||||
|
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||||
|
model_output: Optional[List[SamplerOutput]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Ignore do_log_stats (handled on MQLLMEngine polling)
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def check_health(self):
|
async def check_health(self):
|
||||||
@ -444,104 +450,6 @@ class MQLLMEngineClient:
|
|||||||
lora_request, trace_headers,
|
lora_request, trace_headers,
|
||||||
prompt_adapter_request, priority)
|
prompt_adapter_request, priority)
|
||||||
|
|
||||||
async def beam_search(
|
|
||||||
self,
|
|
||||||
prompt: Union[PromptType, List[int]],
|
|
||||||
request_id: str,
|
|
||||||
params: BeamSearchParams,
|
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
|
||||||
|
|
||||||
beam_width = params.beam_width
|
|
||||||
max_tokens = params.max_tokens
|
|
||||||
ignore_eos = params.ignore_eos
|
|
||||||
temperature = params.temperature
|
|
||||||
length_penalty = params.length_penalty
|
|
||||||
|
|
||||||
tokenizer = await self.get_tokenizer(lora_request=None)
|
|
||||||
tokenizedPrompt = prompt if isinstance(
|
|
||||||
prompt, list) else tokenizer.encode(prompt)
|
|
||||||
tokenizedLength = len(tokenizedPrompt)
|
|
||||||
|
|
||||||
sort_beams_key = create_sort_beams_key_function(
|
|
||||||
tokenizer.eos_token_id, length_penalty)
|
|
||||||
|
|
||||||
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
|
||||||
max_tokens=1,
|
|
||||||
temperature=temperature)
|
|
||||||
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
|
|
||||||
completed = []
|
|
||||||
|
|
||||||
for _ in range(max_tokens):
|
|
||||||
prompts_batch = [
|
|
||||||
TokensPrompt(prompt_token_ids=beam.tokens)
|
|
||||||
for beam in all_beams
|
|
||||||
]
|
|
||||||
|
|
||||||
tasks = []
|
|
||||||
|
|
||||||
request_id = f"beam_search-{random_uuid()}"
|
|
||||||
for i, individual_prompt in enumerate(prompts_batch):
|
|
||||||
request_id_item = f"{request_id}-{i}"
|
|
||||||
task = asyncio.create_task(
|
|
||||||
collect_from_async_generator(
|
|
||||||
self.generate(individual_prompt, beam_search_params,
|
|
||||||
request_id_item)))
|
|
||||||
tasks.append(task)
|
|
||||||
|
|
||||||
output = await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
output = [x[0] for x in output]
|
|
||||||
|
|
||||||
logger.info(output)
|
|
||||||
|
|
||||||
new_beams = []
|
|
||||||
for i, current_beam in enumerate(all_beams):
|
|
||||||
result = output[i]
|
|
||||||
|
|
||||||
if result.outputs[0].logprobs is not None:
|
|
||||||
logprobs = result.outputs[0].logprobs[0]
|
|
||||||
for token_id, logprob_obj in logprobs.items():
|
|
||||||
new_beam = BeamSearchSequence(
|
|
||||||
tokens=current_beam.tokens + [token_id],
|
|
||||||
cum_logprob=current_beam.cum_logprob +
|
|
||||||
logprob_obj.logprob)
|
|
||||||
|
|
||||||
if token_id == tokenizer.eos_token_id and \
|
|
||||||
not ignore_eos:
|
|
||||||
completed.append(new_beam)
|
|
||||||
else:
|
|
||||||
new_beams.append(new_beam)
|
|
||||||
|
|
||||||
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
|
|
||||||
all_beams = sorted_beams[:beam_width]
|
|
||||||
|
|
||||||
completed.extend(all_beams)
|
|
||||||
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
|
||||||
best_beams = sorted_completed[:beam_width]
|
|
||||||
|
|
||||||
for beam in best_beams:
|
|
||||||
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
|
|
||||||
|
|
||||||
beam_search_output = RequestOutput(
|
|
||||||
request_id=request_id,
|
|
||||||
prompt=prompt,
|
|
||||||
outputs=[
|
|
||||||
CompletionOutput(
|
|
||||||
text=beam.text,
|
|
||||||
cumulative_logprob=beam.cum_logprob,
|
|
||||||
token_ids=beam.tokens,
|
|
||||||
index=i,
|
|
||||||
logprobs=beam.cum_logprob,
|
|
||||||
) for (i, beam) in enumerate(best_beams)
|
|
||||||
],
|
|
||||||
finished=True,
|
|
||||||
prompt_token_ids=tokenizedPrompt,
|
|
||||||
prompt_logprobs=None)
|
|
||||||
|
|
||||||
logger.info(beam_search_output)
|
|
||||||
|
|
||||||
yield beam_search_output
|
|
||||||
|
|
||||||
@overload # DEPRECATED
|
@overload # DEPRECATED
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,38 +1,49 @@
|
|||||||
from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
|
import asyncio
|
||||||
runtime_checkable)
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import AsyncGenerator, List, Mapping, Optional, Union
|
||||||
|
|
||||||
|
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||||
from vllm.config import DecodingConfig, ModelConfig
|
from vllm.config import DecodingConfig, ModelConfig
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
from vllm.inputs.data import PromptType
|
from vllm.inputs.data import PromptType, TokensPrompt
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
|
||||||
|
RequestOutput)
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.utils import collect_from_async_generator, random_uuid
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
class EngineClient(ABC):
|
||||||
class EngineClient(Protocol):
|
|
||||||
"""Protocol class for Clients to Engine"""
|
"""Protocol class for Clients to Engine"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@abstractmethod
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
...
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@abstractmethod
|
||||||
def is_stopped(self) -> bool:
|
def is_stopped(self) -> bool:
|
||||||
...
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@abstractmethod
|
||||||
def errored(self) -> bool:
|
def errored(self) -> bool:
|
||||||
...
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@abstractmethod
|
||||||
def dead_error(self) -> BaseException:
|
def dead_error(self) -> BaseException:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
prompt: PromptType,
|
||||||
@ -46,6 +57,101 @@ class EngineClient(Protocol):
|
|||||||
"""Generate outputs for a request."""
|
"""Generate outputs for a request."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
async def beam_search(
|
||||||
|
self,
|
||||||
|
prompt: Union[PromptType, List[int]],
|
||||||
|
request_id: str,
|
||||||
|
params: BeamSearchParams,
|
||||||
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
|
||||||
|
beam_width = params.beam_width
|
||||||
|
max_tokens = params.max_tokens
|
||||||
|
ignore_eos = params.ignore_eos
|
||||||
|
temperature = params.temperature
|
||||||
|
length_penalty = params.length_penalty
|
||||||
|
|
||||||
|
tokenizer = await self.get_tokenizer(lora_request=None)
|
||||||
|
tokenizedPrompt = prompt if isinstance(
|
||||||
|
prompt, list) else tokenizer.encode(prompt)
|
||||||
|
tokenizedLength = len(tokenizedPrompt)
|
||||||
|
|
||||||
|
sort_beams_key = create_sort_beams_key_function(
|
||||||
|
tokenizer.eos_token_id, length_penalty)
|
||||||
|
|
||||||
|
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
||||||
|
max_tokens=1,
|
||||||
|
temperature=temperature)
|
||||||
|
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
|
||||||
|
completed = []
|
||||||
|
|
||||||
|
for _ in range(max_tokens):
|
||||||
|
prompts_batch = [
|
||||||
|
TokensPrompt(prompt_token_ids=beam.tokens)
|
||||||
|
for beam in all_beams
|
||||||
|
]
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
request_id = f"beam_search-{random_uuid()}"
|
||||||
|
for i, individual_prompt in enumerate(prompts_batch):
|
||||||
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
task = asyncio.create_task(
|
||||||
|
collect_from_async_generator(
|
||||||
|
self.generate(individual_prompt, beam_search_params,
|
||||||
|
request_id_item)))
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
output = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
output = [x[0] for x in output]
|
||||||
|
|
||||||
|
new_beams = []
|
||||||
|
for i, current_beam in enumerate(all_beams):
|
||||||
|
result = output[i]
|
||||||
|
|
||||||
|
if result.outputs[0].logprobs is not None:
|
||||||
|
logprobs = result.outputs[0].logprobs[0]
|
||||||
|
for token_id, logprob_obj in logprobs.items():
|
||||||
|
new_beam = BeamSearchSequence(
|
||||||
|
tokens=current_beam.tokens + [token_id],
|
||||||
|
cum_logprob=current_beam.cum_logprob +
|
||||||
|
logprob_obj.logprob)
|
||||||
|
|
||||||
|
if token_id == tokenizer.eos_token_id and \
|
||||||
|
not ignore_eos:
|
||||||
|
completed.append(new_beam)
|
||||||
|
else:
|
||||||
|
new_beams.append(new_beam)
|
||||||
|
|
||||||
|
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
|
||||||
|
all_beams = sorted_beams[:beam_width]
|
||||||
|
|
||||||
|
completed.extend(all_beams)
|
||||||
|
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
|
||||||
|
best_beams = sorted_completed[:beam_width]
|
||||||
|
|
||||||
|
for beam in best_beams:
|
||||||
|
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
|
||||||
|
|
||||||
|
beam_search_output = RequestOutput(
|
||||||
|
request_id=request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
outputs=[
|
||||||
|
CompletionOutput(
|
||||||
|
text=beam.text,
|
||||||
|
cumulative_logprob=beam.cum_logprob,
|
||||||
|
token_ids=beam.tokens,
|
||||||
|
index=i,
|
||||||
|
logprobs=beam.cum_logprob,
|
||||||
|
) for (i, beam) in enumerate(best_beams)
|
||||||
|
],
|
||||||
|
finished=True,
|
||||||
|
prompt_token_ids=tokenizedPrompt,
|
||||||
|
prompt_logprobs=None)
|
||||||
|
|
||||||
|
yield beam_search_output
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
prompt: PromptType,
|
||||||
@ -58,6 +164,7 @@ class EngineClient(Protocol):
|
|||||||
"""Generate outputs for a request from an embedding model."""
|
"""Generate outputs for a request from an embedding model."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def abort(self, request_id: str) -> None:
|
async def abort(self, request_id: str) -> None:
|
||||||
"""Abort a request.
|
"""Abort a request.
|
||||||
|
|
||||||
@ -65,14 +172,17 @@ class EngineClient(Protocol):
|
|||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def get_model_config(self) -> ModelConfig:
|
async def get_model_config(self) -> ModelConfig:
|
||||||
"""Get the model configuration of the vLLM engine."""
|
"""Get the model configuration of the vLLM engine."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def get_decoding_config(self) -> DecodingConfig:
|
async def get_decoding_config(self) -> DecodingConfig:
|
||||||
...
|
...
|
||||||
"""Get the decoding configuration of the vLLM engine."""
|
"""Get the decoding configuration of the vLLM engine."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def get_tokenizer(
|
async def get_tokenizer(
|
||||||
self,
|
self,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -80,9 +190,11 @@ class EngineClient(Protocol):
|
|||||||
"""Get the appropriate tokenizer for the request"""
|
"""Get the appropriate tokenizer for the request"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def is_tracing_enabled(self) -> bool:
|
async def is_tracing_enabled(self) -> bool:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def do_log_stats(
|
async def do_log_stats(
|
||||||
self,
|
self,
|
||||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||||
@ -90,14 +202,17 @@ class EngineClient(Protocol):
|
|||||||
) -> None:
|
) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def check_health(self) -> None:
|
async def check_health(self) -> None:
|
||||||
"""Raise if unhealthy"""
|
"""Raise if unhealthy"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def start_profile(self) -> None:
|
async def start_profile(self) -> None:
|
||||||
"""Start profiling the engine"""
|
"""Start profiling the engine"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
async def stop_profile(self) -> None:
|
async def stop_profile(self) -> None:
|
||||||
"""Start profiling the engine"""
|
"""Start profiling the engine"""
|
||||||
...
|
...
|
||||||
|
|||||||
@ -9,8 +9,6 @@ from typing import Union
|
|||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
||||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||||
apply_hf_chat_template,
|
apply_hf_chat_template,
|
||||||
@ -237,11 +235,6 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
log_tracing_disabled_warning()
|
log_tracing_disabled_warning()
|
||||||
|
|
||||||
if isinstance(sampling_params, BeamSearchParams):
|
if isinstance(sampling_params, BeamSearchParams):
|
||||||
assert isinstance(self.engine_client,
|
|
||||||
(AsyncLLMEngine,
|
|
||||||
MQLLMEngineClient)), \
|
|
||||||
"Beam search is only supported with" \
|
|
||||||
"AsyncLLMEngine and MQLLMEngineClient."
|
|
||||||
result_generator = self.engine_client.beam_search(
|
result_generator = self.engine_client.beam_search(
|
||||||
engine_inputs['prompt_token_ids'],
|
engine_inputs['prompt_token_ids'],
|
||||||
request_id,
|
request_id,
|
||||||
|
|||||||
@ -8,8 +8,6 @@ from typing import Tuple, Union, cast
|
|||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
||||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
@ -151,11 +149,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
log_tracing_disabled_warning()
|
log_tracing_disabled_warning()
|
||||||
|
|
||||||
if isinstance(sampling_params, BeamSearchParams):
|
if isinstance(sampling_params, BeamSearchParams):
|
||||||
assert isinstance(self.engine_client,
|
|
||||||
(AsyncLLMEngine,
|
|
||||||
MQLLMEngineClient)), \
|
|
||||||
"Beam search is only supported with" \
|
|
||||||
"AsyncLLMEngine and MQLLMEngineClient."
|
|
||||||
generator = self.engine_client.beam_search(
|
generator = self.engine_client.beam_search(
|
||||||
prompt_inputs["prompt_token_ids"],
|
prompt_inputs["prompt_token_ids"],
|
||||||
request_id_item,
|
request_id_item,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user