[Frontend] Use MQLLMEngine for embeddings models too (#8584)

This commit is contained in:
Nick Hill 2024-09-19 17:51:06 +01:00 committed by GitHub
parent 855c8ae2c9
commit 76515f303b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 90 additions and 46 deletions

View File

@ -2,6 +2,7 @@ from dataclasses import dataclass
from enum import Enum
from typing import List, Mapping, Optional, Union
from vllm import PoolingParams
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
@ -21,9 +22,9 @@ class MQEngineDeadError(RuntimeError):
@dataclass
class RPCGenerateRequest:
class RPCProcessRequest:
inputs: PromptInputs
sampling_params: SamplingParams
params: Union[SamplingParams, PoolingParams]
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
@ -55,7 +56,7 @@ class RPCStartupResponse:
tracing_enabled: bool
RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest,
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest,
RPCStartupRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]

View File

@ -11,6 +11,7 @@ import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket
from vllm import PoolingParams
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block
@ -19,8 +20,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCGenerateRequest,
RPCHealthRequest, RPCStartupRequest,
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
@ -111,20 +112,8 @@ class MQLLMEngineClient:
@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
if engine_args.pipeline_parallel_size > 1:
return True
is_embedding = ModelConfig(
model=engine_args.model,
revision=engine_args.revision,
tokenizer=engine_args.model,
tokenizer_mode="auto",
trust_remote_code=engine_args.trust_remote_code,
quantization=engine_args.quantization,
seed=0,
dtype="auto").embedding_mode
return is_embedding
# Pipeline parallel not yet supported
return engine_args.pipeline_parallel_size > 1
@contextmanager
def get_data_socket(self) -> Iterator[Socket]:
@ -382,12 +371,9 @@ class MQLLMEngineClient:
@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
return ENGINE_DEAD_ERROR(self._errored_with)
else:
return ENGINE_DEAD_ERROR()
return ENGINE_DEAD_ERROR(self._errored_with)
async def generate(
def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
@ -396,6 +382,67 @@ class MQLLMEngineClient:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
"""
return self._process_request(inputs, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request)
def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
return self._process_request(inputs, pooling_params, request_id,
lora_request, trace_headers)
async def _process_request(
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out.
@ -410,19 +457,19 @@ class MQLLMEngineClient:
try:
# 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower)
if sampling_params.logits_processors:
if isinstance(params, SamplingParams) and params.logits_processors:
# Defensive shallow copy
sampling_params = copy.copy(sampling_params)
logits_processors = sampling_params.logits_processors
sampling_params.logits_processors = None
params = copy.copy(params)
logits_processors = params.logits_processors
params.logits_processors = None
lp_bytes = cloudpickle.dumps(logits_processors)
else:
lp_bytes = None
request_bytes = pickle.dumps(
RPCGenerateRequest(
RPCProcessRequest(
inputs=inputs,
sampling_params=sampling_params,
params=params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
@ -452,8 +499,3 @@ class MQLLMEngineClient:
await self.abort(request_id)
finally:
self.output_queues.pop(request_id)
async def encode(self, *args,
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")

View File

@ -6,7 +6,7 @@ from typing import Iterator, List, Optional, Union
import cloudpickle
import zmq
from vllm import AsyncEngineArgs, LLMEngine
from vllm import AsyncEngineArgs, LLMEngine, SamplingParams
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block
@ -15,8 +15,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCGenerateRequest,
RPCHealthRequest, RPCStartupRequest,
RPCError, RPCHealthRequest,
RPCProcessRequest, RPCStartupRequest,
RPCStartupResponse)
# yapf: enable
from vllm.logger import init_logger
@ -39,8 +39,8 @@ class MQLLMEngine:
in concurrnet manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc.
The :class:`LLMEngine.generate` is kicked off when a new
RPCGenerateRequest is received by the input_socket.
The :class:`LLMEngine` generate or encode process is kicked off when a new
RPCProcessRequest is received by the input_socket.
The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal
@ -213,12 +213,13 @@ class MQLLMEngine:
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)
if isinstance(request, RPCGenerateRequest):
if isinstance(request, RPCProcessRequest):
if len(frames) > 1:
# Use cloudpickle for logits processors
assert isinstance(request.params, SamplingParams)
lprocs = cloudpickle.loads(frames[1].buffer)
request.sampling_params.logits_processors = lprocs
self._handle_generate_request(request)
request.params.logits_processors = lprocs
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
elif isinstance(request, RPCHealthRequest):
@ -231,8 +232,8 @@ class MQLLMEngine:
self._send_unhealthy(e)
raise e
def _handle_generate_request(self, request: RPCGenerateRequest):
"""Handle RPCGenerateRequest by adding it to the LLMEngine."""
def _handle_process_request(self, request: RPCProcessRequest):
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
request_id = request.request_id
if self._errored_with is not None:
@ -245,7 +246,7 @@ class MQLLMEngine:
self.engine.add_request(
request_id=request_id,
inputs=request.inputs,
params=request.sampling_params,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request)