[Misc] Rename embedding classes to pooling (#10801)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-01 14:36:51 +08:00 committed by GitHub
parent f877a7d12a
commit d2f058e76c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 166 additions and 123 deletions

View File

@ -10,7 +10,7 @@ prompts = [
# Create an LLM. # Create an LLM.
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True) model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs. # Generate embedding. The output is a list of PoolingRequestOutputs.
outputs = model.encode(prompts) outputs = model.encode(prompts)
# Print the outputs. # Print the outputs.
for output in outputs: for output in outputs:

View File

@ -3,7 +3,7 @@ from typing import List
import pytest import pytest
from vllm import LLM, EmbeddingRequestOutput, PoolingParams from vllm import LLM, PoolingParams, PoolingRequestOutput
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
MODEL_NAME = "intfloat/e5-mistral-7b-instruct" MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
@ -43,8 +43,8 @@ def llm():
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
def assert_outputs_equal(o1: List[EmbeddingRequestOutput], def assert_outputs_equal(o1: List[PoolingRequestOutput],
o2: List[EmbeddingRequestOutput]): o2: List[PoolingRequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2] assert [o.outputs for o in o1] == [o.outputs for o in o2]

View File

@ -3,7 +3,7 @@ import warnings
import pytest import pytest
import torch.cuda import torch.cuda
from vllm.model_executor.models import (is_embedding_model, from vllm.model_executor.models import (is_pooling_model,
is_text_generation_model, is_text_generation_model,
supports_multimodal) supports_multimodal)
from vllm.model_executor.models.adapters import as_embedding_model from vllm.model_executor.models.adapters import as_embedding_model
@ -31,7 +31,7 @@ def test_registry_imports(model_arch):
# All vLLM models should be convertible to an embedding model # All vLLM models should be convertible to an embedding model
embed_model = as_embedding_model(model_cls) embed_model = as_embedding_model(model_cls)
assert is_embedding_model(embed_model) assert is_pooling_model(embed_model)
if model_arch in _MULTIMODAL_MODELS: if model_arch in _MULTIMODAL_MODELS:
assert supports_multimodal(model_cls) assert supports_multimodal(model_cls)

View File

@ -8,10 +8,10 @@ from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.embedding_model_runner import (
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.multi_step_model_runner import StatefulModelInput from vllm.worker.multi_step_model_runner import StatefulModelInput
from vllm.worker.pooling_model_runner import (
ModelInputForGPUWithPoolingMetadata)
class MockAttentionBackend(AttentionBackend): class MockAttentionBackend(AttentionBackend):

View File

@ -7,8 +7,8 @@ from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput, from vllm.outputs import (CompletionOutput, PoolingOutput,
EmbeddingRequestOutput, RequestOutput) PoolingRequestOutput, RequestOutput)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
@ -25,8 +25,8 @@ __all__ = [
"SamplingParams", "SamplingParams",
"RequestOutput", "RequestOutput",
"CompletionOutput", "CompletionOutput",
"EmbeddingOutput", "PoolingOutput",
"EmbeddingRequestOutput", "PoolingRequestOutput",
"LLMEngine", "LLMEngine",
"EngineArgs", "EngineArgs",
"AsyncLLMEngine", "AsyncLLMEngine",
@ -34,3 +34,26 @@ __all__ = [
"initialize_ray_cluster", "initialize_ray_cluster",
"PoolingParams", "PoolingParams",
] ]
def __getattr__(name: str):
import warnings
if name == "EmbeddingOutput":
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PoolingOutput
if name == "EmbeddingRequestOutput":
msg = ("EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PoolingRequestOutput
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -359,7 +359,7 @@ class ModelConfig:
# NOTE: Listed from highest to lowest priority, # NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them # in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures), "generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_embedding_model(architectures), "embedding": ModelRegistry.is_pooling_model(architectures),
} }
supported_tasks_lst: List[_Task] = [ supported_tasks_lst: List[_Task] = [
task for task, is_supported in task_support.items() if is_supported task for task, is_supported in task_support.items() if is_supported

View File

@ -25,7 +25,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
@ -74,7 +74,7 @@ STOP_ITERATION = Exception() # Sentinel
class AsyncStream: class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request """A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator.""" that can be iterated over asynchronously via an async generator."""
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
@ -83,7 +83,7 @@ class AsyncStream:
self._queue: asyncio.Queue = asyncio.Queue() self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False self._finished = False
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, def put(self, item: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None: Exception]) -> None:
if not self._finished: if not self._finished:
self._queue.put_nowait(item) self._queue.put_nowait(item)
@ -103,7 +103,7 @@ class AsyncStream:
async def generator( async def generator(
self self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
try: try:
while True: while True:
result = await self._queue.get() result = await self._queue.get()
@ -154,7 +154,7 @@ class RequestTracker:
def process_request_output(self, def process_request_output(self,
request_output: Union[RequestOutput, request_output: Union[RequestOutput,
EmbeddingRequestOutput], PoolingRequestOutput],
*, *,
verbose: bool = False) -> None: verbose: bool = False) -> None:
"""Process a request output from the engine.""" """Process a request output from the engine."""
@ -265,7 +265,7 @@ class _AsyncLLMEngine(LLMEngine):
async def step_async( async def step_async(
self, virtual_engine: int self, virtual_engine: int
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible. The workers are ran asynchronously if possible.
@ -907,7 +907,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[ ) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]: RequestOutput, PoolingRequestOutput], None]]:
... ...
@overload @overload
@ -922,7 +922,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[ ) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]: RequestOutput, PoolingRequestOutput], None]]:
... ...
@deprecate_kwargs( @deprecate_kwargs(
@ -941,7 +941,7 @@ class AsyncLLMEngine(EngineClient):
priority: int = 0, priority: int = 0,
*, *,
inputs: Optional[PromptType] = None, # DEPRECATED inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
if inputs is not None: if inputs is not None:
prompt = inputs prompt = inputs
assert prompt is not None and params is not None assert prompt is not None and params is not None
@ -1070,7 +1070,7 @@ class AsyncLLMEngine(EngineClient):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model. """Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the Generate outputs for a request. This method is a coroutine. It adds the
@ -1088,7 +1088,7 @@ class AsyncLLMEngine(EngineClient):
Only applicable with priority scheduling. Only applicable with priority scheduling.
Yields: Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine The output `PoolingRequestOutput` objects from the LLMEngine
for the request. for the request.
Details: Details:
@ -1141,7 +1141,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority, priority=priority,
): ):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput) yield LLMEngine.validate_output(output, PoolingRequestOutput)
async def abort(self, request_id: str) -> None: async def abort(self, request_id: str) -> None:
"""Abort a request. """Abort a request.

View File

@ -40,7 +40,7 @@ from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor) get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
@ -80,7 +80,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) _O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
@dataclass @dataclass
@ -112,7 +112,7 @@ class SchedulerContext:
def __init__(self, multi_step_stream_outputs: bool = False): def __init__(self, multi_step_stream_outputs: bool = False):
self.output_queue: Deque[OutputData] = deque() self.output_queue: Deque[OutputData] = deque()
self.request_outputs: List[Union[RequestOutput, self.request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = [] PoolingRequestOutput]] = []
self.seq_group_metadata_list: Optional[ self.seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None List[SequenceGroupMetadata]] = None
self.scheduler_outputs: Optional[SchedulerOutputs] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None
@ -1314,7 +1314,7 @@ class LLMEngine:
else: else:
seq.append_token_id(sample.output_token, sample.logprobs) seq.append_token_id(sample.output_token, sample.logprobs)
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
.. figure:: https://i.imgur.com/sv2HssD.png .. figure:: https://i.imgur.com/sv2HssD.png

View File

@ -35,7 +35,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
@ -495,7 +495,7 @@ class MQLLMEngineClient(EngineClient):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
... ...
@overload @overload
@ -507,7 +507,7 @@ class MQLLMEngineClient(EngineClient):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
... ...
@deprecate_kwargs( @deprecate_kwargs(
@ -524,7 +524,7 @@ class MQLLMEngineClient(EngineClient):
priority: int = 0, priority: int = 0,
*, *,
inputs: Optional[PromptType] = None # DEPRECATED inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model. """Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the Generate outputs for a request. This method is a coroutine. It adds the
@ -540,7 +540,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
Yields: Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine The output `PoolingRequestOutput` objects from the LLMEngine
for the request. for the request.
""" """
if inputs is not None: if inputs is not None:
@ -549,7 +549,7 @@ class MQLLMEngineClient(EngineClient):
and request_id is not None) and request_id is not None)
return cast( return cast(
AsyncGenerator[EmbeddingRequestOutput, None], AsyncGenerator[PoolingRequestOutput, None],
self._process_request(prompt, self._process_request(prompt,
pooling_params, pooling_params,
request_id, request_id,
@ -567,7 +567,7 @@ class MQLLMEngineClient(EngineClient):
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]: PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses.""" """Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out. # If already dead, error out.

View File

@ -11,8 +11,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
RequestOutput)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
@ -209,7 +208,7 @@ class EngineClient(ABC):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.""" """Generate outputs for a request from an embedding model."""
... ...

View File

@ -26,7 +26,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import ( from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions) GuidedDecodingRequest, LLMGuidedOptions)
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
@ -679,7 +679,7 @@ class LLM:
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
... ...
@overload # LEGACY: multi (prompt + optional token ids) @overload # LEGACY: multi (prompt + optional token ids)
@ -691,7 +691,7 @@ class LLM:
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
... ...
@overload # LEGACY: single (token ids + optional prompt) @overload # LEGACY: single (token ids + optional prompt)
@ -704,7 +704,7 @@ class LLM:
prompt_token_ids: List[int], prompt_token_ids: List[int],
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
... ...
@overload # LEGACY: multi (token ids + optional prompt) @overload # LEGACY: multi (token ids + optional prompt)
@ -717,7 +717,7 @@ class LLM:
prompt_token_ids: List[List[int]], prompt_token_ids: List[List[int]],
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
... ...
@overload # LEGACY: single or multi token ids [pos-only] @overload # LEGACY: single or multi token ids [pos-only]
@ -728,7 +728,7 @@ class LLM:
prompt_token_ids: Union[List[int], List[List[int]]], prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
... ...
@overload @overload
@ -741,7 +741,7 @@ class LLM:
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
... ...
@deprecate_kwargs( @deprecate_kwargs(
@ -759,7 +759,7 @@ class LLM:
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
This class automatically batches the given prompts, considering This class automatically batches the given prompts, considering
@ -778,7 +778,7 @@ class LLM:
generation, if any. generation, if any.
Returns: Returns:
A list of ``EmbeddingRequestOutput`` objects containing the A list of ``PoolingRequestOutput`` objects containing the
generated embeddings in the same order as the input prompts. generated embeddings in the same order as the input prompts.
Note: Note:
@ -821,7 +821,7 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput) PoolingRequestOutput)
def score( def score(
self, self,
@ -832,7 +832,7 @@ class LLM:
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[PoolingRequestOutput]:
"""Generates similarity scores for all pairs <text,text_pair>. """Generates similarity scores for all pairs <text,text_pair>.
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
@ -854,7 +854,7 @@ class LLM:
generation, if any. generation, if any.
Returns: Returns:
A list of ``EmbeddingRequestOutput`` objects containing the A list of ``PoolingRequestOutput`` objects containing the
generated scores in the same order as the input prompts. generated scores in the same order as the input prompts.
""" """
task = self.llm_engine.model_config.task task = self.llm_engine.model_config.task
@ -943,7 +943,7 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput) PoolingRequestOutput)
def start_profile(self) -> None: def start_profile(self) -> None:
self.llm_engine.start_profile() self.llm_engine.start_profile()
@ -1085,7 +1085,7 @@ class LLM:
def _run_engine( def _run_engine(
self, *, use_tqdm: bool self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
# Initialize tqdm. # Initialize tqdm.
if use_tqdm: if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests() num_requests = self.llm_engine.get_num_unfinished_requests()
@ -1098,7 +1098,7 @@ class LLM:
) )
# Run the engine. # Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
total_in_toks = 0 total_in_toks = 0
total_out_toks = 0 total_out_toks = 0
while self.llm_engine.has_unfinished_requests(): while self.llm_engine.has_unfinished_requests():

View File

@ -18,14 +18,14 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
ErrorResponse, UsageInfo) ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
def _get_embedding( def _get_embedding(
output: EmbeddingOutput, output: PoolingOutput,
encoding_format: Literal["float", "base64"], encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]: ) -> Union[List[float], str]:
if encoding_format == "float": if encoding_format == "float":
@ -40,7 +40,7 @@ def _get_embedding(
def request_output_to_embedding_response( def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str, final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str, created_time: int, model_name: str,
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse: encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = [] data: List[EmbeddingResponseData] = []
@ -169,7 +169,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response(str(e)) return self.create_error_response(str(e))
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
try: try:
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params()
@ -207,7 +207,7 @@ class OpenAIServingEmbedding(OpenAIServing):
num_prompts = len(engine_prompts) num_prompts = len(engine_prompts)
# Non-streaming response # Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]] final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts final_res_batch = [None] * num_prompts
try: try:
async for i, res in result_generator: async for i, res in result_generator:
@ -215,7 +215,7 @@ class OpenAIServingEmbedding(OpenAIServing):
assert all(final_res is not None for final_res in final_res_batch) assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[EmbeddingRequestOutput], final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch) final_res_batch)
response = request_output_to_embedding_response( response = request_output_to_embedding_response(

View File

@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import make_async, merge_async_iterators, random_uuid from vllm.utils import make_async, merge_async_iterators, random_uuid
@ -21,7 +21,7 @@ logger = init_logger(__name__)
def request_output_to_score_response( def request_output_to_score_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str, final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str) -> ScoreResponse: created_time: int, model_name: str) -> ScoreResponse:
data: List[ScoreResponseData] = [] data: List[ScoreResponseData] = []
score = None score = None
@ -133,7 +133,7 @@ class OpenAIServingScores(OpenAIServing):
return self.create_error_response(str(e)) return self.create_error_response(str(e))
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
input_pairs = make_pairs(request.text_1, request.text_2) input_pairs = make_pairs(request.text_1, request.text_2)
@ -194,7 +194,7 @@ class OpenAIServingScores(OpenAIServing):
num_prompts = len(engine_prompts) num_prompts = len(engine_prompts)
# Non-streaming response # Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]] final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts final_res_batch = [None] * num_prompts
try: try:
@ -203,7 +203,7 @@ class OpenAIServingScores(OpenAIServing):
assert all(final_res is not None for final_res in final_res_batch) assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[EmbeddingRequestOutput], final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch) final_res_batch)
response = request_output_to_score_response( response = request_output_to_score_response(

View File

@ -1,15 +1,14 @@
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
SupportsPP, has_inner_state, supports_lora, SupportsPP, has_inner_state, supports_lora,
supports_multimodal, supports_pp) supports_multimodal, supports_pp)
from .interfaces_base import (VllmModelForEmbedding, from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
VllmModelForTextGeneration, is_embedding_model, is_pooling_model, is_text_generation_model)
is_text_generation_model)
from .registry import ModelRegistry from .registry import ModelRegistry
__all__ = [ __all__ = [
"ModelRegistry", "ModelRegistry",
"VllmModelForEmbedding", "VllmModelForPooling",
"is_embedding_model", "is_pooling_model",
"VllmModelForTextGeneration", "VllmModelForTextGeneration",
"is_text_generation_model", "is_text_generation_model",
"HasInnerState", "HasInnerState",

View File

@ -4,7 +4,7 @@ from typing import Any, TypeVar
import torch import torch
import torch.nn as nn import torch.nn as nn
from .interfaces_base import VllmModelForEmbedding, is_embedding_model from .interfaces_base import VllmModelForPooling, is_pooling_model
_T = TypeVar("_T", bound=type[nn.Module]) _T = TypeVar("_T", bound=type[nn.Module])
@ -12,7 +12,7 @@ _T = TypeVar("_T", bound=type[nn.Module])
def as_embedding_model(cls: _T) -> _T: def as_embedding_model(cls: _T) -> _T:
"""Subclass an existing vLLM model to support embeddings.""" """Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models # Avoid modifying existing embedding models
if is_embedding_model(cls): if is_pooling_model(cls):
return cls return cls
# Lazy import # Lazy import
@ -23,7 +23,7 @@ def as_embedding_model(cls: _T) -> _T:
from .utils import AutoWeightsLoader, WeightsMapper from .utils import AutoWeightsLoader, WeightsMapper
class ModelForEmbedding(cls, VllmModelForEmbedding): class ModelForEmbedding(cls, VllmModelForPooling):
def __init__( def __init__(
self, self,

View File

@ -7,7 +7,7 @@ from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import supports_kw from vllm.utils import supports_kw
from .interfaces_base import is_embedding_model from .interfaces_base import is_pooling_model
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
@ -389,4 +389,4 @@ def _supports_cross_encoding(
def supports_cross_encoding( def supports_cross_encoding(
model: Union[Type[object], object], model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: ) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
return is_embedding_model(model) and _supports_cross_encoding(model) return is_pooling_model(model) and _supports_cross_encoding(model)

View File

@ -141,7 +141,7 @@ def is_text_generation_model(
@runtime_checkable @runtime_checkable
class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]): class VllmModelForPooling(VllmModel[C_co, T], Protocol[C_co, T]):
def pooler( def pooler(
self, self,
@ -153,23 +153,22 @@ class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]):
@overload @overload
def is_embedding_model( def is_pooling_model(model: Type[object]) -> TypeIs[Type[VllmModelForPooling]]:
model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]:
... ...
@overload @overload
def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]: def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]:
... ...
def is_embedding_model( def is_pooling_model(
model: Union[Type[object], object], model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]: ) -> Union[TypeIs[Type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
if not is_vllm_model(model): if not is_vllm_model(model):
return False return False
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, VllmModelForEmbedding) return isinstance(model, VllmModelForPooling)
return isinstance(model, VllmModelForEmbedding) return isinstance(model, VllmModelForPooling)

View File

@ -24,7 +24,7 @@ from .adapters import as_embedding_model
from .interfaces import (has_inner_state, is_attention_free, from .interfaces import (has_inner_state, is_attention_free,
supports_cross_encoding, supports_multimodal, supports_cross_encoding, supports_multimodal,
supports_pp) supports_pp)
from .interfaces_base import is_embedding_model, is_text_generation_model from .interfaces_base import is_pooling_model, is_text_generation_model
logger = init_logger(__name__) logger = init_logger(__name__)
@ -211,7 +211,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
class _ModelInfo: class _ModelInfo:
architecture: str architecture: str
is_text_generation_model: bool is_text_generation_model: bool
is_embedding_model: bool is_pooling_model: bool
supports_cross_encoding: bool supports_cross_encoding: bool
supports_multimodal: bool supports_multimodal: bool
supports_pp: bool supports_pp: bool
@ -220,19 +220,19 @@ class _ModelInfo:
@staticmethod @staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
is_embedding_model_ = is_embedding_model(model) is_pooling_model_ = is_pooling_model(model)
if not is_embedding_model_: if not is_pooling_model_:
try: try:
as_embedding_model(model) as_embedding_model(model)
except Exception: except Exception:
pass pass
else: else:
is_embedding_model_ = True is_pooling_model_ = True
return _ModelInfo( return _ModelInfo(
architecture=model.__name__, architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model), is_text_generation_model=is_text_generation_model(model),
is_embedding_model=is_embedding_model_, is_pooling_model=is_pooling_model_,
supports_cross_encoding=supports_cross_encoding(model), supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model), supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model), supports_pp=supports_pp(model),
@ -441,12 +441,12 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_text_generation_model return model_cls.is_text_generation_model
def is_embedding_model( def is_pooling_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, List[str]],
) -> bool: ) -> bool:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_embedding_model return model_cls.is_pooling_model
def is_cross_encoder_model( def is_cross_encoder_model(
self, self,

View File

@ -53,8 +53,8 @@ class CompletionOutput:
@dataclass @dataclass
class EmbeddingOutput: class PoolingOutput:
"""The output data of one completion output of a request. """The output data of one pooling output of a request.
Args: Args:
embedding: The embedding vector, which is a list of floats. The embedding: The embedding vector, which is a list of floats. The
@ -63,7 +63,7 @@ class EmbeddingOutput:
embedding: List[float] embedding: List[float]
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"EmbeddingOutput(" return (f"PoolingOutput("
f"embedding={len(self.embedding)})") f"embedding={len(self.embedding)})")
@ -316,18 +316,18 @@ class RequestOutput:
f"multi_modal_placeholders={self.multi_modal_placeholders})") f"multi_modal_placeholders={self.multi_modal_placeholders})")
class EmbeddingRequestOutput: class PoolingRequestOutput:
""" """
The output data of an embedding request to the LLM. The output data of a pooling request to the LLM.
Args: Args:
request_id (str): A unique identifier for the embedding request. request_id (str): A unique identifier for the pooling request.
outputs (EmbeddingOutput): The embedding results for the given input. outputs (PoolingOutput): The pooling results for the given input.
prompt_token_ids (List[int]): A list of token IDs used in the prompt. prompt_token_ids (List[int]): A list of token IDs used in the prompt.
finished (bool): A flag indicating whether the embedding is completed. finished (bool): A flag indicating whether the pooling is completed.
""" """
def __init__(self, request_id: str, outputs: "EmbeddingOutput", def __init__(self, request_id: str, outputs: "PoolingOutput",
prompt_token_ids: List[int], finished: bool): prompt_token_ids: List[int], finished: bool):
self.request_id = request_id self.request_id = request_id
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
@ -336,11 +336,11 @@ class EmbeddingRequestOutput:
@classmethod @classmethod
def from_seq_group(cls, def from_seq_group(cls,
seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput": seq_group: 'SequenceGroup') -> "PoolingRequestOutput":
if seq_group.embeddings is None: if seq_group.embeddings is None:
raise ValueError( raise ValueError(
"Embeddings are missing in seq_group for EmbeddingRequest.") "Embeddings are missing in seq_group for EmbeddingRequest.")
output = EmbeddingOutput(seq_group.embeddings) output = PoolingOutput(seq_group.embeddings)
prompt_token_ids = seq_group.prompt_token_ids prompt_token_ids = seq_group.prompt_token_ids
finished = seq_group.is_finished() finished = seq_group.is_finished()
@ -348,15 +348,15 @@ class EmbeddingRequestOutput:
def __repr__(self): def __repr__(self):
""" """
Returns a string representation of an EmbeddingRequestOutput instance. Returns a string representation of an PoolingRequestOutput instance.
The representation includes the request_id and the number of outputs, The representation includes the request_id and the number of outputs,
providing a quick overview of the embedding request's results. providing a quick overview of the pooling request's results.
Returns: Returns:
str: A string representation of the EmbeddingRequestOutput instance. str: A string representation of the PoolingRequestOutput instance.
""" """
return (f"EmbeddingRequestOutput(request_id='{self.request_id}', " return (f"PoolingRequestOutput(request_id='{self.request_id}', "
f"outputs={repr(self.outputs)}, " f"outputs={repr(self.outputs)}, "
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
f"finished={self.finished})") f"finished={self.finished})")
@ -415,7 +415,30 @@ class RequestOutputFactory:
# Determine the type based on a condition, for example: # Determine the type based on a condition, for example:
if hasattr(seq_group, if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None: 'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group) return PoolingRequestOutput.from_seq_group(seq_group)
else: else:
return RequestOutput.from_seq_group(seq_group, use_cache, return RequestOutput.from_seq_group(seq_group, use_cache,
seq_id_to_seq_group) seq_id_to_seq_group)
def __getattr__(name: str):
import warnings
if name == "EmbeddingOutput":
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PoolingOutput
if name == "EmbeddingRequestOutput":
msg = ("EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PoolingRequestOutput
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -9,7 +9,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
@ -133,7 +133,7 @@ class AsyncLLM(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
"""Add new request to the AsyncLLM.""" """Add new request to the AsyncLLM."""
if self.detokenizer.is_request_active(request_id): if self.detokenizer.is_request_active(request_id):

View File

@ -1,11 +1,11 @@
import asyncio import asyncio
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union from typing import Any, AsyncGenerator, Callable, Optional, Type, Union
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
class AsyncStream: class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request """A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator.""" that can be iterated over asynchronously via an async generator."""
STOP_ITERATION = Exception() # Sentinel STOP_ITERATION = Exception() # Sentinel
@ -16,7 +16,7 @@ class AsyncStream:
self._queue: asyncio.Queue = asyncio.Queue() self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False self._finished = False
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, def put(self, item: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None: Exception]) -> None:
if not self._finished: if not self._finished:
self._queue.put_nowait(item) self._queue.put_nowait(item)
@ -32,7 +32,7 @@ class AsyncStream:
async def generator( async def generator(
self self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
finished = False finished = False
try: try:
while True: while True:

View File

@ -16,12 +16,12 @@ from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU,
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU): class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU):
""" """
Used by the CPUEmbeddingModelRunner. Used by the CPUPoolingModelRunner.
""" """
pooling_metadata: Optional["PoolingMetadata"] = None pooling_metadata: Optional["PoolingMetadata"] = None
class CPUEmbeddingModelRunner( class CPUPoolingModelRunner(
CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]): CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = ( _model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = (
ModelInputForCPUWithPoolingMetadata) ModelInputForCPUWithPoolingMetadata)

View File

@ -14,9 +14,9 @@ from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_embedding_model_runner import CPUEmbeddingModelRunner
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerBase, LoraNotSupportedWorkerBase, WorkerBase,
WorkerInput) WorkerInput)
@ -164,7 +164,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
else {"return_hidden_states": True} else {"return_hidden_states": True}
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
if self.model_config.task == "embedding": if self.model_config.task == "embedding":
ModelRunnerClass = CPUEmbeddingModelRunner ModelRunnerClass = CPUPoolingModelRunner
elif self.model_config.is_encoder_decoder: elif self.model_config.is_encoder_decoder:
ModelRunnerClass = CPUEncoderDecoderModelRunner ModelRunnerClass = CPUEncoderDecoderModelRunner
self.model_runner: CPUModelRunnerBase = ModelRunnerClass( self.model_runner: CPUModelRunnerBase = ModelRunnerClass(

View File

@ -21,12 +21,12 @@ logger = init_logger(__name__)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
""" """
Used by the EmbeddingModelRunner. Used by the PoolingModelRunner.
""" """
pooling_metadata: Optional["PoolingMetadata"] = None pooling_metadata: Optional["PoolingMetadata"] = None
class EmbeddingModelRunner( class PoolingModelRunner(
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
ModelInputForGPUWithPoolingMetadata) ModelInputForGPUWithPoolingMetadata)
@ -52,7 +52,7 @@ class EmbeddingModelRunner(
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
if num_steps > 1: if num_steps > 1:
raise ValueError( raise ValueError(
"EmbeddingModelRunner does not support multi-step execution.") "PoolingModelRunner does not support multi-step execution.")
if self.lora_config: if self.lora_config:
assert model_input.lora_requests is not None assert model_input.lora_requests is not None

View File

@ -22,9 +22,9 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta) SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput) WorkerInput)
@ -75,7 +75,7 @@ class Worker(LocalOrDistributedWorkerBase):
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_config.task == "embedding": if model_config.task == "embedding":
ModelRunnerClass = EmbeddingModelRunner ModelRunnerClass = PoolingModelRunner
elif self.model_config.is_encoder_decoder: elif self.model_config.is_encoder_decoder:
ModelRunnerClass = EncoderDecoderModelRunner ModelRunnerClass = EncoderDecoderModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass( self.model_runner: GPUModelRunnerBase = ModelRunnerClass(