mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 19:00:56 +08:00
[Core] Set pooling params based on task and model (#21128)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
4adc66f64d
commit
45badd05d0
@ -2,9 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
from array import array
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
from scipy.spatial.distance import cosine
|
||||
@ -14,10 +12,6 @@ from vllm.config import ModelConfig
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
# GritLM embedding implementation is only supported by XFormers backend.
|
||||
pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"),
|
||||
reason="GritLM requires XFormers")
|
||||
|
||||
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
|
||||
MAX_MODEL_LEN = 4000
|
||||
|
||||
@ -26,11 +20,11 @@ def _arr(arr):
|
||||
"""
|
||||
Convert a list of integers to an array of integers.
|
||||
"""
|
||||
return array("i", arr)
|
||||
return np.array(arr)
|
||||
|
||||
|
||||
def test_find_array():
|
||||
from vllm.model_executor.models.gritlm import GritLMPooler
|
||||
from vllm.model_executor.models.gritlm import GritLMMeanPool
|
||||
|
||||
model_config = ModelConfig(
|
||||
MODEL_NAME,
|
||||
@ -41,17 +35,19 @@ def test_find_array():
|
||||
dtype="bfloat16",
|
||||
seed=0,
|
||||
)
|
||||
pooler = GritLMPooler(model_config=model_config)
|
||||
pooling = GritLMMeanPool(model_config=model_config)
|
||||
|
||||
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
|
||||
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
|
||||
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
|
||||
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
|
||||
assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1
|
||||
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
|
||||
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
|
||||
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
|
||||
assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=3) == -1
|
||||
assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=4) == 3
|
||||
assert pooling._find_array(arr, _arr([3, 5]), start_idx=0) == -1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
|
||||
pooling._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
|
||||
|
||||
|
||||
def run_llm_encode(
|
||||
|
||||
@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput, RequestOutput,
|
||||
ScoringRequestOutput)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.pooling_params import PoolingParams, PoolingTask
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
RequestOutputKind, SamplingParams)
|
||||
@ -964,6 +964,7 @@ class LLM:
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -979,6 +980,7 @@ class LLM:
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -994,6 +996,7 @@ class LLM:
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -1010,6 +1013,7 @@ class LLM:
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -1026,6 +1030,7 @@ class LLM:
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -1040,6 +1045,7 @@ class LLM:
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -1059,6 +1065,7 @@ class LLM:
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
) -> list[PoolingRequestOutput]:
|
||||
"""Apply pooling to the hidden states corresponding to the input
|
||||
prompts.
|
||||
@ -1080,6 +1087,7 @@ class LLM:
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
pooling_task: Override the pooling task to use.
|
||||
|
||||
Returns:
|
||||
A list of `PoolingRequestOutput` objects containing the
|
||||
@ -1116,11 +1124,12 @@ class LLM:
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
elif isinstance(pooling_params, PoolingParams):
|
||||
pooling_params.verify(model_config)
|
||||
|
||||
if isinstance(pooling_params, PoolingParams):
|
||||
pooling_params.verify(pooling_task, model_config)
|
||||
else:
|
||||
for pooling_param in pooling_params:
|
||||
pooling_param.verify(model_config)
|
||||
pooling_param.verify(pooling_task, model_config)
|
||||
|
||||
tokenization_kwargs = dict[str, Any]()
|
||||
_validate_truncation_size(model_config.max_model_len,
|
||||
@ -1181,12 +1190,15 @@ class LLM:
|
||||
raise ValueError("Embedding API is not supported by this model. "
|
||||
"Please set `--task embed`.")
|
||||
|
||||
items = self.encode(prompts,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
items = self.encode(
|
||||
prompts,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
pooling_task="embed",
|
||||
)
|
||||
|
||||
return [EmbeddingRequestOutput.from_base(item) for item in items]
|
||||
|
||||
@ -1228,10 +1240,13 @@ class LLM:
|
||||
"Classification API is not supported by this model. "
|
||||
"Please set `--task classify`.")
|
||||
|
||||
items = self.encode(prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
items = self.encode(
|
||||
prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
pooling_task="classify",
|
||||
)
|
||||
|
||||
return [ClassificationRequestOutput.from_base(item) for item in items]
|
||||
|
||||
@ -1251,7 +1266,9 @@ class LLM:
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
pooling_task="embed",
|
||||
)
|
||||
|
||||
encoded_output_1: list[PoolingRequestOutput] = encoded_output[
|
||||
0:len(text_1)]
|
||||
@ -1287,7 +1304,7 @@ class LLM:
|
||||
if len(data_1) == 1:
|
||||
data_1 = data_1 * len(data_2)
|
||||
|
||||
pooling_params = PoolingParams(use_cross_encoder=True)
|
||||
pooling_params = PoolingParams(task="score")
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
|
||||
truncate_prompt_tokens, tokenization_kwargs)
|
||||
|
||||
@ -1347,8 +1347,8 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
|
||||
# --8<-- [end:score-extra-params]
|
||||
|
||||
def to_pooling_params(self, *, use_cross_encoder: bool = False):
|
||||
return PoolingParams(use_cross_encoder=use_cross_encoder)
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams()
|
||||
|
||||
|
||||
class RerankRequest(OpenAIBaseModel):
|
||||
@ -1375,8 +1375,8 @@ class RerankRequest(OpenAIBaseModel):
|
||||
|
||||
# --8<-- [end:rerank-extra-params]
|
||||
|
||||
def to_pooling_params(self, *, use_cross_encoder: bool = False):
|
||||
return PoolingParams(use_cross_encoder=use_cross_encoder)
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams()
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from typing_extensions import override
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@ -21,12 +22,14 @@ from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext,
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ClassificationMixin(OpenAIServing):
|
||||
|
||||
@override
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
@ -75,6 +78,7 @@ class ClassificationMixin(OpenAIServing):
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
@override
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
@ -158,3 +162,31 @@ class ServingClassification(ClassificationMixin):
|
||||
)
|
||||
|
||||
return await super().handle(ctx) # type: ignore
|
||||
|
||||
@override
|
||||
def _validate_request(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
if error := super()._validate_request(ctx):
|
||||
return error
|
||||
|
||||
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
|
||||
|
||||
return None
|
||||
|
||||
@override
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> Union[PoolingParams, ErrorResponse]:
|
||||
pooling_params = super()._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
try:
|
||||
pooling_params.verify("classify", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return pooling_params
|
||||
|
||||
@ -24,6 +24,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -45,6 +46,7 @@ def _get_embedding(
|
||||
|
||||
class EmbeddingMixin(OpenAIServing):
|
||||
|
||||
@override
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
@ -97,6 +99,7 @@ class EmbeddingMixin(OpenAIServing):
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
@override
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
@ -191,11 +194,20 @@ class OpenAIServingEmbedding(EmbeddingMixin):
|
||||
|
||||
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
|
||||
|
||||
pooling_params = ctx.request.to_pooling_params()
|
||||
return None
|
||||
|
||||
@override
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
ctx: ServeContext[EmbeddingRequest],
|
||||
) -> Union[PoolingParams, ErrorResponse]:
|
||||
pooling_params = super()._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
try:
|
||||
pooling_params.verify(self.model_config)
|
||||
pooling_params.verify("embed", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return None
|
||||
return pooling_params
|
||||
|
||||
@ -305,6 +305,16 @@ class OpenAIServing:
|
||||
" Please, select a smaller truncation size.")
|
||||
return None
|
||||
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
) -> Union[PoolingParams, ErrorResponse]:
|
||||
if not hasattr(ctx.request, "to_pooling_params"):
|
||||
return self.create_error_response(
|
||||
"Request type does not support pooling parameters")
|
||||
|
||||
return ctx.request.to_pooling_params()
|
||||
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
@ -318,11 +328,9 @@ class OpenAIServing:
|
||||
trace_headers = (None if ctx.raw_request is None else await
|
||||
self._get_trace_headers(ctx.raw_request.headers))
|
||||
|
||||
if not hasattr(ctx.request, "to_pooling_params"):
|
||||
return self.create_error_response(
|
||||
"Request type does not support pooling parameters")
|
||||
|
||||
pooling_params = ctx.request.to_pooling_params()
|
||||
pooling_params = self._create_pooling_params(ctx)
|
||||
if isinstance(pooling_params, ErrorResponse):
|
||||
return pooling_params
|
||||
|
||||
if ctx.engine_prompts is None:
|
||||
return self.create_error_response(
|
||||
|
||||
@ -142,6 +142,11 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
try:
|
||||
pooling_params.verify("encode", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
|
||||
@ -55,14 +55,13 @@ class ServingScores(OpenAIServing):
|
||||
texts_1: list[str],
|
||||
texts_2: list[str],
|
||||
request: Union[RerankRequest, ScoreRequest],
|
||||
request_id=str,
|
||||
request_id: str,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[Union[LoRARequest, None]] = None,
|
||||
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
|
||||
None]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
|
||||
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||
input_texts = texts_1 + texts_2
|
||||
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
@ -89,6 +88,11 @@ class ServingScores(OpenAIServing):
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
try:
|
||||
pooling_params.verify("embed", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
@ -169,14 +173,13 @@ class ServingScores(OpenAIServing):
|
||||
data_1: Union[list[str], list[ScoreContentPartParam]],
|
||||
data_2: Union[list[str], list[ScoreContentPartParam]],
|
||||
request: Union[RerankRequest, ScoreRequest],
|
||||
request_id=str,
|
||||
request_id: str,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[Union[LoRARequest, None]] = None,
|
||||
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
|
||||
None]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
|
||||
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||
request_prompts: list[str] = []
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
|
||||
@ -245,7 +248,12 @@ class ServingScores(OpenAIServing):
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
pooling_params = request.to_pooling_params(use_cross_encoder=True)
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
try:
|
||||
pooling_params.verify("score", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
@ -286,8 +294,7 @@ class ServingScores(OpenAIServing):
|
||||
request_id: str,
|
||||
raw_request: Optional[Request] = None,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
|
||||
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
@ -374,6 +381,8 @@ class ServingScores(OpenAIServing):
|
||||
raw_request,
|
||||
request.truncate_prompt_tokens,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
|
||||
return self.request_output_to_score_response(
|
||||
final_res_batch,
|
||||
@ -420,6 +429,9 @@ class ServingScores(OpenAIServing):
|
||||
raw_request,
|
||||
request.truncate_prompt_tokens,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
|
||||
return self.request_output_to_rerank_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
|
||||
Union)
|
||||
|
||||
@ -15,6 +16,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||
from vllm.utils import make_async
|
||||
@ -135,6 +137,11 @@ class ExecutorBase(ABC):
|
||||
|
||||
return self.collective_rpc(rpc_func)
|
||||
|
||||
@cached_property # Avoid unnecessary RPC calls
|
||||
def supported_pooling_tasks(self) -> tuple[PoolingTask, ...]:
|
||||
output = self.collective_rpc("get_supported_pooling_tasks")
|
||||
return tuple({task for tasks in output for task in tasks})
|
||||
|
||||
def execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Callable, Literal, Optional, TypeVar, Union
|
||||
from typing import Callable, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -15,13 +15,12 @@ from vllm.config import ModelConfig, PoolerConfig
|
||||
from vllm.model_executor.pooling_metadata import ( # noqa: E501
|
||||
PoolingMetadata as V0PoolingMetadata)
|
||||
from vllm.model_executor.pooling_metadata import PoolingTensors
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.pooling_params import PoolingParams, PoolingTask
|
||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
||||
|
||||
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
|
||||
PoolingTask = Literal["encode", "embed", "classify", "score"]
|
||||
|
||||
|
||||
class PoolingType(IntEnum):
|
||||
@ -67,6 +66,15 @@ class ResolvedPoolingConfig:
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PoolingParamsUpdate:
|
||||
requires_token_ids: bool = False
|
||||
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
|
||||
|
||||
def apply(self, params: PoolingParams) -> None:
|
||||
params.requires_token_ids = self.requires_token_ids
|
||||
|
||||
|
||||
class Pooler(nn.Module, ABC):
|
||||
"""The interface required for all poolers used in pooling models in vLLM."""
|
||||
|
||||
@ -93,7 +101,10 @@ class Pooler(nn.Module, ABC):
|
||||
|
||||
return SimplePooler.from_config(resolved_config)
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
"""
|
||||
Construct the pooling parameters to use for a task,
|
||||
or `None` if the task is not supported.
|
||||
@ -121,6 +132,23 @@ def get_prompt_lens(
|
||||
pooling_metadata, hidden_states.device).prompt_lens
|
||||
|
||||
|
||||
def get_prompt_token_ids(
|
||||
pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
|
||||
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||
assert pooling_metadata.prompt_token_ids is not None, (
|
||||
"Please set `requires_token_ids=True` in `get_pooling_updates`")
|
||||
|
||||
return [
|
||||
pooling_metadata.prompt_token_ids[i, :num]
|
||||
for i, num in enumerate(pooling_metadata.prompt_lens)
|
||||
]
|
||||
|
||||
return [
|
||||
torch.tensor(seq_data_i.prompt_token_ids)
|
||||
for seq_data_i in pooling_metadata.seq_data.values()
|
||||
]
|
||||
|
||||
|
||||
def get_classification_activation_function(config: PretrainedConfig):
|
||||
return PoolerClassify()
|
||||
|
||||
@ -165,7 +193,10 @@ class PoolingMethod(nn.Module, ABC):
|
||||
raise NotImplementedError(f"Unsupported method: {pooling_type}")
|
||||
|
||||
@abstractmethod
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@ -206,11 +237,14 @@ class PoolingMethod(nn.Module, ABC):
|
||||
|
||||
class CLSPool(PoolingMethod):
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
# The equalities are split up to keep mypy happy
|
||||
if (task == "encode" or task == "embed" or task == "classify"
|
||||
or task == "score"):
|
||||
return PoolingParams()
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
assert_never(task)
|
||||
|
||||
@ -236,11 +270,14 @@ class CLSPool(PoolingMethod):
|
||||
|
||||
class LastPool(PoolingMethod):
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
# The equalities are split up to keep mypy happy
|
||||
if (task == "encode" or task == "embed" or task == "classify"
|
||||
or task == "score"):
|
||||
return PoolingParams()
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
assert_never(task)
|
||||
|
||||
@ -262,9 +299,12 @@ class LastPool(PoolingMethod):
|
||||
|
||||
class AllPool(PoolingMethod):
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
if task == "encode":
|
||||
return PoolingParams()
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
# The equalities are split up to keep mypy happy
|
||||
if task == "embed" or task == "classify" or task == "score":
|
||||
@ -299,11 +339,14 @@ class AllPool(PoolingMethod):
|
||||
|
||||
class MeanPool(PoolingMethod):
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
# The equalities are split up to keep mypy happy
|
||||
if (task == "encode" or task == "embed" or task == "classify"
|
||||
or task == "score"):
|
||||
return PoolingParams()
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
assert_never(task)
|
||||
|
||||
@ -520,8 +563,11 @@ class SimplePooler(Pooler):
|
||||
self.pooling = pooling
|
||||
self.head = head
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
return self.pooling.get_pooling_params(task)
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -559,27 +605,13 @@ class StepPooler(Pooler):
|
||||
self.step_tag_id = step_tag_id
|
||||
self.returned_token_ids = returned_token_ids
|
||||
|
||||
def get_prompt_token_ids(
|
||||
self,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[torch.Tensor]:
|
||||
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||
return [
|
||||
pooling_metadata.prompt_token_ids[i, :num]
|
||||
for i, num in enumerate(pooling_metadata.prompt_lens)
|
||||
]
|
||||
return [
|
||||
torch.tensor(seq_data_i.prompt_token_ids)
|
||||
for seq_data_i in pooling_metadata.seq_data.values()
|
||||
]
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)
|
||||
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
|
||||
|
||||
pooled_data = list[torch.Tensor]()
|
||||
returned_token_ids = self.returned_token_ids
|
||||
@ -595,9 +627,12 @@ class StepPooler(Pooler):
|
||||
|
||||
return pooled_data
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
if task == "encode":
|
||||
return PoolingParams(logits_processing_needs_token_ids=True)
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
# The equalities are split up to keep mypy happy
|
||||
if task == "embed" or task == "classify" or task == "score":
|
||||
@ -650,19 +685,24 @@ class ClassifierPooler(nn.Module):
|
||||
self.cross_encoder_act_fn = get_cross_encoder_activation_function(
|
||||
config.hf_config) if act_fn is None else act_fn
|
||||
|
||||
def _get_act_fn(self, use_cross_encoder: bool):
|
||||
return (self.cross_encoder_act_fn
|
||||
if use_cross_encoder else self.classification_act_fn)
|
||||
def _get_act_fn(self, task: PoolingTask):
|
||||
if task == "encode" or task == "classify":
|
||||
return self.classification_act_fn
|
||||
if task == "score":
|
||||
return self.cross_encoder_act_fn
|
||||
|
||||
raise ValueError(f"Unsupported task: {task!r}")
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
# The equalities are split up to keep mypy happy
|
||||
if task == "encode" or task == "classify" or task == "score":
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
if task == "encode":
|
||||
return PoolingParams()
|
||||
if task == "embed":
|
||||
return None
|
||||
if task == "classify":
|
||||
return PoolingParams()
|
||||
if task == "score":
|
||||
return PoolingParams(use_cross_encoder=True)
|
||||
|
||||
assert_never(task)
|
||||
|
||||
@ -682,27 +722,28 @@ class ClassifierPooler(nn.Module):
|
||||
else:
|
||||
pooled_output = [self.classifier(data) for data in pooled_data]
|
||||
|
||||
task_list: list[PoolingTask]
|
||||
if isinstance(pooling_metadata, V0PoolingMetadata):
|
||||
use_cross_encoder_list = [
|
||||
pooling_param.use_cross_encoder
|
||||
for _, pooling_param in pooling_metadata.seq_groups
|
||||
task_list = [
|
||||
task for _, pooling_param in pooling_metadata.seq_groups
|
||||
if (task := pooling_param.task) is not None
|
||||
]
|
||||
else:
|
||||
use_cross_encoder_list = [
|
||||
pooling_param.use_cross_encoder
|
||||
for pooling_param in pooling_metadata.pooling_params
|
||||
task_list = [
|
||||
task for pooling_param in pooling_metadata.pooling_params
|
||||
if (task := pooling_param.task) is not None
|
||||
]
|
||||
|
||||
assert len(task_list) == len(pooled_output)
|
||||
|
||||
# shape of scores: (batch_size, num_labels)
|
||||
if all(use_cross_encoder == use_cross_encoder_list[0]
|
||||
for use_cross_encoder in use_cross_encoder_list):
|
||||
act_fn = self._get_act_fn(use_cross_encoder_list[0])
|
||||
if len(set(task_list)) <= 1:
|
||||
act_fn = self._get_act_fn(task_list[0])
|
||||
scores = act_fn(pooled_output)
|
||||
else:
|
||||
scores = torch.stack([
|
||||
self._get_act_fn(use_cross_encoder)(vecs)
|
||||
for use_cross_encoder, vecs in zip(use_cross_encoder_list,
|
||||
pooled_output)
|
||||
self._get_act_fn(task)(vecs)
|
||||
for task, vecs in zip(task_list, pooled_output)
|
||||
])
|
||||
|
||||
return build_output(scores)
|
||||
|
||||
@ -18,13 +18,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
||||
PoolingMethod, PoolingTask,
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
PoolingType)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
|
||||
@ -91,8 +92,11 @@ class BertPooler(Pooler):
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
return self.pooling.get_pooling_params(task)
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -1,18 +1,24 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from array import array
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import PoolerHead, PoolerNormalize
|
||||
from vllm.model_executor.layers.pooler import (Pooler, PoolerHead,
|
||||
PoolerNormalize,
|
||||
PoolingParamsUpdate,
|
||||
build_output, get_prompt_lens,
|
||||
get_prompt_token_ids)
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||
PoolingTensors)
|
||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.sequence import PoolerOutput
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
|
||||
from .interfaces import SupportsV0Only
|
||||
@ -20,7 +26,8 @@ from .interfaces import SupportsV0Only
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GritLMPooler(nn.Module):
|
||||
class GritLMMeanPool(nn.Module):
|
||||
"""As `MeanPool`, but only includes non-instruction tokens."""
|
||||
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
super().__init__()
|
||||
@ -38,8 +45,8 @@ class GritLMPooler(nn.Module):
|
||||
for tok in ["<s>", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"]
|
||||
}
|
||||
|
||||
def tokens_to_ids(tokens: list[str]) -> array:
|
||||
return array("i", [self.token_ids[token] for token in tokens])
|
||||
def tokens_to_ids(tokens: list[str]) -> np.ndarray:
|
||||
return np.array([self.token_ids[token] for token in tokens])
|
||||
|
||||
self.user_pattern_ids = tokens_to_ids(
|
||||
["▁<", "|", "user", "|", ">", "<0x0A>"])
|
||||
@ -48,32 +55,44 @@ class GritLMPooler(nn.Module):
|
||||
self.embed_pattern_ids = tokens_to_ids(
|
||||
["▁<", "|", "embed", "|", ">", "<0x0A>"])
|
||||
|
||||
self.head = PoolerHead(PoolerNormalize())
|
||||
|
||||
def _find_array(self, arr: array, target: array, start_idx: int) -> int:
|
||||
def _find_array(
|
||||
self,
|
||||
arr: np.ndarray,
|
||||
target: np.ndarray,
|
||||
start_idx: int = 0,
|
||||
end_idx: Optional[int] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the first occurrence of target in arr starting from start_idx.
|
||||
Find the first occurrence of `target` in `arr` starting from
|
||||
`start_idx`.
|
||||
|
||||
Args:
|
||||
arr: The array to search within
|
||||
target: The consecutive subsequence to find
|
||||
start_idx: The starting index to search from
|
||||
arr: The array to search within.
|
||||
target: The consecutive subsequence to find.
|
||||
start_idx: The starting index to search from (inclusive).
|
||||
end_idx: The ending index to search from (exclusive).
|
||||
|
||||
Returns:
|
||||
int: The index of the first occurrence of target in arr.
|
||||
The index of the first occurrence of `target` in `arr`.
|
||||
"""
|
||||
if start_idx < 0:
|
||||
raise ValueError("start_idx must be non-negative")
|
||||
if not target or not arr:
|
||||
raise ValueError("Empty arr or target not allowed")
|
||||
raise ValueError("`start_idx` must be non-negative")
|
||||
if len(arr) == 0 or len(target) == 0:
|
||||
raise ValueError("Empty `arr` or `target` not allowed")
|
||||
|
||||
arr_len = len(arr)
|
||||
target_len = len(target)
|
||||
for i in range(start_idx, len(arr) - target_len + 1):
|
||||
if arr[i:i + target_len] == target:
|
||||
|
||||
if end_idx is None:
|
||||
end_idx = arr_len
|
||||
|
||||
for i in range(start_idx, min(end_idx, arr_len - target_len + 1)):
|
||||
if (arr[i:i + target_len] == target).all():
|
||||
return i
|
||||
|
||||
return -1
|
||||
|
||||
def _get_instruction_len(self, prompt_token_ids: array) -> int:
|
||||
def _get_instruction_len(self, prompt_token_ids: np.ndarray) -> int:
|
||||
"""
|
||||
Get the length of the instruction in the prompt.
|
||||
|
||||
@ -83,7 +102,6 @@ class GritLMPooler(nn.Module):
|
||||
The pattern matching is done using integers instead of strings
|
||||
because the prompt is given as a list of token IDs.
|
||||
"""
|
||||
|
||||
instruction_len = 0
|
||||
|
||||
# Return no instruction in case of missing BOS token.
|
||||
@ -98,7 +116,8 @@ class GritLMPooler(nn.Module):
|
||||
embed_pattern_ids = self.embed_pattern_ids
|
||||
if self._find_array(prompt_token_ids,
|
||||
self.user_pattern_ids,
|
||||
start_idx=1) == 1:
|
||||
start_idx=1,
|
||||
end_idx=2) == 1:
|
||||
embed_pattern_ids = self.embed_newline_pattern_ids
|
||||
|
||||
# Find the embed pattern in the prompt.
|
||||
@ -116,64 +135,92 @@ class GritLMPooler(nn.Module):
|
||||
|
||||
return instruction_len
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
# The equalities are split up to keep mypy happy
|
||||
if task == "encode" or task == "embed":
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
if task == "classify" or task == "score":
|
||||
return None
|
||||
|
||||
assert_never(task)
|
||||
|
||||
def forward_one(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_len: Optional[torch.Tensor] = None,
|
||||
instr_len: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
|
||||
"partial prefill not supported with MEAN pooling"
|
||||
|
||||
return hidden_states[instr_len:].mean(dim=0, dtype=torch.float32)
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
prompt_lens: torch.Tensor,
|
||||
instr_lens: torch.Tensor,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
offset = 0
|
||||
pooled_data = list[torch.Tensor]()
|
||||
|
||||
for prompt_len, instr_len in zip(prompt_lens, instr_lens):
|
||||
pooled_data.append(hidden_states[offset + instr_len:offset +
|
||||
prompt_len].mean(
|
||||
dim=0, dtype=torch.float32))
|
||||
offset += prompt_len
|
||||
|
||||
return pooled_data
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
|
||||
instr_lens = torch.tensor(
|
||||
[
|
||||
self._get_instruction_len(token_ids.cpu().numpy())
|
||||
for token_ids in get_prompt_token_ids(pooling_metadata)
|
||||
],
|
||||
device=prompt_lens.device,
|
||||
)
|
||||
|
||||
if isinstance(hidden_states, list):
|
||||
return [
|
||||
self.forward_one(h, prompt_len, instr_len) for h, prompt_len,
|
||||
instr_len in zip(hidden_states, prompt_lens, instr_lens)
|
||||
]
|
||||
|
||||
return self.forward_all(hidden_states, prompt_lens, instr_lens)
|
||||
|
||||
|
||||
class GritLMPooler(Pooler):
|
||||
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
super().__init__()
|
||||
|
||||
self.pooling = GritLMMeanPool(model_config)
|
||||
self.head = PoolerHead(PoolerNormalize())
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
"""
|
||||
Pool the hidden states by summing the embeddings of
|
||||
non-instruction tokens.
|
||||
"""
|
||||
prompts_token_ids = [
|
||||
token_ids.prompt_token_ids_array
|
||||
for _, token_ids in pooling_metadata.seq_data.items()
|
||||
]
|
||||
|
||||
instruction_lens = torch.tensor(
|
||||
[
|
||||
self._get_instruction_len(prompt_token_ids)
|
||||
for prompt_token_ids in prompts_token_ids
|
||||
],
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
prompt_lens = PoolingTensors.from_pooling_metadata(
|
||||
pooling_metadata, hidden_states.device).prompt_lens
|
||||
|
||||
mask = torch.zeros_like(hidden_states, dtype=torch.bool)
|
||||
|
||||
start_idx = 0
|
||||
for prompt_len, instruction_len in zip(prompt_lens, instruction_lens):
|
||||
end_idx = start_idx + prompt_len
|
||||
mask[start_idx + instruction_len:end_idx] = True
|
||||
start_idx = end_idx
|
||||
|
||||
masked_hidden_states = hidden_states.masked_fill(~mask, 0.0)
|
||||
|
||||
sum_embeddings = torch.zeros(len(prompt_lens),
|
||||
hidden_states.size(1),
|
||||
device=hidden_states.device)
|
||||
|
||||
start_idx = 0
|
||||
for i, prompt_len in enumerate(prompt_lens):
|
||||
end_idx = start_idx + prompt_len
|
||||
sum_embeddings[i] = masked_hidden_states[start_idx:end_idx].sum(
|
||||
dim=0)
|
||||
start_idx = end_idx
|
||||
|
||||
num_non_instruction_tokens = prompt_lens - instruction_lens
|
||||
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
|
||||
1)
|
||||
|
||||
pooled_data = self.head(mean_embeddings,
|
||||
pooling_metadata=pooling_metadata)
|
||||
|
||||
pooled_outputs = [
|
||||
PoolingSequenceGroupOutput(data) for data in pooled_data
|
||||
]
|
||||
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return build_output(pooled_data)
|
||||
|
||||
|
||||
class GritLM(LlamaForCausalLM, SupportsV0Only):
|
||||
@ -202,7 +249,7 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# Use full attention for pooling
|
||||
# Use full attention for pooling (this is why V1 is not supported yet)
|
||||
if vllm_config.model_config.runner_type == "pooling":
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
hf_config.is_causal = False
|
||||
|
||||
@ -599,13 +599,6 @@ def supports_cross_encoding(
|
||||
return is_pooling_model(model) and _supports_cross_encoding(model)
|
||||
|
||||
|
||||
def has_step_pooler(model: Union[type[object], object]) -> bool:
|
||||
"""Check if the model uses step pooler."""
|
||||
from vllm.model_executor.layers.pooler import StepPooler
|
||||
|
||||
return is_pooling_model(model) and isinstance(model.pooler, StepPooler)
|
||||
|
||||
|
||||
class SupportsQuant:
|
||||
"""The interface required for all models that support quantization."""
|
||||
|
||||
|
||||
@ -14,14 +14,15 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
||||
PoolingMethod, PoolingTask,
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
PoolingType)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
||||
@ -270,8 +271,11 @@ class ModernBertPooler(Pooler):
|
||||
eps=config.norm_eps,
|
||||
bias=config.norm_bias)
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
return self.pooling.get_pooling_params(task)
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import msgspec
|
||||
|
||||
@ -10,12 +10,14 @@ from vllm.sampling_params import RequestOutputKind
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
PoolingTask = Literal["encode", "embed", "classify", "score"]
|
||||
|
||||
|
||||
class PoolingParams(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""API parameters for pooling models. This
|
||||
"""API parameters for pooling models.
|
||||
|
||||
Attributes:
|
||||
dimensions: Reduce the dimensions of embeddings
|
||||
@ -24,24 +26,33 @@ class PoolingParams(
|
||||
|
||||
dimensions: Optional[int] = None
|
||||
|
||||
use_cross_encoder: bool = False
|
||||
"""Internal use only."""
|
||||
|
||||
logits_processing_needs_token_ids: bool = False
|
||||
"""Internal use only."""
|
||||
|
||||
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
task: Optional[PoolingTask] = None
|
||||
"""Internal use only."""
|
||||
|
||||
requires_token_ids: bool = False
|
||||
"""Internal use only."""
|
||||
|
||||
def clone(self) -> "PoolingParams":
|
||||
"""Returns a deep copy of the PoolingParams instance."""
|
||||
return PoolingParams(
|
||||
dimensions=self.dimensions,
|
||||
use_cross_encoder=self.use_cross_encoder,
|
||||
logits_processing_needs_token_ids=self.
|
||||
logits_processing_needs_token_ids,
|
||||
task=self.task,
|
||||
requires_token_ids=self.requires_token_ids,
|
||||
)
|
||||
|
||||
def verify(self, model_config: "ModelConfig") -> None:
|
||||
def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None:
|
||||
if self.task is None:
|
||||
self.task = task
|
||||
elif self.task != task:
|
||||
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
|
||||
raise ValueError(msg)
|
||||
|
||||
# NOTE: Task validation needs to done against the model instance,
|
||||
# which is not available in model config. So, it's not included
|
||||
# in this method
|
||||
|
||||
if self.dimensions is not None:
|
||||
if not model_config.is_matryoshka:
|
||||
raise ValueError(
|
||||
@ -61,12 +72,10 @@ class PoolingParams(
|
||||
raise ValueError("Dimensions must be greater than 0")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PoolingParams("
|
||||
f"dimensions={self.dimensions}, "
|
||||
f"use_cross_encoder={self.use_cross_encoder}, "
|
||||
f"logits_processing_needs_token_ids={self.logits_processing_needs_token_ids})"
|
||||
)
|
||||
return (f"PoolingParams("
|
||||
f"dimensions={self.dimensions}, "
|
||||
f"task={self.task}, "
|
||||
f"requires_token_ids={self.requires_token_ids})")
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\
|
||||
|
||||
@ -181,6 +181,12 @@ class EngineCore:
|
||||
|
||||
def add_request(self, request: EngineCoreRequest):
|
||||
"""Add request to the scheduler."""
|
||||
if pooling_params := request.pooling_params:
|
||||
supported_pooling_tasks = (
|
||||
self.model_executor.supported_pooling_tasks)
|
||||
if pooling_params.task not in supported_pooling_tasks:
|
||||
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
|
||||
f"Supported tasks: {supported_pooling_tasks}")
|
||||
|
||||
if request.mm_hashes is not None:
|
||||
# Here, if hash exists for a multimodal input, then it will be
|
||||
|
||||
@ -8,7 +8,6 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.interfaces import has_step_pooler
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -54,9 +53,6 @@ class CPUModelRunner(GPUModelRunner):
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
if has_step_pooler(self.model):
|
||||
self.input_batch.logits_processing_needs_token_ids = True
|
||||
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model, self.model_config,
|
||||
self.scheduler_config,
|
||||
|
||||
@ -70,7 +70,6 @@ class InputBatch:
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
is_spec_decode: bool = False,
|
||||
logits_processing_needs_token_ids: bool = False,
|
||||
):
|
||||
self.is_spec_decode = is_spec_decode
|
||||
self.max_num_reqs = max_num_reqs
|
||||
@ -79,8 +78,6 @@ class InputBatch:
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
self.logits_processing_needs_token_ids = (
|
||||
logits_processing_needs_token_ids)
|
||||
|
||||
self._req_ids: list[Optional[str]] = []
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
@ -233,6 +230,9 @@ class InputBatch:
|
||||
# req_index -> bad_words_token_ids
|
||||
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
||||
|
||||
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
|
||||
dtype=bool)
|
||||
|
||||
self.req_output_token_ids: list[Optional[list[int]]] = []
|
||||
|
||||
# This is updated each time the batch constituents change.
|
||||
@ -365,9 +365,12 @@ class InputBatch:
|
||||
if sampling_params.bad_words_token_ids:
|
||||
self.bad_words_token_ids[
|
||||
req_index] = sampling_params.bad_words_token_ids
|
||||
elif pooling_params := request.pooling_params:
|
||||
self.pooling_params[req_id] = pooling_params
|
||||
self.logits_processing_needs_token_ids[req_index] = (
|
||||
pooling_params.requires_token_ids)
|
||||
else:
|
||||
assert request.pooling_params is not None
|
||||
self.pooling_params[req_id] = request.pooling_params
|
||||
raise NotImplementedError(request)
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
@ -620,9 +623,9 @@ class InputBatch:
|
||||
copy_slice(self.repetition_penalties_cpu_tensor,
|
||||
self.repetition_penalties, num_reqs)
|
||||
|
||||
needs_prompt_token_ids = (not self.no_penalties or
|
||||
(self.num_reqs > 0
|
||||
and self.logits_processing_needs_token_ids))
|
||||
needs_prompt_token_ids = (
|
||||
not self.no_penalties
|
||||
or self.logits_processing_needs_token_ids[:num_reqs].any())
|
||||
if needs_prompt_token_ids:
|
||||
# The prompt tokens are used only for applying penalties or
|
||||
# step pooling during the sampling/pooling process.
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import gc
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast, get_args
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -32,12 +32,13 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.model_executor.models.interfaces import (has_step_pooler,
|
||||
is_mixture_of_experts)
|
||||
from vllm.model_executor.models.interfaces import is_mixture_of_experts
|
||||
from vllm.model_executor.models.interfaces_base import (VllmModelForPooling,
|
||||
is_pooling_model)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.pooling_params import PoolingParams, PoolingTask
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
@ -404,6 +405,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
req_id = new_req_data.req_id
|
||||
sampling_params = new_req_data.sampling_params
|
||||
pooling_params = new_req_data.pooling_params
|
||||
|
||||
if sampling_params and \
|
||||
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
@ -411,6 +413,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
generator = None
|
||||
|
||||
if pooling_params:
|
||||
assert pooling_params.task is not None, (
|
||||
"You did not set `task` in the API")
|
||||
|
||||
model = cast(VllmModelForPooling, self.model)
|
||||
to_update = (model.pooler.get_pooling_updates(
|
||||
pooling_params.task))
|
||||
assert to_update is not None, (
|
||||
f"{pooling_params.task=} is not supported by the model")
|
||||
|
||||
to_update.apply(pooling_params)
|
||||
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
@ -1092,6 +1106,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
model = self.get_model()
|
||||
if not is_pooling_model(model):
|
||||
return []
|
||||
|
||||
return [
|
||||
task for task in get_args(PoolingTask)
|
||||
if model.pooler.get_pooling_updates(task)
|
||||
]
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@ -1737,8 +1761,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
model_loader.load_weights(self.model,
|
||||
model_config=self.model_config)
|
||||
if has_step_pooler(self.model):
|
||||
self.input_batch.logits_processing_needs_token_ids = True
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model,
|
||||
self.model_config,
|
||||
@ -2138,17 +2160,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
req_num_tokens = num_tokens // num_reqs
|
||||
|
||||
model = cast(VllmModelForPooling, self.model)
|
||||
dummy_task = self.get_supported_pooling_tasks()[0]
|
||||
dummy_pooling_params = PoolingParams(task=dummy_task)
|
||||
|
||||
to_update = model.pooler.get_pooling_updates(dummy_task)
|
||||
assert to_update is not None
|
||||
to_update.apply(dummy_pooling_params)
|
||||
|
||||
dummy_metadata = PoolingMetadata(
|
||||
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
|
||||
device=self.device),
|
||||
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
|
||||
dtype=torch.int32,
|
||||
device=self.device),
|
||||
pooling_params=[PoolingParams()] * num_reqs)
|
||||
pooling_params=[dummy_pooling_params] * num_reqs)
|
||||
|
||||
try:
|
||||
pooler_output = self.model.pooler(hidden_states=hidden_states_list,
|
||||
pooling_metadata=dummy_metadata)
|
||||
pooler_output = model.pooler(hidden_states=hidden_states_list,
|
||||
pooling_metadata=dummy_metadata)
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
raise RuntimeError(
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
@ -309,6 +310,9 @@ class Worker(WorkerBase):
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
return self.model_runner.get_supported_pooling_tasks()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import bisect
|
||||
import gc
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast, get_args
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
@ -25,10 +25,12 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.tpu import TPUModelLoader
|
||||
from vllm.model_executor.models.interfaces_base import is_pooling_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
||||
PlaceholderRange)
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
|
||||
is_pin_memory_available, prev_power_of_2)
|
||||
@ -483,6 +485,16 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
model = self.get_model()
|
||||
if not is_pooling_model(model):
|
||||
return []
|
||||
|
||||
return [
|
||||
task for task in get_args(PoolingTask)
|
||||
if model.pooler.get_pooling_updates(task)
|
||||
]
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -275,6 +276,9 @@ class TPUWorker:
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
return self.model_runner.get_supported_pooling_tasks()
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
|
||||
TypeVar)
|
||||
TypeVar, get_args)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -12,6 +12,8 @@ import torch.nn as nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.models.interfaces_base import is_pooling_model
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -223,6 +225,16 @@ class ModelRunnerBase(ABC, Generic[T]):
|
||||
def get_model(self) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
model = self.get_model()
|
||||
if not is_pooling_model(model):
|
||||
return []
|
||||
|
||||
return [
|
||||
task for task in get_args(PoolingTask)
|
||||
if model.pooler.get_pooling_updates(task)
|
||||
]
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: T,
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
import torch
|
||||
|
||||
@ -10,6 +10,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -195,7 +196,20 @@ class PoolingModelRunner(
|
||||
seq_groups: List[Tuple[List[int], PoolingParams]] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
|
||||
pooling_params = seq_group_metadata.pooling_params
|
||||
assert pooling_params is not None
|
||||
assert pooling_params.task is not None, (
|
||||
"You did not set `task` in the API")
|
||||
|
||||
to_update = (cast(VllmModelForPooling,
|
||||
self.model).pooler.get_pooling_updates(
|
||||
pooling_params.task))
|
||||
assert to_update is not None, (
|
||||
f"{pooling_params.task=} is not supported by the model")
|
||||
|
||||
to_update.apply(pooling_params)
|
||||
|
||||
seq_groups.append((seq_ids, pooling_params))
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user