[Model][6/N] Improve all pooling task | Support chunked prefill with ALL pooling (#27145)

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
wang.yuqi 2025-12-04 21:44:15 +08:00 committed by GitHub
parent 1b7c7f5159
commit 74c4d80c6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 224 additions and 93 deletions

View File

@ -54,7 +54,7 @@ th:not(:first-child) {
| beam-search | ✅ | ✅ | ✅ | [](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | | | beam-search | ✅ | ✅ | ✅ | [](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | |
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ | | [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
\* Chunked prefill and prefix caching are only applicable to last-token pooling. \* Chunked prefill and prefix caching are only applicable to last-token or all pooling with causal attention.
<sup>^</sup> LoRA is only applicable to the language backbone of multimodal models. <sup>^</sup> LoRA is only applicable to the language backbone of multimodal models.
### Feature x Hardware ### Feature x Hardware

View File

@ -61,11 +61,8 @@ def test_pooling_params(llm: LLM):
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM): def test_token_classify(llm: LLM):
# chunked prefill does not support all pooling llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
err_msg = "pooling_task must be one of.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
def test_score_api(llm: LLM): def test_score_api(llm: LLM):

View File

@ -255,21 +255,21 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str): async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
# token_classify uses ALL pooling, which does not support chunked prefill.
task = "token_classify" task = "token_classify"
input_text = ["This product was excellent and exceeded my expectations"]
response = requests.post( response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={ json={
"model": model_name, "model": model_name,
"input": "test", "input": input_text,
"encoding_format": "float", "encoding_format": "float",
"task": task, "task": task,
}, },
) )
assert response.json()["error"]["type"] == "BadRequestError" poolings = PoolingResponse.model_validate(response.json())
assert response.json()["error"]["message"].startswith( assert len(poolings.data) == 1
f"Task {task} is not supported" assert len(poolings.data[0].data) == 8
) assert len(poolings.data[0].data[0]) == 2
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -42,7 +42,7 @@ def llm():
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM): def test_token_embed(llm: LLM):
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False) outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
multi_vector = outputs[0].outputs.data multi_vector = outputs[0].outputs.data
assert multi_vector.shape == (11, 384) assert multi_vector.shape == (11, 384)

View File

@ -36,6 +36,13 @@ def llm():
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
def test_config(llm: LLM):
vllm_config = llm.llm_engine.vllm_config
assert vllm_config.cache_config.enable_prefix_caching
assert vllm_config.scheduler_config.enable_chunked_prefill
def test_pooling_params(llm: LLM): def test_pooling_params(llm: LLM):
def get_outputs(use_activation): def get_outputs(use_activation):
outputs = llm.reward( outputs = llm.reward(

View File

@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from transformers import AutoModel
from tests.models.utils import check_embeddings_close
from vllm import TokensPrompt
@pytest.mark.parametrize(
"model",
["Qwen/Qwen3-Embedding-0.6B"],
)
@torch.inference_mode
def test_embed_models(hf_runner, vllm_runner, model: str):
chunk_size = 10
n_prompt_tokens = [55, 56, 57]
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
with vllm_runner(
model,
runner="pooling",
max_model_len=128,
max_num_batched_tokens=chunk_size,
enforce_eager=True,
# `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner
enable_chunked_prefill=True,
enable_prefix_caching=True,
) as vllm_model:
vllm_outputs = vllm_model.token_embed(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
)
with hf_runner(
model,
auto_cls=AutoModel,
) as hf_model:
hf_outputs = []
for token_prompt in token_prompts:
inputs = hf_model.wrap_device({"input_ids": torch.tensor([token_prompt])})
input_ids = inputs["input_ids"]
output = hf_model.model(input_ids)
hf_outputs.append(output.last_hidden_state.cpu().float()[0])
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
check_embeddings_close(
embeddings_0_lst=hf_output,
embeddings_1_lst=vllm_output,
name_0="hf",
name_1="vllm",
tol=1e-2,
)

View File

@ -20,7 +20,6 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
max_model_len=128, max_model_len=128,
enforce_eager=True, enforce_eager=True,
runner="pooling", runner="pooling",
enable_chunked_prefill=False,
enable_prefix_caching=True, enable_prefix_caching=True,
) as vllm_model: ) as vllm_model:
pooling_outputs = vllm_model.llm.encode( pooling_outputs = vllm_model.llm.encode(

View File

@ -629,8 +629,8 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
( (
"internlm/internlm2-1_8b-reward", "internlm/internlm2-1_8b-reward",
"decoder", "decoder",
False, True,
"Pooling models with all pooling does not support chunked prefill.", "Pooling models with causal attn and all pooling support chunked prefill.",
), ),
( (
"BAAI/bge-base-en", "BAAI/bge-base-en",
@ -748,8 +748,8 @@ def test_is_chunked_prefill_supported(
( (
"internlm/internlm2-1_8b-reward", "internlm/internlm2-1_8b-reward",
"decoder", "decoder",
False, True,
"Pooling models with all pooling does not support prefix caching.", "Pooling models with causal attn and all pooling support prefix caching.",
), ),
( (
"BAAI/bge-base-en", "BAAI/bge-base-en",

View File

@ -1780,20 +1780,22 @@ class ModelConfig:
return False return False
elif attn_type == "decoder": elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower() pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]: if pooling_type in ["mean", "step", "cls"]:
logger.debug( logger.debug(
"Pooling models with %s pooling does not " "Pooling models with %s pooling does not "
"support chunked prefill.", "support chunked prefill.",
pooling_type, pooling_type,
) )
return False return False
else: elif pooling_type in ["all", "last"]:
# pooling_type == "last"
logger.debug( logger.debug(
"Pooling models with causal attn and last pooling support " "Pooling models with causal attn and %s pooling support "
"chunked prefill." "chunked prefill.",
pooling_type,
) )
return True return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid, # vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types. # attention_free or encoder_decoder attn types.
return attn_type != "encoder_decoder" return attn_type != "encoder_decoder"
@ -1817,20 +1819,22 @@ class ModelConfig:
return False return False
elif attn_type == "decoder": elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower() pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]: if pooling_type in ["mean", "step", "cls"]:
logger.debug( logger.debug(
"Pooling models with %s pooling does not " "Pooling models with %s pooling does not "
"support prefix caching.", "support prefix caching.",
pooling_type, pooling_type,
) )
return False return False
else: elif pooling_type in ["all", "last"]:
# pooling_type == "last"
logger.debug( logger.debug(
"Pooling models with causal attn and last pooling support " "Pooling models with causal attn and %s pooling support "
"prefix caching." "prefix caching.",
pooling_type,
) )
return True return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid, # vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types. # attention_free or encoder_decoder attn types.
return False return False

View File

@ -127,14 +127,14 @@ class PoolingMethod(nn.Module, ABC):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor, pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor: ) -> PoolerOutput:
raise NotImplementedError raise NotImplementedError
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor] | torch.Tensor: ) -> PoolerOutput:
pooling_cursor = pooling_metadata.pooling_cursor pooling_cursor = pooling_metadata.pooling_cursor
return self.forward_all(hidden_states, pooling_cursor) return self.forward_all(hidden_states, pooling_cursor)
@ -147,7 +147,7 @@ class CLSPool(PoolingMethod):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor, pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor: ) -> PoolerOutput:
assert not pooling_cursor.is_partial_prefill(), ( assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with CLS pooling" "partial prefill not supported with CLS pooling"
) )
@ -163,27 +163,65 @@ class LastPool(PoolingMethod):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor, pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor: ) -> PoolerOutput:
return hidden_states[pooling_cursor.last_token_indices_gpu] return hidden_states[pooling_cursor.last_token_indices_gpu]
class AllPool(PoolingMethod): class AllPool(PoolingMethod):
def __init__(self):
super().__init__()
vllm_config = get_current_vllm_config()
self.enable_chunked_prefill = (
vllm_config.scheduler_config.enable_chunked_prefill
)
def get_supported_tasks(self) -> Set[PoolingTask]: def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"} return {"token_embed", "token_classify"}
def forward_all( def forward_all(
self, self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor
hidden_states: torch.Tensor, ) -> PoolerOutput:
pooling_cursor: PoolingCursor, raise NotImplementedError(
) -> list[torch.Tensor] | torch.Tensor: "forward_all is not implemented for AllPool. Use forward instead."
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with ALL pooling"
) )
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooling_cursor = pooling_metadata.pooling_cursor
is_finished = pooling_cursor.is_finished()
hidden_states_lst = list( hidden_states_lst = list(
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist()) hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
) )
return [hidden_states_lst[i] for i in pooling_cursor.index] hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index]
if not self.enable_chunked_prefill:
return hidden_states_lst
pooling_states = pooling_metadata.pooling_states
# If chunked_prefill is enabled
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
for p, hs_chunk in zip(pooling_states, hidden_states_lst):
p.hidden_states_cache.append(hs_chunk)
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list: PoolerOutput = []
for p, finished in zip(pooling_states, is_finished):
if finished:
hidden_states_cache = p.hidden_states_cache
if len(hidden_states_cache) == 1:
output_list.append(hidden_states_cache[0])
else:
output_list.append(torch.concat(hidden_states_cache, dim=0))
p.clean()
else:
output_list.append(None)
return output_list
class MeanPool(PoolingMethod): class MeanPool(PoolingMethod):
@ -194,7 +232,7 @@ class MeanPool(PoolingMethod):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor, pooling_cursor: PoolingCursor,
) -> list[torch.Tensor] | torch.Tensor: ) -> PoolerOutput:
assert not pooling_cursor.is_partial_prefill(), ( assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with MEAN pooling" "partial prefill not supported with MEAN pooling"
) )
@ -399,7 +437,7 @@ class PoolerHead(nn.Module):
self, self,
pooled_data: list[torch.Tensor] | torch.Tensor, pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
): ) -> PoolerOutput:
return self.activation(pooled_data) return self.activation(pooled_data)
@ -418,7 +456,7 @@ class EmbeddingPoolerHead(PoolerHead):
self, self,
pooled_data: list[torch.Tensor] | torch.Tensor, pooled_data: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
): ) -> PoolerOutput:
if isinstance(pooled_data, list): if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data) pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension] # pooled_data shape: [batchsize, hidden_dimension]
@ -586,8 +624,12 @@ class ClassifierPooler(Pooler):
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead): class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
def forward( def forward(
self, pooled_data: torch.Tensor, pooling_param: PoolingParams self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams
) -> torch.Tensor: ) -> PoolerOutput:
# for unfinished chunked prefill
if pooled_data is None:
return None
pooled_data = pooled_data.to(self.head_dtype) pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension] # pooled_data shape: [n_tokens, hidden_dimension]
@ -630,9 +672,13 @@ class TokenClassifierPoolerHead(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor | None,
pooling_param: PoolingParams, pooling_param: PoolingParams,
) -> torch.Tensor: ) -> PoolerOutput:
# for unfinished chunked prefill
if hidden_states is None:
return None
hidden_states = hidden_states.to(self.head_dtype) hidden_states = hidden_states.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size] # hidden_states shape: [n_token, hidden_size]
@ -686,17 +732,20 @@ class StepPooler(Pooler):
self, self,
hidden_states: torch.Tensor | list[torch.Tensor], hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]: ) -> PoolerOutput:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata) pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = pooling_metadata.get_prompt_token_ids() prompt_token_ids = pooling_metadata.get_prompt_token_ids()
pooled_data = list[torch.Tensor]()
pooling_params = pooling_metadata.pooling_params pooling_params = pooling_metadata.pooling_params
pooled_data: PoolerOutput = []
for data, token_id, pooling_param in zip( for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params pooled_data_lst, prompt_token_ids, pooling_params
): ):
# for unfinished chunked prefill
if data is None:
pooled_data.append(data)
continue
step_tag_id = pooling_param.step_tag_id step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids returned_token_ids = pooling_param.returned_token_ids

View File

@ -64,7 +64,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
from .interfaces_base import default_pooling_type from .interfaces_base import attn_type
logger = init_logger(__name__) logger = init_logger(__name__)
@ -220,7 +220,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
) )
@default_pooling_type("All") @attn_type("attention_free")
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
TerratorchMultiModalProcessor, TerratorchMultiModalProcessor,
info=TerratorchProcessingInfo, info=TerratorchProcessingInfo,

View File

@ -89,7 +89,7 @@ class LogprobsTensors(NamedTuple):
# [num_reqs, <dynamic>] # [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used # The shape of each element depends on the pooler used
PoolerOutput = torch.Tensor | list[torch.Tensor] PoolerOutput = list[torch.Tensor | None] | torch.Tensor | None
@dataclass @dataclass

View File

@ -17,6 +17,7 @@ class PoolingCursor:
first_token_indices_gpu: torch.Tensor first_token_indices_gpu: torch.Tensor
last_token_indices_gpu: torch.Tensor last_token_indices_gpu: torch.Tensor
prompt_lens_cpu: torch.Tensor prompt_lens_cpu: torch.Tensor
seq_lens_cpu: torch.Tensor
num_scheduled_tokens_cpu: torch.Tensor num_scheduled_tokens_cpu: torch.Tensor
def __getitem__(self, indices: slice): def __getitem__(self, indices: slice):
@ -25,12 +26,25 @@ class PoolingCursor:
first_token_indices_gpu=self.first_token_indices_gpu[indices], first_token_indices_gpu=self.first_token_indices_gpu[indices],
last_token_indices_gpu=self.last_token_indices_gpu[indices], last_token_indices_gpu=self.last_token_indices_gpu[indices],
prompt_lens_cpu=self.prompt_lens_cpu[indices], prompt_lens_cpu=self.prompt_lens_cpu[indices],
seq_lens_cpu=self.seq_lens_cpu[indices],
num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices], num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices],
) )
def is_partial_prefill(self): def is_partial_prefill(self):
return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu)
def is_finished(self):
return self.prompt_lens_cpu == self.seq_lens_cpu
class PoolingStates:
def __init__(self):
# for chunked prefill with ALL pooling
self.hidden_states_cache: list[torch.Tensor] = []
def clean(self):
self.hidden_states_cache.clear()
@dataclass @dataclass
class PoolingMetadata: class PoolingMetadata:
@ -39,6 +53,7 @@ class PoolingMetadata:
prompt_lens: torch.Tensor # CPU Tensor prompt_lens: torch.Tensor # CPU Tensor
prompt_token_ids: torch.Tensor | None prompt_token_ids: torch.Tensor | None
pooling_params: list[PoolingParams] pooling_params: list[PoolingParams]
pooling_states: list[PoolingStates]
pooling_cursor: PoolingCursor | None = None pooling_cursor: PoolingCursor | None = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
@ -60,6 +75,7 @@ class PoolingMetadata:
if self.prompt_token_ids is None if self.prompt_token_ids is None
else self.prompt_token_ids[indices], else self.prompt_token_ids[indices],
pooling_params=self.pooling_params[indices], pooling_params=self.pooling_params[indices],
pooling_states=self.pooling_states[indices],
pooling_cursor=None pooling_cursor=None
if self.pooling_cursor is None if self.pooling_cursor is None
else self.pooling_cursor[indices], else self.pooling_cursor[indices],
@ -74,15 +90,21 @@ class PoolingMetadata:
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)] return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
def build_pooling_cursor( def build_pooling_cursor(
self, num_scheduled_tokens: list[int], device: torch.device self,
num_scheduled_tokens: list[int],
seq_lens_cpu: torch.Tensor,
device: torch.device,
): ):
self.pooling_cursor = build_pooling_cursor( self.pooling_cursor = build_pooling_cursor(
num_scheduled_tokens, self.prompt_lens, device num_scheduled_tokens, seq_lens_cpu, self.prompt_lens, device
) )
def build_pooling_cursor( def build_pooling_cursor(
num_scheduled_tokens: list[int], prompt_lens: torch.Tensor, device: torch.device num_scheduled_tokens: list[int],
seq_lens_cpu: torch.Tensor,
prompt_lens: torch.Tensor,
device: torch.device,
): ):
assert len(prompt_lens) == len(num_scheduled_tokens) assert len(prompt_lens) == len(num_scheduled_tokens)
@ -99,5 +121,6 @@ def build_pooling_cursor(
first_token_indices_gpu=cumsum[:n_seq], first_token_indices_gpu=cumsum[:n_seq],
last_token_indices_gpu=cumsum[1:] - 1, last_token_indices_gpu=cumsum[1:] - 1,
prompt_lens_cpu=prompt_lens, prompt_lens_cpu=prompt_lens,
seq_lens_cpu=seq_lens_cpu,
num_scheduled_tokens_cpu=num_scheduled_tokens_cpu, num_scheduled_tokens_cpu=num_scheduled_tokens_cpu,
) )

View File

@ -15,7 +15,7 @@ from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collection_utils import swap_dict_values from vllm.utils.collection_utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
from vllm.v1.sample.logits_processor import ( from vllm.v1.sample.logits_processor import (
BatchUpdateBuilder, BatchUpdateBuilder,
LogitsProcessors, LogitsProcessors,
@ -33,7 +33,6 @@ class CachedRequestState:
prompt_token_ids: list[int] | None prompt_token_ids: list[int] | None
mm_features: list[MultiModalFeatureSpec] mm_features: list[MultiModalFeatureSpec]
sampling_params: SamplingParams | None sampling_params: SamplingParams | None
pooling_params: PoolingParams | None
generator: torch.Generator | None generator: torch.Generator | None
block_ids: tuple[list[int], ...] block_ids: tuple[list[int], ...]
@ -51,11 +50,18 @@ class CachedRequestState:
# Used when both async_scheduling and spec_decode are enabled. # Used when both async_scheduling and spec_decode are enabled.
prev_num_draft_len: int = 0 prev_num_draft_len: int = 0
# for pooling models
pooling_params: PoolingParams | None = None
pooling_states: PoolingStates | None = None
def __post_init__(self): def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds self.prompt_token_ids, self.prompt_embeds
) )
if self.pooling_params is not None:
self.pooling_states = PoolingStates()
@property @property
def num_tokens(self) -> int: def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids) return self.num_prompt_tokens + len(self.output_token_ids)
@ -255,7 +261,9 @@ class InputBatch:
# This is updated each time the batch constituents change. # This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata() self.sampling_metadata = self._make_sampling_metadata()
# for pooling models
self.pooling_params: dict[str, PoolingParams] = {} self.pooling_params: dict[str, PoolingParams] = {}
self.pooling_states: dict[str, PoolingStates] = {}
# Cached reference to the GPU tensor of previously sampled tokens # Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: torch.Tensor | None = None self.prev_sampled_token_ids: torch.Tensor | None = None
@ -413,7 +421,11 @@ class InputBatch:
sampling_params.bad_words_token_ids sampling_params.bad_words_token_ids
) )
elif pooling_params := request.pooling_params: elif pooling_params := request.pooling_params:
pooling_states = request.pooling_states
assert pooling_states is not None
self.pooling_params[req_id] = pooling_params self.pooling_params[req_id] = pooling_params
self.pooling_states[req_id] = pooling_states
self.logits_processing_needs_token_ids[req_index] = ( self.logits_processing_needs_token_ids[req_index] = (
pooling_params.requires_token_ids pooling_params.requires_token_ids
) )
@ -469,6 +481,7 @@ class InputBatch:
if self.is_pooling_model: if self.is_pooling_model:
self.pooling_params.pop(req_id, None) self.pooling_params.pop(req_id, None)
self.pooling_states.pop(req_id, None)
return req_index return req_index
self.greedy_reqs.discard(req_id) self.greedy_reqs.discard(req_id)
@ -837,13 +850,19 @@ class InputBatch:
assert len(self.req_ids) == len(self.pooling_params) assert len(self.req_ids) == len(self.pooling_params)
return [self.pooling_params[req_id] for req_id in self.req_ids] return [self.pooling_params[req_id] for req_id in self.req_ids]
def get_pooling_states(self) -> list[PoolingStates]:
assert len(self.req_ids) == len(self.pooling_states)
return [self.pooling_states[req_id] for req_id in self.req_ids]
def get_pooling_metadata(self) -> PoolingMetadata: def get_pooling_metadata(self) -> PoolingMetadata:
pooling_params = self.get_pooling_params() pooling_params = self.get_pooling_params()
pooling_states = self.get_pooling_states()
return PoolingMetadata( return PoolingMetadata(
prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]), prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
prompt_token_ids=self.sampling_metadata.prompt_token_ids, prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params, pooling_params=pooling_params,
pooling_states=pooling_states,
) )
def _make_prompt_token_ids_tensor(self) -> torch.Tensor: def _make_prompt_token_ids_tensor(self) -> torch.Tensor:

View File

@ -131,7 +131,7 @@ from vllm.v1.outputs import (
SamplerOutput, SamplerOutput,
make_empty_encoder_model_runner_output, make_empty_encoder_model_runner_output,
) )
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.logits_processor.interface import LogitsProcessor from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
@ -2291,20 +2291,6 @@ class GPUModelRunner(
supported_tasks = list(model.pooler.get_supported_tasks()) supported_tasks = list(model.pooler.get_supported_tasks())
if self.scheduler_config.enable_chunked_prefill:
if "token_embed" in supported_tasks:
supported_tasks.remove("token_embed")
if "token_classify" in supported_tasks:
supported_tasks.remove("token_classify")
logger.debug_once(
"Chunked prefill is not supported with "
"token_embed and token_classify tasks "
"which using ALL pooling. "
"Please turn off chunked prefill by "
"`--no-enable-chunked-prefill` before using it."
)
if "score" in supported_tasks: if "score" in supported_tasks:
num_labels = getattr(self.model_config.hf_config, "num_labels", 0) num_labels = getattr(self.model_config.hf_config, "num_labels", 0)
if num_labels != 1: if num_labels != 1:
@ -2381,11 +2367,12 @@ class GPUModelRunner(
) )
hidden_states = hidden_states[:num_scheduled_tokens] hidden_states = hidden_states[:num_scheduled_tokens]
seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs]
pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata = self.input_batch.get_pooling_metadata()
pooling_metadata.build_pooling_cursor( pooling_metadata.build_pooling_cursor(
num_scheduled_tokens_np.tolist(), device=hidden_states.device num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device
) )
seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs]
model = cast(VllmModelForPooling, self.model) model = cast(VllmModelForPooling, self.model)
raw_pooler_output: PoolerOutput = model.pooler( raw_pooler_output: PoolerOutput = model.pooler(
@ -2393,7 +2380,7 @@ class GPUModelRunner(
pooling_metadata=pooling_metadata, pooling_metadata=pooling_metadata,
) )
raw_pooler_output = json_map_leaves( raw_pooler_output = json_map_leaves(
lambda x: x.to("cpu", non_blocking=True), lambda x: x.to("cpu", non_blocking=True) if x is not None else x,
raw_pooler_output, raw_pooler_output,
) )
self._sync_device() self._sync_device()
@ -4248,10 +4235,13 @@ class GPUModelRunner(
prompt_lens=dummy_prompt_lens, prompt_lens=dummy_prompt_lens,
prompt_token_ids=dummy_token_ids, prompt_token_ids=dummy_token_ids,
pooling_params=[dummy_pooling_params] * num_reqs, pooling_params=[dummy_pooling_params] * num_reqs,
pooling_states=[PoolingStates() for i in range(num_reqs)],
) )
dummy_metadata.build_pooling_cursor( dummy_metadata.build_pooling_cursor(
num_scheduled_tokens_list, device=hidden_states.device num_scheduled_tokens_list,
seq_lens_cpu=dummy_prompt_lens,
device=hidden_states.device,
) )
try: try:
@ -4278,22 +4268,12 @@ class GPUModelRunner(
supported_pooling_tasks = self.get_supported_pooling_tasks() supported_pooling_tasks = self.get_supported_pooling_tasks()
if not supported_pooling_tasks: if not supported_pooling_tasks:
if self.scheduler_config.enable_chunked_prefill: raise RuntimeError(
raise RuntimeError( f"Model {self.model_config.model} does not support "
f"Model {self.model_config.model} does not support " "any pooling tasks. See "
"any pooling tasks with chunked prefill enabled. " "https://docs.vllm.ai/en/latest/models/pooling_models.html "
"Please add --no-enable-chunked-prefill to your " "to learn more."
"config or CLI args. See " )
"https://docs.vllm.ai/en/latest/models/pooling_models.html "
"to learn more."
)
else:
raise RuntimeError(
f"Model {self.model_config.model} does not support "
"any pooling tasks. See "
"https://docs.vllm.ai/en/latest/models/pooling_models.html "
"to learn more."
)
output_size = dict[PoolingTask, float]() output_size = dict[PoolingTask, float]()
for task in supported_pooling_tasks: for task in supported_pooling_tasks: