mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 19:10:54 +08:00
[Model][6/N] Improve all pooling task | Support chunked prefill with ALL pooling (#27145)
Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
1b7c7f5159
commit
74c4d80c6c
@ -54,7 +54,7 @@ th:not(:first-child) {
|
||||
| beam-search | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | |
|
||||
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
|
||||
|
||||
\* Chunked prefill and prefix caching are only applicable to last-token pooling.
|
||||
\* Chunked prefill and prefix caching are only applicable to last-token or all pooling with causal attention.
|
||||
<sup>^</sup> LoRA is only applicable to the language backbone of multimodal models.
|
||||
|
||||
### Feature x Hardware
|
||||
|
||||
@ -61,11 +61,8 @@ def test_pooling_params(llm: LLM):
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_encode_api(llm: LLM):
|
||||
# chunked prefill does not support all pooling
|
||||
err_msg = "pooling_task must be one of.+"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
|
||||
def test_token_classify(llm: LLM):
|
||||
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
|
||||
|
||||
|
||||
def test_score_api(llm: LLM):
|
||||
|
||||
@ -255,21 +255,21 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
|
||||
# token_classify uses ALL pooling, which does not support chunked prefill.
|
||||
task = "token_classify"
|
||||
input_text = ["This product was excellent and exceeded my expectations"]
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": "test",
|
||||
"input": input_text,
|
||||
"encoding_format": "float",
|
||||
"task": task,
|
||||
},
|
||||
)
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
assert response.json()["error"]["message"].startswith(
|
||||
f"Task {task} is not supported"
|
||||
)
|
||||
poolings = PoolingResponse.model_validate(response.json())
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == 8
|
||||
assert len(poolings.data[0].data[0]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -42,7 +42,7 @@ def llm():
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_encode_api(llm: LLM):
|
||||
def test_token_embed(llm: LLM):
|
||||
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
|
||||
multi_vector = outputs[0].outputs.data
|
||||
assert multi_vector.shape == (11, 384)
|
||||
|
||||
@ -36,6 +36,13 @@ def llm():
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_config(llm: LLM):
|
||||
vllm_config = llm.llm_engine.vllm_config
|
||||
assert vllm_config.cache_config.enable_prefix_caching
|
||||
assert vllm_config.scheduler_config.enable_chunked_prefill
|
||||
|
||||
|
||||
def test_pooling_params(llm: LLM):
|
||||
def get_outputs(use_activation):
|
||||
outputs = llm.reward(
|
||||
|
||||
@ -0,0 +1,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModel
|
||||
|
||||
from tests.models.utils import check_embeddings_close
|
||||
from vllm import TokensPrompt
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["Qwen/Qwen3-Embedding-0.6B"],
|
||||
)
|
||||
@torch.inference_mode
|
||||
def test_embed_models(hf_runner, vllm_runner, model: str):
|
||||
chunk_size = 10
|
||||
n_prompt_tokens = [55, 56, 57]
|
||||
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
max_model_len=128,
|
||||
max_num_batched_tokens=chunk_size,
|
||||
enforce_eager=True,
|
||||
# `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner
|
||||
enable_chunked_prefill=True,
|
||||
enable_prefix_caching=True,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.token_embed(
|
||||
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
|
||||
)
|
||||
|
||||
with hf_runner(
|
||||
model,
|
||||
auto_cls=AutoModel,
|
||||
) as hf_model:
|
||||
hf_outputs = []
|
||||
for token_prompt in token_prompts:
|
||||
inputs = hf_model.wrap_device({"input_ids": torch.tensor([token_prompt])})
|
||||
input_ids = inputs["input_ids"]
|
||||
output = hf_model.model(input_ids)
|
||||
hf_outputs.append(output.last_hidden_state.cpu().float()[0])
|
||||
|
||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_output,
|
||||
embeddings_1_lst=vllm_output,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
@ -20,7 +20,6 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
|
||||
max_model_len=128,
|
||||
enforce_eager=True,
|
||||
runner="pooling",
|
||||
enable_chunked_prefill=False,
|
||||
enable_prefix_caching=True,
|
||||
) as vllm_model:
|
||||
pooling_outputs = vllm_model.llm.encode(
|
||||
|
||||
@ -629,8 +629,8 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
|
||||
(
|
||||
"internlm/internlm2-1_8b-reward",
|
||||
"decoder",
|
||||
False,
|
||||
"Pooling models with all pooling does not support chunked prefill.",
|
||||
True,
|
||||
"Pooling models with causal attn and all pooling support chunked prefill.",
|
||||
),
|
||||
(
|
||||
"BAAI/bge-base-en",
|
||||
@ -748,8 +748,8 @@ def test_is_chunked_prefill_supported(
|
||||
(
|
||||
"internlm/internlm2-1_8b-reward",
|
||||
"decoder",
|
||||
False,
|
||||
"Pooling models with all pooling does not support prefix caching.",
|
||||
True,
|
||||
"Pooling models with causal attn and all pooling support prefix caching.",
|
||||
),
|
||||
(
|
||||
"BAAI/bge-base-en",
|
||||
|
||||
@ -1780,20 +1780,22 @@ class ModelConfig:
|
||||
return False
|
||||
elif attn_type == "decoder":
|
||||
pooling_type = self.pooler_config.pooling_type.lower()
|
||||
if pooling_type in ["all", "mean", "step", "cls"]:
|
||||
if pooling_type in ["mean", "step", "cls"]:
|
||||
logger.debug(
|
||||
"Pooling models with %s pooling does not "
|
||||
"support chunked prefill.",
|
||||
pooling_type,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
# pooling_type == "last"
|
||||
elif pooling_type in ["all", "last"]:
|
||||
logger.debug(
|
||||
"Pooling models with causal attn and last pooling support "
|
||||
"chunked prefill."
|
||||
"Pooling models with causal attn and %s pooling support "
|
||||
"chunked prefill.",
|
||||
pooling_type,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"{pooling_type=} not supported.")
|
||||
# vllm currently does not have pooling models using hybrid,
|
||||
# attention_free or encoder_decoder attn types.
|
||||
return attn_type != "encoder_decoder"
|
||||
@ -1817,20 +1819,22 @@ class ModelConfig:
|
||||
return False
|
||||
elif attn_type == "decoder":
|
||||
pooling_type = self.pooler_config.pooling_type.lower()
|
||||
if pooling_type in ["all", "mean", "step", "cls"]:
|
||||
if pooling_type in ["mean", "step", "cls"]:
|
||||
logger.debug(
|
||||
"Pooling models with %s pooling does not "
|
||||
"support prefix caching.",
|
||||
pooling_type,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
# pooling_type == "last"
|
||||
elif pooling_type in ["all", "last"]:
|
||||
logger.debug(
|
||||
"Pooling models with causal attn and last pooling support "
|
||||
"prefix caching."
|
||||
"Pooling models with causal attn and %s pooling support "
|
||||
"prefix caching.",
|
||||
pooling_type,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"{pooling_type=} not supported.")
|
||||
# vllm currently does not have pooling models using hybrid,
|
||||
# attention_free or encoder_decoder attn types.
|
||||
return False
|
||||
|
||||
@ -127,14 +127,14 @@ class PoolingMethod(nn.Module, ABC):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
) -> PoolerOutput:
|
||||
pooling_cursor = pooling_metadata.pooling_cursor
|
||||
return self.forward_all(hidden_states, pooling_cursor)
|
||||
|
||||
@ -147,7 +147,7 @@ class CLSPool(PoolingMethod):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
) -> PoolerOutput:
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with CLS pooling"
|
||||
)
|
||||
@ -163,27 +163,65 @@ class LastPool(PoolingMethod):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
) -> PoolerOutput:
|
||||
return hidden_states[pooling_cursor.last_token_indices_gpu]
|
||||
|
||||
|
||||
class AllPool(PoolingMethod):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.enable_chunked_prefill = (
|
||||
vllm_config.scheduler_config.enable_chunked_prefill
|
||||
)
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify"}
|
||||
|
||||
def forward_all(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with ALL pooling"
|
||||
self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError(
|
||||
"forward_all is not implemented for AllPool. Use forward instead."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooling_cursor = pooling_metadata.pooling_cursor
|
||||
is_finished = pooling_cursor.is_finished()
|
||||
hidden_states_lst = list(
|
||||
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
|
||||
)
|
||||
return [hidden_states_lst[i] for i in pooling_cursor.index]
|
||||
hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index]
|
||||
|
||||
if not self.enable_chunked_prefill:
|
||||
return hidden_states_lst
|
||||
|
||||
pooling_states = pooling_metadata.pooling_states
|
||||
|
||||
# If chunked_prefill is enabled
|
||||
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
|
||||
for p, hs_chunk in zip(pooling_states, hidden_states_lst):
|
||||
p.hidden_states_cache.append(hs_chunk)
|
||||
|
||||
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
|
||||
output_list: PoolerOutput = []
|
||||
for p, finished in zip(pooling_states, is_finished):
|
||||
if finished:
|
||||
hidden_states_cache = p.hidden_states_cache
|
||||
if len(hidden_states_cache) == 1:
|
||||
output_list.append(hidden_states_cache[0])
|
||||
else:
|
||||
output_list.append(torch.concat(hidden_states_cache, dim=0))
|
||||
p.clean()
|
||||
else:
|
||||
output_list.append(None)
|
||||
|
||||
return output_list
|
||||
|
||||
|
||||
class MeanPool(PoolingMethod):
|
||||
@ -194,7 +232,7 @@ class MeanPool(PoolingMethod):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_cursor: PoolingCursor,
|
||||
) -> list[torch.Tensor] | torch.Tensor:
|
||||
) -> PoolerOutput:
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with MEAN pooling"
|
||||
)
|
||||
@ -399,7 +437,7 @@ class PoolerHead(nn.Module):
|
||||
self,
|
||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
):
|
||||
) -> PoolerOutput:
|
||||
return self.activation(pooled_data)
|
||||
|
||||
|
||||
@ -418,7 +456,7 @@ class EmbeddingPoolerHead(PoolerHead):
|
||||
self,
|
||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
):
|
||||
) -> PoolerOutput:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_dimension]
|
||||
@ -586,8 +624,12 @@ class ClassifierPooler(Pooler):
|
||||
|
||||
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
|
||||
def forward(
|
||||
self, pooled_data: torch.Tensor, pooling_param: PoolingParams
|
||||
) -> torch.Tensor:
|
||||
self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams
|
||||
) -> PoolerOutput:
|
||||
# for unfinished chunked prefill
|
||||
if pooled_data is None:
|
||||
return None
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
# pooled_data shape: [n_tokens, hidden_dimension]
|
||||
|
||||
@ -630,9 +672,13 @@ class TokenClassifierPoolerHead(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states: torch.Tensor | None,
|
||||
pooling_param: PoolingParams,
|
||||
) -> torch.Tensor:
|
||||
) -> PoolerOutput:
|
||||
# for unfinished chunked prefill
|
||||
if hidden_states is None:
|
||||
return None
|
||||
|
||||
hidden_states = hidden_states.to(self.head_dtype)
|
||||
# hidden_states shape: [n_token, hidden_size]
|
||||
|
||||
@ -686,17 +732,20 @@ class StepPooler(Pooler):
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
) -> PoolerOutput:
|
||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
|
||||
|
||||
pooled_data = list[torch.Tensor]()
|
||||
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
|
||||
pooled_data: PoolerOutput = []
|
||||
for data, token_id, pooling_param in zip(
|
||||
pooled_data_lst, prompt_token_ids, pooling_params
|
||||
):
|
||||
# for unfinished chunked prefill
|
||||
if data is None:
|
||||
pooled_data.append(data)
|
||||
continue
|
||||
|
||||
step_tag_id = pooling_param.step_tag_id
|
||||
returned_token_ids = pooling_param.returned_token_ids
|
||||
|
||||
|
||||
@ -64,7 +64,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .interfaces_base import attn_type
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -220,7 +220,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
|
||||
|
||||
@default_pooling_type("All")
|
||||
@attn_type("attention_free")
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
TerratorchMultiModalProcessor,
|
||||
info=TerratorchProcessingInfo,
|
||||
|
||||
@ -89,7 +89,7 @@ class LogprobsTensors(NamedTuple):
|
||||
|
||||
# [num_reqs, <dynamic>]
|
||||
# The shape of each element depends on the pooler used
|
||||
PoolerOutput = torch.Tensor | list[torch.Tensor]
|
||||
PoolerOutput = list[torch.Tensor | None] | torch.Tensor | None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -17,6 +17,7 @@ class PoolingCursor:
|
||||
first_token_indices_gpu: torch.Tensor
|
||||
last_token_indices_gpu: torch.Tensor
|
||||
prompt_lens_cpu: torch.Tensor
|
||||
seq_lens_cpu: torch.Tensor
|
||||
num_scheduled_tokens_cpu: torch.Tensor
|
||||
|
||||
def __getitem__(self, indices: slice):
|
||||
@ -25,12 +26,25 @@ class PoolingCursor:
|
||||
first_token_indices_gpu=self.first_token_indices_gpu[indices],
|
||||
last_token_indices_gpu=self.last_token_indices_gpu[indices],
|
||||
prompt_lens_cpu=self.prompt_lens_cpu[indices],
|
||||
seq_lens_cpu=self.seq_lens_cpu[indices],
|
||||
num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices],
|
||||
)
|
||||
|
||||
def is_partial_prefill(self):
|
||||
return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu)
|
||||
|
||||
def is_finished(self):
|
||||
return self.prompt_lens_cpu == self.seq_lens_cpu
|
||||
|
||||
|
||||
class PoolingStates:
|
||||
def __init__(self):
|
||||
# for chunked prefill with ALL pooling
|
||||
self.hidden_states_cache: list[torch.Tensor] = []
|
||||
|
||||
def clean(self):
|
||||
self.hidden_states_cache.clear()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolingMetadata:
|
||||
@ -39,6 +53,7 @@ class PoolingMetadata:
|
||||
prompt_lens: torch.Tensor # CPU Tensor
|
||||
prompt_token_ids: torch.Tensor | None
|
||||
pooling_params: list[PoolingParams]
|
||||
pooling_states: list[PoolingStates]
|
||||
pooling_cursor: PoolingCursor | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@ -60,6 +75,7 @@ class PoolingMetadata:
|
||||
if self.prompt_token_ids is None
|
||||
else self.prompt_token_ids[indices],
|
||||
pooling_params=self.pooling_params[indices],
|
||||
pooling_states=self.pooling_states[indices],
|
||||
pooling_cursor=None
|
||||
if self.pooling_cursor is None
|
||||
else self.pooling_cursor[indices],
|
||||
@ -74,15 +90,21 @@ class PoolingMetadata:
|
||||
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
|
||||
|
||||
def build_pooling_cursor(
|
||||
self, num_scheduled_tokens: list[int], device: torch.device
|
||||
self,
|
||||
num_scheduled_tokens: list[int],
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
device: torch.device,
|
||||
):
|
||||
self.pooling_cursor = build_pooling_cursor(
|
||||
num_scheduled_tokens, self.prompt_lens, device
|
||||
num_scheduled_tokens, seq_lens_cpu, self.prompt_lens, device
|
||||
)
|
||||
|
||||
|
||||
def build_pooling_cursor(
|
||||
num_scheduled_tokens: list[int], prompt_lens: torch.Tensor, device: torch.device
|
||||
num_scheduled_tokens: list[int],
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
prompt_lens: torch.Tensor,
|
||||
device: torch.device,
|
||||
):
|
||||
assert len(prompt_lens) == len(num_scheduled_tokens)
|
||||
|
||||
@ -99,5 +121,6 @@ def build_pooling_cursor(
|
||||
first_token_indices_gpu=cumsum[:n_seq],
|
||||
last_token_indices_gpu=cumsum[1:] - 1,
|
||||
prompt_lens_cpu=prompt_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
num_scheduled_tokens_cpu=num_scheduled_tokens_cpu,
|
||||
)
|
||||
|
||||
@ -15,7 +15,7 @@ from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.utils.collection_utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
BatchUpdateBuilder,
|
||||
LogitsProcessors,
|
||||
@ -33,7 +33,6 @@ class CachedRequestState:
|
||||
prompt_token_ids: list[int] | None
|
||||
mm_features: list[MultiModalFeatureSpec]
|
||||
sampling_params: SamplingParams | None
|
||||
pooling_params: PoolingParams | None
|
||||
generator: torch.Generator | None
|
||||
|
||||
block_ids: tuple[list[int], ...]
|
||||
@ -51,11 +50,18 @@ class CachedRequestState:
|
||||
# Used when both async_scheduling and spec_decode are enabled.
|
||||
prev_num_draft_len: int = 0
|
||||
|
||||
# for pooling models
|
||||
pooling_params: PoolingParams | None = None
|
||||
pooling_states: PoolingStates | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
self.prompt_token_ids, self.prompt_embeds
|
||||
)
|
||||
|
||||
if self.pooling_params is not None:
|
||||
self.pooling_states = PoolingStates()
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return self.num_prompt_tokens + len(self.output_token_ids)
|
||||
@ -255,7 +261,9 @@ class InputBatch:
|
||||
# This is updated each time the batch constituents change.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
|
||||
# for pooling models
|
||||
self.pooling_params: dict[str, PoolingParams] = {}
|
||||
self.pooling_states: dict[str, PoolingStates] = {}
|
||||
|
||||
# Cached reference to the GPU tensor of previously sampled tokens
|
||||
self.prev_sampled_token_ids: torch.Tensor | None = None
|
||||
@ -413,7 +421,11 @@ class InputBatch:
|
||||
sampling_params.bad_words_token_ids
|
||||
)
|
||||
elif pooling_params := request.pooling_params:
|
||||
pooling_states = request.pooling_states
|
||||
assert pooling_states is not None
|
||||
|
||||
self.pooling_params[req_id] = pooling_params
|
||||
self.pooling_states[req_id] = pooling_states
|
||||
self.logits_processing_needs_token_ids[req_index] = (
|
||||
pooling_params.requires_token_ids
|
||||
)
|
||||
@ -469,6 +481,7 @@ class InputBatch:
|
||||
|
||||
if self.is_pooling_model:
|
||||
self.pooling_params.pop(req_id, None)
|
||||
self.pooling_states.pop(req_id, None)
|
||||
return req_index
|
||||
|
||||
self.greedy_reqs.discard(req_id)
|
||||
@ -837,13 +850,19 @@ class InputBatch:
|
||||
assert len(self.req_ids) == len(self.pooling_params)
|
||||
return [self.pooling_params[req_id] for req_id in self.req_ids]
|
||||
|
||||
def get_pooling_states(self) -> list[PoolingStates]:
|
||||
assert len(self.req_ids) == len(self.pooling_states)
|
||||
return [self.pooling_states[req_id] for req_id in self.req_ids]
|
||||
|
||||
def get_pooling_metadata(self) -> PoolingMetadata:
|
||||
pooling_params = self.get_pooling_params()
|
||||
pooling_states = self.get_pooling_states()
|
||||
|
||||
return PoolingMetadata(
|
||||
prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
|
||||
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||
pooling_params=pooling_params,
|
||||
pooling_states=pooling_states,
|
||||
)
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
|
||||
@ -131,7 +131,7 @@ from vllm.v1.outputs import (
|
||||
SamplerOutput,
|
||||
make_empty_encoder_model_runner_output,
|
||||
)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
@ -2291,20 +2291,6 @@ class GPUModelRunner(
|
||||
|
||||
supported_tasks = list(model.pooler.get_supported_tasks())
|
||||
|
||||
if self.scheduler_config.enable_chunked_prefill:
|
||||
if "token_embed" in supported_tasks:
|
||||
supported_tasks.remove("token_embed")
|
||||
if "token_classify" in supported_tasks:
|
||||
supported_tasks.remove("token_classify")
|
||||
|
||||
logger.debug_once(
|
||||
"Chunked prefill is not supported with "
|
||||
"token_embed and token_classify tasks "
|
||||
"which using ALL pooling. "
|
||||
"Please turn off chunked prefill by "
|
||||
"`--no-enable-chunked-prefill` before using it."
|
||||
)
|
||||
|
||||
if "score" in supported_tasks:
|
||||
num_labels = getattr(self.model_config.hf_config, "num_labels", 0)
|
||||
if num_labels != 1:
|
||||
@ -2381,11 +2367,12 @@ class GPUModelRunner(
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs]
|
||||
|
||||
pooling_metadata = self.input_batch.get_pooling_metadata()
|
||||
pooling_metadata.build_pooling_cursor(
|
||||
num_scheduled_tokens_np.tolist(), device=hidden_states.device
|
||||
num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device
|
||||
)
|
||||
seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs]
|
||||
|
||||
model = cast(VllmModelForPooling, self.model)
|
||||
raw_pooler_output: PoolerOutput = model.pooler(
|
||||
@ -2393,7 +2380,7 @@ class GPUModelRunner(
|
||||
pooling_metadata=pooling_metadata,
|
||||
)
|
||||
raw_pooler_output = json_map_leaves(
|
||||
lambda x: x.to("cpu", non_blocking=True),
|
||||
lambda x: x.to("cpu", non_blocking=True) if x is not None else x,
|
||||
raw_pooler_output,
|
||||
)
|
||||
self._sync_device()
|
||||
@ -4248,10 +4235,13 @@ class GPUModelRunner(
|
||||
prompt_lens=dummy_prompt_lens,
|
||||
prompt_token_ids=dummy_token_ids,
|
||||
pooling_params=[dummy_pooling_params] * num_reqs,
|
||||
pooling_states=[PoolingStates() for i in range(num_reqs)],
|
||||
)
|
||||
|
||||
dummy_metadata.build_pooling_cursor(
|
||||
num_scheduled_tokens_list, device=hidden_states.device
|
||||
num_scheduled_tokens_list,
|
||||
seq_lens_cpu=dummy_prompt_lens,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -4278,22 +4268,12 @@ class GPUModelRunner(
|
||||
supported_pooling_tasks = self.get_supported_pooling_tasks()
|
||||
|
||||
if not supported_pooling_tasks:
|
||||
if self.scheduler_config.enable_chunked_prefill:
|
||||
raise RuntimeError(
|
||||
f"Model {self.model_config.model} does not support "
|
||||
"any pooling tasks with chunked prefill enabled. "
|
||||
"Please add --no-enable-chunked-prefill to your "
|
||||
"config or CLI args. See "
|
||||
"https://docs.vllm.ai/en/latest/models/pooling_models.html "
|
||||
"to learn more."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Model {self.model_config.model} does not support "
|
||||
"any pooling tasks. See "
|
||||
"https://docs.vllm.ai/en/latest/models/pooling_models.html "
|
||||
"to learn more."
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Model {self.model_config.model} does not support "
|
||||
"any pooling tasks. See "
|
||||
"https://docs.vllm.ai/en/latest/models/pooling_models.html "
|
||||
"to learn more."
|
||||
)
|
||||
|
||||
output_size = dict[PoolingTask, float]()
|
||||
for task in supported_pooling_tasks:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user