mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 19:55:42 +08:00
[Core] Set pooling params based on task and model (#21128)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
4adc66f64d
commit
45badd05d0
@ -2,9 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib.util
|
import numpy as np
|
||||||
from array import array
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
from scipy.spatial.distance import cosine
|
from scipy.spatial.distance import cosine
|
||||||
@ -14,10 +12,6 @@ from vllm.config import ModelConfig
|
|||||||
|
|
||||||
from ....utils import RemoteOpenAIServer
|
from ....utils import RemoteOpenAIServer
|
||||||
|
|
||||||
# GritLM embedding implementation is only supported by XFormers backend.
|
|
||||||
pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"),
|
|
||||||
reason="GritLM requires XFormers")
|
|
||||||
|
|
||||||
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
|
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
|
||||||
MAX_MODEL_LEN = 4000
|
MAX_MODEL_LEN = 4000
|
||||||
|
|
||||||
@ -26,11 +20,11 @@ def _arr(arr):
|
|||||||
"""
|
"""
|
||||||
Convert a list of integers to an array of integers.
|
Convert a list of integers to an array of integers.
|
||||||
"""
|
"""
|
||||||
return array("i", arr)
|
return np.array(arr)
|
||||||
|
|
||||||
|
|
||||||
def test_find_array():
|
def test_find_array():
|
||||||
from vllm.model_executor.models.gritlm import GritLMPooler
|
from vllm.model_executor.models.gritlm import GritLMMeanPool
|
||||||
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
MODEL_NAME,
|
MODEL_NAME,
|
||||||
@ -41,17 +35,19 @@ def test_find_array():
|
|||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
pooler = GritLMPooler(model_config=model_config)
|
pooling = GritLMMeanPool(model_config=model_config)
|
||||||
|
|
||||||
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||||
|
|
||||||
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
|
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3
|
||||||
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
|
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3
|
||||||
assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
|
assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1
|
||||||
assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1
|
assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=3) == -1
|
||||||
|
assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=4) == 3
|
||||||
|
assert pooling._find_array(arr, _arr([3, 5]), start_idx=0) == -1
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
|
pooling._find_array(arr, _arr([3, 4, 5]), start_idx=-1)
|
||||||
|
|
||||||
|
|
||||||
def run_llm_encode(
|
def run_llm_encode(
|
||||||
|
|||||||
@ -44,7 +44,7 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
|
|||||||
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
|
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
|
||||||
PoolingRequestOutput, RequestOutput,
|
PoolingRequestOutput, RequestOutput,
|
||||||
ScoringRequestOutput)
|
ScoringRequestOutput)
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams, PoolingTask
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||||
RequestOutputKind, SamplingParams)
|
RequestOutputKind, SamplingParams)
|
||||||
@ -964,6 +964,7 @@ class LLM:
|
|||||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
pooling_task: PoolingTask = "encode",
|
||||||
) -> list[PoolingRequestOutput]:
|
) -> list[PoolingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -979,6 +980,7 @@ class LLM:
|
|||||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
pooling_task: PoolingTask = "encode",
|
||||||
) -> list[PoolingRequestOutput]:
|
) -> list[PoolingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -994,6 +996,7 @@ class LLM:
|
|||||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
pooling_task: PoolingTask = "encode",
|
||||||
) -> list[PoolingRequestOutput]:
|
) -> list[PoolingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -1010,6 +1013,7 @@ class LLM:
|
|||||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
pooling_task: PoolingTask = "encode",
|
||||||
) -> list[PoolingRequestOutput]:
|
) -> list[PoolingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -1026,6 +1030,7 @@ class LLM:
|
|||||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
pooling_task: PoolingTask = "encode",
|
||||||
) -> list[PoolingRequestOutput]:
|
) -> list[PoolingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -1040,6 +1045,7 @@ class LLM:
|
|||||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
pooling_task: PoolingTask = "encode",
|
||||||
) -> list[PoolingRequestOutput]:
|
) -> list[PoolingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -1059,6 +1065,7 @@ class LLM:
|
|||||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
pooling_task: PoolingTask = "encode",
|
||||||
) -> list[PoolingRequestOutput]:
|
) -> list[PoolingRequestOutput]:
|
||||||
"""Apply pooling to the hidden states corresponding to the input
|
"""Apply pooling to the hidden states corresponding to the input
|
||||||
prompts.
|
prompts.
|
||||||
@ -1080,6 +1087,7 @@ class LLM:
|
|||||||
lora_request: LoRA request to use for generation, if any.
|
lora_request: LoRA request to use for generation, if any.
|
||||||
prompt_adapter_request: Prompt Adapter request to use for
|
prompt_adapter_request: Prompt Adapter request to use for
|
||||||
generation, if any.
|
generation, if any.
|
||||||
|
pooling_task: Override the pooling task to use.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of `PoolingRequestOutput` objects containing the
|
A list of `PoolingRequestOutput` objects containing the
|
||||||
@ -1116,11 +1124,12 @@ class LLM:
|
|||||||
if pooling_params is None:
|
if pooling_params is None:
|
||||||
# Use default pooling params.
|
# Use default pooling params.
|
||||||
pooling_params = PoolingParams()
|
pooling_params = PoolingParams()
|
||||||
elif isinstance(pooling_params, PoolingParams):
|
|
||||||
pooling_params.verify(model_config)
|
if isinstance(pooling_params, PoolingParams):
|
||||||
|
pooling_params.verify(pooling_task, model_config)
|
||||||
else:
|
else:
|
||||||
for pooling_param in pooling_params:
|
for pooling_param in pooling_params:
|
||||||
pooling_param.verify(model_config)
|
pooling_param.verify(pooling_task, model_config)
|
||||||
|
|
||||||
tokenization_kwargs = dict[str, Any]()
|
tokenization_kwargs = dict[str, Any]()
|
||||||
_validate_truncation_size(model_config.max_model_len,
|
_validate_truncation_size(model_config.max_model_len,
|
||||||
@ -1181,12 +1190,15 @@ class LLM:
|
|||||||
raise ValueError("Embedding API is not supported by this model. "
|
raise ValueError("Embedding API is not supported by this model. "
|
||||||
"Please set `--task embed`.")
|
"Please set `--task embed`.")
|
||||||
|
|
||||||
items = self.encode(prompts,
|
items = self.encode(
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
prompts,
|
||||||
use_tqdm=use_tqdm,
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
pooling_params=pooling_params,
|
use_tqdm=use_tqdm,
|
||||||
lora_request=lora_request,
|
pooling_params=pooling_params,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
lora_request=lora_request,
|
||||||
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
pooling_task="embed",
|
||||||
|
)
|
||||||
|
|
||||||
return [EmbeddingRequestOutput.from_base(item) for item in items]
|
return [EmbeddingRequestOutput.from_base(item) for item in items]
|
||||||
|
|
||||||
@ -1228,10 +1240,13 @@ class LLM:
|
|||||||
"Classification API is not supported by this model. "
|
"Classification API is not supported by this model. "
|
||||||
"Please set `--task classify`.")
|
"Please set `--task classify`.")
|
||||||
|
|
||||||
items = self.encode(prompts,
|
items = self.encode(
|
||||||
use_tqdm=use_tqdm,
|
prompts,
|
||||||
lora_request=lora_request,
|
use_tqdm=use_tqdm,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
lora_request=lora_request,
|
||||||
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
pooling_task="classify",
|
||||||
|
)
|
||||||
|
|
||||||
return [ClassificationRequestOutput.from_base(item) for item in items]
|
return [ClassificationRequestOutput.from_base(item) for item in items]
|
||||||
|
|
||||||
@ -1251,7 +1266,9 @@ class LLM:
|
|||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
use_tqdm=use_tqdm,
|
use_tqdm=use_tqdm,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
pooling_task="embed",
|
||||||
|
)
|
||||||
|
|
||||||
encoded_output_1: list[PoolingRequestOutput] = encoded_output[
|
encoded_output_1: list[PoolingRequestOutput] = encoded_output[
|
||||||
0:len(text_1)]
|
0:len(text_1)]
|
||||||
@ -1287,7 +1304,7 @@ class LLM:
|
|||||||
if len(data_1) == 1:
|
if len(data_1) == 1:
|
||||||
data_1 = data_1 * len(data_2)
|
data_1 = data_1 * len(data_2)
|
||||||
|
|
||||||
pooling_params = PoolingParams(use_cross_encoder=True)
|
pooling_params = PoolingParams(task="score")
|
||||||
tokenization_kwargs: dict[str, Any] = {}
|
tokenization_kwargs: dict[str, Any] = {}
|
||||||
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
|
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
|
||||||
truncate_prompt_tokens, tokenization_kwargs)
|
truncate_prompt_tokens, tokenization_kwargs)
|
||||||
|
|||||||
@ -1347,8 +1347,8 @@ class ScoreRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
# --8<-- [end:score-extra-params]
|
# --8<-- [end:score-extra-params]
|
||||||
|
|
||||||
def to_pooling_params(self, *, use_cross_encoder: bool = False):
|
def to_pooling_params(self):
|
||||||
return PoolingParams(use_cross_encoder=use_cross_encoder)
|
return PoolingParams()
|
||||||
|
|
||||||
|
|
||||||
class RerankRequest(OpenAIBaseModel):
|
class RerankRequest(OpenAIBaseModel):
|
||||||
@ -1375,8 +1375,8 @@ class RerankRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
# --8<-- [end:rerank-extra-params]
|
# --8<-- [end:rerank-extra-params]
|
||||||
|
|
||||||
def to_pooling_params(self, *, use_cross_encoder: bool = False):
|
def to_pooling_params(self):
|
||||||
return PoolingParams(use_cross_encoder=use_cross_encoder)
|
return PoolingParams()
|
||||||
|
|
||||||
|
|
||||||
class RerankDocument(BaseModel):
|
class RerankDocument(BaseModel):
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import Optional, Union, cast
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
@ -21,12 +22,14 @@ from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext,
|
|||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
|
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ClassificationMixin(OpenAIServing):
|
class ClassificationMixin(OpenAIServing):
|
||||||
|
|
||||||
|
@override
|
||||||
async def _preprocess(
|
async def _preprocess(
|
||||||
self,
|
self,
|
||||||
ctx: ServeContext,
|
ctx: ServeContext,
|
||||||
@ -75,6 +78,7 @@ class ClassificationMixin(OpenAIServing):
|
|||||||
logger.exception("Error in preprocessing prompt inputs")
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
@override
|
||||||
def _build_response(
|
def _build_response(
|
||||||
self,
|
self,
|
||||||
ctx: ServeContext,
|
ctx: ServeContext,
|
||||||
@ -158,3 +162,31 @@ class ServingClassification(ClassificationMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return await super().handle(ctx) # type: ignore
|
return await super().handle(ctx) # type: ignore
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _validate_request(
|
||||||
|
self,
|
||||||
|
ctx: ClassificationServeContext,
|
||||||
|
) -> Optional[ErrorResponse]:
|
||||||
|
if error := super()._validate_request(ctx):
|
||||||
|
return error
|
||||||
|
|
||||||
|
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _create_pooling_params(
|
||||||
|
self,
|
||||||
|
ctx: ClassificationServeContext,
|
||||||
|
) -> Union[PoolingParams, ErrorResponse]:
|
||||||
|
pooling_params = super()._create_pooling_params(ctx)
|
||||||
|
if isinstance(pooling_params, ErrorResponse):
|
||||||
|
return pooling_params
|
||||||
|
|
||||||
|
try:
|
||||||
|
pooling_params.verify("classify", self.model_config)
|
||||||
|
except ValueError as e:
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
return pooling_params
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||||
PoolingRequestOutput)
|
PoolingRequestOutput)
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -45,6 +46,7 @@ def _get_embedding(
|
|||||||
|
|
||||||
class EmbeddingMixin(OpenAIServing):
|
class EmbeddingMixin(OpenAIServing):
|
||||||
|
|
||||||
|
@override
|
||||||
async def _preprocess(
|
async def _preprocess(
|
||||||
self,
|
self,
|
||||||
ctx: ServeContext,
|
ctx: ServeContext,
|
||||||
@ -97,6 +99,7 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
logger.exception("Error in preprocessing prompt inputs")
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
@override
|
||||||
def _build_response(
|
def _build_response(
|
||||||
self,
|
self,
|
||||||
ctx: ServeContext,
|
ctx: ServeContext,
|
||||||
@ -191,11 +194,20 @@ class OpenAIServingEmbedding(EmbeddingMixin):
|
|||||||
|
|
||||||
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
|
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
|
||||||
|
|
||||||
pooling_params = ctx.request.to_pooling_params()
|
return None
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _create_pooling_params(
|
||||||
|
self,
|
||||||
|
ctx: ServeContext[EmbeddingRequest],
|
||||||
|
) -> Union[PoolingParams, ErrorResponse]:
|
||||||
|
pooling_params = super()._create_pooling_params(ctx)
|
||||||
|
if isinstance(pooling_params, ErrorResponse):
|
||||||
|
return pooling_params
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pooling_params.verify(self.model_config)
|
pooling_params.verify("embed", self.model_config)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
return None
|
return pooling_params
|
||||||
|
|||||||
@ -305,6 +305,16 @@ class OpenAIServing:
|
|||||||
" Please, select a smaller truncation size.")
|
" Please, select a smaller truncation size.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _create_pooling_params(
|
||||||
|
self,
|
||||||
|
ctx: ServeContext,
|
||||||
|
) -> Union[PoolingParams, ErrorResponse]:
|
||||||
|
if not hasattr(ctx.request, "to_pooling_params"):
|
||||||
|
return self.create_error_response(
|
||||||
|
"Request type does not support pooling parameters")
|
||||||
|
|
||||||
|
return ctx.request.to_pooling_params()
|
||||||
|
|
||||||
async def _prepare_generators(
|
async def _prepare_generators(
|
||||||
self,
|
self,
|
||||||
ctx: ServeContext,
|
ctx: ServeContext,
|
||||||
@ -318,11 +328,9 @@ class OpenAIServing:
|
|||||||
trace_headers = (None if ctx.raw_request is None else await
|
trace_headers = (None if ctx.raw_request is None else await
|
||||||
self._get_trace_headers(ctx.raw_request.headers))
|
self._get_trace_headers(ctx.raw_request.headers))
|
||||||
|
|
||||||
if not hasattr(ctx.request, "to_pooling_params"):
|
pooling_params = self._create_pooling_params(ctx)
|
||||||
return self.create_error_response(
|
if isinstance(pooling_params, ErrorResponse):
|
||||||
"Request type does not support pooling parameters")
|
return pooling_params
|
||||||
|
|
||||||
pooling_params = ctx.request.to_pooling_params()
|
|
||||||
|
|
||||||
if ctx.engine_prompts is None:
|
if ctx.engine_prompts is None:
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
|
|||||||
@ -142,6 +142,11 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
try:
|
try:
|
||||||
pooling_params = request.to_pooling_params()
|
pooling_params = request.to_pooling_params()
|
||||||
|
|
||||||
|
try:
|
||||||
|
pooling_params.verify("encode", self.model_config)
|
||||||
|
except ValueError as e:
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
|
||||||
|
|||||||
@ -55,14 +55,13 @@ class ServingScores(OpenAIServing):
|
|||||||
texts_1: list[str],
|
texts_1: list[str],
|
||||||
texts_2: list[str],
|
texts_2: list[str],
|
||||||
request: Union[RerankRequest, ScoreRequest],
|
request: Union[RerankRequest, ScoreRequest],
|
||||||
request_id=str,
|
request_id: str,
|
||||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||||
lora_request: Optional[Union[LoRARequest, None]] = None,
|
lora_request: Optional[Union[LoRARequest, None]] = None,
|
||||||
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
|
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
|
||||||
None]] = None,
|
None]] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
) -> list[PoolingRequestOutput]:
|
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||||
|
|
||||||
input_texts = texts_1 + texts_2
|
input_texts = texts_1 + texts_2
|
||||||
|
|
||||||
engine_prompts: list[TokensPrompt] = []
|
engine_prompts: list[TokensPrompt] = []
|
||||||
@ -89,6 +88,11 @@ class ServingScores(OpenAIServing):
|
|||||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||||
pooling_params = request.to_pooling_params()
|
pooling_params = request.to_pooling_params()
|
||||||
|
|
||||||
|
try:
|
||||||
|
pooling_params.verify("embed", self.model_config)
|
||||||
|
except ValueError as e:
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
|
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
@ -169,14 +173,13 @@ class ServingScores(OpenAIServing):
|
|||||||
data_1: Union[list[str], list[ScoreContentPartParam]],
|
data_1: Union[list[str], list[ScoreContentPartParam]],
|
||||||
data_2: Union[list[str], list[ScoreContentPartParam]],
|
data_2: Union[list[str], list[ScoreContentPartParam]],
|
||||||
request: Union[RerankRequest, ScoreRequest],
|
request: Union[RerankRequest, ScoreRequest],
|
||||||
request_id=str,
|
request_id: str,
|
||||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||||
lora_request: Optional[Union[LoRARequest, None]] = None,
|
lora_request: Optional[Union[LoRARequest, None]] = None,
|
||||||
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
|
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
|
||||||
None]] = None,
|
None]] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
) -> list[PoolingRequestOutput]:
|
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||||
|
|
||||||
request_prompts: list[str] = []
|
request_prompts: list[str] = []
|
||||||
engine_prompts: list[TokensPrompt] = []
|
engine_prompts: list[TokensPrompt] = []
|
||||||
|
|
||||||
@ -245,7 +248,12 @@ class ServingScores(OpenAIServing):
|
|||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||||
|
|
||||||
pooling_params = request.to_pooling_params(use_cross_encoder=True)
|
pooling_params = request.to_pooling_params()
|
||||||
|
|
||||||
|
try:
|
||||||
|
pooling_params.verify("score", self.model_config)
|
||||||
|
except ValueError as e:
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
@ -286,8 +294,7 @@ class ServingScores(OpenAIServing):
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
raw_request: Optional[Request] = None,
|
raw_request: Optional[Request] = None,
|
||||||
truncate_prompt_tokens: Optional[int] = None,
|
truncate_prompt_tokens: Optional[int] = None,
|
||||||
) -> list[PoolingRequestOutput]:
|
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||||
|
|
||||||
(
|
(
|
||||||
lora_request,
|
lora_request,
|
||||||
prompt_adapter_request,
|
prompt_adapter_request,
|
||||||
@ -374,6 +381,8 @@ class ServingScores(OpenAIServing):
|
|||||||
raw_request,
|
raw_request,
|
||||||
request.truncate_prompt_tokens,
|
request.truncate_prompt_tokens,
|
||||||
)
|
)
|
||||||
|
if isinstance(final_res_batch, ErrorResponse):
|
||||||
|
return final_res_batch
|
||||||
|
|
||||||
return self.request_output_to_score_response(
|
return self.request_output_to_score_response(
|
||||||
final_res_batch,
|
final_res_batch,
|
||||||
@ -420,6 +429,9 @@ class ServingScores(OpenAIServing):
|
|||||||
raw_request,
|
raw_request,
|
||||||
request.truncate_prompt_tokens,
|
request.truncate_prompt_tokens,
|
||||||
)
|
)
|
||||||
|
if isinstance(final_res_batch, ErrorResponse):
|
||||||
|
return final_res_batch
|
||||||
|
|
||||||
return self.request_output_to_rerank_response(
|
return self.request_output_to_rerank_response(
|
||||||
final_res_batch,
|
final_res_batch,
|
||||||
request_id,
|
request_id,
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from functools import cached_property
|
||||||
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
|
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
|
||||||
Union)
|
Union)
|
||||||
|
|
||||||
@ -15,6 +16,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.pooling_params import PoolingTask
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||||
from vllm.utils import make_async
|
from vllm.utils import make_async
|
||||||
@ -135,6 +137,11 @@ class ExecutorBase(ABC):
|
|||||||
|
|
||||||
return self.collective_rpc(rpc_func)
|
return self.collective_rpc(rpc_func)
|
||||||
|
|
||||||
|
@cached_property # Avoid unnecessary RPC calls
|
||||||
|
def supported_pooling_tasks(self) -> tuple[PoolingTask, ...]:
|
||||||
|
output = self.collective_rpc("get_supported_pooling_tasks")
|
||||||
|
return tuple({task for tasks in output for task in tasks})
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self, execute_model_req: ExecuteModelRequest
|
self, execute_model_req: ExecuteModelRequest
|
||||||
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Callable, Literal, Optional, TypeVar, Union
|
from typing import Callable, Optional, TypeVar, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -15,13 +15,12 @@ from vllm.config import ModelConfig, PoolerConfig
|
|||||||
from vllm.model_executor.pooling_metadata import ( # noqa: E501
|
from vllm.model_executor.pooling_metadata import ( # noqa: E501
|
||||||
PoolingMetadata as V0PoolingMetadata)
|
PoolingMetadata as V0PoolingMetadata)
|
||||||
from vllm.model_executor.pooling_metadata import PoolingTensors
|
from vllm.model_executor.pooling_metadata import PoolingTensors
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams, PoolingTask
|
||||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||||
from vllm.utils import resolve_obj_by_qualname
|
from vllm.utils import resolve_obj_by_qualname
|
||||||
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
||||||
|
|
||||||
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
|
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
|
||||||
PoolingTask = Literal["encode", "embed", "classify", "score"]
|
|
||||||
|
|
||||||
|
|
||||||
class PoolingType(IntEnum):
|
class PoolingType(IntEnum):
|
||||||
@ -67,6 +66,15 @@ class ResolvedPoolingConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PoolingParamsUpdate:
|
||||||
|
requires_token_ids: bool = False
|
||||||
|
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
|
||||||
|
|
||||||
|
def apply(self, params: PoolingParams) -> None:
|
||||||
|
params.requires_token_ids = self.requires_token_ids
|
||||||
|
|
||||||
|
|
||||||
class Pooler(nn.Module, ABC):
|
class Pooler(nn.Module, ABC):
|
||||||
"""The interface required for all poolers used in pooling models in vLLM."""
|
"""The interface required for all poolers used in pooling models in vLLM."""
|
||||||
|
|
||||||
@ -93,7 +101,10 @@ class Pooler(nn.Module, ABC):
|
|||||||
|
|
||||||
return SimplePooler.from_config(resolved_config)
|
return SimplePooler.from_config(resolved_config)
|
||||||
|
|
||||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
def get_pooling_updates(
|
||||||
|
self,
|
||||||
|
task: PoolingTask,
|
||||||
|
) -> Optional[PoolingParamsUpdate]:
|
||||||
"""
|
"""
|
||||||
Construct the pooling parameters to use for a task,
|
Construct the pooling parameters to use for a task,
|
||||||
or `None` if the task is not supported.
|
or `None` if the task is not supported.
|
||||||
@ -121,6 +132,23 @@ def get_prompt_lens(
|
|||||||
pooling_metadata, hidden_states.device).prompt_lens
|
pooling_metadata, hidden_states.device).prompt_lens
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_token_ids(
|
||||||
|
pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
|
||||||
|
if isinstance(pooling_metadata, V1PoolingMetadata):
|
||||||
|
assert pooling_metadata.prompt_token_ids is not None, (
|
||||||
|
"Please set `requires_token_ids=True` in `get_pooling_updates`")
|
||||||
|
|
||||||
|
return [
|
||||||
|
pooling_metadata.prompt_token_ids[i, :num]
|
||||||
|
for i, num in enumerate(pooling_metadata.prompt_lens)
|
||||||
|
]
|
||||||
|
|
||||||
|
return [
|
||||||
|
torch.tensor(seq_data_i.prompt_token_ids)
|
||||||
|
for seq_data_i in pooling_metadata.seq_data.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_classification_activation_function(config: PretrainedConfig):
|
def get_classification_activation_function(config: PretrainedConfig):
|
||||||
return PoolerClassify()
|
return PoolerClassify()
|
||||||
|
|
||||||
@ -165,7 +193,10 @@ class PoolingMethod(nn.Module, ABC):
|
|||||||
raise NotImplementedError(f"Unsupported method: {pooling_type}")
|
raise NotImplementedError(f"Unsupported method: {pooling_type}")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
def get_pooling_updates(
|
||||||
|
self,
|
||||||
|
task: PoolingTask,
|
||||||
|
) -> Optional[PoolingParamsUpdate]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -206,11 +237,14 @@ class PoolingMethod(nn.Module, ABC):
|
|||||||
|
|
||||||
class CLSPool(PoolingMethod):
|
class CLSPool(PoolingMethod):
|
||||||
|
|
||||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
def get_pooling_updates(
|
||||||
|
self,
|
||||||
|
task: PoolingTask,
|
||||||
|
) -> Optional[PoolingParamsUpdate]:
|
||||||
# The equalities are split up to keep mypy happy
|
# The equalities are split up to keep mypy happy
|
||||||
if (task == "encode" or task == "embed" or task == "classify"
|
if (task == "encode" or task == "embed" or task == "classify"
|
||||||
or task == "score"):
|
or task == "score"):
|
||||||
return PoolingParams()
|
return PoolingParamsUpdate()
|
||||||
|
|
||||||
assert_never(task)
|
assert_never(task)
|
||||||
|
|
||||||
@ -236,11 +270,14 @@ class CLSPool(PoolingMethod):
|
|||||||
|
|
||||||
class LastPool(PoolingMethod):
|
class LastPool(PoolingMethod):
|
||||||
|
|
||||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
def get_pooling_updates(
|
||||||
|
self,
|
||||||
|
task: PoolingTask,
|
||||||
|
) -> Optional[PoolingParamsUpdate]:
|
||||||
# The equalities are split up to keep mypy happy
|
# The equalities are split up to keep mypy happy
|
||||||
if (task == "encode" or task == "embed" or task == "classify"
|
if (task == "encode" or task == "embed" or task == "classify"
|
||||||
or task == "score"):
|
or task == "score"):
|
||||||
return PoolingParams()
|
return PoolingParamsUpdate()
|
||||||
|
|
||||||
assert_never(task)
|
assert_never(task)
|
||||||
|
|
||||||
@ -262,9 +299,12 @@ class LastPool(PoolingMethod):
|
|||||||
|
|
||||||
class AllPool(PoolingMethod):
|
class AllPool(PoolingMethod):
|
||||||
|
|
||||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
def get_pooling_updates(
|
||||||
|
self,
|
||||||
|
task: PoolingTask,
|
||||||
|
) -> Optional[PoolingParamsUpdate]:
|
||||||
if task == "encode":
|
if task == "encode":
|
||||||
return PoolingParams()
|
return PoolingParamsUpdate()
|
||||||
|
|
||||||
# The equalities are split up to keep mypy happy
|
# The equalities are split up to keep mypy happy
|
||||||
if task == "embed" or task == "classify" or task == "score":
|
if task == "embed" or task == "classify" or task == "score":
|
||||||
@ -299,11 +339,14 @@ class AllPool(PoolingMethod):
|
|||||||
|
|
||||||
class MeanPool(PoolingMethod):
|
class MeanPool(PoolingMethod):
|
||||||
|
|
||||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
def get_pooling_updates(
|
||||||
|
self,
|
||||||
|
task: PoolingTask,
|
||||||
|
) -> Optional[PoolingParamsUpdate]:
|
||||||
# The equalities are split up to keep mypy happy
|
# The equalities are split up to keep mypy happy
|
||||||
if (task == "encode" or task == "embed" or task == "classify"
|
if (task == "encode" or task == "embed" or task == "classify"
|
||||||
or task == "score"):
|
or task == "score"):
|
||||||
return PoolingParams()
|
return PoolingParamsUpdate()
|
||||||
|
|
||||||
assert_never(task)
|
assert_never(task)
|
||||||
|
|
||||||
@ -520,8 +563,11 @@ class SimplePooler(Pooler):
|
|||||||
self.pooling = pooling
|
self.pooling = pooling
|
||||||
self.head = head
|
self.head = head
|
||||||
|
|
||||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
def get_pooling_updates(
|
||||||
return self.pooling.get_pooling_params(task)
|
self,
|
||||||
|
task: PoolingTask,
|
||||||
|
) -> Optional[PoolingParamsUpdate]:
|
||||||
|
return self.pooling.get_pooling_updates(task)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -559,27 +605,13 @@ class StepPooler(Pooler):
|
|||||||
self.step_tag_id = step_tag_id
|
self.step_tag_id = step_tag_id
|
||||||
self.returned_token_ids = returned_token_ids
|
self.returned_token_ids = returned_token_ids
|
||||||
|
|
||||||
def get_prompt_token_ids(
|
|
||||||
self,
|
|
||||||
pooling_metadata: PoolingMetadata,
|
|
||||||
) -> list[torch.Tensor]:
|
|
||||||
if isinstance(pooling_metadata, V1PoolingMetadata):
|
|
||||||
return [
|
|
||||||
pooling_metadata.prompt_token_ids[i, :num]
|
|
||||||
for i, num in enumerate(pooling_metadata.prompt_lens)
|
|
||||||
]
|
|
||||||
return [
|
|
||||||
torch.tensor(seq_data_i.prompt_token_ids)
|
|
||||||
for seq_data_i in pooling_metadata.seq_data.values()
|
|
||||||
]
|
|
||||||
|
|
||||||
def extract_states(
|
def extract_states(
|
||||||
self,
|
self,
|
||||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> Union[list[torch.Tensor], torch.Tensor]:
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||||
prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)
|
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
|
||||||
|
|
||||||
pooled_data = list[torch.Tensor]()
|
pooled_data = list[torch.Tensor]()
|
||||||
returned_token_ids = self.returned_token_ids
|
returned_token_ids = self.returned_token_ids
|
||||||
@ -595,9 +627,12 @@ class StepPooler(Pooler):
|
|||||||
|
|
||||||
return pooled_data
|
return pooled_data
|
||||||
|
|
||||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
def get_pooling_updates(
|
||||||
|
self,
|
||||||
|
task: PoolingTask,
|
||||||
|
) -> Optional[PoolingParamsUpdate]:
|
||||||
if task == "encode":
|
if task == "encode":
|
||||||
return PoolingParams(logits_processing_needs_token_ids=True)
|
return PoolingParamsUpdate(requires_token_ids=True)
|
||||||
|
|
||||||
# The equalities are split up to keep mypy happy
|
# The equalities are split up to keep mypy happy
|
||||||
if task == "embed" or task == "classify" or task == "score":
|
if task == "embed" or task == "classify" or task == "score":
|
||||||
@ -650,19 +685,24 @@ class ClassifierPooler(nn.Module):
|
|||||||
self.cross_encoder_act_fn = get_cross_encoder_activation_function(
|
self.cross_encoder_act_fn = get_cross_encoder_activation_function(
|
||||||
config.hf_config) if act_fn is None else act_fn
|
config.hf_config) if act_fn is None else act_fn
|
||||||
|
|
||||||
def _get_act_fn(self, use_cross_encoder: bool):
|
def _get_act_fn(self, task: PoolingTask):
|
||||||
return (self.cross_encoder_act_fn
|
if task == "encode" or task == "classify":
|
||||||
if use_cross_encoder else self.classification_act_fn)
|
return self.classification_act_fn
|
||||||
|
if task == "score":
|
||||||
|
return self.cross_encoder_act_fn
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported task: {task!r}")
|
||||||
|
|
||||||
|
def get_pooling_updates(
|
||||||
|
self,
|
||||||
|
task: PoolingTask,
|
||||||
|
) -> Optional[PoolingParamsUpdate]:
|
||||||
|
# The equalities are split up to keep mypy happy
|
||||||
|
if task == "encode" or task == "classify" or task == "score":
|
||||||
|
return PoolingParamsUpdate()
|
||||||
|
|
||||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
|
||||||
if task == "encode":
|
|
||||||
return PoolingParams()
|
|
||||||
if task == "embed":
|
if task == "embed":
|
||||||
return None
|
return None
|
||||||
if task == "classify":
|
|
||||||
return PoolingParams()
|
|
||||||
if task == "score":
|
|
||||||
return PoolingParams(use_cross_encoder=True)
|
|
||||||
|
|
||||||
assert_never(task)
|
assert_never(task)
|
||||||
|
|
||||||
@ -682,27 +722,28 @@ class ClassifierPooler(nn.Module):
|
|||||||
else:
|
else:
|
||||||
pooled_output = [self.classifier(data) for data in pooled_data]
|
pooled_output = [self.classifier(data) for data in pooled_data]
|
||||||
|
|
||||||
|
task_list: list[PoolingTask]
|
||||||
if isinstance(pooling_metadata, V0PoolingMetadata):
|
if isinstance(pooling_metadata, V0PoolingMetadata):
|
||||||
use_cross_encoder_list = [
|
task_list = [
|
||||||
pooling_param.use_cross_encoder
|
task for _, pooling_param in pooling_metadata.seq_groups
|
||||||
for _, pooling_param in pooling_metadata.seq_groups
|
if (task := pooling_param.task) is not None
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
use_cross_encoder_list = [
|
task_list = [
|
||||||
pooling_param.use_cross_encoder
|
task for pooling_param in pooling_metadata.pooling_params
|
||||||
for pooling_param in pooling_metadata.pooling_params
|
if (task := pooling_param.task) is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
|
assert len(task_list) == len(pooled_output)
|
||||||
|
|
||||||
# shape of scores: (batch_size, num_labels)
|
# shape of scores: (batch_size, num_labels)
|
||||||
if all(use_cross_encoder == use_cross_encoder_list[0]
|
if len(set(task_list)) <= 1:
|
||||||
for use_cross_encoder in use_cross_encoder_list):
|
act_fn = self._get_act_fn(task_list[0])
|
||||||
act_fn = self._get_act_fn(use_cross_encoder_list[0])
|
|
||||||
scores = act_fn(pooled_output)
|
scores = act_fn(pooled_output)
|
||||||
else:
|
else:
|
||||||
scores = torch.stack([
|
scores = torch.stack([
|
||||||
self._get_act_fn(use_cross_encoder)(vecs)
|
self._get_act_fn(task)(vecs)
|
||||||
for use_cross_encoder, vecs in zip(use_cross_encoder_list,
|
for task, vecs in zip(task_list, pooled_output)
|
||||||
pooled_output)
|
|
||||||
])
|
])
|
||||||
|
|
||||||
return build_output(scores)
|
return build_output(scores)
|
||||||
|
|||||||
@ -18,13 +18,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
||||||
PoolingMethod, PoolingTask,
|
PoolingMethod,
|
||||||
|
PoolingParamsUpdate,
|
||||||
PoolingType)
|
PoolingType)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingTask
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
|
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
|
||||||
@ -91,8 +92,11 @@ class BertPooler(Pooler):
|
|||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.activation = nn.Tanh()
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
def get_pooling_updates(
|
||||||
return self.pooling.get_pooling_params(task)
|
self,
|
||||||
|
task: PoolingTask,
|
||||||
|
) -> Optional[PoolingParamsUpdate]:
|
||||||
|
return self.pooling.get_pooling_updates(task)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,18 +1,24 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from array import array
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.pooler import PoolerHead, PoolerNormalize
|
from vllm.model_executor.layers.pooler import (Pooler, PoolerHead,
|
||||||
|
PoolerNormalize,
|
||||||
|
PoolingParamsUpdate,
|
||||||
|
build_output, get_prompt_lens,
|
||||||
|
get_prompt_token_ids)
|
||||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
PoolingTensors)
|
from vllm.pooling_params import PoolingTask
|
||||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
from vllm.sequence import PoolerOutput
|
||||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||||
|
|
||||||
from .interfaces import SupportsV0Only
|
from .interfaces import SupportsV0Only
|
||||||
@ -20,7 +26,8 @@ from .interfaces import SupportsV0Only
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GritLMPooler(nn.Module):
|
class GritLMMeanPool(nn.Module):
|
||||||
|
"""As `MeanPool`, but only includes non-instruction tokens."""
|
||||||
|
|
||||||
def __init__(self, model_config: ModelConfig):
|
def __init__(self, model_config: ModelConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -38,8 +45,8 @@ class GritLMPooler(nn.Module):
|
|||||||
for tok in ["<s>", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"]
|
for tok in ["<s>", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"]
|
||||||
}
|
}
|
||||||
|
|
||||||
def tokens_to_ids(tokens: list[str]) -> array:
|
def tokens_to_ids(tokens: list[str]) -> np.ndarray:
|
||||||
return array("i", [self.token_ids[token] for token in tokens])
|
return np.array([self.token_ids[token] for token in tokens])
|
||||||
|
|
||||||
self.user_pattern_ids = tokens_to_ids(
|
self.user_pattern_ids = tokens_to_ids(
|
||||||
["▁<", "|", "user", "|", ">", "<0x0A>"])
|
["▁<", "|", "user", "|", ">", "<0x0A>"])
|
||||||
@ -48,32 +55,44 @@ class GritLMPooler(nn.Module):
|
|||||||
self.embed_pattern_ids = tokens_to_ids(
|
self.embed_pattern_ids = tokens_to_ids(
|
||||||
["▁<", "|", "embed", "|", ">", "<0x0A>"])
|
["▁<", "|", "embed", "|", ">", "<0x0A>"])
|
||||||
|
|
||||||
self.head = PoolerHead(PoolerNormalize())
|
def _find_array(
|
||||||
|
self,
|
||||||
def _find_array(self, arr: array, target: array, start_idx: int) -> int:
|
arr: np.ndarray,
|
||||||
|
target: np.ndarray,
|
||||||
|
start_idx: int = 0,
|
||||||
|
end_idx: Optional[int] = None,
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Find the first occurrence of target in arr starting from start_idx.
|
Find the first occurrence of `target` in `arr` starting from
|
||||||
|
`start_idx`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
arr: The array to search within
|
arr: The array to search within.
|
||||||
target: The consecutive subsequence to find
|
target: The consecutive subsequence to find.
|
||||||
start_idx: The starting index to search from
|
start_idx: The starting index to search from (inclusive).
|
||||||
|
end_idx: The ending index to search from (exclusive).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: The index of the first occurrence of target in arr.
|
The index of the first occurrence of `target` in `arr`.
|
||||||
"""
|
"""
|
||||||
if start_idx < 0:
|
if start_idx < 0:
|
||||||
raise ValueError("start_idx must be non-negative")
|
raise ValueError("`start_idx` must be non-negative")
|
||||||
if not target or not arr:
|
if len(arr) == 0 or len(target) == 0:
|
||||||
raise ValueError("Empty arr or target not allowed")
|
raise ValueError("Empty `arr` or `target` not allowed")
|
||||||
|
|
||||||
|
arr_len = len(arr)
|
||||||
target_len = len(target)
|
target_len = len(target)
|
||||||
for i in range(start_idx, len(arr) - target_len + 1):
|
|
||||||
if arr[i:i + target_len] == target:
|
if end_idx is None:
|
||||||
|
end_idx = arr_len
|
||||||
|
|
||||||
|
for i in range(start_idx, min(end_idx, arr_len - target_len + 1)):
|
||||||
|
if (arr[i:i + target_len] == target).all():
|
||||||
return i
|
return i
|
||||||
|
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def _get_instruction_len(self, prompt_token_ids: array) -> int:
|
def _get_instruction_len(self, prompt_token_ids: np.ndarray) -> int:
|
||||||
"""
|
"""
|
||||||
Get the length of the instruction in the prompt.
|
Get the length of the instruction in the prompt.
|
||||||
|
|
||||||
@ -83,7 +102,6 @@ class GritLMPooler(nn.Module):
|
|||||||
The pattern matching is done using integers instead of strings
|
The pattern matching is done using integers instead of strings
|
||||||
because the prompt is given as a list of token IDs.
|
because the prompt is given as a list of token IDs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
instruction_len = 0
|
instruction_len = 0
|
||||||
|
|
||||||
# Return no instruction in case of missing BOS token.
|
# Return no instruction in case of missing BOS token.
|
||||||
@ -98,7 +116,8 @@ class GritLMPooler(nn.Module):
|
|||||||
embed_pattern_ids = self.embed_pattern_ids
|
embed_pattern_ids = self.embed_pattern_ids
|
||||||
if self._find_array(prompt_token_ids,
|
if self._find_array(prompt_token_ids,
|
||||||
self.user_pattern_ids,
|
self.user_pattern_ids,
|
||||||
start_idx=1) == 1:
|
start_idx=1,
|
||||||
|
end_idx=2) == 1:
|
||||||
embed_pattern_ids = self.embed_newline_pattern_ids
|
embed_pattern_ids = self.embed_newline_pattern_ids
|
||||||
|
|
||||||
# Find the embed pattern in the prompt.
|
# Find the embed pattern in the prompt.
|
||||||
@ -116,64 +135,92 @@ class GritLMPooler(nn.Module):
|
|||||||
|
|
||||||
return instruction_len
|
return instruction_len
|
||||||
|
|
||||||
|
def get_pooling_updates(
|
||||||
|
self,
|
||||||
|
task: PoolingTask,
|
||||||
|
) -> Optional[PoolingParamsUpdate]:
|
||||||
|
# The equalities are split up to keep mypy happy
|
||||||
|
if task == "encode" or task == "embed":
|
||||||
|
return PoolingParamsUpdate(requires_token_ids=True)
|
||||||
|
|
||||||
|
if task == "classify" or task == "score":
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert_never(task)
|
||||||
|
|
||||||
|
def forward_one(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
prompt_len: Optional[torch.Tensor] = None,
|
||||||
|
instr_len: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert prompt_len is None or prompt_len == hidden_states.shape[0], \
|
||||||
|
"partial prefill not supported with MEAN pooling"
|
||||||
|
|
||||||
|
return hidden_states[instr_len:].mean(dim=0, dtype=torch.float32)
|
||||||
|
|
||||||
|
def forward_all(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
prompt_lens: torch.Tensor,
|
||||||
|
instr_lens: torch.Tensor,
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
|
offset = 0
|
||||||
|
pooled_data = list[torch.Tensor]()
|
||||||
|
|
||||||
|
for prompt_len, instr_len in zip(prompt_lens, instr_lens):
|
||||||
|
pooled_data.append(hidden_states[offset + instr_len:offset +
|
||||||
|
prompt_len].mean(
|
||||||
|
dim=0, dtype=torch.float32))
|
||||||
|
offset += prompt_len
|
||||||
|
|
||||||
|
return pooled_data
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor]:
|
||||||
|
prompt_lens = get_prompt_lens(hidden_states, pooling_metadata)
|
||||||
|
instr_lens = torch.tensor(
|
||||||
|
[
|
||||||
|
self._get_instruction_len(token_ids.cpu().numpy())
|
||||||
|
for token_ids in get_prompt_token_ids(pooling_metadata)
|
||||||
|
],
|
||||||
|
device=prompt_lens.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(hidden_states, list):
|
||||||
|
return [
|
||||||
|
self.forward_one(h, prompt_len, instr_len) for h, prompt_len,
|
||||||
|
instr_len in zip(hidden_states, prompt_lens, instr_lens)
|
||||||
|
]
|
||||||
|
|
||||||
|
return self.forward_all(hidden_states, prompt_lens, instr_lens)
|
||||||
|
|
||||||
|
|
||||||
|
class GritLMPooler(Pooler):
|
||||||
|
|
||||||
|
def __init__(self, model_config: ModelConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pooling = GritLMMeanPool(model_config)
|
||||||
|
self.head = PoolerHead(PoolerNormalize())
|
||||||
|
|
||||||
|
def get_pooling_updates(
|
||||||
|
self,
|
||||||
|
task: PoolingTask,
|
||||||
|
) -> Optional[PoolingParamsUpdate]:
|
||||||
|
return self.pooling.get_pooling_updates(task)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> PoolerOutput:
|
) -> PoolerOutput:
|
||||||
"""
|
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||||
Pool the hidden states by summing the embeddings of
|
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||||
non-instruction tokens.
|
return build_output(pooled_data)
|
||||||
"""
|
|
||||||
prompts_token_ids = [
|
|
||||||
token_ids.prompt_token_ids_array
|
|
||||||
for _, token_ids in pooling_metadata.seq_data.items()
|
|
||||||
]
|
|
||||||
|
|
||||||
instruction_lens = torch.tensor(
|
|
||||||
[
|
|
||||||
self._get_instruction_len(prompt_token_ids)
|
|
||||||
for prompt_token_ids in prompts_token_ids
|
|
||||||
],
|
|
||||||
device=hidden_states.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_lens = PoolingTensors.from_pooling_metadata(
|
|
||||||
pooling_metadata, hidden_states.device).prompt_lens
|
|
||||||
|
|
||||||
mask = torch.zeros_like(hidden_states, dtype=torch.bool)
|
|
||||||
|
|
||||||
start_idx = 0
|
|
||||||
for prompt_len, instruction_len in zip(prompt_lens, instruction_lens):
|
|
||||||
end_idx = start_idx + prompt_len
|
|
||||||
mask[start_idx + instruction_len:end_idx] = True
|
|
||||||
start_idx = end_idx
|
|
||||||
|
|
||||||
masked_hidden_states = hidden_states.masked_fill(~mask, 0.0)
|
|
||||||
|
|
||||||
sum_embeddings = torch.zeros(len(prompt_lens),
|
|
||||||
hidden_states.size(1),
|
|
||||||
device=hidden_states.device)
|
|
||||||
|
|
||||||
start_idx = 0
|
|
||||||
for i, prompt_len in enumerate(prompt_lens):
|
|
||||||
end_idx = start_idx + prompt_len
|
|
||||||
sum_embeddings[i] = masked_hidden_states[start_idx:end_idx].sum(
|
|
||||||
dim=0)
|
|
||||||
start_idx = end_idx
|
|
||||||
|
|
||||||
num_non_instruction_tokens = prompt_lens - instruction_lens
|
|
||||||
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
|
|
||||||
1)
|
|
||||||
|
|
||||||
pooled_data = self.head(mean_embeddings,
|
|
||||||
pooling_metadata=pooling_metadata)
|
|
||||||
|
|
||||||
pooled_outputs = [
|
|
||||||
PoolingSequenceGroupOutput(data) for data in pooled_data
|
|
||||||
]
|
|
||||||
|
|
||||||
return PoolerOutput(outputs=pooled_outputs)
|
|
||||||
|
|
||||||
|
|
||||||
class GritLM(LlamaForCausalLM, SupportsV0Only):
|
class GritLM(LlamaForCausalLM, SupportsV0Only):
|
||||||
@ -202,7 +249,7 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Use full attention for pooling
|
# Use full attention for pooling (this is why V1 is not supported yet)
|
||||||
if vllm_config.model_config.runner_type == "pooling":
|
if vllm_config.model_config.runner_type == "pooling":
|
||||||
hf_config = vllm_config.model_config.hf_config
|
hf_config = vllm_config.model_config.hf_config
|
||||||
hf_config.is_causal = False
|
hf_config.is_causal = False
|
||||||
|
|||||||
@ -599,13 +599,6 @@ def supports_cross_encoding(
|
|||||||
return is_pooling_model(model) and _supports_cross_encoding(model)
|
return is_pooling_model(model) and _supports_cross_encoding(model)
|
||||||
|
|
||||||
|
|
||||||
def has_step_pooler(model: Union[type[object], object]) -> bool:
|
|
||||||
"""Check if the model uses step pooler."""
|
|
||||||
from vllm.model_executor.layers.pooler import StepPooler
|
|
||||||
|
|
||||||
return is_pooling_model(model) and isinstance(model.pooler, StepPooler)
|
|
||||||
|
|
||||||
|
|
||||||
class SupportsQuant:
|
class SupportsQuant:
|
||||||
"""The interface required for all models that support quantization."""
|
"""The interface required for all models that support quantization."""
|
||||||
|
|
||||||
|
|||||||
@ -14,14 +14,15 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
|||||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
||||||
PoolingMethod, PoolingTask,
|
PoolingMethod,
|
||||||
|
PoolingParamsUpdate,
|
||||||
PoolingType)
|
PoolingType)
|
||||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingTask
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
||||||
@ -270,8 +271,11 @@ class ModernBertPooler(Pooler):
|
|||||||
eps=config.norm_eps,
|
eps=config.norm_eps,
|
||||||
bias=config.norm_bias)
|
bias=config.norm_bias)
|
||||||
|
|
||||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
def get_pooling_updates(
|
||||||
return self.pooling.get_pooling_params(task)
|
self,
|
||||||
|
task: PoolingTask,
|
||||||
|
) -> Optional[PoolingParamsUpdate]:
|
||||||
|
return self.pooling.get_pooling_updates(task)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Literal, Optional
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
@ -10,12 +10,14 @@ from vllm.sampling_params import RequestOutputKind
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
|
||||||
|
PoolingTask = Literal["encode", "embed", "classify", "score"]
|
||||||
|
|
||||||
|
|
||||||
class PoolingParams(
|
class PoolingParams(
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
array_like=True): # type: ignore[call-arg]
|
array_like=True): # type: ignore[call-arg]
|
||||||
"""API parameters for pooling models. This
|
"""API parameters for pooling models.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
dimensions: Reduce the dimensions of embeddings
|
dimensions: Reduce the dimensions of embeddings
|
||||||
@ -24,24 +26,33 @@ class PoolingParams(
|
|||||||
|
|
||||||
dimensions: Optional[int] = None
|
dimensions: Optional[int] = None
|
||||||
|
|
||||||
use_cross_encoder: bool = False
|
|
||||||
"""Internal use only."""
|
|
||||||
|
|
||||||
logits_processing_needs_token_ids: bool = False
|
|
||||||
"""Internal use only."""
|
|
||||||
|
|
||||||
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
||||||
|
|
||||||
|
task: Optional[PoolingTask] = None
|
||||||
|
"""Internal use only."""
|
||||||
|
|
||||||
|
requires_token_ids: bool = False
|
||||||
|
"""Internal use only."""
|
||||||
|
|
||||||
def clone(self) -> "PoolingParams":
|
def clone(self) -> "PoolingParams":
|
||||||
"""Returns a deep copy of the PoolingParams instance."""
|
"""Returns a deep copy of the PoolingParams instance."""
|
||||||
return PoolingParams(
|
return PoolingParams(
|
||||||
dimensions=self.dimensions,
|
dimensions=self.dimensions,
|
||||||
use_cross_encoder=self.use_cross_encoder,
|
task=self.task,
|
||||||
logits_processing_needs_token_ids=self.
|
requires_token_ids=self.requires_token_ids,
|
||||||
logits_processing_needs_token_ids,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def verify(self, model_config: "ModelConfig") -> None:
|
def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None:
|
||||||
|
if self.task is None:
|
||||||
|
self.task = task
|
||||||
|
elif self.task != task:
|
||||||
|
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# NOTE: Task validation needs to done against the model instance,
|
||||||
|
# which is not available in model config. So, it's not included
|
||||||
|
# in this method
|
||||||
|
|
||||||
if self.dimensions is not None:
|
if self.dimensions is not None:
|
||||||
if not model_config.is_matryoshka:
|
if not model_config.is_matryoshka:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -61,12 +72,10 @@ class PoolingParams(
|
|||||||
raise ValueError("Dimensions must be greater than 0")
|
raise ValueError("Dimensions must be greater than 0")
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (f"PoolingParams("
|
||||||
f"PoolingParams("
|
f"dimensions={self.dimensions}, "
|
||||||
f"dimensions={self.dimensions}, "
|
f"task={self.task}, "
|
||||||
f"use_cross_encoder={self.use_cross_encoder}, "
|
f"requires_token_ids={self.requires_token_ids})")
|
||||||
f"logits_processing_needs_token_ids={self.logits_processing_needs_token_ids})"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\
|
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\
|
||||||
|
|||||||
@ -181,6 +181,12 @@ class EngineCore:
|
|||||||
|
|
||||||
def add_request(self, request: EngineCoreRequest):
|
def add_request(self, request: EngineCoreRequest):
|
||||||
"""Add request to the scheduler."""
|
"""Add request to the scheduler."""
|
||||||
|
if pooling_params := request.pooling_params:
|
||||||
|
supported_pooling_tasks = (
|
||||||
|
self.model_executor.supported_pooling_tasks)
|
||||||
|
if pooling_params.task not in supported_pooling_tasks:
|
||||||
|
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
|
||||||
|
f"Supported tasks: {supported_pooling_tasks}")
|
||||||
|
|
||||||
if request.mm_hashes is not None:
|
if request.mm_hashes is not None:
|
||||||
# Here, if hash exists for a multimodal input, then it will be
|
# Here, if hash exists for a multimodal input, then it will be
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import torch
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.models.interfaces import has_step_pooler
|
|
||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -54,9 +53,6 @@ class CPUModelRunner(GPUModelRunner):
|
|||||||
logger.info("Starting to load model %s...", self.model_config.model)
|
logger.info("Starting to load model %s...", self.model_config.model)
|
||||||
self.model = get_model(vllm_config=self.vllm_config)
|
self.model = get_model(vllm_config=self.vllm_config)
|
||||||
|
|
||||||
if has_step_pooler(self.model):
|
|
||||||
self.input_batch.logits_processing_needs_token_ids = True
|
|
||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
self.model = self.load_lora_model(self.model, self.model_config,
|
self.model = self.load_lora_model(self.model, self.model_config,
|
||||||
self.scheduler_config,
|
self.scheduler_config,
|
||||||
|
|||||||
@ -70,7 +70,6 @@ class InputBatch:
|
|||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
block_sizes: list[int], # The block_size of each kv cache group
|
block_sizes: list[int], # The block_size of each kv cache group
|
||||||
is_spec_decode: bool = False,
|
is_spec_decode: bool = False,
|
||||||
logits_processing_needs_token_ids: bool = False,
|
|
||||||
):
|
):
|
||||||
self.is_spec_decode = is_spec_decode
|
self.is_spec_decode = is_spec_decode
|
||||||
self.max_num_reqs = max_num_reqs
|
self.max_num_reqs = max_num_reqs
|
||||||
@ -79,8 +78,6 @@ class InputBatch:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.pin_memory = pin_memory
|
self.pin_memory = pin_memory
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.logits_processing_needs_token_ids = (
|
|
||||||
logits_processing_needs_token_ids)
|
|
||||||
|
|
||||||
self._req_ids: list[Optional[str]] = []
|
self._req_ids: list[Optional[str]] = []
|
||||||
self.req_id_to_index: dict[str, int] = {}
|
self.req_id_to_index: dict[str, int] = {}
|
||||||
@ -233,6 +230,9 @@ class InputBatch:
|
|||||||
# req_index -> bad_words_token_ids
|
# req_index -> bad_words_token_ids
|
||||||
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
||||||
|
|
||||||
|
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
|
||||||
|
dtype=bool)
|
||||||
|
|
||||||
self.req_output_token_ids: list[Optional[list[int]]] = []
|
self.req_output_token_ids: list[Optional[list[int]]] = []
|
||||||
|
|
||||||
# This is updated each time the batch constituents change.
|
# This is updated each time the batch constituents change.
|
||||||
@ -365,9 +365,12 @@ class InputBatch:
|
|||||||
if sampling_params.bad_words_token_ids:
|
if sampling_params.bad_words_token_ids:
|
||||||
self.bad_words_token_ids[
|
self.bad_words_token_ids[
|
||||||
req_index] = sampling_params.bad_words_token_ids
|
req_index] = sampling_params.bad_words_token_ids
|
||||||
|
elif pooling_params := request.pooling_params:
|
||||||
|
self.pooling_params[req_id] = pooling_params
|
||||||
|
self.logits_processing_needs_token_ids[req_index] = (
|
||||||
|
pooling_params.requires_token_ids)
|
||||||
else:
|
else:
|
||||||
assert request.pooling_params is not None
|
raise NotImplementedError(request)
|
||||||
self.pooling_params[req_id] = request.pooling_params
|
|
||||||
|
|
||||||
# Add request lora ID
|
# Add request lora ID
|
||||||
if request.lora_request:
|
if request.lora_request:
|
||||||
@ -620,9 +623,9 @@ class InputBatch:
|
|||||||
copy_slice(self.repetition_penalties_cpu_tensor,
|
copy_slice(self.repetition_penalties_cpu_tensor,
|
||||||
self.repetition_penalties, num_reqs)
|
self.repetition_penalties, num_reqs)
|
||||||
|
|
||||||
needs_prompt_token_ids = (not self.no_penalties or
|
needs_prompt_token_ids = (
|
||||||
(self.num_reqs > 0
|
not self.no_penalties
|
||||||
and self.logits_processing_needs_token_ids))
|
or self.logits_processing_needs_token_ids[:num_reqs].any())
|
||||||
if needs_prompt_token_ids:
|
if needs_prompt_token_ids:
|
||||||
# The prompt tokens are used only for applying penalties or
|
# The prompt tokens are used only for applying penalties or
|
||||||
# step pooling during the sampling/pooling process.
|
# step pooling during the sampling/pooling process.
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union, cast, get_args
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -32,12 +32,13 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||||
from vllm.model_executor.models.interfaces import (has_step_pooler,
|
from vllm.model_executor.models.interfaces import is_mixture_of_experts
|
||||||
is_mixture_of_experts)
|
from vllm.model_executor.models.interfaces_base import (VllmModelForPooling,
|
||||||
|
is_pooling_model)
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams, PoolingTask
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
@ -404,6 +405,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
req_id = new_req_data.req_id
|
req_id = new_req_data.req_id
|
||||||
sampling_params = new_req_data.sampling_params
|
sampling_params = new_req_data.sampling_params
|
||||||
pooling_params = new_req_data.pooling_params
|
pooling_params = new_req_data.pooling_params
|
||||||
|
|
||||||
if sampling_params and \
|
if sampling_params and \
|
||||||
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||||
generator = torch.Generator(device=self.device)
|
generator = torch.Generator(device=self.device)
|
||||||
@ -411,6 +413,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
generator = None
|
generator = None
|
||||||
|
|
||||||
|
if pooling_params:
|
||||||
|
assert pooling_params.task is not None, (
|
||||||
|
"You did not set `task` in the API")
|
||||||
|
|
||||||
|
model = cast(VllmModelForPooling, self.model)
|
||||||
|
to_update = (model.pooler.get_pooling_updates(
|
||||||
|
pooling_params.task))
|
||||||
|
assert to_update is not None, (
|
||||||
|
f"{pooling_params.task=} is not supported by the model")
|
||||||
|
|
||||||
|
to_update.apply(pooling_params)
|
||||||
|
|
||||||
self.requests[req_id] = CachedRequestState(
|
self.requests[req_id] = CachedRequestState(
|
||||||
req_id=req_id,
|
req_id=req_id,
|
||||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||||
@ -1092,6 +1106,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||||
|
model = self.get_model()
|
||||||
|
if not is_pooling_model(model):
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [
|
||||||
|
task for task in get_args(PoolingTask)
|
||||||
|
if model.pooler.get_pooling_updates(task)
|
||||||
|
]
|
||||||
|
|
||||||
def apply_grammar_bitmask(
|
def apply_grammar_bitmask(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
@ -1737,8 +1761,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
model_loader.load_weights(self.model,
|
model_loader.load_weights(self.model,
|
||||||
model_config=self.model_config)
|
model_config=self.model_config)
|
||||||
if has_step_pooler(self.model):
|
|
||||||
self.input_batch.logits_processing_needs_token_ids = True
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
self.model = self.load_lora_model(self.model,
|
self.model = self.load_lora_model(self.model,
|
||||||
self.model_config,
|
self.model_config,
|
||||||
@ -2138,17 +2160,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
req_num_tokens = num_tokens // num_reqs
|
req_num_tokens = num_tokens // num_reqs
|
||||||
|
|
||||||
|
model = cast(VllmModelForPooling, self.model)
|
||||||
|
dummy_task = self.get_supported_pooling_tasks()[0]
|
||||||
|
dummy_pooling_params = PoolingParams(task=dummy_task)
|
||||||
|
|
||||||
|
to_update = model.pooler.get_pooling_updates(dummy_task)
|
||||||
|
assert to_update is not None
|
||||||
|
to_update.apply(dummy_pooling_params)
|
||||||
|
|
||||||
dummy_metadata = PoolingMetadata(
|
dummy_metadata = PoolingMetadata(
|
||||||
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
|
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
|
||||||
device=self.device),
|
device=self.device),
|
||||||
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
|
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device),
|
device=self.device),
|
||||||
pooling_params=[PoolingParams()] * num_reqs)
|
pooling_params=[dummy_pooling_params] * num_reqs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pooler_output = self.model.pooler(hidden_states=hidden_states_list,
|
pooler_output = model.pooler(hidden_states=hidden_states_list,
|
||||||
pooling_metadata=dummy_metadata)
|
pooling_metadata=dummy_metadata)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if 'out of memory' in str(e):
|
if 'out of memory' in str(e):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.pooling_params import PoolingTask
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
@ -309,6 +310,9 @@ class Worker(WorkerBase):
|
|||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self.model_runner.get_model()
|
return self.model_runner.get_model()
|
||||||
|
|
||||||
|
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||||
|
return self.model_runner.get_supported_pooling_tasks()
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import bisect
|
import bisect
|
||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
from typing import TYPE_CHECKING, Any, Optional, cast, get_args
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -25,10 +25,12 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.layers import BaseLayerWithLoRA
|
from vllm.lora.layers import BaseLayerWithLoRA
|
||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
from vllm.model_executor.model_loader.tpu import TPUModelLoader
|
from vllm.model_executor.model_loader.tpu import TPUModelLoader
|
||||||
|
from vllm.model_executor.models.interfaces_base import is_pooling_model
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
||||||
PlaceholderRange)
|
PlaceholderRange)
|
||||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
|
from vllm.pooling_params import PoolingTask
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
|
||||||
is_pin_memory_available, prev_power_of_2)
|
is_pin_memory_available, prev_power_of_2)
|
||||||
@ -483,6 +485,16 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||||
|
model = self.get_model()
|
||||||
|
if not is_pooling_model(model):
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [
|
||||||
|
task for task in get_args(PoolingTask)
|
||||||
|
if model.pooler.get_pooling_updates(task)
|
||||||
|
]
|
||||||
|
|
||||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||||
"""
|
"""
|
||||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.pooling_params import PoolingTask
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||||
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
|
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@ -275,6 +276,9 @@ class TPUWorker:
|
|||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self.model_runner.get_model()
|
return self.model_runner.get_model()
|
||||||
|
|
||||||
|
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||||
|
return self.model_runner.get_supported_pooling_tasks()
|
||||||
|
|
||||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||||
return self.model_runner.get_kv_cache_spec()
|
return self.model_runner.get_kv_cache_spec()
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
|
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
|
||||||
TypeVar)
|
TypeVar, get_args)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -12,6 +12,8 @@ import torch.nn as nn
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.model_executor.models.interfaces_base import is_pooling_model
|
||||||
|
from vllm.pooling_params import PoolingTask
|
||||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -223,6 +225,16 @@ class ModelRunnerBase(ABC, Generic[T]):
|
|||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||||
|
model = self.get_model()
|
||||||
|
if not is_pooling_model(model):
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [
|
||||||
|
task for task in get_args(PoolingTask)
|
||||||
|
if model.pooler.get_pooling_updates(task)
|
||||||
|
]
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
model_input: T,
|
model_input: T,
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -10,6 +10,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.distributed import get_pp_group
|
from vllm.distributed import get_pp_group
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
|
||||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.multimodal import MultiModalKwargs
|
from vllm.multimodal import MultiModalKwargs
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
@ -195,7 +196,20 @@ class PoolingModelRunner(
|
|||||||
seq_groups: List[Tuple[List[int], PoolingParams]] = []
|
seq_groups: List[Tuple[List[int], PoolingParams]] = []
|
||||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
|
||||||
pooling_params = seq_group_metadata.pooling_params
|
pooling_params = seq_group_metadata.pooling_params
|
||||||
|
assert pooling_params is not None
|
||||||
|
assert pooling_params.task is not None, (
|
||||||
|
"You did not set `task` in the API")
|
||||||
|
|
||||||
|
to_update = (cast(VllmModelForPooling,
|
||||||
|
self.model).pooler.get_pooling_updates(
|
||||||
|
pooling_params.task))
|
||||||
|
assert to_update is not None, (
|
||||||
|
f"{pooling_params.task=} is not supported by the model")
|
||||||
|
|
||||||
|
to_update.apply(pooling_params)
|
||||||
|
|
||||||
seq_groups.append((seq_ids, pooling_params))
|
seq_groups.append((seq_ids, pooling_params))
|
||||||
|
|
||||||
seq_data: Dict[int, SequenceData] = {}
|
seq_data: Dict[int, SequenceData] = {}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user