From 45badd05d04254eb75c48cee7b1d454a80de2165 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 18 Jul 2025 20:41:17 +0800 Subject: [PATCH] [Core] Set pooling params based on task and model (#21128) Signed-off-by: DarkLight1337 --- tests/models/language/pooling/test_gritlm.py | 26 +-- vllm/entrypoints/llm.py | 49 +++-- vllm/entrypoints/openai/protocol.py | 8 +- .../openai/serving_classification.py | 32 +++ vllm/entrypoints/openai/serving_embedding.py | 18 +- vllm/entrypoints/openai/serving_engine.py | 18 +- vllm/entrypoints/openai/serving_pooling.py | 5 + vllm/entrypoints/openai/serving_score.py | 30 ++- vllm/executor/executor_base.py | 7 + vllm/model_executor/layers/pooler.py | 149 ++++++++----- vllm/model_executor/models/bert.py | 12 +- vllm/model_executor/models/gritlm.py | 203 +++++++++++------- vllm/model_executor/models/interfaces.py | 7 - vllm/model_executor/models/modernbert.py | 12 +- vllm/pooling_params.py | 45 ++-- vllm/v1/engine/core.py | 6 + vllm/v1/worker/cpu_model_runner.py | 4 - vllm/v1/worker/gpu_input_batch.py | 19 +- vllm/v1/worker/gpu_model_runner.py | 48 ++++- vllm/v1/worker/gpu_worker.py | 4 + vllm/v1/worker/tpu_model_runner.py | 14 +- vllm/v1/worker/tpu_worker.py | 4 + vllm/worker/model_runner_base.py | 14 +- vllm/worker/pooling_model_runner.py | 16 +- 24 files changed, 509 insertions(+), 241 deletions(-) diff --git a/tests/models/language/pooling/test_gritlm.py b/tests/models/language/pooling/test_gritlm.py index c2f70bb647a4e..1274657991bfe 100644 --- a/tests/models/language/pooling/test_gritlm.py +++ b/tests/models/language/pooling/test_gritlm.py @@ -2,9 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations -import importlib.util -from array import array - +import numpy as np import openai import pytest from scipy.spatial.distance import cosine @@ -14,10 +12,6 @@ from vllm.config import ModelConfig from ....utils import RemoteOpenAIServer -# GritLM embedding implementation is only supported by XFormers backend. -pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"), - reason="GritLM requires XFormers") - MODEL_NAME = "parasail-ai/GritLM-7B-vllm" MAX_MODEL_LEN = 4000 @@ -26,11 +20,11 @@ def _arr(arr): """ Convert a list of integers to an array of integers. """ - return array("i", arr) + return np.array(arr) def test_find_array(): - from vllm.model_executor.models.gritlm import GritLMPooler + from vllm.model_executor.models.gritlm import GritLMMeanPool model_config = ModelConfig( MODEL_NAME, @@ -41,17 +35,19 @@ def test_find_array(): dtype="bfloat16", seed=0, ) - pooler = GritLMPooler(model_config=model_config) + pooling = GritLMMeanPool(model_config=model_config) arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3 - assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3 - assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1 - assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1 + assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3 + assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3 + assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1 + assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=3) == -1 + assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=4) == 3 + assert pooling._find_array(arr, _arr([3, 5]), start_idx=0) == -1 with pytest.raises(ValueError): - pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1) + pooling._find_array(arr, _arr([3, 4, 5]), start_idx=-1) def run_llm_encode( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e7398ecc23c8f..78f9d32d811d3 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, PoolingRequestOutput, RequestOutput, ScoringRequestOutput) -from vllm.pooling_params import PoolingParams +from vllm.pooling_params import PoolingParams, PoolingTask from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, RequestOutputKind, SamplingParams) @@ -964,6 +964,7 @@ class LLM: use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + pooling_task: PoolingTask = "encode", ) -> list[PoolingRequestOutput]: ... @@ -979,6 +980,7 @@ class LLM: use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + pooling_task: PoolingTask = "encode", ) -> list[PoolingRequestOutput]: ... @@ -994,6 +996,7 @@ class LLM: use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + pooling_task: PoolingTask = "encode", ) -> list[PoolingRequestOutput]: ... @@ -1010,6 +1013,7 @@ class LLM: use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + pooling_task: PoolingTask = "encode", ) -> list[PoolingRequestOutput]: ... @@ -1026,6 +1030,7 @@ class LLM: use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + pooling_task: PoolingTask = "encode", ) -> list[PoolingRequestOutput]: ... @@ -1040,6 +1045,7 @@ class LLM: use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + pooling_task: PoolingTask = "encode", ) -> list[PoolingRequestOutput]: ... @@ -1059,6 +1065,7 @@ class LLM: use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + pooling_task: PoolingTask = "encode", ) -> list[PoolingRequestOutput]: """Apply pooling to the hidden states corresponding to the input prompts. @@ -1080,6 +1087,7 @@ class LLM: lora_request: LoRA request to use for generation, if any. prompt_adapter_request: Prompt Adapter request to use for generation, if any. + pooling_task: Override the pooling task to use. Returns: A list of `PoolingRequestOutput` objects containing the @@ -1116,11 +1124,12 @@ class LLM: if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() - elif isinstance(pooling_params, PoolingParams): - pooling_params.verify(model_config) + + if isinstance(pooling_params, PoolingParams): + pooling_params.verify(pooling_task, model_config) else: for pooling_param in pooling_params: - pooling_param.verify(model_config) + pooling_param.verify(pooling_task, model_config) tokenization_kwargs = dict[str, Any]() _validate_truncation_size(model_config.max_model_len, @@ -1181,12 +1190,15 @@ class LLM: raise ValueError("Embedding API is not supported by this model. " "Please set `--task embed`.") - items = self.encode(prompts, - truncate_prompt_tokens=truncate_prompt_tokens, - use_tqdm=use_tqdm, - pooling_params=pooling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + items = self.encode( + prompts, + truncate_prompt_tokens=truncate_prompt_tokens, + use_tqdm=use_tqdm, + pooling_params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + pooling_task="embed", + ) return [EmbeddingRequestOutput.from_base(item) for item in items] @@ -1228,10 +1240,13 @@ class LLM: "Classification API is not supported by this model. " "Please set `--task classify`.") - items = self.encode(prompts, - use_tqdm=use_tqdm, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + items = self.encode( + prompts, + use_tqdm=use_tqdm, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + pooling_task="classify", + ) return [ClassificationRequestOutput.from_base(item) for item in items] @@ -1251,7 +1266,9 @@ class LLM: truncate_prompt_tokens=truncate_prompt_tokens, use_tqdm=use_tqdm, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + pooling_task="embed", + ) encoded_output_1: list[PoolingRequestOutput] = encoded_output[ 0:len(text_1)] @@ -1287,7 +1304,7 @@ class LLM: if len(data_1) == 1: data_1 = data_1 * len(data_2) - pooling_params = PoolingParams(use_cross_encoder=True) + pooling_params = PoolingParams(task="score") tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(self.llm_engine.model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index a421ed1fc3278..95e5bcd3bae1e 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1347,8 +1347,8 @@ class ScoreRequest(OpenAIBaseModel): # --8<-- [end:score-extra-params] - def to_pooling_params(self, *, use_cross_encoder: bool = False): - return PoolingParams(use_cross_encoder=use_cross_encoder) + def to_pooling_params(self): + return PoolingParams() class RerankRequest(OpenAIBaseModel): @@ -1375,8 +1375,8 @@ class RerankRequest(OpenAIBaseModel): # --8<-- [end:rerank-extra-params] - def to_pooling_params(self, *, use_cross_encoder: bool = False): - return PoolingParams(use_cross_encoder=use_cross_encoder) + def to_pooling_params(self): + return PoolingParams() class RerankDocument(BaseModel): diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index 3ac4f01ea6028..e4ea5ab8dc5f2 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -6,6 +6,7 @@ from typing import Optional, Union, cast import numpy as np from fastapi import Request +from typing_extensions import override from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient @@ -21,12 +22,14 @@ from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext, from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger from vllm.outputs import ClassificationOutput, PoolingRequestOutput +from vllm.pooling_params import PoolingParams logger = init_logger(__name__) class ClassificationMixin(OpenAIServing): + @override async def _preprocess( self, ctx: ServeContext, @@ -75,6 +78,7 @@ class ClassificationMixin(OpenAIServing): logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) + @override def _build_response( self, ctx: ServeContext, @@ -158,3 +162,31 @@ class ServingClassification(ClassificationMixin): ) return await super().handle(ctx) # type: ignore + + @override + def _validate_request( + self, + ctx: ClassificationServeContext, + ) -> Optional[ErrorResponse]: + if error := super()._validate_request(ctx): + return error + + ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens + + return None + + @override + def _create_pooling_params( + self, + ctx: ClassificationServeContext, + ) -> Union[PoolingParams, ErrorResponse]: + pooling_params = super()._create_pooling_params(ctx) + if isinstance(pooling_params, ErrorResponse): + return pooling_params + + try: + pooling_params.verify("classify", self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) + + return pooling_params diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index e87decfe636ac..f5ce86a78c8b4 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -24,6 +24,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, PoolingRequestOutput) +from vllm.pooling_params import PoolingParams logger = init_logger(__name__) @@ -45,6 +46,7 @@ def _get_embedding( class EmbeddingMixin(OpenAIServing): + @override async def _preprocess( self, ctx: ServeContext, @@ -97,6 +99,7 @@ class EmbeddingMixin(OpenAIServing): logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) + @override def _build_response( self, ctx: ServeContext, @@ -191,11 +194,20 @@ class OpenAIServingEmbedding(EmbeddingMixin): ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens - pooling_params = ctx.request.to_pooling_params() + return None + + @override + def _create_pooling_params( + self, + ctx: ServeContext[EmbeddingRequest], + ) -> Union[PoolingParams, ErrorResponse]: + pooling_params = super()._create_pooling_params(ctx) + if isinstance(pooling_params, ErrorResponse): + return pooling_params try: - pooling_params.verify(self.model_config) + pooling_params.verify("embed", self.model_config) except ValueError as e: return self.create_error_response(str(e)) - return None + return pooling_params diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 462317a0878c7..393e32f0ed9b6 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -305,6 +305,16 @@ class OpenAIServing: " Please, select a smaller truncation size.") return None + def _create_pooling_params( + self, + ctx: ServeContext, + ) -> Union[PoolingParams, ErrorResponse]: + if not hasattr(ctx.request, "to_pooling_params"): + return self.create_error_response( + "Request type does not support pooling parameters") + + return ctx.request.to_pooling_params() + async def _prepare_generators( self, ctx: ServeContext, @@ -318,11 +328,9 @@ class OpenAIServing: trace_headers = (None if ctx.raw_request is None else await self._get_trace_headers(ctx.raw_request.headers)) - if not hasattr(ctx.request, "to_pooling_params"): - return self.create_error_response( - "Request type does not support pooling parameters") - - pooling_params = ctx.request.to_pooling_params() + pooling_params = self._create_pooling_params(ctx) + if isinstance(pooling_params, ErrorResponse): + return pooling_params if ctx.engine_prompts is None: return self.create_error_response( diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index c2ed50d04d124..eec21087b9957 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -142,6 +142,11 @@ class OpenAIServingPooling(OpenAIServing): try: pooling_params = request.to_pooling_params() + try: + pooling_params.verify("encode", self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) + for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}" diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 8d47a417f9cde..35f6581768a32 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -55,14 +55,13 @@ class ServingScores(OpenAIServing): texts_1: list[str], texts_2: list[str], request: Union[RerankRequest, ScoreRequest], - request_id=str, + request_id: str, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[Union[LoRARequest, None]] = None, prompt_adapter_request: Optional[Union[PromptAdapterRequest, None]] = None, trace_headers: Optional[Mapping[str, str]] = None, - ) -> list[PoolingRequestOutput]: - + ) -> Union[list[PoolingRequestOutput], ErrorResponse]: input_texts = texts_1 + texts_2 engine_prompts: list[TokensPrompt] = [] @@ -89,6 +88,11 @@ class ServingScores(OpenAIServing): generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] pooling_params = request.to_pooling_params() + try: + pooling_params.verify("embed", self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) + for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}" @@ -169,14 +173,13 @@ class ServingScores(OpenAIServing): data_1: Union[list[str], list[ScoreContentPartParam]], data_2: Union[list[str], list[ScoreContentPartParam]], request: Union[RerankRequest, ScoreRequest], - request_id=str, + request_id: str, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[Union[LoRARequest, None]] = None, prompt_adapter_request: Optional[Union[PromptAdapterRequest, None]] = None, trace_headers: Optional[Mapping[str, str]] = None, - ) -> list[PoolingRequestOutput]: - + ) -> Union[list[PoolingRequestOutput], ErrorResponse]: request_prompts: list[str] = [] engine_prompts: list[TokensPrompt] = [] @@ -245,7 +248,12 @@ class ServingScores(OpenAIServing): # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - pooling_params = request.to_pooling_params(use_cross_encoder=True) + pooling_params = request.to_pooling_params() + + try: + pooling_params.verify("score", self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}" @@ -286,8 +294,7 @@ class ServingScores(OpenAIServing): request_id: str, raw_request: Optional[Request] = None, truncate_prompt_tokens: Optional[int] = None, - ) -> list[PoolingRequestOutput]: - + ) -> Union[list[PoolingRequestOutput], ErrorResponse]: ( lora_request, prompt_adapter_request, @@ -374,6 +381,8 @@ class ServingScores(OpenAIServing): raw_request, request.truncate_prompt_tokens, ) + if isinstance(final_res_batch, ErrorResponse): + return final_res_batch return self.request_output_to_score_response( final_res_batch, @@ -420,6 +429,9 @@ class ServingScores(OpenAIServing): raw_request, request.truncate_prompt_tokens, ) + if isinstance(final_res_batch, ErrorResponse): + return final_res_batch + return self.request_output_to_rerank_response( final_res_batch, request_id, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 99e12201c96af..ca9f1376b9f40 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -4,6 +4,7 @@ import asyncio import time from abc import ABC, abstractmethod +from functools import cached_property from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, Union) @@ -15,6 +16,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.pooling_params import PoolingTask from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import make_async @@ -135,6 +137,11 @@ class ExecutorBase(ABC): return self.collective_rpc(rpc_func) + @cached_property # Avoid unnecessary RPC calls + def supported_pooling_tasks(self) -> tuple[PoolingTask, ...]: + output = self.collective_rpc("get_supported_pooling_tasks") + return tuple({task for tasks in output for task in tasks}) + def execute_model( self, execute_model_req: ExecuteModelRequest ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 74916492f574d..6a474b8e73a35 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import IntEnum -from typing import Callable, Literal, Optional, TypeVar, Union +from typing import Callable, Optional, TypeVar, Union import torch import torch.nn as nn @@ -15,13 +15,12 @@ from vllm.config import ModelConfig, PoolerConfig from vllm.model_executor.pooling_metadata import ( # noqa: E501 PoolingMetadata as V0PoolingMetadata) from vllm.model_executor.pooling_metadata import PoolingTensors -from vllm.pooling_params import PoolingParams +from vllm.pooling_params import PoolingParams, PoolingTask from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.utils import resolve_obj_by_qualname from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] -PoolingTask = Literal["encode", "embed", "classify", "score"] class PoolingType(IntEnum): @@ -67,6 +66,15 @@ class ResolvedPoolingConfig: ) +@dataclass(frozen=True) +class PoolingParamsUpdate: + requires_token_ids: bool = False + """Set this flag to enable `get_prompt_token_ids` for your pooler.""" + + def apply(self, params: PoolingParams) -> None: + params.requires_token_ids = self.requires_token_ids + + class Pooler(nn.Module, ABC): """The interface required for all poolers used in pooling models in vLLM.""" @@ -93,7 +101,10 @@ class Pooler(nn.Module, ABC): return SimplePooler.from_config(resolved_config) - def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + def get_pooling_updates( + self, + task: PoolingTask, + ) -> Optional[PoolingParamsUpdate]: """ Construct the pooling parameters to use for a task, or `None` if the task is not supported. @@ -121,6 +132,23 @@ def get_prompt_lens( pooling_metadata, hidden_states.device).prompt_lens +def get_prompt_token_ids( + pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: + if isinstance(pooling_metadata, V1PoolingMetadata): + assert pooling_metadata.prompt_token_ids is not None, ( + "Please set `requires_token_ids=True` in `get_pooling_updates`") + + return [ + pooling_metadata.prompt_token_ids[i, :num] + for i, num in enumerate(pooling_metadata.prompt_lens) + ] + + return [ + torch.tensor(seq_data_i.prompt_token_ids) + for seq_data_i in pooling_metadata.seq_data.values() + ] + + def get_classification_activation_function(config: PretrainedConfig): return PoolerClassify() @@ -165,7 +193,10 @@ class PoolingMethod(nn.Module, ABC): raise NotImplementedError(f"Unsupported method: {pooling_type}") @abstractmethod - def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + def get_pooling_updates( + self, + task: PoolingTask, + ) -> Optional[PoolingParamsUpdate]: raise NotImplementedError @abstractmethod @@ -206,11 +237,14 @@ class PoolingMethod(nn.Module, ABC): class CLSPool(PoolingMethod): - def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + def get_pooling_updates( + self, + task: PoolingTask, + ) -> Optional[PoolingParamsUpdate]: # The equalities are split up to keep mypy happy if (task == "encode" or task == "embed" or task == "classify" or task == "score"): - return PoolingParams() + return PoolingParamsUpdate() assert_never(task) @@ -236,11 +270,14 @@ class CLSPool(PoolingMethod): class LastPool(PoolingMethod): - def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + def get_pooling_updates( + self, + task: PoolingTask, + ) -> Optional[PoolingParamsUpdate]: # The equalities are split up to keep mypy happy if (task == "encode" or task == "embed" or task == "classify" or task == "score"): - return PoolingParams() + return PoolingParamsUpdate() assert_never(task) @@ -262,9 +299,12 @@ class LastPool(PoolingMethod): class AllPool(PoolingMethod): - def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + def get_pooling_updates( + self, + task: PoolingTask, + ) -> Optional[PoolingParamsUpdate]: if task == "encode": - return PoolingParams() + return PoolingParamsUpdate() # The equalities are split up to keep mypy happy if task == "embed" or task == "classify" or task == "score": @@ -299,11 +339,14 @@ class AllPool(PoolingMethod): class MeanPool(PoolingMethod): - def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + def get_pooling_updates( + self, + task: PoolingTask, + ) -> Optional[PoolingParamsUpdate]: # The equalities are split up to keep mypy happy if (task == "encode" or task == "embed" or task == "classify" or task == "score"): - return PoolingParams() + return PoolingParamsUpdate() assert_never(task) @@ -520,8 +563,11 @@ class SimplePooler(Pooler): self.pooling = pooling self.head = head - def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: - return self.pooling.get_pooling_params(task) + def get_pooling_updates( + self, + task: PoolingTask, + ) -> Optional[PoolingParamsUpdate]: + return self.pooling.get_pooling_updates(task) def forward( self, @@ -559,27 +605,13 @@ class StepPooler(Pooler): self.step_tag_id = step_tag_id self.returned_token_ids = returned_token_ids - def get_prompt_token_ids( - self, - pooling_metadata: PoolingMetadata, - ) -> list[torch.Tensor]: - if isinstance(pooling_metadata, V1PoolingMetadata): - return [ - pooling_metadata.prompt_token_ids[i, :num] - for i, num in enumerate(pooling_metadata.prompt_lens) - ] - return [ - torch.tensor(seq_data_i.prompt_token_ids) - for seq_data_i in pooling_metadata.seq_data.values() - ] - def extract_states( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> Union[list[torch.Tensor], torch.Tensor]: pooled_data_lst = self.pooling(hidden_states, pooling_metadata) - prompt_token_ids = self.get_prompt_token_ids(pooling_metadata) + prompt_token_ids = get_prompt_token_ids(pooling_metadata) pooled_data = list[torch.Tensor]() returned_token_ids = self.returned_token_ids @@ -595,9 +627,12 @@ class StepPooler(Pooler): return pooled_data - def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: + def get_pooling_updates( + self, + task: PoolingTask, + ) -> Optional[PoolingParamsUpdate]: if task == "encode": - return PoolingParams(logits_processing_needs_token_ids=True) + return PoolingParamsUpdate(requires_token_ids=True) # The equalities are split up to keep mypy happy if task == "embed" or task == "classify" or task == "score": @@ -650,19 +685,24 @@ class ClassifierPooler(nn.Module): self.cross_encoder_act_fn = get_cross_encoder_activation_function( config.hf_config) if act_fn is None else act_fn - def _get_act_fn(self, use_cross_encoder: bool): - return (self.cross_encoder_act_fn - if use_cross_encoder else self.classification_act_fn) + def _get_act_fn(self, task: PoolingTask): + if task == "encode" or task == "classify": + return self.classification_act_fn + if task == "score": + return self.cross_encoder_act_fn + + raise ValueError(f"Unsupported task: {task!r}") + + def get_pooling_updates( + self, + task: PoolingTask, + ) -> Optional[PoolingParamsUpdate]: + # The equalities are split up to keep mypy happy + if task == "encode" or task == "classify" or task == "score": + return PoolingParamsUpdate() - def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: - if task == "encode": - return PoolingParams() if task == "embed": return None - if task == "classify": - return PoolingParams() - if task == "score": - return PoolingParams(use_cross_encoder=True) assert_never(task) @@ -682,27 +722,28 @@ class ClassifierPooler(nn.Module): else: pooled_output = [self.classifier(data) for data in pooled_data] + task_list: list[PoolingTask] if isinstance(pooling_metadata, V0PoolingMetadata): - use_cross_encoder_list = [ - pooling_param.use_cross_encoder - for _, pooling_param in pooling_metadata.seq_groups + task_list = [ + task for _, pooling_param in pooling_metadata.seq_groups + if (task := pooling_param.task) is not None ] else: - use_cross_encoder_list = [ - pooling_param.use_cross_encoder - for pooling_param in pooling_metadata.pooling_params + task_list = [ + task for pooling_param in pooling_metadata.pooling_params + if (task := pooling_param.task) is not None ] + assert len(task_list) == len(pooled_output) + # shape of scores: (batch_size, num_labels) - if all(use_cross_encoder == use_cross_encoder_list[0] - for use_cross_encoder in use_cross_encoder_list): - act_fn = self._get_act_fn(use_cross_encoder_list[0]) + if len(set(task_list)) <= 1: + act_fn = self._get_act_fn(task_list[0]) scores = act_fn(pooled_output) else: scores = torch.stack([ - self._get_act_fn(use_cross_encoder)(vecs) - for use_cross_encoder, vecs in zip(use_cross_encoder_list, - pooled_output) + self._get_act_fn(task)(vecs) + for task, vecs in zip(task_list, pooled_output) ]) return build_output(scores) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index bd4445c49a039..006f547bb4617 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -18,13 +18,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, - PoolingMethod, PoolingTask, + PoolingMethod, + PoolingParamsUpdate, PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.pooling_params import PoolingParams +from vllm.pooling_params import PoolingTask from vllm.sequence import IntermediateTensors from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only @@ -91,8 +92,11 @@ class BertPooler(Pooler): self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() - def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: - return self.pooling.get_pooling_params(task) + def get_pooling_updates( + self, + task: PoolingTask, + ) -> Optional[PoolingParamsUpdate]: + return self.pooling.get_pooling_updates(task) def forward( self, diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index ba0e22892d86c..8443482119b0c 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -1,18 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from array import array +from typing import Optional, Union +import numpy as np import torch import torch.nn as nn +from typing_extensions import assert_never from vllm.config import ModelConfig, VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import PoolerHead, PoolerNormalize +from vllm.model_executor.layers.pooler import (Pooler, PoolerHead, + PoolerNormalize, + PoolingParamsUpdate, + build_output, get_prompt_lens, + get_prompt_token_ids) from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.pooling_metadata import (PoolingMetadata, - PoolingTensors) -from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.pooling_params import PoolingTask +from vllm.sequence import PoolerOutput from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from .interfaces import SupportsV0Only @@ -20,7 +26,8 @@ from .interfaces import SupportsV0Only logger = init_logger(__name__) -class GritLMPooler(nn.Module): +class GritLMMeanPool(nn.Module): + """As `MeanPool`, but only includes non-instruction tokens.""" def __init__(self, model_config: ModelConfig): super().__init__() @@ -38,8 +45,8 @@ class GritLMPooler(nn.Module): for tok in ["", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"] } - def tokens_to_ids(tokens: list[str]) -> array: - return array("i", [self.token_ids[token] for token in tokens]) + def tokens_to_ids(tokens: list[str]) -> np.ndarray: + return np.array([self.token_ids[token] for token in tokens]) self.user_pattern_ids = tokens_to_ids( ["▁<", "|", "user", "|", ">", "<0x0A>"]) @@ -48,32 +55,44 @@ class GritLMPooler(nn.Module): self.embed_pattern_ids = tokens_to_ids( ["▁<", "|", "embed", "|", ">", "<0x0A>"]) - self.head = PoolerHead(PoolerNormalize()) - - def _find_array(self, arr: array, target: array, start_idx: int) -> int: + def _find_array( + self, + arr: np.ndarray, + target: np.ndarray, + start_idx: int = 0, + end_idx: Optional[int] = None, + ) -> int: """ - Find the first occurrence of target in arr starting from start_idx. + Find the first occurrence of `target` in `arr` starting from + `start_idx`. Args: - arr: The array to search within - target: The consecutive subsequence to find - start_idx: The starting index to search from + arr: The array to search within. + target: The consecutive subsequence to find. + start_idx: The starting index to search from (inclusive). + end_idx: The ending index to search from (exclusive). Returns: - int: The index of the first occurrence of target in arr. + The index of the first occurrence of `target` in `arr`. """ if start_idx < 0: - raise ValueError("start_idx must be non-negative") - if not target or not arr: - raise ValueError("Empty arr or target not allowed") + raise ValueError("`start_idx` must be non-negative") + if len(arr) == 0 or len(target) == 0: + raise ValueError("Empty `arr` or `target` not allowed") + arr_len = len(arr) target_len = len(target) - for i in range(start_idx, len(arr) - target_len + 1): - if arr[i:i + target_len] == target: + + if end_idx is None: + end_idx = arr_len + + for i in range(start_idx, min(end_idx, arr_len - target_len + 1)): + if (arr[i:i + target_len] == target).all(): return i + return -1 - def _get_instruction_len(self, prompt_token_ids: array) -> int: + def _get_instruction_len(self, prompt_token_ids: np.ndarray) -> int: """ Get the length of the instruction in the prompt. @@ -83,7 +102,6 @@ class GritLMPooler(nn.Module): The pattern matching is done using integers instead of strings because the prompt is given as a list of token IDs. """ - instruction_len = 0 # Return no instruction in case of missing BOS token. @@ -98,7 +116,8 @@ class GritLMPooler(nn.Module): embed_pattern_ids = self.embed_pattern_ids if self._find_array(prompt_token_ids, self.user_pattern_ids, - start_idx=1) == 1: + start_idx=1, + end_idx=2) == 1: embed_pattern_ids = self.embed_newline_pattern_ids # Find the embed pattern in the prompt. @@ -116,64 +135,92 @@ class GritLMPooler(nn.Module): return instruction_len + def get_pooling_updates( + self, + task: PoolingTask, + ) -> Optional[PoolingParamsUpdate]: + # The equalities are split up to keep mypy happy + if task == "encode" or task == "embed": + return PoolingParamsUpdate(requires_token_ids=True) + + if task == "classify" or task == "score": + return None + + assert_never(task) + + def forward_one( + self, + hidden_states: torch.Tensor, + prompt_len: Optional[torch.Tensor] = None, + instr_len: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert prompt_len is None or prompt_len == hidden_states.shape[0], \ + "partial prefill not supported with MEAN pooling" + + return hidden_states[instr_len:].mean(dim=0, dtype=torch.float32) + + def forward_all( + self, + hidden_states: torch.Tensor, + prompt_lens: torch.Tensor, + instr_lens: torch.Tensor, + ) -> Union[list[torch.Tensor], torch.Tensor]: + offset = 0 + pooled_data = list[torch.Tensor]() + + for prompt_len, instr_len in zip(prompt_lens, instr_lens): + pooled_data.append(hidden_states[offset + instr_len:offset + + prompt_len].mean( + dim=0, dtype=torch.float32)) + offset += prompt_len + + return pooled_data + + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) + instr_lens = torch.tensor( + [ + self._get_instruction_len(token_ids.cpu().numpy()) + for token_ids in get_prompt_token_ids(pooling_metadata) + ], + device=prompt_lens.device, + ) + + if isinstance(hidden_states, list): + return [ + self.forward_one(h, prompt_len, instr_len) for h, prompt_len, + instr_len in zip(hidden_states, prompt_lens, instr_lens) + ] + + return self.forward_all(hidden_states, prompt_lens, instr_lens) + + +class GritLMPooler(Pooler): + + def __init__(self, model_config: ModelConfig): + super().__init__() + + self.pooling = GritLMMeanPool(model_config) + self.head = PoolerHead(PoolerNormalize()) + + def get_pooling_updates( + self, + task: PoolingTask, + ) -> Optional[PoolingParamsUpdate]: + return self.pooling.get_pooling_updates(task) + def forward( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> PoolerOutput: - """ - Pool the hidden states by summing the embeddings of - non-instruction tokens. - """ - prompts_token_ids = [ - token_ids.prompt_token_ids_array - for _, token_ids in pooling_metadata.seq_data.items() - ] - - instruction_lens = torch.tensor( - [ - self._get_instruction_len(prompt_token_ids) - for prompt_token_ids in prompts_token_ids - ], - device=hidden_states.device, - ) - - prompt_lens = PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states.device).prompt_lens - - mask = torch.zeros_like(hidden_states, dtype=torch.bool) - - start_idx = 0 - for prompt_len, instruction_len in zip(prompt_lens, instruction_lens): - end_idx = start_idx + prompt_len - mask[start_idx + instruction_len:end_idx] = True - start_idx = end_idx - - masked_hidden_states = hidden_states.masked_fill(~mask, 0.0) - - sum_embeddings = torch.zeros(len(prompt_lens), - hidden_states.size(1), - device=hidden_states.device) - - start_idx = 0 - for i, prompt_len in enumerate(prompt_lens): - end_idx = start_idx + prompt_len - sum_embeddings[i] = masked_hidden_states[start_idx:end_idx].sum( - dim=0) - start_idx = end_idx - - num_non_instruction_tokens = prompt_lens - instruction_lens - mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze( - 1) - - pooled_data = self.head(mean_embeddings, - pooling_metadata=pooling_metadata) - - pooled_outputs = [ - PoolingSequenceGroupOutput(data) for data in pooled_data - ] - - return PoolerOutput(outputs=pooled_outputs) + pooled_data = self.pooling(hidden_states, pooling_metadata) + pooled_data = self.head(pooled_data, pooling_metadata) + return build_output(pooled_data) class GritLM(LlamaForCausalLM, SupportsV0Only): @@ -202,7 +249,7 @@ class GritLM(LlamaForCausalLM, SupportsV0Only): prefix: str = "", **kwargs, ) -> None: - # Use full attention for pooling + # Use full attention for pooling (this is why V1 is not supported yet) if vllm_config.model_config.runner_type == "pooling": hf_config = vllm_config.model_config.hf_config hf_config.is_causal = False diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 417f90594497a..b60f1a5b6ff20 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -599,13 +599,6 @@ def supports_cross_encoding( return is_pooling_model(model) and _supports_cross_encoding(model) -def has_step_pooler(model: Union[type[object], object]) -> bool: - """Check if the model uses step pooler.""" - from vllm.model_executor.layers.pooler import StepPooler - - return is_pooling_model(model) and isinstance(model.pooler, StepPooler) - - class SupportsQuant: """The interface required for all models that support quantization.""" diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 94a7ddcc01c93..74986f9f57340 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -14,14 +14,15 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, - PoolingMethod, PoolingTask, + PoolingMethod, + PoolingParamsUpdate, PoolingType) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.pooling_params import PoolingParams +from vllm.pooling_params import PoolingTask from vllm.sequence import IntermediateTensors from .interfaces import SupportsCrossEncoding, SupportsV0Only @@ -270,8 +271,11 @@ class ModernBertPooler(Pooler): eps=config.norm_eps, bias=config.norm_bias) - def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: - return self.pooling.get_pooling_params(task) + def get_pooling_updates( + self, + task: PoolingTask, + ) -> Optional[PoolingParamsUpdate]: + return self.pooling.get_pooling_updates(task) def forward( self, diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 1a7305727e111..868facbe2557a 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Literal, Optional import msgspec @@ -10,12 +10,14 @@ from vllm.sampling_params import RequestOutputKind if TYPE_CHECKING: from vllm.config import ModelConfig +PoolingTask = Literal["encode", "embed", "classify", "score"] + class PoolingParams( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] array_like=True): # type: ignore[call-arg] - """API parameters for pooling models. This + """API parameters for pooling models. Attributes: dimensions: Reduce the dimensions of embeddings @@ -24,24 +26,33 @@ class PoolingParams( dimensions: Optional[int] = None - use_cross_encoder: bool = False - """Internal use only.""" - - logits_processing_needs_token_ids: bool = False - """Internal use only.""" - output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY + task: Optional[PoolingTask] = None + """Internal use only.""" + + requires_token_ids: bool = False + """Internal use only.""" + def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" return PoolingParams( dimensions=self.dimensions, - use_cross_encoder=self.use_cross_encoder, - logits_processing_needs_token_ids=self. - logits_processing_needs_token_ids, + task=self.task, + requires_token_ids=self.requires_token_ids, ) - def verify(self, model_config: "ModelConfig") -> None: + def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None: + if self.task is None: + self.task = task + elif self.task != task: + msg = f"You cannot overwrite {self.task=!r} with {task=!r}!" + raise ValueError(msg) + + # NOTE: Task validation needs to done against the model instance, + # which is not available in model config. So, it's not included + # in this method + if self.dimensions is not None: if not model_config.is_matryoshka: raise ValueError( @@ -61,12 +72,10 @@ class PoolingParams( raise ValueError("Dimensions must be greater than 0") def __repr__(self) -> str: - return ( - f"PoolingParams(" - f"dimensions={self.dimensions}, " - f"use_cross_encoder={self.use_cross_encoder}, " - f"logits_processing_needs_token_ids={self.logits_processing_needs_token_ids})" - ) + return (f"PoolingParams(" + f"dimensions={self.dimensions}, " + f"task={self.task}, " + f"requires_token_ids={self.requires_token_ids})") def __post_init__(self) -> None: assert self.output_kind == RequestOutputKind.FINAL_ONLY,\ diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f5c59bef47818..b3210197750b6 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -181,6 +181,12 @@ class EngineCore: def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" + if pooling_params := request.pooling_params: + supported_pooling_tasks = ( + self.model_executor.supported_pooling_tasks) + if pooling_params.task not in supported_pooling_tasks: + raise ValueError(f"Unsupported task: {pooling_params.task!r} " + f"Supported tasks: {supported_pooling_tasks}") if request.mm_hashes is not None: # Here, if hash exists for a multimodal input, then it will be diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 410a54e7466f6..c315dcb183256 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -8,7 +8,6 @@ import torch from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models.interfaces import has_step_pooler from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) @@ -54,9 +53,6 @@ class CPUModelRunner(GPUModelRunner): logger.info("Starting to load model %s...", self.model_config.model) self.model = get_model(vllm_config=self.vllm_config) - if has_step_pooler(self.model): - self.input_batch.logits_processing_needs_token_ids = True - if self.lora_config: self.model = self.load_lora_model(self.model, self.model_config, self.scheduler_config, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1a79d72be0a9b..a242c7fca5ef1 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -70,7 +70,6 @@ class InputBatch: vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group is_spec_decode: bool = False, - logits_processing_needs_token_ids: bool = False, ): self.is_spec_decode = is_spec_decode self.max_num_reqs = max_num_reqs @@ -79,8 +78,6 @@ class InputBatch: self.device = device self.pin_memory = pin_memory self.vocab_size = vocab_size - self.logits_processing_needs_token_ids = ( - logits_processing_needs_token_ids) self._req_ids: list[Optional[str]] = [] self.req_id_to_index: dict[str, int] = {} @@ -233,6 +230,9 @@ class InputBatch: # req_index -> bad_words_token_ids self.bad_words_token_ids: dict[int, list[list[int]]] = {} + self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, + dtype=bool) + self.req_output_token_ids: list[Optional[list[int]]] = [] # This is updated each time the batch constituents change. @@ -365,9 +365,12 @@ class InputBatch: if sampling_params.bad_words_token_ids: self.bad_words_token_ids[ req_index] = sampling_params.bad_words_token_ids + elif pooling_params := request.pooling_params: + self.pooling_params[req_id] = pooling_params + self.logits_processing_needs_token_ids[req_index] = ( + pooling_params.requires_token_ids) else: - assert request.pooling_params is not None - self.pooling_params[req_id] = request.pooling_params + raise NotImplementedError(request) # Add request lora ID if request.lora_request: @@ -620,9 +623,9 @@ class InputBatch: copy_slice(self.repetition_penalties_cpu_tensor, self.repetition_penalties, num_reqs) - needs_prompt_token_ids = (not self.no_penalties or - (self.num_reqs > 0 - and self.logits_processing_needs_token_ids)) + needs_prompt_token_ids = ( + not self.no_penalties + or self.logits_processing_needs_token_ids[:num_reqs].any()) if needs_prompt_token_ids: # The prompt tokens are used only for applying penalties or # step pooling during the sampling/pooling process. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 60fb78c060c24..c3eeb6c2e390b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4,7 +4,7 @@ import gc import time from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union, cast, get_args import numpy as np import torch @@ -32,12 +32,13 @@ from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader -from vllm.model_executor.models.interfaces import (has_step_pooler, - is_mixture_of_experts) +from vllm.model_executor.models.interfaces import is_mixture_of_experts +from vllm.model_executor.models.interfaces_base import (VllmModelForPooling, + is_pooling_model) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality -from vllm.pooling_params import PoolingParams +from vllm.pooling_params import PoolingParams, PoolingTask from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, @@ -404,6 +405,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params + if sampling_params and \ sampling_params.sampling_type == SamplingType.RANDOM_SEED: generator = torch.Generator(device=self.device) @@ -411,6 +413,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): else: generator = None + if pooling_params: + assert pooling_params.task is not None, ( + "You did not set `task` in the API") + + model = cast(VllmModelForPooling, self.model) + to_update = (model.pooler.get_pooling_updates( + pooling_params.task)) + assert to_update is not None, ( + f"{pooling_params.task=} is not supported by the model") + + to_update.apply(pooling_params) + self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, @@ -1092,6 +1106,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): def get_model(self) -> nn.Module: return self.model + def get_supported_pooling_tasks(self) -> list[PoolingTask]: + model = self.get_model() + if not is_pooling_model(model): + return [] + + return [ + task for task in get_args(PoolingTask) + if model.pooler.get_pooling_updates(task) + ] + def apply_grammar_bitmask( self, scheduler_output: "SchedulerOutput", @@ -1737,8 +1761,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) model_loader.load_weights(self.model, model_config=self.model_config) - if has_step_pooler(self.model): - self.input_batch.logits_processing_needs_token_ids = True if self.lora_config: self.model = self.load_lora_model(self.model, self.model_config, @@ -2138,17 +2160,25 @@ class GPUModelRunner(LoRAModelRunnerMixin): req_num_tokens = num_tokens // num_reqs + model = cast(VllmModelForPooling, self.model) + dummy_task = self.get_supported_pooling_tasks()[0] + dummy_pooling_params = PoolingParams(task=dummy_task) + + to_update = model.pooler.get_pooling_updates(dummy_task) + assert to_update is not None + to_update.apply(dummy_pooling_params) + dummy_metadata = PoolingMetadata( prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list], device=self.device), prompt_token_ids=torch.zeros((num_reqs, req_num_tokens), dtype=torch.int32, device=self.device), - pooling_params=[PoolingParams()] * num_reqs) + pooling_params=[dummy_pooling_params] * num_reqs) try: - pooler_output = self.model.pooler(hidden_states=hidden_states_list, - pooling_metadata=dummy_metadata) + pooler_output = model.pooler(hidden_states=hidden_states_list, + pooling_metadata=dummy_metadata) except RuntimeError as e: if 'out of memory' in str(e): raise RuntimeError( diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 6458b55777a4d..1610d0ecee2f1 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -23,6 +23,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform +from vllm.pooling_params import PoolingTask from vllm.sequence import IntermediateTensors from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec @@ -309,6 +310,9 @@ class Worker(WorkerBase): def get_model(self) -> nn.Module: return self.model_runner.get_model() + def get_supported_pooling_tasks(self) -> list[PoolingTask]: + return self.model_runner.get_supported_pooling_tasks() + @torch.inference_mode() def execute_model( self, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 8565df429738f..1b55e5d61aa9a 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -3,7 +3,7 @@ import bisect import gc import time -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast, get_args from unittest.mock import patch import numpy as np @@ -25,10 +25,12 @@ from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.tpu import TPUModelLoader +from vllm.model_executor.models.interfaces_base import is_pooling_model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, PlaceholderRange) from vllm.multimodal.utils import group_mm_inputs_by_modality +from vllm.pooling_params import PoolingTask from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv, is_pin_memory_available, prev_power_of_2) @@ -483,6 +485,16 @@ class TPUModelRunner(LoRAModelRunnerMixin): def get_model(self) -> nn.Module: return self.model + def get_supported_pooling_tasks(self) -> list[PoolingTask]: + model = self.get_model() + if not is_pooling_model(model): + return [] + + return [ + task for task in get_args(PoolingTask) + if model.pooler.get_pooling_updates(task) + ] + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index c4bf40d665477..592d9fc17c9e4 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -19,6 +19,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform +from vllm.pooling_params import PoolingTask from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT from vllm.v1.core.sched.output import SchedulerOutput @@ -275,6 +276,9 @@ class TPUWorker: def get_model(self) -> nn.Module: return self.model_runner.get_model() + def get_supported_pooling_tasks(self) -> list[PoolingTask]: + return self.model_runner.get_supported_pooling_tasks() + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index d567ce4a6e78f..b0737dfe31978 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -4,7 +4,7 @@ import dataclasses from abc import ABC, abstractmethod from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, - TypeVar) + TypeVar, get_args) import torch import torch.nn as nn @@ -12,6 +12,8 @@ import torch.nn as nn from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.models.interfaces_base import is_pooling_model +from vllm.pooling_params import PoolingTask from vllm.sequence import IntermediateTensors, SequenceGroupMetadata if TYPE_CHECKING: @@ -223,6 +225,16 @@ class ModelRunnerBase(ABC, Generic[T]): def get_model(self) -> nn.Module: raise NotImplementedError + def get_supported_pooling_tasks(self) -> list[PoolingTask]: + model = self.get_model() + if not is_pooling_model(model): + return [] + + return [ + task for task in get_args(PoolingTask) + if model.pooler.get_pooling_updates(task) + ] + def execute_model( self, model_input: T, diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index f80955f71a5a3..2c3f4eb3ad4d4 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast import torch @@ -10,6 +10,7 @@ from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger +from vllm.model_executor.models.interfaces_base import VllmModelForPooling from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MultiModalKwargs from vllm.pooling_params import PoolingParams @@ -195,7 +196,20 @@ class PoolingModelRunner( seq_groups: List[Tuple[List[int], PoolingParams]] = [] for i, seq_group_metadata in enumerate(seq_group_metadata_list): seq_ids = list(seq_group_metadata.seq_data.keys()) + pooling_params = seq_group_metadata.pooling_params + assert pooling_params is not None + assert pooling_params.task is not None, ( + "You did not set `task` in the API") + + to_update = (cast(VllmModelForPooling, + self.model).pooler.get_pooling_updates( + pooling_params.task)) + assert to_update is not None, ( + f"{pooling_params.task=} is not supported by the model") + + to_update.apply(pooling_params) + seq_groups.append((seq_ids, pooling_params)) seq_data: Dict[int, SequenceData] = {}