[Model][6/N] Improve all pooling task | Support chunked prefill with ALL pooling (#27145)

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
wang.yuqi 2025-12-04 21:44:15 +08:00 committed by GitHub
parent 1b7c7f5159
commit 74c4d80c6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 224 additions and 93 deletions

View File

@ -54,7 +54,7 @@ th:not(:first-child) {
| beam-search | ✅ | ✅ | ✅ | [](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | |
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
\* Chunked prefill and prefix caching are only applicable to last-token pooling.
\* Chunked prefill and prefix caching are only applicable to last-token or all pooling with causal attention.
<sup>^</sup> LoRA is only applicable to the language backbone of multimodal models.
### Feature x Hardware

View File

@ -61,11 +61,8 @@ def test_pooling_params(llm: LLM):
@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM):
# chunked prefill does not support all pooling
err_msg = "pooling_task must be one of.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
def test_token_classify(llm: LLM):
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
def test_score_api(llm: LLM):

View File

@ -255,21 +255,21 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
# token_classify uses ALL pooling, which does not support chunked prefill.
task = "token_classify"
input_text = ["This product was excellent and exceeded my expectations"]
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": "test",
"input": input_text,
"encoding_format": "float",
"task": task,
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(
f"Task {task} is not supported"
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 8
assert len(poolings.data[0].data[0]) == 2
@pytest.mark.asyncio

View File

@ -42,7 +42,7 @@ def llm():
@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM):
def test_token_embed(llm: LLM):
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
multi_vector = outputs[0].outputs.data
assert multi_vector.shape == (11, 384)

View File

@ -36,6 +36,13 @@ def llm():
cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
def test_config(llm: LLM):
vllm_config = llm.llm_engine.vllm_config
assert vllm_config.cache_config.enable_prefix_caching
assert vllm_config.scheduler_config.enable_chunked_prefill
def test_pooling_params(llm: LLM):
def get_outputs(use_activation):
outputs = llm.reward(

View File

@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from transformers import AutoModel
from tests.models.utils import check_embeddings_close
from vllm import TokensPrompt
@pytest.mark.parametrize(
"model",
["Qwen/Qwen3-Embedding-0.6B"],
)
@torch.inference_mode
def test_embed_models(hf_runner, vllm_runner, model: str):
chunk_size = 10
n_prompt_tokens = [55, 56, 57]
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
with vllm_runner(
model,
runner="pooling",
max_model_len=128,
max_num_batched_tokens=chunk_size,
enforce_eager=True,
# `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner
enable_chunked_prefill=True,
enable_prefix_caching=True,
) as vllm_model:
vllm_outputs = vllm_model.token_embed(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
)
with hf_runner(
model,
auto_cls=AutoModel,
) as hf_model:
hf_outputs = []
for token_prompt in token_prompts:
inputs = hf_model.wrap_device({"input_ids": torch.tensor([token_prompt])})
input_ids = inputs["input_ids"]
output = hf_model.model(input_ids)
hf_outputs.append(output.last_hidden_state.cpu().float()[0])
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
check_embeddings_close(
embeddings_0_lst=hf_output,
embeddings_1_lst=vllm_output,
name_0="hf",
name_1="vllm",
tol=1e-2,
)

View File

@ -20,7 +20,6 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
max_model_len=128,
enforce_eager=True,
runner="pooling",
enable_chunked_prefill=False,
enable_prefix_caching=True,
) as vllm_model:
pooling_outputs = vllm_model.llm.encode(

View File

@ -629,8 +629,8 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
(
"internlm/internlm2-1_8b-reward",
"decoder",
False,
"Pooling models with all pooling does not support chunked prefill.",
True,
"Pooling models with causal attn and all pooling support chunked prefill.",
),
(
"BAAI/bge-base-en",
@ -748,8 +748,8 @@ def test_is_chunked_prefill_supported(
(
"internlm/internlm2-1_8b-reward",
"decoder",
False,
"Pooling models with all pooling does not support prefix caching.",
True,
"Pooling models with causal attn and all pooling support prefix caching.",
),
(
"BAAI/bge-base-en",

View File

@ -1780,20 +1780,22 @@ class ModelConfig:
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]:
if pooling_type in ["mean", "step", "cls"]:
logger.debug(
"Pooling models with %s pooling does not "
"support chunked prefill.",
pooling_type,
)
return False
else:
# pooling_type == "last"
elif pooling_type in ["all", "last"]:
logger.debug(
"Pooling models with causal attn and last pooling support "
"chunked prefill."
"Pooling models with causal attn and %s pooling support "
"chunked prefill.",
pooling_type,
)
return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return attn_type != "encoder_decoder"
@ -1817,20 +1819,22 @@ class ModelConfig:
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]:
if pooling_type in ["mean", "step", "cls"]:
logger.debug(
"Pooling models with %s pooling does not "
"support prefix caching.",
pooling_type,
)
return False
else:
# pooling_type == "last"
elif pooling_type in ["all", "last"]:
logger.debug(
"Pooling models with causal attn and last pooling support "
"prefix caching."
"Pooling models with causal attn and %s pooling support "
"prefix caching.",
pooling_type,
)
return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return False

View File

@ -127,14 +127,14 @@ class PoolingMethod(nn.Module, ABC):
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor:
) -> PoolerOutput:
raise NotImplementedError
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor] | torch.Tensor:
) -> PoolerOutput:
pooling_cursor = pooling_metadata.pooling_cursor
return self.forward_all(hidden_states, pooling_cursor)
@ -147,7 +147,7 @@ class CLSPool(PoolingMethod):
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor:
) -> PoolerOutput:
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with CLS pooling"
)
@ -163,27 +163,65 @@ class LastPool(PoolingMethod):
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor:
) -> PoolerOutput:
return hidden_states[pooling_cursor.last_token_indices_gpu]
class AllPool(PoolingMethod):
def __init__(self):
super().__init__()
vllm_config = get_current_vllm_config()
self.enable_chunked_prefill = (
vllm_config.scheduler_config.enable_chunked_prefill
)
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def forward_all(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor:
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with ALL pooling"
self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor
) -> PoolerOutput:
raise NotImplementedError(
"forward_all is not implemented for AllPool. Use forward instead."
)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooling_cursor = pooling_metadata.pooling_cursor
is_finished = pooling_cursor.is_finished()
hidden_states_lst = list(
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
)
return [hidden_states_lst[i] for i in pooling_cursor.index]
hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index]
if not self.enable_chunked_prefill:
return hidden_states_lst
pooling_states = pooling_metadata.pooling_states
# If chunked_prefill is enabled
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
for p, hs_chunk in zip(pooling_states, hidden_states_lst):
p.hidden_states_cache.append(hs_chunk)
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list: PoolerOutput = []
for p, finished in zip(pooling_states, is_finished):
if finished:
hidden_states_cache = p.hidden_states_cache
if len(hidden_states_cache) == 1:
output_list.append(hidden_states_cache[0])
else:
output_list.append(torch.concat(hidden_states_cache, dim=0))
p.clean()
else:
output_list.append(None)
return output_list
class MeanPool(PoolingMethod):
@ -194,7 +232,7 @@ class MeanPool(PoolingMethod):
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor:
) -> PoolerOutput:
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with MEAN pooling"
)
@ -399,7 +437,7 @@ class PoolerHead(nn.Module):
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
):
) -> PoolerOutput:
return self.activation(pooled_data)
@ -418,7 +456,7 @@ class EmbeddingPoolerHead(PoolerHead):
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
):
) -> PoolerOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
@ -586,8 +624,12 @@ class ClassifierPooler(Pooler):
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
def forward(
self, pooled_data: torch.Tensor, pooling_param: PoolingParams
) -> torch.Tensor:
self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams
) -> PoolerOutput:
# for unfinished chunked prefill
if pooled_data is None:
return None
pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]
@ -630,9 +672,13 @@ class TokenClassifierPoolerHead(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
hidden_states: torch.Tensor | None,
pooling_param: PoolingParams,
) -> torch.Tensor:
) -> PoolerOutput:
# for unfinished chunked prefill
if hidden_states is None:
return None
hidden_states = hidden_states.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]
@ -686,17 +732,20 @@ class StepPooler(Pooler):
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]:
) -> PoolerOutput:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
pooled_data = list[torch.Tensor]()
pooling_params = pooling_metadata.pooling_params
pooled_data: PoolerOutput = []
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
# for unfinished chunked prefill
if data is None:
pooled_data.append(data)
continue
step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids

View File

@ -64,7 +64,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
from .interfaces_base import default_pooling_type
from .interfaces_base import attn_type
logger = init_logger(__name__)
@ -220,7 +220,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
)
@default_pooling_type("All")
@attn_type("attention_free")
@MULTIMODAL_REGISTRY.register_processor(
TerratorchMultiModalProcessor,
info=TerratorchProcessingInfo,

View File

@ -89,7 +89,7 @@ class LogprobsTensors(NamedTuple):
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
PoolerOutput = torch.Tensor | list[torch.Tensor]
PoolerOutput = list[torch.Tensor | None] | torch.Tensor | None
@dataclass

View File

@ -17,6 +17,7 @@ class PoolingCursor:
first_token_indices_gpu: torch.Tensor
last_token_indices_gpu: torch.Tensor
prompt_lens_cpu: torch.Tensor
seq_lens_cpu: torch.Tensor
num_scheduled_tokens_cpu: torch.Tensor
def __getitem__(self, indices: slice):
@ -25,12 +26,25 @@ class PoolingCursor:
first_token_indices_gpu=self.first_token_indices_gpu[indices],
last_token_indices_gpu=self.last_token_indices_gpu[indices],
prompt_lens_cpu=self.prompt_lens_cpu[indices],
seq_lens_cpu=self.seq_lens_cpu[indices],
num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices],
)
def is_partial_prefill(self):
return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu)
def is_finished(self):
return self.prompt_lens_cpu == self.seq_lens_cpu
class PoolingStates:
def __init__(self):
# for chunked prefill with ALL pooling
self.hidden_states_cache: list[torch.Tensor] = []
def clean(self):
self.hidden_states_cache.clear()
@dataclass
class PoolingMetadata:
@ -39,6 +53,7 @@ class PoolingMetadata:
prompt_lens: torch.Tensor # CPU Tensor
prompt_token_ids: torch.Tensor | None
pooling_params: list[PoolingParams]
pooling_states: list[PoolingStates]
pooling_cursor: PoolingCursor | None = None
def __post_init__(self) -> None:
@ -60,6 +75,7 @@ class PoolingMetadata:
if self.prompt_token_ids is None
else self.prompt_token_ids[indices],
pooling_params=self.pooling_params[indices],
pooling_states=self.pooling_states[indices],
pooling_cursor=None
if self.pooling_cursor is None
else self.pooling_cursor[indices],
@ -74,15 +90,21 @@ class PoolingMetadata:
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
def build_pooling_cursor(
self, num_scheduled_tokens: list[int], device: torch.device
self,
num_scheduled_tokens: list[int],
seq_lens_cpu: torch.Tensor,
device: torch.device,
):
self.pooling_cursor = build_pooling_cursor(
num_scheduled_tokens, self.prompt_lens, device
num_scheduled_tokens, seq_lens_cpu, self.prompt_lens, device
)
def build_pooling_cursor(
num_scheduled_tokens: list[int], prompt_lens: torch.Tensor, device: torch.device
num_scheduled_tokens: list[int],
seq_lens_cpu: torch.Tensor,
prompt_lens: torch.Tensor,
device: torch.device,
):
assert len(prompt_lens) == len(num_scheduled_tokens)
@ -99,5 +121,6 @@ def build_pooling_cursor(
first_token_indices_gpu=cumsum[:n_seq],
last_token_indices_gpu=cumsum[1:] - 1,
prompt_lens_cpu=prompt_lens,
seq_lens_cpu=seq_lens_cpu,
num_scheduled_tokens_cpu=num_scheduled_tokens_cpu,
)

View File

@ -15,7 +15,7 @@ from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collection_utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
from vllm.v1.sample.logits_processor import (
BatchUpdateBuilder,
LogitsProcessors,
@ -33,7 +33,6 @@ class CachedRequestState:
prompt_token_ids: list[int] | None
mm_features: list[MultiModalFeatureSpec]
sampling_params: SamplingParams | None
pooling_params: PoolingParams | None
generator: torch.Generator | None
block_ids: tuple[list[int], ...]
@ -51,11 +50,18 @@ class CachedRequestState:
# Used when both async_scheduling and spec_decode are enabled.
prev_num_draft_len: int = 0
# for pooling models
pooling_params: PoolingParams | None = None
pooling_states: PoolingStates | None = None
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds
)
if self.pooling_params is not None:
self.pooling_states = PoolingStates()
@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
@ -255,7 +261,9 @@ class InputBatch:
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
# for pooling models
self.pooling_params: dict[str, PoolingParams] = {}
self.pooling_states: dict[str, PoolingStates] = {}
# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: torch.Tensor | None = None
@ -413,7 +421,11 @@ class InputBatch:
sampling_params.bad_words_token_ids
)
elif pooling_params := request.pooling_params:
pooling_states = request.pooling_states
assert pooling_states is not None
self.pooling_params[req_id] = pooling_params
self.pooling_states[req_id] = pooling_states
self.logits_processing_needs_token_ids[req_index] = (
pooling_params.requires_token_ids
)
@ -469,6 +481,7 @@ class InputBatch:
if self.is_pooling_model:
self.pooling_params.pop(req_id, None)
self.pooling_states.pop(req_id, None)
return req_index
self.greedy_reqs.discard(req_id)
@ -837,13 +850,19 @@ class InputBatch:
assert len(self.req_ids) == len(self.pooling_params)
return [self.pooling_params[req_id] for req_id in self.req_ids]
def get_pooling_states(self) -> list[PoolingStates]:
assert len(self.req_ids) == len(self.pooling_states)
return [self.pooling_states[req_id] for req_id in self.req_ids]
def get_pooling_metadata(self) -> PoolingMetadata:
pooling_params = self.get_pooling_params()
pooling_states = self.get_pooling_states()
return PoolingMetadata(
prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
pooling_states=pooling_states,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:

View File

@ -131,7 +131,7 @@ from vllm.v1.outputs import (
SamplerOutput,
make_empty_encoder_model_runner_output,
)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata
@ -2291,20 +2291,6 @@ class GPUModelRunner(
supported_tasks = list(model.pooler.get_supported_tasks())
if self.scheduler_config.enable_chunked_prefill:
if "token_embed" in supported_tasks:
supported_tasks.remove("token_embed")
if "token_classify" in supported_tasks:
supported_tasks.remove("token_classify")
logger.debug_once(
"Chunked prefill is not supported with "
"token_embed and token_classify tasks "
"which using ALL pooling. "
"Please turn off chunked prefill by "
"`--no-enable-chunked-prefill` before using it."
)
if "score" in supported_tasks:
num_labels = getattr(self.model_config.hf_config, "num_labels", 0)
if num_labels != 1:
@ -2381,11 +2367,12 @@ class GPUModelRunner(
)
hidden_states = hidden_states[:num_scheduled_tokens]
seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs]
pooling_metadata = self.input_batch.get_pooling_metadata()
pooling_metadata.build_pooling_cursor(
num_scheduled_tokens_np.tolist(), device=hidden_states.device
num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device
)
seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs]
model = cast(VllmModelForPooling, self.model)
raw_pooler_output: PoolerOutput = model.pooler(
@ -2393,7 +2380,7 @@ class GPUModelRunner(
pooling_metadata=pooling_metadata,
)
raw_pooler_output = json_map_leaves(
lambda x: x.to("cpu", non_blocking=True),
lambda x: x.to("cpu", non_blocking=True) if x is not None else x,
raw_pooler_output,
)
self._sync_device()
@ -4248,10 +4235,13 @@ class GPUModelRunner(
prompt_lens=dummy_prompt_lens,
prompt_token_ids=dummy_token_ids,
pooling_params=[dummy_pooling_params] * num_reqs,
pooling_states=[PoolingStates() for i in range(num_reqs)],
)
dummy_metadata.build_pooling_cursor(
num_scheduled_tokens_list, device=hidden_states.device
num_scheduled_tokens_list,
seq_lens_cpu=dummy_prompt_lens,
device=hidden_states.device,
)
try:
@ -4278,22 +4268,12 @@ class GPUModelRunner(
supported_pooling_tasks = self.get_supported_pooling_tasks()
if not supported_pooling_tasks:
if self.scheduler_config.enable_chunked_prefill:
raise RuntimeError(
f"Model {self.model_config.model} does not support "
"any pooling tasks with chunked prefill enabled. "
"Please add --no-enable-chunked-prefill to your "
"config or CLI args. See "
"https://docs.vllm.ai/en/latest/models/pooling_models.html "
"to learn more."
)
else:
raise RuntimeError(
f"Model {self.model_config.model} does not support "
"any pooling tasks. See "
"https://docs.vllm.ai/en/latest/models/pooling_models.html "
"to learn more."
)
raise RuntimeError(
f"Model {self.model_config.model} does not support "
"any pooling tasks. See "
"https://docs.vllm.ai/en/latest/models/pooling_models.html "
"to learn more."
)
output_size = dict[PoolingTask, float]()
for task in supported_pooling_tasks: