[Core] Set pooling params based on task and model (#21128)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-07-18 20:41:17 +08:00 committed by GitHub
parent 4adc66f64d
commit 45badd05d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 509 additions and 241 deletions

View File

@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import importlib.util
from array import array
import numpy as np
import openai
import pytest
from scipy.spatial.distance import cosine
@ -14,10 +12,6 @@ from vllm.config import ModelConfig
from ....utils import RemoteOpenAIServer
# GritLM embedding implementation is only supported by XFormers backend.
pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"),
reason="GritLM requires XFormers")
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
MAX_MODEL_LEN = 4000
@ -26,11 +20,11 @@ def _arr(arr):
"""
Convert a list of integers to an array of integers.
"""
return array("i", arr)
return np.array(arr)
def test_find_array():
from vllm.model_executor.models.gritlm import GritLMPooler
from vllm.model_executor.models.gritlm import GritLMMeanPool
model_config = ModelConfig(
MODEL_NAME,
@ -41,17 +35,19 @@ def test_find_array():
dtype="bfloat16",
seed=0,
)
pooler = GritLMPooler(model_config=model_config)
pooling = GritLMMeanPool(model_config=model_config)
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=3) == -1
assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=4) == 3
assert pooling._find_array(arr, _arr([3, 5]), start_idx=0) == -1
with pytest.raises(ValueError):
pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
pooling._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
def run_llm_encode(

View File

@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.pooling_params import PoolingParams, PoolingTask
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
@ -964,6 +964,7 @@ class LLM:
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
) -> list[PoolingRequestOutput]:
...
@ -979,6 +980,7 @@ class LLM:
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
) -> list[PoolingRequestOutput]:
...
@ -994,6 +996,7 @@ class LLM:
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
) -> list[PoolingRequestOutput]:
...
@ -1010,6 +1013,7 @@ class LLM:
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
) -> list[PoolingRequestOutput]:
...
@ -1026,6 +1030,7 @@ class LLM:
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
) -> list[PoolingRequestOutput]:
...
@ -1040,6 +1045,7 @@ class LLM:
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
) -> list[PoolingRequestOutput]:
...
@ -1059,6 +1065,7 @@ class LLM:
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
pooling_task: PoolingTask = "encode",
) -> list[PoolingRequestOutput]:
"""Apply pooling to the hidden states corresponding to the input
prompts.
@ -1080,6 +1087,7 @@ class LLM:
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
pooling_task: Override the pooling task to use.
Returns:
A list of `PoolingRequestOutput` objects containing the
@ -1116,11 +1124,12 @@ class LLM:
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
elif isinstance(pooling_params, PoolingParams):
pooling_params.verify(model_config)
if isinstance(pooling_params, PoolingParams):
pooling_params.verify(pooling_task, model_config)
else:
for pooling_param in pooling_params:
pooling_param.verify(model_config)
pooling_param.verify(pooling_task, model_config)
tokenization_kwargs = dict[str, Any]()
_validate_truncation_size(model_config.max_model_len,
@ -1181,12 +1190,15 @@ class LLM:
raise ValueError("Embedding API is not supported by this model. "
"Please set `--task embed`.")
items = self.encode(prompts,
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
items = self.encode(
prompts,
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
pooling_task="embed",
)
return [EmbeddingRequestOutput.from_base(item) for item in items]
@ -1228,10 +1240,13 @@ class LLM:
"Classification API is not supported by this model. "
"Please set `--task classify`.")
items = self.encode(prompts,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
items = self.encode(
prompts,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
pooling_task="classify",
)
return [ClassificationRequestOutput.from_base(item) for item in items]
@ -1251,7 +1266,9 @@ class LLM:
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
pooling_task="embed",
)
encoded_output_1: list[PoolingRequestOutput] = encoded_output[
0:len(text_1)]
@ -1287,7 +1304,7 @@ class LLM:
if len(data_1) == 1:
data_1 = data_1 * len(data_2)
pooling_params = PoolingParams(use_cross_encoder=True)
pooling_params = PoolingParams(task="score")
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs)

View File

@ -1347,8 +1347,8 @@ class ScoreRequest(OpenAIBaseModel):
# --8<-- [end:score-extra-params]
def to_pooling_params(self, *, use_cross_encoder: bool = False):
return PoolingParams(use_cross_encoder=use_cross_encoder)
def to_pooling_params(self):
return PoolingParams()
class RerankRequest(OpenAIBaseModel):
@ -1375,8 +1375,8 @@ class RerankRequest(OpenAIBaseModel):
# --8<-- [end:rerank-extra-params]
def to_pooling_params(self, *, use_cross_encoder: bool = False):
return PoolingParams(use_cross_encoder=use_cross_encoder)
def to_pooling_params(self):
return PoolingParams()
class RerankDocument(BaseModel):

View File

@ -6,6 +6,7 @@ from typing import Optional, Union, cast
import numpy as np
from fastapi import Request
from typing_extensions import override
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
@ -21,12 +22,14 @@ from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext,
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams
logger = init_logger(__name__)
class ClassificationMixin(OpenAIServing):
@override
async def _preprocess(
self,
ctx: ServeContext,
@ -75,6 +78,7 @@ class ClassificationMixin(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
@override
def _build_response(
self,
ctx: ServeContext,
@ -158,3 +162,31 @@ class ServingClassification(ClassificationMixin):
)
return await super().handle(ctx) # type: ignore
@override
def _validate_request(
self,
ctx: ClassificationServeContext,
) -> Optional[ErrorResponse]:
if error := super()._validate_request(ctx):
return error
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
return None
@override
def _create_pooling_params(
self,
ctx: ClassificationServeContext,
) -> Union[PoolingParams, ErrorResponse]:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
try:
pooling_params.verify("classify", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
return pooling_params

View File

@ -24,6 +24,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
PoolingRequestOutput)
from vllm.pooling_params import PoolingParams
logger = init_logger(__name__)
@ -45,6 +46,7 @@ def _get_embedding(
class EmbeddingMixin(OpenAIServing):
@override
async def _preprocess(
self,
ctx: ServeContext,
@ -97,6 +99,7 @@ class EmbeddingMixin(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
@override
def _build_response(
self,
ctx: ServeContext,
@ -191,11 +194,20 @@ class OpenAIServingEmbedding(EmbeddingMixin):
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
pooling_params = ctx.request.to_pooling_params()
return None
@override
def _create_pooling_params(
self,
ctx: ServeContext[EmbeddingRequest],
) -> Union[PoolingParams, ErrorResponse]:
pooling_params = super()._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
try:
pooling_params.verify(self.model_config)
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
return None
return pooling_params

View File

@ -305,6 +305,16 @@ class OpenAIServing:
" Please, select a smaller truncation size.")
return None
def _create_pooling_params(
self,
ctx: ServeContext,
) -> Union[PoolingParams, ErrorResponse]:
if not hasattr(ctx.request, "to_pooling_params"):
return self.create_error_response(
"Request type does not support pooling parameters")
return ctx.request.to_pooling_params()
async def _prepare_generators(
self,
ctx: ServeContext,
@ -318,11 +328,9 @@ class OpenAIServing:
trace_headers = (None if ctx.raw_request is None else await
self._get_trace_headers(ctx.raw_request.headers))
if not hasattr(ctx.request, "to_pooling_params"):
return self.create_error_response(
"Request type does not support pooling parameters")
pooling_params = ctx.request.to_pooling_params()
pooling_params = self._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
if ctx.engine_prompts is None:
return self.create_error_response(

View File

@ -142,6 +142,11 @@ class OpenAIServingPooling(OpenAIServing):
try:
pooling_params = request.to_pooling_params()
try:
pooling_params.verify("encode", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"

View File

@ -55,14 +55,13 @@ class ServingScores(OpenAIServing):
texts_1: list[str],
texts_2: list[str],
request: Union[RerankRequest, ScoreRequest],
request_id=str,
request_id: str,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[Union[LoRARequest, None]] = None,
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
None]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> list[PoolingRequestOutput]:
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
input_texts = texts_1 + texts_2
engine_prompts: list[TokensPrompt] = []
@ -89,6 +88,11 @@ class ServingScores(OpenAIServing):
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params()
try:
pooling_params.verify("embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
@ -169,14 +173,13 @@ class ServingScores(OpenAIServing):
data_1: Union[list[str], list[ScoreContentPartParam]],
data_2: Union[list[str], list[ScoreContentPartParam]],
request: Union[RerankRequest, ScoreRequest],
request_id=str,
request_id: str,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[Union[LoRARequest, None]] = None,
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
None]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> list[PoolingRequestOutput]:
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
request_prompts: list[str] = []
engine_prompts: list[TokensPrompt] = []
@ -245,7 +248,12 @@ class ServingScores(OpenAIServing):
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params(use_cross_encoder=True)
pooling_params = request.to_pooling_params()
try:
pooling_params.verify("score", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
@ -286,8 +294,7 @@ class ServingScores(OpenAIServing):
request_id: str,
raw_request: Optional[Request] = None,
truncate_prompt_tokens: Optional[int] = None,
) -> list[PoolingRequestOutput]:
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
(
lora_request,
prompt_adapter_request,
@ -374,6 +381,8 @@ class ServingScores(OpenAIServing):
raw_request,
request.truncate_prompt_tokens,
)
if isinstance(final_res_batch, ErrorResponse):
return final_res_batch
return self.request_output_to_score_response(
final_res_batch,
@ -420,6 +429,9 @@ class ServingScores(OpenAIServing):
raw_request,
request.truncate_prompt_tokens,
)
if isinstance(final_res_batch, ErrorResponse):
return final_res_batch
return self.request_output_to_rerank_response(
final_res_batch,
request_id,

View File

@ -4,6 +4,7 @@
import asyncio
import time
from abc import ABC, abstractmethod
from functools import cached_property
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
Union)
@ -15,6 +16,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.pooling_params import PoolingTask
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async
@ -135,6 +137,11 @@ class ExecutorBase(ABC):
return self.collective_rpc(rpc_func)
@cached_property # Avoid unnecessary RPC calls
def supported_pooling_tasks(self) -> tuple[PoolingTask, ...]:
output = self.collective_rpc("get_supported_pooling_tasks")
return tuple({task for tasks in output for task in tasks})
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import IntEnum
from typing import Callable, Literal, Optional, TypeVar, Union
from typing import Callable, Optional, TypeVar, Union
import torch
import torch.nn as nn
@ -15,13 +15,12 @@ from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import ( # noqa: E501
PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.pooling_params import PoolingParams
from vllm.pooling_params import PoolingParams, PoolingTask
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
PoolingTask = Literal["encode", "embed", "classify", "score"]
class PoolingType(IntEnum):
@ -67,6 +66,15 @@ class ResolvedPoolingConfig:
)
@dataclass(frozen=True)
class PoolingParamsUpdate:
requires_token_ids: bool = False
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
def apply(self, params: PoolingParams) -> None:
params.requires_token_ids = self.requires_token_ids
class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@ -93,7 +101,10 @@ class Pooler(nn.Module, ABC):
return SimplePooler.from_config(resolved_config)
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
"""
Construct the pooling parameters to use for a task,
or `None` if the task is not supported.
@ -121,6 +132,23 @@ def get_prompt_lens(
pooling_metadata, hidden_states.device).prompt_lens
def get_prompt_token_ids(
pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
if isinstance(pooling_metadata, V1PoolingMetadata):
assert pooling_metadata.prompt_token_ids is not None, (
"Please set `requires_token_ids=True` in `get_pooling_updates`")
return [
pooling_metadata.prompt_token_ids[i, :num]
for i, num in enumerate(pooling_metadata.prompt_lens)
]
return [
torch.tensor(seq_data_i.prompt_token_ids)
for seq_data_i in pooling_metadata.seq_data.values()
]
def get_classification_activation_function(config: PretrainedConfig):
return PoolerClassify()
@ -165,7 +193,10 @@ class PoolingMethod(nn.Module, ABC):
raise NotImplementedError(f"Unsupported method: {pooling_type}")
@abstractmethod
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
raise NotImplementedError
@abstractmethod
@ -206,11 +237,14 @@ class PoolingMethod(nn.Module, ABC):
class CLSPool(PoolingMethod):
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if (task == "encode" or task == "embed" or task == "classify"
or task == "score"):
return PoolingParams()
return PoolingParamsUpdate()
assert_never(task)
@ -236,11 +270,14 @@ class CLSPool(PoolingMethod):
class LastPool(PoolingMethod):
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if (task == "encode" or task == "embed" or task == "classify"
or task == "score"):
return PoolingParams()
return PoolingParamsUpdate()
assert_never(task)
@ -262,9 +299,12 @@ class LastPool(PoolingMethod):
class AllPool(PoolingMethod):
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
if task == "encode":
return PoolingParams()
return PoolingParamsUpdate()
# The equalities are split up to keep mypy happy
if task == "embed" or task == "classify" or task == "score":
@ -299,11 +339,14 @@ class AllPool(PoolingMethod):
class MeanPool(PoolingMethod):
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if (task == "encode" or task == "embed" or task == "classify"
or task == "score"):
return PoolingParams()
return PoolingParamsUpdate()
assert_never(task)
@ -520,8 +563,11 @@ class SimplePooler(Pooler):
self.pooling = pooling
self.head = head
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
return self.pooling.get_pooling_params(task)
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
return self.pooling.get_pooling_updates(task)
def forward(
self,
@ -559,27 +605,13 @@ class StepPooler(Pooler):
self.step_tag_id = step_tag_id
self.returned_token_ids = returned_token_ids
def get_prompt_token_ids(
self,
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor]:
if isinstance(pooling_metadata, V1PoolingMetadata):
return [
pooling_metadata.prompt_token_ids[i, :num]
for i, num in enumerate(pooling_metadata.prompt_lens)
]
return [
torch.tensor(seq_data_i.prompt_token_ids)
for seq_data_i in pooling_metadata.seq_data.values()
]
def extract_states(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
pooled_data = list[torch.Tensor]()
returned_token_ids = self.returned_token_ids
@ -595,9 +627,12 @@ class StepPooler(Pooler):
return pooled_data
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
if task == "encode":
return PoolingParams(logits_processing_needs_token_ids=True)
return PoolingParamsUpdate(requires_token_ids=True)
# The equalities are split up to keep mypy happy
if task == "embed" or task == "classify" or task == "score":
@ -650,19 +685,24 @@ class ClassifierPooler(nn.Module):
self.cross_encoder_act_fn = get_cross_encoder_activation_function(
config.hf_config) if act_fn is None else act_fn
def _get_act_fn(self, use_cross_encoder: bool):
return (self.cross_encoder_act_fn
if use_cross_encoder else self.classification_act_fn)
def _get_act_fn(self, task: PoolingTask):
if task == "encode" or task == "classify":
return self.classification_act_fn
if task == "score":
return self.cross_encoder_act_fn
raise ValueError(f"Unsupported task: {task!r}")
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if task == "encode" or task == "classify" or task == "score":
return PoolingParamsUpdate()
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
if task == "encode":
return PoolingParams()
if task == "embed":
return None
if task == "classify":
return PoolingParams()
if task == "score":
return PoolingParams(use_cross_encoder=True)
assert_never(task)
@ -682,27 +722,28 @@ class ClassifierPooler(nn.Module):
else:
pooled_output = [self.classifier(data) for data in pooled_data]
task_list: list[PoolingTask]
if isinstance(pooling_metadata, V0PoolingMetadata):
use_cross_encoder_list = [
pooling_param.use_cross_encoder
for _, pooling_param in pooling_metadata.seq_groups
task_list = [
task for _, pooling_param in pooling_metadata.seq_groups
if (task := pooling_param.task) is not None
]
else:
use_cross_encoder_list = [
pooling_param.use_cross_encoder
for pooling_param in pooling_metadata.pooling_params
task_list = [
task for pooling_param in pooling_metadata.pooling_params
if (task := pooling_param.task) is not None
]
assert len(task_list) == len(pooled_output)
# shape of scores: (batch_size, num_labels)
if all(use_cross_encoder == use_cross_encoder_list[0]
for use_cross_encoder in use_cross_encoder_list):
act_fn = self._get_act_fn(use_cross_encoder_list[0])
if len(set(task_list)) <= 1:
act_fn = self._get_act_fn(task_list[0])
scores = act_fn(pooled_output)
else:
scores = torch.stack([
self._get_act_fn(use_cross_encoder)(vecs)
for use_cross_encoder, vecs in zip(use_cross_encoder_list,
pooled_output)
self._get_act_fn(task)(vecs)
for task, vecs in zip(task_list, pooled_output)
])
return build_output(scores)

View File

@ -18,13 +18,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
PoolingMethod, PoolingTask,
PoolingMethod,
PoolingParamsUpdate,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
@ -91,8 +92,11 @@ class BertPooler(Pooler):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
return self.pooling.get_pooling_params(task)
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
return self.pooling.get_pooling_updates(task)
def forward(
self,

View File

@ -1,18 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from array import array
from typing import Optional, Union
import numpy as np
import torch
import torch.nn as nn
from typing_extensions import assert_never
from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolerHead, PoolerNormalize
from vllm.model_executor.layers.pooler import (Pooler, PoolerHead,
PoolerNormalize,
PoolingParamsUpdate,
build_output, get_prompt_lens,
get_prompt_token_ids)
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors)
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingTask
from vllm.sequence import PoolerOutput
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsV0Only
@ -20,7 +26,8 @@ from .interfaces import SupportsV0Only
logger = init_logger(__name__)
class GritLMPooler(nn.Module):
class GritLMMeanPool(nn.Module):
"""As `MeanPool`, but only includes non-instruction tokens."""
def __init__(self, model_config: ModelConfig):
super().__init__()
@ -38,8 +45,8 @@ class GritLMPooler(nn.Module):
for tok in ["<s>", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"]
}
def tokens_to_ids(tokens: list[str]) -> array:
return array("i", [self.token_ids[token] for token in tokens])
def tokens_to_ids(tokens: list[str]) -> np.ndarray:
return np.array([self.token_ids[token] for token in tokens])
self.user_pattern_ids = tokens_to_ids(
["▁<", "|", "user", "|", ">", "<0x0A>"])
@ -48,32 +55,44 @@ class GritLMPooler(nn.Module):
self.embed_pattern_ids = tokens_to_ids(
["▁<", "|", "embed", "|", ">", "<0x0A>"])
self.head = PoolerHead(PoolerNormalize())
def _find_array(self, arr: array, target: array, start_idx: int) -> int:
def _find_array(
self,
arr: np.ndarray,
target: np.ndarray,
start_idx: int = 0,
end_idx: Optional[int] = None,
) -> int:
"""
Find the first occurrence of target in arr starting from start_idx.
Find the first occurrence of `target` in `arr` starting from
`start_idx`.
Args:
arr: The array to search within
target: The consecutive subsequence to find
start_idx: The starting index to search from
arr: The array to search within.
target: The consecutive subsequence to find.
start_idx: The starting index to search from (inclusive).
end_idx: The ending index to search from (exclusive).
Returns:
int: The index of the first occurrence of target in arr.
The index of the first occurrence of `target` in `arr`.
"""
if start_idx < 0:
raise ValueError("start_idx must be non-negative")
if not target or not arr:
raise ValueError("Empty arr or target not allowed")
raise ValueError("`start_idx` must be non-negative")
if len(arr) == 0 or len(target) == 0:
raise ValueError("Empty `arr` or `target` not allowed")
arr_len = len(arr)
target_len = len(target)
for i in range(start_idx, len(arr) - target_len + 1):
if arr[i:i + target_len] == target:
if end_idx is None:
end_idx = arr_len
for i in range(start_idx, min(end_idx, arr_len - target_len + 1)):
if (arr[i:i + target_len] == target).all():
return i
return -1
def _get_instruction_len(self, prompt_token_ids: array) -> int:
def _get_instruction_len(self, prompt_token_ids: np.ndarray) -> int:
"""
Get the length of the instruction in the prompt.
@ -83,7 +102,6 @@ class GritLMPooler(nn.Module):
The pattern matching is done using integers instead of strings
because the prompt is given as a list of token IDs.
"""
instruction_len = 0
# Return no instruction in case of missing BOS token.
@ -98,7 +116,8 @@ class GritLMPooler(nn.Module):
embed_pattern_ids = self.embed_pattern_ids
if self._find_array(prompt_token_ids,
self.user_pattern_ids,
start_idx=1) == 1:
start_idx=1,
end_idx=2) == 1:
embed_pattern_ids = self.embed_newline_pattern_ids
# Find the embed pattern in the prompt.
@ -116,64 +135,92 @@ class GritLMPooler(nn.Module):
return instruction_len
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
# The equalities are split up to keep mypy happy
if task == "encode" or task == "embed":
return PoolingParamsUpdate(requires_token_ids=True)
if task == "classify" or task == "score":
return None
assert_never(task)
def forward_one(
self,
hidden_states: torch.Tensor,
prompt_len: Optional[torch.Tensor] = None,
instr_len: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
"partial prefill not supported with MEAN pooling"
return hidden_states[instr_len:].mean(dim=0, dtype=torch.float32)
def forward_all(
self,
hidden_states: torch.Tensor,
prompt_lens: torch.Tensor,
instr_lens: torch.Tensor,
) -> Union[list[torch.Tensor], torch.Tensor]:
offset = 0
pooled_data = list[torch.Tensor]()
for prompt_len, instr_len in zip(prompt_lens, instr_lens):
pooled_data.append(hidden_states[offset + instr_len:offset +
prompt_len].mean(
dim=0, dtype=torch.float32))
offset += prompt_len
return pooled_data
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
instr_lens = torch.tensor(
[
self._get_instruction_len(token_ids.cpu().numpy())
for token_ids in get_prompt_token_ids(pooling_metadata)
],
device=prompt_lens.device,
)
if isinstance(hidden_states, list):
return [
self.forward_one(h, prompt_len, instr_len) for h, prompt_len,
instr_len in zip(hidden_states, prompt_lens, instr_lens)
]
return self.forward_all(hidden_states, prompt_lens, instr_lens)
class GritLMPooler(Pooler):
def __init__(self, model_config: ModelConfig):
super().__init__()
self.pooling = GritLMMeanPool(model_config)
self.head = PoolerHead(PoolerNormalize())
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
return self.pooling.get_pooling_updates(task)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
"""
Pool the hidden states by summing the embeddings of
non-instruction tokens.
"""
prompts_token_ids = [
token_ids.prompt_token_ids_array
for _, token_ids in pooling_metadata.seq_data.items()
]
instruction_lens = torch.tensor(
[
self._get_instruction_len(prompt_token_ids)
for prompt_token_ids in prompts_token_ids
],
device=hidden_states.device,
)
prompt_lens = PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
mask = torch.zeros_like(hidden_states, dtype=torch.bool)
start_idx = 0
for prompt_len, instruction_len in zip(prompt_lens, instruction_lens):
end_idx = start_idx + prompt_len
mask[start_idx + instruction_len:end_idx] = True
start_idx = end_idx
masked_hidden_states = hidden_states.masked_fill(~mask, 0.0)
sum_embeddings = torch.zeros(len(prompt_lens),
hidden_states.size(1),
device=hidden_states.device)
start_idx = 0
for i, prompt_len in enumerate(prompt_lens):
end_idx = start_idx + prompt_len
sum_embeddings[i] = masked_hidden_states[start_idx:end_idx].sum(
dim=0)
start_idx = end_idx
num_non_instruction_tokens = prompt_lens - instruction_lens
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
1)
pooled_data = self.head(mean_embeddings,
pooling_metadata=pooling_metadata)
pooled_outputs = [
PoolingSequenceGroupOutput(data) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return build_output(pooled_data)
class GritLM(LlamaForCausalLM, SupportsV0Only):
@ -202,7 +249,7 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
prefix: str = "",
**kwargs,
) -> None:
# Use full attention for pooling
# Use full attention for pooling (this is why V1 is not supported yet)
if vllm_config.model_config.runner_type == "pooling":
hf_config = vllm_config.model_config.hf_config
hf_config.is_causal = False

View File

@ -599,13 +599,6 @@ def supports_cross_encoding(
return is_pooling_model(model) and _supports_cross_encoding(model)
def has_step_pooler(model: Union[type[object], object]) -> bool:
"""Check if the model uses step pooler."""
from vllm.model_executor.layers.pooler import StepPooler
return is_pooling_model(model) and isinstance(model.pooler, StepPooler)
class SupportsQuant:
"""The interface required for all models that support quantization."""

View File

@ -14,14 +14,15 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
PoolingMethod, PoolingTask,
PoolingMethod,
PoolingParamsUpdate,
PoolingType)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsCrossEncoding, SupportsV0Only
@ -270,8 +271,11 @@ class ModernBertPooler(Pooler):
eps=config.norm_eps,
bias=config.norm_bias)
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
return self.pooling.get_pooling_params(task)
def get_pooling_updates(
self,
task: PoolingTask,
) -> Optional[PoolingParamsUpdate]:
return self.pooling.get_pooling_updates(task)
def forward(
self,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional
import msgspec
@ -10,12 +10,14 @@ from vllm.sampling_params import RequestOutputKind
if TYPE_CHECKING:
from vllm.config import ModelConfig
PoolingTask = Literal["encode", "embed", "classify", "score"]
class PoolingParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""API parameters for pooling models. This
"""API parameters for pooling models.
Attributes:
dimensions: Reduce the dimensions of embeddings
@ -24,24 +26,33 @@ class PoolingParams(
dimensions: Optional[int] = None
use_cross_encoder: bool = False
"""Internal use only."""
logits_processing_needs_token_ids: bool = False
"""Internal use only."""
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
task: Optional[PoolingTask] = None
"""Internal use only."""
requires_token_ids: bool = False
"""Internal use only."""
def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance."""
return PoolingParams(
dimensions=self.dimensions,
use_cross_encoder=self.use_cross_encoder,
logits_processing_needs_token_ids=self.
logits_processing_needs_token_ids,
task=self.task,
requires_token_ids=self.requires_token_ids,
)
def verify(self, model_config: "ModelConfig") -> None:
def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None:
if self.task is None:
self.task = task
elif self.task != task:
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
raise ValueError(msg)
# NOTE: Task validation needs to done against the model instance,
# which is not available in model config. So, it's not included
# in this method
if self.dimensions is not None:
if not model_config.is_matryoshka:
raise ValueError(
@ -61,12 +72,10 @@ class PoolingParams(
raise ValueError("Dimensions must be greater than 0")
def __repr__(self) -> str:
return (
f"PoolingParams("
f"dimensions={self.dimensions}, "
f"use_cross_encoder={self.use_cross_encoder}, "
f"logits_processing_needs_token_ids={self.logits_processing_needs_token_ids})"
)
return (f"PoolingParams("
f"dimensions={self.dimensions}, "
f"task={self.task}, "
f"requires_token_ids={self.requires_token_ids})")
def __post_init__(self) -> None:
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\

View File

@ -181,6 +181,12 @@ class EngineCore:
def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""
if pooling_params := request.pooling_params:
supported_pooling_tasks = (
self.model_executor.supported_pooling_tasks)
if pooling_params.task not in supported_pooling_tasks:
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
f"Supported tasks: {supported_pooling_tasks}")
if request.mm_hashes is not None:
# Here, if hash exists for a multimodal input, then it will be

View File

@ -8,7 +8,6 @@ import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import has_step_pooler
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
@ -54,9 +53,6 @@ class CPUModelRunner(GPUModelRunner):
logger.info("Starting to load model %s...", self.model_config.model)
self.model = get_model(vllm_config=self.vllm_config)
if has_step_pooler(self.model):
self.input_batch.logits_processing_needs_token_ids = True
if self.lora_config:
self.model = self.load_lora_model(self.model, self.model_config,
self.scheduler_config,

View File

@ -70,7 +70,6 @@ class InputBatch:
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
is_spec_decode: bool = False,
logits_processing_needs_token_ids: bool = False,
):
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
@ -79,8 +78,6 @@ class InputBatch:
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self.logits_processing_needs_token_ids = (
logits_processing_needs_token_ids)
self._req_ids: list[Optional[str]] = []
self.req_id_to_index: dict[str, int] = {}
@ -233,6 +230,9 @@ class InputBatch:
# req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
dtype=bool)
self.req_output_token_ids: list[Optional[list[int]]] = []
# This is updated each time the batch constituents change.
@ -365,9 +365,12 @@ class InputBatch:
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
elif pooling_params := request.pooling_params:
self.pooling_params[req_id] = pooling_params
self.logits_processing_needs_token_ids[req_index] = (
pooling_params.requires_token_ids)
else:
assert request.pooling_params is not None
self.pooling_params[req_id] = request.pooling_params
raise NotImplementedError(request)
# Add request lora ID
if request.lora_request:
@ -620,9 +623,9 @@ class InputBatch:
copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs)
needs_prompt_token_ids = (not self.no_penalties or
(self.num_reqs > 0
and self.logits_processing_needs_token_ids))
needs_prompt_token_ids = (
not self.no_penalties
or self.logits_processing_needs_token_ids[:num_reqs].any())
if needs_prompt_token_ids:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.

View File

@ -4,7 +4,7 @@
import gc
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union, cast, get_args
import numpy as np
import torch
@ -32,12 +32,13 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (has_step_pooler,
is_mixture_of_experts)
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.model_executor.models.interfaces_base import (VllmModelForPooling,
is_pooling_model)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.pooling_params import PoolingParams
from vllm.pooling_params import PoolingParams, PoolingTask
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
@ -404,6 +405,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
pooling_params = new_req_data.pooling_params
if sampling_params and \
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
@ -411,6 +413,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
generator = None
if pooling_params:
assert pooling_params.task is not None, (
"You did not set `task` in the API")
model = cast(VllmModelForPooling, self.model)
to_update = (model.pooler.get_pooling_updates(
pooling_params.task))
assert to_update is not None, (
f"{pooling_params.task=} is not supported by the model")
to_update.apply(pooling_params)
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
@ -1092,6 +1106,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def get_model(self) -> nn.Module:
return self.model
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model()
if not is_pooling_model(model):
return []
return [
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]
def apply_grammar_bitmask(
self,
scheduler_output: "SchedulerOutput",
@ -1737,8 +1761,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
model_loader.load_weights(self.model,
model_config=self.model_config)
if has_step_pooler(self.model):
self.input_batch.logits_processing_needs_token_ids = True
if self.lora_config:
self.model = self.load_lora_model(self.model,
self.model_config,
@ -2138,17 +2160,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
req_num_tokens = num_tokens // num_reqs
model = cast(VllmModelForPooling, self.model)
dummy_task = self.get_supported_pooling_tasks()[0]
dummy_pooling_params = PoolingParams(task=dummy_task)
to_update = model.pooler.get_pooling_updates(dummy_task)
assert to_update is not None
to_update.apply(dummy_pooling_params)
dummy_metadata = PoolingMetadata(
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
device=self.device),
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
dtype=torch.int32,
device=self.device),
pooling_params=[PoolingParams()] * num_reqs)
pooling_params=[dummy_pooling_params] * num_reqs)
try:
pooler_output = self.model.pooler(hidden_states=hidden_states_list,
pooling_metadata=dummy_metadata)
pooler_output = model.pooler(hidden_states=hidden_states_list,
pooling_metadata=dummy_metadata)
except RuntimeError as e:
if 'out of memory' in str(e):
raise RuntimeError(

View File

@ -23,6 +23,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
@ -309,6 +310,9 @@ class Worker(WorkerBase):
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
return self.model_runner.get_supported_pooling_tasks()
@torch.inference_mode()
def execute_model(
self,

View File

@ -3,7 +3,7 @@
import bisect
import gc
import time
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast, get_args
from unittest.mock import patch
import numpy as np
@ -25,10 +25,12 @@ from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.tpu import TPUModelLoader
from vllm.model_executor.models.interfaces_base import is_pooling_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
is_pin_memory_available, prev_power_of_2)
@ -483,6 +485,16 @@ class TPUModelRunner(LoRAModelRunnerMixin):
def get_model(self) -> nn.Module:
return self.model
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model()
if not is_pooling_model(model):
return []
return [
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each

View File

@ -19,6 +19,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingTask
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.core.sched.output import SchedulerOutput
@ -275,6 +276,9 @@ class TPUWorker:
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
return self.model_runner.get_supported_pooling_tasks()
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()

View File

@ -4,7 +4,7 @@
import dataclasses
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar)
TypeVar, get_args)
import torch
import torch.nn as nn
@ -12,6 +12,8 @@ import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.interfaces_base import is_pooling_model
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
if TYPE_CHECKING:
@ -223,6 +225,16 @@ class ModelRunnerBase(ABC, Generic[T]):
def get_model(self) -> nn.Module:
raise NotImplementedError
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model()
if not is_pooling_model(model):
return []
return [
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]
def execute_model(
self,
model_input: T,

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
import torch
@ -10,6 +10,7 @@ from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
@ -195,7 +196,20 @@ class PoolingModelRunner(
seq_groups: List[Tuple[List[int], PoolingParams]] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
pooling_params = seq_group_metadata.pooling_params
assert pooling_params is not None
assert pooling_params.task is not None, (
"You did not set `task` in the API")
to_update = (cast(VllmModelForPooling,
self.model).pooler.get_pooling_updates(
pooling_params.task))
assert to_update is not None, (
f"{pooling_params.task=} is not supported by the model")
to_update.apply(pooling_params)
seq_groups.append((seq_ids, pooling_params))
seq_data: Dict[int, SequenceData] = {}