mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:28:42 +08:00
[Misc] Rename embedding classes to pooling (#10801)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
f877a7d12a
commit
d2f058e76c
@ -10,7 +10,7 @@ prompts = [
|
||||
|
||||
# Create an LLM.
|
||||
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True)
|
||||
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||
# Generate embedding. The output is a list of PoolingRequestOutputs.
|
||||
outputs = model.encode(prompts)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, EmbeddingRequestOutput, PoolingParams
|
||||
from vllm import LLM, PoolingParams, PoolingRequestOutput
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
||||
@ -43,8 +43,8 @@ def llm():
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
|
||||
o2: List[EmbeddingRequestOutput]):
|
||||
def assert_outputs_equal(o1: List[PoolingRequestOutput],
|
||||
o2: List[PoolingRequestOutput]):
|
||||
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
||||
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import warnings
|
||||
import pytest
|
||||
import torch.cuda
|
||||
|
||||
from vllm.model_executor.models import (is_embedding_model,
|
||||
from vllm.model_executor.models import (is_pooling_model,
|
||||
is_text_generation_model,
|
||||
supports_multimodal)
|
||||
from vllm.model_executor.models.adapters import as_embedding_model
|
||||
@ -31,7 +31,7 @@ def test_registry_imports(model_arch):
|
||||
|
||||
# All vLLM models should be convertible to an embedding model
|
||||
embed_model = as_embedding_model(model_cls)
|
||||
assert is_embedding_model(embed_model)
|
||||
assert is_pooling_model(embed_model)
|
||||
|
||||
if model_arch in _MULTIMODAL_MODELS:
|
||||
assert supports_multimodal(model_cls)
|
||||
|
||||
@ -8,10 +8,10 @@ from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.worker.embedding_model_runner import (
|
||||
ModelInputForGPUWithPoolingMetadata)
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
from vllm.worker.multi_step_model_runner import StatefulModelInput
|
||||
from vllm.worker.pooling_model_runner import (
|
||||
ModelInputForGPUWithPoolingMetadata)
|
||||
|
||||
|
||||
class MockAttentionBackend(AttentionBackend):
|
||||
|
||||
@ -7,8 +7,8 @@ from vllm.entrypoints.llm import LLM
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
|
||||
EmbeddingRequestOutput, RequestOutput)
|
||||
from vllm.outputs import (CompletionOutput, PoolingOutput,
|
||||
PoolingRequestOutput, RequestOutput)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
@ -25,8 +25,8 @@ __all__ = [
|
||||
"SamplingParams",
|
||||
"RequestOutput",
|
||||
"CompletionOutput",
|
||||
"EmbeddingOutput",
|
||||
"EmbeddingRequestOutput",
|
||||
"PoolingOutput",
|
||||
"PoolingRequestOutput",
|
||||
"LLMEngine",
|
||||
"EngineArgs",
|
||||
"AsyncLLMEngine",
|
||||
@ -34,3 +34,26 @@ __all__ = [
|
||||
"initialize_ray_cluster",
|
||||
"PoolingParams",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
import warnings
|
||||
|
||||
if name == "EmbeddingOutput":
|
||||
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
|
||||
"The original name will be removed in an upcoming version.")
|
||||
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
|
||||
return PoolingOutput
|
||||
|
||||
if name == "EmbeddingRequestOutput":
|
||||
msg = ("EmbeddingRequestOutput has been renamed to "
|
||||
"PoolingRequestOutput. "
|
||||
"The original name will be removed in an upcoming version.")
|
||||
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
|
||||
return PoolingRequestOutput
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@ -359,7 +359,7 @@ class ModelConfig:
|
||||
# NOTE: Listed from highest to lowest priority,
|
||||
# in case the model supports multiple of them
|
||||
"generate": ModelRegistry.is_text_generation_model(architectures),
|
||||
"embedding": ModelRegistry.is_embedding_model(architectures),
|
||||
"embedding": ModelRegistry.is_pooling_model(architectures),
|
||||
}
|
||||
supported_tasks_lst: List[_Task] = [
|
||||
task for task, is_supported in task_support.items() if is_supported
|
||||
|
||||
@ -25,7 +25,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -74,7 +74,7 @@ STOP_ITERATION = Exception() # Sentinel
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
|
||||
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
|
||||
that can be iterated over asynchronously via an async generator."""
|
||||
|
||||
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
|
||||
@ -83,7 +83,7 @@ class AsyncStream:
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
|
||||
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
|
||||
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
|
||||
Exception]) -> None:
|
||||
if not self._finished:
|
||||
self._queue.put_nowait(item)
|
||||
@ -103,7 +103,7 @@ class AsyncStream:
|
||||
|
||||
async def generator(
|
||||
self
|
||||
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
||||
try:
|
||||
while True:
|
||||
result = await self._queue.get()
|
||||
@ -154,7 +154,7 @@ class RequestTracker:
|
||||
|
||||
def process_request_output(self,
|
||||
request_output: Union[RequestOutput,
|
||||
EmbeddingRequestOutput],
|
||||
PoolingRequestOutput],
|
||||
*,
|
||||
verbose: bool = False) -> None:
|
||||
"""Process a request output from the engine."""
|
||||
@ -265,7 +265,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
async def step_async(
|
||||
self, virtual_engine: int
|
||||
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
The workers are ran asynchronously if possible.
|
||||
|
||||
@ -907,7 +907,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> Coroutine[None, None, AsyncGenerator[Union[
|
||||
RequestOutput, EmbeddingRequestOutput], None]]:
|
||||
RequestOutput, PoolingRequestOutput], None]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@ -922,7 +922,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> Coroutine[None, None, AsyncGenerator[Union[
|
||||
RequestOutput, EmbeddingRequestOutput], None]]:
|
||||
RequestOutput, PoolingRequestOutput], None]]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
@ -941,7 +941,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
priority: int = 0,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
||||
if inputs is not None:
|
||||
prompt = inputs
|
||||
assert prompt is not None and params is not None
|
||||
@ -1070,7 +1070,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""Generate outputs for a request from an embedding model.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
@ -1088,7 +1088,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
Only applicable with priority scheduling.
|
||||
|
||||
Yields:
|
||||
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
||||
The output `PoolingRequestOutput` objects from the LLMEngine
|
||||
for the request.
|
||||
|
||||
Details:
|
||||
@ -1141,7 +1141,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
):
|
||||
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
|
||||
yield LLMEngine.validate_output(output, PoolingRequestOutput)
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
|
||||
@ -40,7 +40,7 @@ from vllm.model_executor.guided_decoding import (
|
||||
get_local_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
|
||||
RequestOutputFactory)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
@ -80,7 +80,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
||||
|
||||
|
||||
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
||||
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
|
||||
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -112,7 +112,7 @@ class SchedulerContext:
|
||||
def __init__(self, multi_step_stream_outputs: bool = False):
|
||||
self.output_queue: Deque[OutputData] = deque()
|
||||
self.request_outputs: List[Union[RequestOutput,
|
||||
EmbeddingRequestOutput]] = []
|
||||
PoolingRequestOutput]] = []
|
||||
self.seq_group_metadata_list: Optional[
|
||||
List[SequenceGroupMetadata]] = None
|
||||
self.scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
@ -1314,7 +1314,7 @@ class LLMEngine:
|
||||
else:
|
||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||
|
||||
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
|
||||
.. figure:: https://i.imgur.com/sv2HssD.png
|
||||
|
||||
@ -35,7 +35,7 @@ from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
@ -495,7 +495,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@ -507,7 +507,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
@ -524,7 +524,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
priority: int = 0,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None # DEPRECATED
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""Generate outputs for a request from an embedding model.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
@ -540,7 +540,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
|
||||
Yields:
|
||||
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
||||
The output `PoolingRequestOutput` objects from the LLMEngine
|
||||
for the request.
|
||||
"""
|
||||
if inputs is not None:
|
||||
@ -549,7 +549,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
and request_id is not None)
|
||||
|
||||
return cast(
|
||||
AsyncGenerator[EmbeddingRequestOutput, None],
|
||||
AsyncGenerator[PoolingRequestOutput, None],
|
||||
self._process_request(prompt,
|
||||
pooling_params,
|
||||
request_id,
|
||||
@ -567,7 +567,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
|
||||
EmbeddingRequestOutput, None]]:
|
||||
PoolingRequestOutput, None]]:
|
||||
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
||||
|
||||
# If already dead, error out.
|
||||
|
||||
@ -11,8 +11,7 @@ from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
|
||||
RequestOutput)
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
@ -209,7 +208,7 @@ class EngineClient(ABC):
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""Generate outputs for a request from an embedding model."""
|
||||
...
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest, LLMGuidedOptions)
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
@ -679,7 +679,7 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (prompt + optional token ids)
|
||||
@ -691,7 +691,7 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single (token ids + optional prompt)
|
||||
@ -704,7 +704,7 @@ class LLM:
|
||||
prompt_token_ids: List[int],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (token ids + optional prompt)
|
||||
@ -717,7 +717,7 @@ class LLM:
|
||||
prompt_token_ids: List[List[int]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single or multi token ids [pos-only]
|
||||
@ -728,7 +728,7 @@ class LLM:
|
||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@ -741,7 +741,7 @@ class LLM:
|
||||
Sequence[PoolingParams]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
@ -759,7 +759,7 @@ class LLM:
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
This class automatically batches the given prompts, considering
|
||||
@ -778,7 +778,7 @@ class LLM:
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of ``EmbeddingRequestOutput`` objects containing the
|
||||
A list of ``PoolingRequestOutput`` objects containing the
|
||||
generated embeddings in the same order as the input prompts.
|
||||
|
||||
Note:
|
||||
@ -821,7 +821,7 @@ class LLM:
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
EmbeddingRequestOutput)
|
||||
PoolingRequestOutput)
|
||||
|
||||
def score(
|
||||
self,
|
||||
@ -832,7 +832,7 @@ class LLM:
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
) -> List[PoolingRequestOutput]:
|
||||
"""Generates similarity scores for all pairs <text,text_pair>.
|
||||
|
||||
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
|
||||
@ -854,7 +854,7 @@ class LLM:
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of ``EmbeddingRequestOutput`` objects containing the
|
||||
A list of ``PoolingRequestOutput`` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
"""
|
||||
task = self.llm_engine.model_config.task
|
||||
@ -943,7 +943,7 @@ class LLM:
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
EmbeddingRequestOutput)
|
||||
PoolingRequestOutput)
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.llm_engine.start_profile()
|
||||
@ -1085,7 +1085,7 @@ class LLM:
|
||||
|
||||
def _run_engine(
|
||||
self, *, use_tqdm: bool
|
||||
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||
@ -1098,7 +1098,7 @@ class LLM:
|
||||
)
|
||||
|
||||
# Run the engine.
|
||||
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
|
||||
outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
|
||||
total_in_toks = 0
|
||||
total_out_toks = 0
|
||||
while self.llm_engine.has_unfinished_requests():
|
||||
|
||||
@ -18,14 +18,14 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_embedding(
|
||||
output: EmbeddingOutput,
|
||||
output: PoolingOutput,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> Union[List[float], str]:
|
||||
if encoding_format == "float":
|
||||
@ -40,7 +40,7 @@ def _get_embedding(
|
||||
|
||||
|
||||
def request_output_to_embedding_response(
|
||||
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
|
||||
final_res_batch: List[PoolingRequestOutput], request_id: str,
|
||||
created_time: int, model_name: str,
|
||||
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
|
||||
data: List[EmbeddingResponseData] = []
|
||||
@ -169,7 +169,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
||||
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
@ -207,7 +207,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||
final_res_batch: List[Optional[PoolingRequestOutput]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
@ -215,7 +215,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
|
||||
final_res_batch_checked = cast(List[PoolingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = request_output_to_embedding_response(
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import EmbeddingRequestOutput
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import make_async, merge_async_iterators, random_uuid
|
||||
|
||||
@ -21,7 +21,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
def request_output_to_score_response(
|
||||
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
|
||||
final_res_batch: List[PoolingRequestOutput], request_id: str,
|
||||
created_time: int, model_name: str) -> ScoreResponse:
|
||||
data: List[ScoreResponseData] = []
|
||||
score = None
|
||||
@ -133,7 +133,7 @@ class OpenAIServingScores(OpenAIServing):
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
||||
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
input_pairs = make_pairs(request.text_1, request.text_2)
|
||||
|
||||
@ -194,7 +194,7 @@ class OpenAIServingScores(OpenAIServing):
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||
final_res_batch: List[Optional[PoolingRequestOutput]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
try:
|
||||
@ -203,7 +203,7 @@ class OpenAIServingScores(OpenAIServing):
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
|
||||
final_res_batch_checked = cast(List[PoolingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = request_output_to_score_response(
|
||||
|
||||
@ -1,15 +1,14 @@
|
||||
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
|
||||
SupportsPP, has_inner_state, supports_lora,
|
||||
supports_multimodal, supports_pp)
|
||||
from .interfaces_base import (VllmModelForEmbedding,
|
||||
VllmModelForTextGeneration, is_embedding_model,
|
||||
is_text_generation_model)
|
||||
from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
|
||||
is_pooling_model, is_text_generation_model)
|
||||
from .registry import ModelRegistry
|
||||
|
||||
__all__ = [
|
||||
"ModelRegistry",
|
||||
"VllmModelForEmbedding",
|
||||
"is_embedding_model",
|
||||
"VllmModelForPooling",
|
||||
"is_pooling_model",
|
||||
"VllmModelForTextGeneration",
|
||||
"is_text_generation_model",
|
||||
"HasInnerState",
|
||||
@ -20,4 +19,4 @@ __all__ = [
|
||||
"supports_multimodal",
|
||||
"SupportsPP",
|
||||
"supports_pp",
|
||||
]
|
||||
]
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Any, TypeVar
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .interfaces_base import VllmModelForEmbedding, is_embedding_model
|
||||
from .interfaces_base import VllmModelForPooling, is_pooling_model
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
@ -12,7 +12,7 @@ _T = TypeVar("_T", bound=type[nn.Module])
|
||||
def as_embedding_model(cls: _T) -> _T:
|
||||
"""Subclass an existing vLLM model to support embeddings."""
|
||||
# Avoid modifying existing embedding models
|
||||
if is_embedding_model(cls):
|
||||
if is_pooling_model(cls):
|
||||
return cls
|
||||
|
||||
# Lazy import
|
||||
@ -23,7 +23,7 @@ def as_embedding_model(cls: _T) -> _T:
|
||||
|
||||
from .utils import AutoWeightsLoader, WeightsMapper
|
||||
|
||||
class ModelForEmbedding(cls, VllmModelForEmbedding):
|
||||
class ModelForEmbedding(cls, VllmModelForPooling):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -7,7 +7,7 @@ from typing_extensions import TypeIs, TypeVar
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import supports_kw
|
||||
|
||||
from .interfaces_base import is_embedding_model
|
||||
from .interfaces_base import is_pooling_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention import AttentionMetadata
|
||||
@ -389,4 +389,4 @@ def _supports_cross_encoding(
|
||||
def supports_cross_encoding(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
|
||||
return is_embedding_model(model) and _supports_cross_encoding(model)
|
||||
return is_pooling_model(model) and _supports_cross_encoding(model)
|
||||
|
||||
@ -141,7 +141,7 @@ def is_text_generation_model(
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]):
|
||||
class VllmModelForPooling(VllmModel[C_co, T], Protocol[C_co, T]):
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
@ -153,23 +153,22 @@ class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]):
|
||||
|
||||
|
||||
@overload
|
||||
def is_embedding_model(
|
||||
model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]:
|
||||
def is_pooling_model(model: Type[object]) -> TypeIs[Type[VllmModelForPooling]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]:
|
||||
def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]:
|
||||
...
|
||||
|
||||
|
||||
def is_embedding_model(
|
||||
def is_pooling_model(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]:
|
||||
) -> Union[TypeIs[Type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
|
||||
if not is_vllm_model(model):
|
||||
return False
|
||||
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, VllmModelForEmbedding)
|
||||
return isinstance(model, VllmModelForPooling)
|
||||
|
||||
return isinstance(model, VllmModelForEmbedding)
|
||||
return isinstance(model, VllmModelForPooling)
|
||||
|
||||
@ -24,7 +24,7 @@ from .adapters import as_embedding_model
|
||||
from .interfaces import (has_inner_state, is_attention_free,
|
||||
supports_cross_encoding, supports_multimodal,
|
||||
supports_pp)
|
||||
from .interfaces_base import is_embedding_model, is_text_generation_model
|
||||
from .interfaces_base import is_pooling_model, is_text_generation_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -211,7 +211,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
||||
class _ModelInfo:
|
||||
architecture: str
|
||||
is_text_generation_model: bool
|
||||
is_embedding_model: bool
|
||||
is_pooling_model: bool
|
||||
supports_cross_encoding: bool
|
||||
supports_multimodal: bool
|
||||
supports_pp: bool
|
||||
@ -220,19 +220,19 @@ class _ModelInfo:
|
||||
|
||||
@staticmethod
|
||||
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
|
||||
is_embedding_model_ = is_embedding_model(model)
|
||||
if not is_embedding_model_:
|
||||
is_pooling_model_ = is_pooling_model(model)
|
||||
if not is_pooling_model_:
|
||||
try:
|
||||
as_embedding_model(model)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
is_embedding_model_ = True
|
||||
is_pooling_model_ = True
|
||||
|
||||
return _ModelInfo(
|
||||
architecture=model.__name__,
|
||||
is_text_generation_model=is_text_generation_model(model),
|
||||
is_embedding_model=is_embedding_model_,
|
||||
is_pooling_model=is_pooling_model_,
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_pp=supports_pp(model),
|
||||
@ -441,12 +441,12 @@ class _ModelRegistry:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.is_text_generation_model
|
||||
|
||||
def is_embedding_model(
|
||||
def is_pooling_model(
|
||||
self,
|
||||
architectures: Union[str, List[str]],
|
||||
) -> bool:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.is_embedding_model
|
||||
return model_cls.is_pooling_model
|
||||
|
||||
def is_cross_encoder_model(
|
||||
self,
|
||||
|
||||
@ -53,8 +53,8 @@ class CompletionOutput:
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingOutput:
|
||||
"""The output data of one completion output of a request.
|
||||
class PoolingOutput:
|
||||
"""The output data of one pooling output of a request.
|
||||
|
||||
Args:
|
||||
embedding: The embedding vector, which is a list of floats. The
|
||||
@ -63,7 +63,7 @@ class EmbeddingOutput:
|
||||
embedding: List[float]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"EmbeddingOutput("
|
||||
return (f"PoolingOutput("
|
||||
f"embedding={len(self.embedding)})")
|
||||
|
||||
|
||||
@ -316,18 +316,18 @@ class RequestOutput:
|
||||
f"multi_modal_placeholders={self.multi_modal_placeholders})")
|
||||
|
||||
|
||||
class EmbeddingRequestOutput:
|
||||
class PoolingRequestOutput:
|
||||
"""
|
||||
The output data of an embedding request to the LLM.
|
||||
The output data of a pooling request to the LLM.
|
||||
|
||||
Args:
|
||||
request_id (str): A unique identifier for the embedding request.
|
||||
outputs (EmbeddingOutput): The embedding results for the given input.
|
||||
request_id (str): A unique identifier for the pooling request.
|
||||
outputs (PoolingOutput): The pooling results for the given input.
|
||||
prompt_token_ids (List[int]): A list of token IDs used in the prompt.
|
||||
finished (bool): A flag indicating whether the embedding is completed.
|
||||
finished (bool): A flag indicating whether the pooling is completed.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, outputs: "EmbeddingOutput",
|
||||
def __init__(self, request_id: str, outputs: "PoolingOutput",
|
||||
prompt_token_ids: List[int], finished: bool):
|
||||
self.request_id = request_id
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
@ -336,11 +336,11 @@ class EmbeddingRequestOutput:
|
||||
|
||||
@classmethod
|
||||
def from_seq_group(cls,
|
||||
seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
|
||||
seq_group: 'SequenceGroup') -> "PoolingRequestOutput":
|
||||
if seq_group.embeddings is None:
|
||||
raise ValueError(
|
||||
"Embeddings are missing in seq_group for EmbeddingRequest.")
|
||||
output = EmbeddingOutput(seq_group.embeddings)
|
||||
output = PoolingOutput(seq_group.embeddings)
|
||||
prompt_token_ids = seq_group.prompt_token_ids
|
||||
finished = seq_group.is_finished()
|
||||
|
||||
@ -348,15 +348,15 @@ class EmbeddingRequestOutput:
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Returns a string representation of an EmbeddingRequestOutput instance.
|
||||
Returns a string representation of an PoolingRequestOutput instance.
|
||||
|
||||
The representation includes the request_id and the number of outputs,
|
||||
providing a quick overview of the embedding request's results.
|
||||
providing a quick overview of the pooling request's results.
|
||||
|
||||
Returns:
|
||||
str: A string representation of the EmbeddingRequestOutput instance.
|
||||
str: A string representation of the PoolingRequestOutput instance.
|
||||
"""
|
||||
return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
|
||||
return (f"PoolingRequestOutput(request_id='{self.request_id}', "
|
||||
f"outputs={repr(self.outputs)}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"finished={self.finished})")
|
||||
@ -415,7 +415,30 @@ class RequestOutputFactory:
|
||||
# Determine the type based on a condition, for example:
|
||||
if hasattr(seq_group,
|
||||
'embeddings') and seq_group.embeddings is not None:
|
||||
return EmbeddingRequestOutput.from_seq_group(seq_group)
|
||||
return PoolingRequestOutput.from_seq_group(seq_group)
|
||||
else:
|
||||
return RequestOutput.from_seq_group(seq_group, use_cache,
|
||||
seq_id_to_seq_group)
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
import warnings
|
||||
|
||||
if name == "EmbeddingOutput":
|
||||
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
|
||||
"The original name will be removed in an upcoming version.")
|
||||
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
|
||||
return PoolingOutput
|
||||
|
||||
if name == "EmbeddingRequestOutput":
|
||||
msg = ("EmbeddingRequestOutput has been renamed to "
|
||||
"PoolingRequestOutput. "
|
||||
"The original name will be removed in an upcoming version.")
|
||||
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
|
||||
return PoolingRequestOutput
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@ -9,7 +9,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -133,7 +133,7 @@ class AsyncLLM(EngineClient):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
||||
"""Add new request to the AsyncLLM."""
|
||||
|
||||
if self.detokenizer.is_request_active(request_id):
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union
|
||||
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
|
||||
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
|
||||
that can be iterated over asynchronously via an async generator."""
|
||||
|
||||
STOP_ITERATION = Exception() # Sentinel
|
||||
@ -16,7 +16,7 @@ class AsyncStream:
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
|
||||
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
|
||||
def put(self, item: Union[RequestOutput, PoolingRequestOutput,
|
||||
Exception]) -> None:
|
||||
if not self._finished:
|
||||
self._queue.put_nowait(item)
|
||||
@ -32,7 +32,7 @@ class AsyncStream:
|
||||
|
||||
async def generator(
|
||||
self
|
||||
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
||||
finished = False
|
||||
try:
|
||||
while True:
|
||||
|
||||
@ -16,12 +16,12 @@ from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU,
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU):
|
||||
"""
|
||||
Used by the CPUEmbeddingModelRunner.
|
||||
Used by the CPUPoolingModelRunner.
|
||||
"""
|
||||
pooling_metadata: Optional["PoolingMetadata"] = None
|
||||
|
||||
|
||||
class CPUEmbeddingModelRunner(
|
||||
class CPUPoolingModelRunner(
|
||||
CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]):
|
||||
_model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = (
|
||||
ModelInputForCPUWithPoolingMetadata)
|
||||
@ -14,9 +14,9 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.worker.cpu_embedding_model_runner import CPUEmbeddingModelRunner
|
||||
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
|
||||
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
|
||||
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||
LoraNotSupportedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
@ -164,7 +164,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
else {"return_hidden_states": True}
|
||||
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
|
||||
if self.model_config.task == "embedding":
|
||||
ModelRunnerClass = CPUEmbeddingModelRunner
|
||||
ModelRunnerClass = CPUPoolingModelRunner
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
ModelRunnerClass = CPUEncoderDecoderModelRunner
|
||||
self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
|
||||
|
||||
@ -21,12 +21,12 @@ logger = init_logger(__name__)
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
|
||||
"""
|
||||
Used by the EmbeddingModelRunner.
|
||||
Used by the PoolingModelRunner.
|
||||
"""
|
||||
pooling_metadata: Optional["PoolingMetadata"] = None
|
||||
|
||||
|
||||
class EmbeddingModelRunner(
|
||||
class PoolingModelRunner(
|
||||
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
|
||||
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
|
||||
ModelInputForGPUWithPoolingMetadata)
|
||||
@ -52,7 +52,7 @@ class EmbeddingModelRunner(
|
||||
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"EmbeddingModelRunner does not support multi-step execution.")
|
||||
"PoolingModelRunner does not support multi-step execution.")
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
@ -22,9 +22,9 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SequenceGroupMetadata, SequenceGroupMetadataDelta)
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||
from vllm.worker.pooling_model_runner import PoolingModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
@ -75,7 +75,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
|
||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||
if model_config.task == "embedding":
|
||||
ModelRunnerClass = EmbeddingModelRunner
|
||||
ModelRunnerClass = PoolingModelRunner
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
ModelRunnerClass = EncoderDecoderModelRunner
|
||||
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user