mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 09:07:14 +08:00
[Frontend] Use MQLLMEngine for embeddings models too (#8584)
This commit is contained in:
parent
855c8ae2c9
commit
76515f303b
@ -2,6 +2,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Mapping, Optional, Union
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
@ -21,9 +22,9 @@ class MQEngineDeadError(RuntimeError):
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCGenerateRequest:
|
||||
class RPCProcessRequest:
|
||||
inputs: PromptInputs
|
||||
sampling_params: SamplingParams
|
||||
params: Union[SamplingParams, PoolingParams]
|
||||
request_id: str
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
trace_headers: Optional[Mapping[str, str]] = None
|
||||
@ -55,7 +56,7 @@ class RPCStartupResponse:
|
||||
tracing_enabled: bool
|
||||
|
||||
|
||||
RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest,
|
||||
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest,
|
||||
RPCStartupRequest]
|
||||
|
||||
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
|
||||
|
||||
@ -11,6 +11,7 @@ import zmq.asyncio
|
||||
from zmq import Frame # type: ignore[attr-defined]
|
||||
from zmq.asyncio import Socket
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
# yapf conflicts with isort for this block
|
||||
@ -19,8 +20,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||
IPC_OUTPUT_EXT, RPC_REQUEST_T,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
RPCError, RPCGenerateRequest,
|
||||
RPCHealthRequest, RPCStartupRequest,
|
||||
RPCError, RPCHealthRequest,
|
||||
RPCProcessRequest, RPCStartupRequest,
|
||||
RPCStartupResponse)
|
||||
# yapf: enable
|
||||
from vllm.envs import VLLM_RPC_TIMEOUT
|
||||
@ -111,20 +112,8 @@ class MQLLMEngineClient:
|
||||
|
||||
@staticmethod
|
||||
def is_unsupported_config(engine_args: AsyncEngineArgs):
|
||||
if engine_args.pipeline_parallel_size > 1:
|
||||
return True
|
||||
|
||||
is_embedding = ModelConfig(
|
||||
model=engine_args.model,
|
||||
revision=engine_args.revision,
|
||||
tokenizer=engine_args.model,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=engine_args.trust_remote_code,
|
||||
quantization=engine_args.quantization,
|
||||
seed=0,
|
||||
dtype="auto").embedding_mode
|
||||
|
||||
return is_embedding
|
||||
# Pipeline parallel not yet supported
|
||||
return engine_args.pipeline_parallel_size > 1
|
||||
|
||||
@contextmanager
|
||||
def get_data_socket(self) -> Iterator[Socket]:
|
||||
@ -382,12 +371,9 @@ class MQLLMEngineClient:
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
if self._errored_with is not None:
|
||||
return ENGINE_DEAD_ERROR(self._errored_with)
|
||||
else:
|
||||
return ENGINE_DEAD_ERROR()
|
||||
return ENGINE_DEAD_ERROR(self._errored_with)
|
||||
|
||||
async def generate(
|
||||
def generate(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
sampling_params: SamplingParams,
|
||||
@ -396,6 +382,67 @@ class MQLLMEngineClient:
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
request into the waiting queue of the LLMEngine and streams the outputs
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
inputs: The inputs to the LLM. See
|
||||
:class:`~vllm.inputs.PromptInputs`
|
||||
for more details about the format of each input.
|
||||
sampling_params: The sampling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
prompt_adapter_request: Prompt Adapter request to use
|
||||
for generation, if any.
|
||||
"""
|
||||
return self._process_request(inputs, sampling_params, request_id,
|
||||
lora_request, trace_headers,
|
||||
prompt_adapter_request)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
"""Generate outputs for a request from an embedding model.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
request into the waiting queue of the LLMEngine and streams the outputs
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
inputs: The inputs to the LLM. See
|
||||
:class:`~vllm.inputs.PromptInputs`
|
||||
for more details about the format of each input.
|
||||
pooling_params: The pooling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
|
||||
Yields:
|
||||
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
||||
for the request.
|
||||
"""
|
||||
return self._process_request(inputs, pooling_params, request_id,
|
||||
lora_request, trace_headers)
|
||||
|
||||
async def _process_request(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
|
||||
EmbeddingRequestOutput, None]]:
|
||||
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
||||
|
||||
# If already dead, error out.
|
||||
@ -410,19 +457,19 @@ class MQLLMEngineClient:
|
||||
try:
|
||||
# 2) Detach logits processors so that they can be pickled
|
||||
# separately (may require cloudpickle which is slower)
|
||||
if sampling_params.logits_processors:
|
||||
if isinstance(params, SamplingParams) and params.logits_processors:
|
||||
# Defensive shallow copy
|
||||
sampling_params = copy.copy(sampling_params)
|
||||
logits_processors = sampling_params.logits_processors
|
||||
sampling_params.logits_processors = None
|
||||
params = copy.copy(params)
|
||||
logits_processors = params.logits_processors
|
||||
params.logits_processors = None
|
||||
lp_bytes = cloudpickle.dumps(logits_processors)
|
||||
else:
|
||||
lp_bytes = None
|
||||
|
||||
request_bytes = pickle.dumps(
|
||||
RPCGenerateRequest(
|
||||
RPCProcessRequest(
|
||||
inputs=inputs,
|
||||
sampling_params=sampling_params,
|
||||
params=params,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
@ -452,8 +499,3 @@ class MQLLMEngineClient:
|
||||
await self.abort(request_id)
|
||||
finally:
|
||||
self.output_queues.pop(request_id)
|
||||
|
||||
async def encode(self, *args,
|
||||
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
raise NotImplementedError(
|
||||
"Embeddings not supported with multiprocessing backend")
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Iterator, List, Optional, Union
|
||||
import cloudpickle
|
||||
import zmq
|
||||
|
||||
from vllm import AsyncEngineArgs, LLMEngine
|
||||
from vllm import AsyncEngineArgs, LLMEngine, SamplingParams
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
# yapf conflicts with isort for this block
|
||||
@ -15,8 +15,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
RPCError, RPCGenerateRequest,
|
||||
RPCHealthRequest, RPCStartupRequest,
|
||||
RPCError, RPCHealthRequest,
|
||||
RPCProcessRequest, RPCStartupRequest,
|
||||
RPCStartupResponse)
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
@ -39,8 +39,8 @@ class MQLLMEngine:
|
||||
in concurrnet manner. It runs a background loop and uses zeromq to
|
||||
receive new requests and stream outputs incrementally via ipc.
|
||||
|
||||
The :class:`LLMEngine.generate` is kicked off when a new
|
||||
RPCGenerateRequest is received by the input_socket.
|
||||
The :class:`LLMEngine` generate or encode process is kicked off when a new
|
||||
RPCProcessRequest is received by the input_socket.
|
||||
|
||||
The self.engine_loop checks the input_socket for new requests,
|
||||
adds them to the LLMEngine if there are any, calls the internal
|
||||
@ -213,12 +213,13 @@ class MQLLMEngine:
|
||||
frames = self.input_socket.recv_multipart(copy=False)
|
||||
request = pickle.loads(frames[0].buffer)
|
||||
|
||||
if isinstance(request, RPCGenerateRequest):
|
||||
if isinstance(request, RPCProcessRequest):
|
||||
if len(frames) > 1:
|
||||
# Use cloudpickle for logits processors
|
||||
assert isinstance(request.params, SamplingParams)
|
||||
lprocs = cloudpickle.loads(frames[1].buffer)
|
||||
request.sampling_params.logits_processors = lprocs
|
||||
self._handle_generate_request(request)
|
||||
request.params.logits_processors = lprocs
|
||||
self._handle_process_request(request)
|
||||
elif isinstance(request, RPCAbortRequest):
|
||||
self._handle_abort_request(request)
|
||||
elif isinstance(request, RPCHealthRequest):
|
||||
@ -231,8 +232,8 @@ class MQLLMEngine:
|
||||
self._send_unhealthy(e)
|
||||
raise e
|
||||
|
||||
def _handle_generate_request(self, request: RPCGenerateRequest):
|
||||
"""Handle RPCGenerateRequest by adding it to the LLMEngine."""
|
||||
def _handle_process_request(self, request: RPCProcessRequest):
|
||||
"""Handle RPCProcessRequest by adding it to the LLMEngine."""
|
||||
request_id = request.request_id
|
||||
|
||||
if self._errored_with is not None:
|
||||
@ -245,7 +246,7 @@ class MQLLMEngine:
|
||||
self.engine.add_request(
|
||||
request_id=request_id,
|
||||
inputs=request.inputs,
|
||||
params=request.sampling_params,
|
||||
params=request.params,
|
||||
lora_request=request.lora_request,
|
||||
trace_headers=request.trace_headers,
|
||||
prompt_adapter_request=request.prompt_adapter_request)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user