[Misc] Rename embedding classes to pooling (#10801)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-01 14:36:51 +08:00 committed by GitHub
parent f877a7d12a
commit d2f058e76c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 166 additions and 123 deletions

View File

@ -10,7 +10,7 @@ prompts = [
# Create an LLM.
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
# Generate embedding. The output is a list of PoolingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:

View File

@ -3,7 +3,7 @@ from typing import List
import pytest
from vllm import LLM, EmbeddingRequestOutput, PoolingParams
from vllm import LLM, PoolingParams, PoolingRequestOutput
from vllm.distributed import cleanup_dist_env_and_memory
MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
@ -43,8 +43,8 @@ def llm():
cleanup_dist_env_and_memory()
def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
o2: List[EmbeddingRequestOutput]):
def assert_outputs_equal(o1: List[PoolingRequestOutput],
o2: List[PoolingRequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2]

View File

@ -3,7 +3,7 @@ import warnings
import pytest
import torch.cuda
from vllm.model_executor.models import (is_embedding_model,
from vllm.model_executor.models import (is_pooling_model,
is_text_generation_model,
supports_multimodal)
from vllm.model_executor.models.adapters import as_embedding_model
@ -31,7 +31,7 @@ def test_registry_imports(model_arch):
# All vLLM models should be convertible to an embedding model
embed_model = as_embedding_model(model_cls)
assert is_embedding_model(embed_model)
assert is_pooling_model(embed_model)
if model_arch in _MULTIMODAL_MODELS:
assert supports_multimodal(model_cls)

View File

@ -8,10 +8,10 @@ from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import CommonAttentionState
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.embedding_model_runner import (
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.multi_step_model_runner import StatefulModelInput
from vllm.worker.pooling_model_runner import (
ModelInputForGPUWithPoolingMetadata)
class MockAttentionBackend(AttentionBackend):

View File

@ -7,8 +7,8 @@ from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
from vllm.outputs import (CompletionOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
@ -25,8 +25,8 @@ __all__ = [
"SamplingParams",
"RequestOutput",
"CompletionOutput",
"EmbeddingOutput",
"EmbeddingRequestOutput",
"PoolingOutput",
"PoolingRequestOutput",
"LLMEngine",
"EngineArgs",
"AsyncLLMEngine",
@ -34,3 +34,26 @@ __all__ = [
"initialize_ray_cluster",
"PoolingParams",
]
def __getattr__(name: str):
import warnings
if name == "EmbeddingOutput":
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PoolingOutput
if name == "EmbeddingRequestOutput":
msg = ("EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PoolingRequestOutput
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -359,7 +359,7 @@ class ModelConfig:
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_embedding_model(architectures),
"embedding": ModelRegistry.is_pooling_model(architectures),
}
supported_tasks_lst: List[_Task] = [
task for task, is_supported in task_support.items() if is_supported

View File

@ -25,7 +25,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
@ -74,7 +74,7 @@ STOP_ITERATION = Exception() # Sentinel
class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator."""
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
@ -83,7 +83,7 @@ class AsyncStream:
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None:
if not self._finished:
self._queue.put_nowait(item)
@ -103,7 +103,7 @@ class AsyncStream:
async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
try:
while True:
result = await self._queue.get()
@ -154,7 +154,7 @@ class RequestTracker:
def process_request_output(self,
request_output: Union[RequestOutput,
EmbeddingRequestOutput],
PoolingRequestOutput],
*,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
@ -265,7 +265,7 @@ class _AsyncLLMEngine(LLMEngine):
async def step_async(
self, virtual_engine: int
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
@ -907,7 +907,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
RequestOutput, PoolingRequestOutput], None]]:
...
@overload
@ -922,7 +922,7 @@ class AsyncLLMEngine(EngineClient):
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
RequestOutput, PoolingRequestOutput], None]]:
...
@deprecate_kwargs(
@ -941,7 +941,7 @@ class AsyncLLMEngine(EngineClient):
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
@ -1070,7 +1070,7 @@ class AsyncLLMEngine(EngineClient):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
@ -1088,7 +1088,7 @@ class AsyncLLMEngine(EngineClient):
Only applicable with priority scheduling.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
The output `PoolingRequestOutput` objects from the LLMEngine
for the request.
Details:
@ -1141,7 +1141,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers=trace_headers,
priority=priority,
):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
yield LLMEngine.validate_output(output, PoolingRequestOutput)
async def abort(self, request_id: str) -> None:
"""Abort a request.

View File

@ -40,7 +40,7 @@ from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
@ -80,7 +80,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
@dataclass
@ -112,7 +112,7 @@ class SchedulerContext:
def __init__(self, multi_step_stream_outputs: bool = False):
self.output_queue: Deque[OutputData] = deque()
self.request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
PoolingRequestOutput]] = []
self.seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None
self.scheduler_outputs: Optional[SchedulerOutputs] = None
@ -1314,7 +1314,7 @@ class LLMEngine:
else:
seq.append_token_id(sample.output_token, sample.logprobs)
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
.. figure:: https://i.imgur.com/sv2HssD.png

View File

@ -35,7 +35,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
@ -495,7 +495,7 @@ class MQLLMEngineClient(EngineClient):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
...
@overload
@ -507,7 +507,7 @@ class MQLLMEngineClient(EngineClient):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
...
@deprecate_kwargs(
@ -524,7 +524,7 @@ class MQLLMEngineClient(EngineClient):
priority: int = 0,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
@ -540,7 +540,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: OpenTelemetry trace headers.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
The output `PoolingRequestOutput` objects from the LLMEngine
for the request.
"""
if inputs is not None:
@ -549,7 +549,7 @@ class MQLLMEngineClient(EngineClient):
and request_id is not None)
return cast(
AsyncGenerator[EmbeddingRequestOutput, None],
AsyncGenerator[PoolingRequestOutput, None],
self._process_request(prompt,
pooling_params,
request_id,
@ -567,7 +567,7 @@ class MQLLMEngineClient(EngineClient):
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
EmbeddingRequestOutput, None]]:
PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out.

View File

@ -11,8 +11,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
RequestOutput)
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
@ -209,7 +208,7 @@ class EngineClient(ABC):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from an embedding model."""
...

View File

@ -26,7 +26,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions)
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
@ -679,7 +679,7 @@ class LLM:
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
@ -691,7 +691,7 @@ class LLM:
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
@ -704,7 +704,7 @@ class LLM:
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
@ -717,7 +717,7 @@ class LLM:
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
@ -728,7 +728,7 @@ class LLM:
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@overload
@ -741,7 +741,7 @@ class LLM:
Sequence[PoolingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
...
@deprecate_kwargs(
@ -759,7 +759,7 @@ class LLM:
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
"""Generates the completions for the input prompts.
This class automatically batches the given prompts, considering
@ -778,7 +778,7 @@ class LLM:
generation, if any.
Returns:
A list of ``EmbeddingRequestOutput`` objects containing the
A list of ``PoolingRequestOutput`` objects containing the
generated embeddings in the same order as the input prompts.
Note:
@ -821,7 +821,7 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput)
PoolingRequestOutput)
def score(
self,
@ -832,7 +832,7 @@ class LLM:
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]:
) -> List[PoolingRequestOutput]:
"""Generates similarity scores for all pairs <text,text_pair>.
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
@ -854,7 +854,7 @@ class LLM:
generation, if any.
Returns:
A list of ``EmbeddingRequestOutput`` objects containing the
A list of ``PoolingRequestOutput`` objects containing the
generated scores in the same order as the input prompts.
"""
task = self.llm_engine.model_config.task
@ -943,7 +943,7 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput)
PoolingRequestOutput)
def start_profile(self) -> None:
self.llm_engine.start_profile()
@ -1085,7 +1085,7 @@ class LLM:
def _run_engine(
self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
@ -1098,7 +1098,7 @@ class LLM:
)
# Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
total_in_toks = 0
total_out_toks = 0
while self.llm_engine.has_unfinished_requests():

View File

@ -18,14 +18,14 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
def _get_embedding(
output: EmbeddingOutput,
output: PoolingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]:
if encoding_format == "float":
@ -40,7 +40,7 @@ def _get_embedding(
def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str,
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = []
@ -169,7 +169,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
@ -207,7 +207,7 @@ class OpenAIServingEmbedding(OpenAIServing):
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
@ -215,7 +215,7 @@ class OpenAIServingEmbedding(OpenAIServing):
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)
response = request_output_to_embedding_response(

View File

@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import make_async, merge_async_iterators, random_uuid
@ -21,7 +21,7 @@ logger = init_logger(__name__)
def request_output_to_score_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str) -> ScoreResponse:
data: List[ScoreResponseData] = []
score = None
@ -133,7 +133,7 @@ class OpenAIServingScores(OpenAIServing):
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
input_pairs = make_pairs(request.text_1, request.text_2)
@ -194,7 +194,7 @@ class OpenAIServingScores(OpenAIServing):
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts
try:
@ -203,7 +203,7 @@ class OpenAIServingScores(OpenAIServing):
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)
response = request_output_to_score_response(

View File

@ -1,15 +1,14 @@
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
SupportsPP, has_inner_state, supports_lora,
supports_multimodal, supports_pp)
from .interfaces_base import (VllmModelForEmbedding,
VllmModelForTextGeneration, is_embedding_model,
is_text_generation_model)
from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
is_pooling_model, is_text_generation_model)
from .registry import ModelRegistry
__all__ = [
"ModelRegistry",
"VllmModelForEmbedding",
"is_embedding_model",
"VllmModelForPooling",
"is_pooling_model",
"VllmModelForTextGeneration",
"is_text_generation_model",
"HasInnerState",
@ -20,4 +19,4 @@ __all__ = [
"supports_multimodal",
"SupportsPP",
"supports_pp",
]
]

View File

@ -4,7 +4,7 @@ from typing import Any, TypeVar
import torch
import torch.nn as nn
from .interfaces_base import VllmModelForEmbedding, is_embedding_model
from .interfaces_base import VllmModelForPooling, is_pooling_model
_T = TypeVar("_T", bound=type[nn.Module])
@ -12,7 +12,7 @@ _T = TypeVar("_T", bound=type[nn.Module])
def as_embedding_model(cls: _T) -> _T:
"""Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models
if is_embedding_model(cls):
if is_pooling_model(cls):
return cls
# Lazy import
@ -23,7 +23,7 @@ def as_embedding_model(cls: _T) -> _T:
from .utils import AutoWeightsLoader, WeightsMapper
class ModelForEmbedding(cls, VllmModelForEmbedding):
class ModelForEmbedding(cls, VllmModelForPooling):
def __init__(
self,

View File

@ -7,7 +7,7 @@ from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger
from vllm.utils import supports_kw
from .interfaces_base import is_embedding_model
from .interfaces_base import is_pooling_model
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
@ -389,4 +389,4 @@ def _supports_cross_encoding(
def supports_cross_encoding(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
return is_embedding_model(model) and _supports_cross_encoding(model)
return is_pooling_model(model) and _supports_cross_encoding(model)

View File

@ -141,7 +141,7 @@ def is_text_generation_model(
@runtime_checkable
class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]):
class VllmModelForPooling(VllmModel[C_co, T], Protocol[C_co, T]):
def pooler(
self,
@ -153,23 +153,22 @@ class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]):
@overload
def is_embedding_model(
model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]:
def is_pooling_model(model: Type[object]) -> TypeIs[Type[VllmModelForPooling]]:
...
@overload
def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]:
def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]:
...
def is_embedding_model(
def is_pooling_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]:
) -> Union[TypeIs[Type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
if not is_vllm_model(model):
return False
if isinstance(model, type):
return isinstance(model, VllmModelForEmbedding)
return isinstance(model, VllmModelForPooling)
return isinstance(model, VllmModelForEmbedding)
return isinstance(model, VllmModelForPooling)

View File

@ -24,7 +24,7 @@ from .adapters import as_embedding_model
from .interfaces import (has_inner_state, is_attention_free,
supports_cross_encoding, supports_multimodal,
supports_pp)
from .interfaces_base import is_embedding_model, is_text_generation_model
from .interfaces_base import is_pooling_model, is_text_generation_model
logger = init_logger(__name__)
@ -211,7 +211,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
class _ModelInfo:
architecture: str
is_text_generation_model: bool
is_embedding_model: bool
is_pooling_model: bool
supports_cross_encoding: bool
supports_multimodal: bool
supports_pp: bool
@ -220,19 +220,19 @@ class _ModelInfo:
@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
is_embedding_model_ = is_embedding_model(model)
if not is_embedding_model_:
is_pooling_model_ = is_pooling_model(model)
if not is_pooling_model_:
try:
as_embedding_model(model)
except Exception:
pass
else:
is_embedding_model_ = True
is_pooling_model_ = True
return _ModelInfo(
architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model),
is_embedding_model=is_embedding_model_,
is_pooling_model=is_pooling_model_,
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model),
@ -441,12 +441,12 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_text_generation_model
def is_embedding_model(
def is_pooling_model(
self,
architectures: Union[str, List[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_embedding_model
return model_cls.is_pooling_model
def is_cross_encoder_model(
self,

View File

@ -53,8 +53,8 @@ class CompletionOutput:
@dataclass
class EmbeddingOutput:
"""The output data of one completion output of a request.
class PoolingOutput:
"""The output data of one pooling output of a request.
Args:
embedding: The embedding vector, which is a list of floats. The
@ -63,7 +63,7 @@ class EmbeddingOutput:
embedding: List[float]
def __repr__(self) -> str:
return (f"EmbeddingOutput("
return (f"PoolingOutput("
f"embedding={len(self.embedding)})")
@ -316,18 +316,18 @@ class RequestOutput:
f"multi_modal_placeholders={self.multi_modal_placeholders})")
class EmbeddingRequestOutput:
class PoolingRequestOutput:
"""
The output data of an embedding request to the LLM.
The output data of a pooling request to the LLM.
Args:
request_id (str): A unique identifier for the embedding request.
outputs (EmbeddingOutput): The embedding results for the given input.
request_id (str): A unique identifier for the pooling request.
outputs (PoolingOutput): The pooling results for the given input.
prompt_token_ids (List[int]): A list of token IDs used in the prompt.
finished (bool): A flag indicating whether the embedding is completed.
finished (bool): A flag indicating whether the pooling is completed.
"""
def __init__(self, request_id: str, outputs: "EmbeddingOutput",
def __init__(self, request_id: str, outputs: "PoolingOutput",
prompt_token_ids: List[int], finished: bool):
self.request_id = request_id
self.prompt_token_ids = prompt_token_ids
@ -336,11 +336,11 @@ class EmbeddingRequestOutput:
@classmethod
def from_seq_group(cls,
seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
seq_group: 'SequenceGroup') -> "PoolingRequestOutput":
if seq_group.embeddings is None:
raise ValueError(
"Embeddings are missing in seq_group for EmbeddingRequest.")
output = EmbeddingOutput(seq_group.embeddings)
output = PoolingOutput(seq_group.embeddings)
prompt_token_ids = seq_group.prompt_token_ids
finished = seq_group.is_finished()
@ -348,15 +348,15 @@ class EmbeddingRequestOutput:
def __repr__(self):
"""
Returns a string representation of an EmbeddingRequestOutput instance.
Returns a string representation of an PoolingRequestOutput instance.
The representation includes the request_id and the number of outputs,
providing a quick overview of the embedding request's results.
providing a quick overview of the pooling request's results.
Returns:
str: A string representation of the EmbeddingRequestOutput instance.
str: A string representation of the PoolingRequestOutput instance.
"""
return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
return (f"PoolingRequestOutput(request_id='{self.request_id}', "
f"outputs={repr(self.outputs)}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"finished={self.finished})")
@ -415,7 +415,30 @@ class RequestOutputFactory:
# Determine the type based on a condition, for example:
if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group)
return PoolingRequestOutput.from_seq_group(seq_group)
else:
return RequestOutput.from_seq_group(seq_group, use_cache,
seq_id_to_seq_group)
def __getattr__(name: str):
import warnings
if name == "EmbeddingOutput":
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PoolingOutput
if name == "EmbeddingRequestOutput":
msg = ("EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PoolingRequestOutput
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -9,7 +9,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
@ -133,7 +133,7 @@ class AsyncLLM(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
"""Add new request to the AsyncLLM."""
if self.detokenizer.is_request_active(request_id):

View File

@ -1,11 +1,11 @@
import asyncio
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator."""
STOP_ITERATION = Exception() # Sentinel
@ -16,7 +16,7 @@ class AsyncStream:
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None:
if not self._finished:
self._queue.put_nowait(item)
@ -32,7 +32,7 @@ class AsyncStream:
async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
finished = False
try:
while True:

View File

@ -16,12 +16,12 @@ from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU,
@dataclasses.dataclass(frozen=True)
class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU):
"""
Used by the CPUEmbeddingModelRunner.
Used by the CPUPoolingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None
class CPUEmbeddingModelRunner(
class CPUPoolingModelRunner(
CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = (
ModelInputForCPUWithPoolingMetadata)

View File

@ -14,9 +14,9 @@ from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_embedding_model_runner import CPUEmbeddingModelRunner
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerBase,
WorkerInput)
@ -164,7 +164,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
else {"return_hidden_states": True}
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
if self.model_config.task == "embedding":
ModelRunnerClass = CPUEmbeddingModelRunner
ModelRunnerClass = CPUPoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = CPUEncoderDecoderModelRunner
self.model_runner: CPUModelRunnerBase = ModelRunnerClass(

View File

@ -21,12 +21,12 @@ logger = init_logger(__name__)
@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
"""
Used by the EmbeddingModelRunner.
Used by the PoolingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None
class EmbeddingModelRunner(
class PoolingModelRunner(
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
ModelInputForGPUWithPoolingMetadata)
@ -52,7 +52,7 @@ class EmbeddingModelRunner(
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError(
"EmbeddingModelRunner does not support multi-step execution.")
"PoolingModelRunner does not support multi-step execution.")
if self.lora_config:
assert model_input.lora_requests is not None

View File

@ -22,9 +22,9 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
@ -75,7 +75,7 @@ class Worker(LocalOrDistributedWorkerBase):
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_config.task == "embedding":
ModelRunnerClass = EmbeddingModelRunner
ModelRunnerClass = PoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = EncoderDecoderModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(