[Frontend] merge beam search implementations (#9296)

This commit is contained in:
Brendan Wong 2024-10-14 15:05:52 -07:00 committed by GitHub
parent 473e7b3606
commit 4d31cd424b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 145 additions and 234 deletions

View File

@ -7,7 +7,6 @@ from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
from weakref import ReferenceType from weakref import ReferenceType
import vllm.envs as envs import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
@ -15,25 +14,24 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TokensPrompt from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, from vllm.outputs import EmbeddingRequestOutput, RequestOutput
RequestOutput)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (collect_from_async_generator, deprecate_kwargs, from vllm.utils import deprecate_kwargs, weak_bind
random_uuid, weak_bind)
logger = init_logger(__name__) logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@ -541,7 +539,7 @@ async def build_guided_decoding_logits_processor_async(
return sampling_params return sampling_params
class AsyncLLMEngine: class AsyncLLMEngine(EngineClient):
"""An asynchronous wrapper for :class:`LLMEngine`. """An asynchronous wrapper for :class:`LLMEngine`.
This class is used to wrap the :class:`LLMEngine` class to make it This class is used to wrap the :class:`LLMEngine` class to make it
@ -1039,102 +1037,6 @@ class AsyncLLMEngine:
): ):
yield LLMEngine.validate_output(output, RequestOutput) yield LLMEngine.validate_output(output, RequestOutput)
async def beam_search(
self,
prompt: Union[PromptType, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
tokenizer = await self.get_tokenizer()
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)
sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
completed = []
for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]
tasks = []
request_id = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
tasks.append(task)
output = await asyncio.gather(*tasks)
output = [x[0] for x in output]
logger.info(output)
new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(new_beam)
else:
new_beams.append(new_beam)
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
index=i,
logprobs=beam.cum_logprob,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_logprobs=None)
yield LLMEngine.validate_output(beam_search_output, RequestOutput)
async def encode( async def encode(
self, self,
prompt: PromptType, prompt: PromptType,

View File

@ -12,8 +12,8 @@ from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket from zmq.asyncio import Socket
from vllm import PoolingParams from vllm import PoolingParams
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, EngineConfig, ModelConfig from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@ -26,18 +26,18 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCError, RPCProcessRequest, RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse, RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest) RPCUProfileRequest)
from vllm.engine.protocol import EngineClient
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType, TokensPrompt from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, from vllm.model_executor.layers.sampler import SamplerOutput
RequestOutput) from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import (collect_from_async_generator, deprecate_kwargs, from vllm.utils import deprecate_kwargs
random_uuid)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -53,7 +53,7 @@ class MQClientClosedError(Exception):
""" """
class MQLLMEngineClient: class MQLLMEngineClient(EngineClient):
"""A client wrapper for MQLLMEngine that conforms to the """A client wrapper for MQLLMEngine that conforms to the
EngineClient protocol. EngineClient protocol.
@ -316,7 +316,7 @@ class MQLLMEngineClient:
or response != VLLM_RPC_SUCCESS_STR): or response != VLLM_RPC_SUCCESS_STR):
raise ValueError(error_message) raise ValueError(error_message)
async def get_tokenizer(self, lora_request: LoRARequest): async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
return await self.tokenizer.get_lora_tokenizer_async(lora_request) return await self.tokenizer.get_lora_tokenizer_async(lora_request)
async def get_decoding_config(self) -> DecodingConfig: async def get_decoding_config(self) -> DecodingConfig:
@ -344,8 +344,14 @@ class MQLLMEngineClient:
await self._send_one_way_rpc_request( await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id), socket=self.input_socket) request=RPCAbortRequest(request_id), socket=self.input_socket)
async def do_log_stats(self): async def do_log_stats(
"""Ignore do_log_stats (handled on MQLLMEngine polling)""" self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
"""
Ignore do_log_stats (handled on MQLLMEngine polling)
"""
pass pass
async def check_health(self): async def check_health(self):
@ -444,104 +450,6 @@ class MQLLMEngineClient:
lora_request, trace_headers, lora_request, trace_headers,
prompt_adapter_request, priority) prompt_adapter_request, priority)
async def beam_search(
self,
prompt: Union[PromptType, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)
sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
completed = []
for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]
tasks = []
request_id = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
tasks.append(task)
output = await asyncio.gather(*tasks)
output = [x[0] for x in output]
logger.info(output)
new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(new_beam)
else:
new_beams.append(new_beam)
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
index=i,
logprobs=beam.cum_logprob,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_logprobs=None)
logger.info(beam_search_output)
yield beam_search_output
@overload # DEPRECATED @overload # DEPRECATED
def encode( def encode(
self, self,

View File

@ -1,38 +1,49 @@
from typing import (AsyncGenerator, List, Mapping, Optional, Protocol, import asyncio
runtime_checkable) from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional, Union
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
RequestOutput)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import collect_from_async_generator, random_uuid
logger = init_logger(__name__)
@runtime_checkable class EngineClient(ABC):
class EngineClient(Protocol):
"""Protocol class for Clients to Engine""" """Protocol class for Clients to Engine"""
@property @property
@abstractmethod
def is_running(self) -> bool: def is_running(self) -> bool:
... ...
@property @property
@abstractmethod
def is_stopped(self) -> bool: def is_stopped(self) -> bool:
... ...
@property @property
@abstractmethod
def errored(self) -> bool: def errored(self) -> bool:
... ...
@property @property
@abstractmethod
def dead_error(self) -> BaseException: def dead_error(self) -> BaseException:
... ...
@abstractmethod
def generate( def generate(
self, self,
prompt: PromptType, prompt: PromptType,
@ -46,6 +57,101 @@ class EngineClient(Protocol):
"""Generate outputs for a request.""" """Generate outputs for a request."""
... ...
async def beam_search(
self,
prompt: Union[PromptType, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)
sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
completed = []
for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]
tasks = []
request_id = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
tasks.append(task)
output = await asyncio.gather(*tasks)
output = [x[0] for x in output]
new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(new_beam)
else:
new_beams.append(new_beam)
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
index=i,
logprobs=beam.cum_logprob,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_logprobs=None)
yield beam_search_output
@abstractmethod
def encode( def encode(
self, self,
prompt: PromptType, prompt: PromptType,
@ -58,6 +164,7 @@ class EngineClient(Protocol):
"""Generate outputs for a request from an embedding model.""" """Generate outputs for a request from an embedding model."""
... ...
@abstractmethod
async def abort(self, request_id: str) -> None: async def abort(self, request_id: str) -> None:
"""Abort a request. """Abort a request.
@ -65,14 +172,17 @@ class EngineClient(Protocol):
request_id: The unique id of the request. request_id: The unique id of the request.
""" """
@abstractmethod
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""
... ...
@abstractmethod
async def get_decoding_config(self) -> DecodingConfig: async def get_decoding_config(self) -> DecodingConfig:
... ...
"""Get the decoding configuration of the vLLM engine.""" """Get the decoding configuration of the vLLM engine."""
@abstractmethod
async def get_tokenizer( async def get_tokenizer(
self, self,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
@ -80,9 +190,11 @@ class EngineClient(Protocol):
"""Get the appropriate tokenizer for the request""" """Get the appropriate tokenizer for the request"""
... ...
@abstractmethod
async def is_tracing_enabled(self) -> bool: async def is_tracing_enabled(self) -> bool:
... ...
@abstractmethod
async def do_log_stats( async def do_log_stats(
self, self,
scheduler_outputs: Optional[SchedulerOutputs] = None, scheduler_outputs: Optional[SchedulerOutputs] = None,
@ -90,14 +202,17 @@ class EngineClient(Protocol):
) -> None: ) -> None:
... ...
@abstractmethod
async def check_health(self) -> None: async def check_health(self) -> None:
"""Raise if unhealthy""" """Raise if unhealthy"""
... ...
@abstractmethod
async def start_profile(self) -> None: async def start_profile(self) -> None:
"""Start profiling the engine""" """Start profiling the engine"""
... ...
@abstractmethod
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
"""Start profiling the engine""" """Start profiling the engine"""
... ...

View File

@ -9,8 +9,6 @@ from typing import Union
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage, from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_hf_chat_template, apply_hf_chat_template,
@ -237,11 +235,6 @@ class OpenAIServingChat(OpenAIServing):
log_tracing_disabled_warning() log_tracing_disabled_warning()
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
assert isinstance(self.engine_client,
(AsyncLLMEngine,
MQLLMEngineClient)), \
"Beam search is only supported with" \
"AsyncLLMEngine and MQLLMEngineClient."
result_generator = self.engine_client.beam_search( result_generator = self.engine_client.beam_search(
engine_inputs['prompt_token_ids'], engine_inputs['prompt_token_ids'],
request_id, request_id,

View File

@ -8,8 +8,6 @@ from typing import Tuple, Union, cast
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
@ -151,11 +149,6 @@ class OpenAIServingCompletion(OpenAIServing):
log_tracing_disabled_warning() log_tracing_disabled_warning()
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
assert isinstance(self.engine_client,
(AsyncLLMEngine,
MQLLMEngineClient)), \
"Beam search is only supported with" \
"AsyncLLMEngine and MQLLMEngineClient."
generator = self.engine_client.beam_search( generator = self.engine_client.beam_search(
prompt_inputs["prompt_token_ids"], prompt_inputs["prompt_token_ids"],
request_id_item, request_id_item,