mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-17 22:42:18 +08:00
[Model][6/N] Improve all pooling task | Support chunked prefill with ALL pooling (#27145)
Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
1b7c7f5159
commit
74c4d80c6c
@ -54,7 +54,7 @@ th:not(:first-child) {
|
|||||||
| beam-search | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | |
|
| beam-search | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | |
|
||||||
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
|
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
|
||||||
|
|
||||||
\* Chunked prefill and prefix caching are only applicable to last-token pooling.
|
\* Chunked prefill and prefix caching are only applicable to last-token or all pooling with causal attention.
|
||||||
<sup>^</sup> LoRA is only applicable to the language backbone of multimodal models.
|
<sup>^</sup> LoRA is only applicable to the language backbone of multimodal models.
|
||||||
|
|
||||||
### Feature x Hardware
|
### Feature x Hardware
|
||||||
|
|||||||
@ -61,11 +61,8 @@ def test_pooling_params(llm: LLM):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
def test_encode_api(llm: LLM):
|
def test_token_classify(llm: LLM):
|
||||||
# chunked prefill does not support all pooling
|
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
|
||||||
err_msg = "pooling_task must be one of.+"
|
|
||||||
with pytest.raises(ValueError, match=err_msg):
|
|
||||||
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
|
|
||||||
|
|
||||||
|
|
||||||
def test_score_api(llm: LLM):
|
def test_score_api(llm: LLM):
|
||||||
|
|||||||
@ -255,21 +255,21 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
|
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
|
||||||
# token_classify uses ALL pooling, which does not support chunked prefill.
|
|
||||||
task = "token_classify"
|
task = "token_classify"
|
||||||
|
input_text = ["This product was excellent and exceeded my expectations"]
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
server.url_for("pooling"),
|
server.url_for("pooling"),
|
||||||
json={
|
json={
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"input": "test",
|
"input": input_text,
|
||||||
"encoding_format": "float",
|
"encoding_format": "float",
|
||||||
"task": task,
|
"task": task,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.json()["error"]["type"] == "BadRequestError"
|
poolings = PoolingResponse.model_validate(response.json())
|
||||||
assert response.json()["error"]["message"].startswith(
|
assert len(poolings.data) == 1
|
||||||
f"Task {task} is not supported"
|
assert len(poolings.data[0].data) == 8
|
||||||
)
|
assert len(poolings.data[0].data[0]) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -42,7 +42,7 @@ def llm():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
def test_encode_api(llm: LLM):
|
def test_token_embed(llm: LLM):
|
||||||
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
|
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
|
||||||
multi_vector = outputs[0].outputs.data
|
multi_vector = outputs[0].outputs.data
|
||||||
assert multi_vector.shape == (11, 384)
|
assert multi_vector.shape == (11, 384)
|
||||||
|
|||||||
@ -36,6 +36,13 @@ def llm():
|
|||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
def test_config(llm: LLM):
|
||||||
|
vllm_config = llm.llm_engine.vllm_config
|
||||||
|
assert vllm_config.cache_config.enable_prefix_caching
|
||||||
|
assert vllm_config.scheduler_config.enable_chunked_prefill
|
||||||
|
|
||||||
|
|
||||||
def test_pooling_params(llm: LLM):
|
def test_pooling_params(llm: LLM):
|
||||||
def get_outputs(use_activation):
|
def get_outputs(use_activation):
|
||||||
outputs = llm.reward(
|
outputs = llm.reward(
|
||||||
|
|||||||
@ -0,0 +1,53 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModel
|
||||||
|
|
||||||
|
from tests.models.utils import check_embeddings_close
|
||||||
|
from vllm import TokensPrompt
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
["Qwen/Qwen3-Embedding-0.6B"],
|
||||||
|
)
|
||||||
|
@torch.inference_mode
|
||||||
|
def test_embed_models(hf_runner, vllm_runner, model: str):
|
||||||
|
chunk_size = 10
|
||||||
|
n_prompt_tokens = [55, 56, 57]
|
||||||
|
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
runner="pooling",
|
||||||
|
max_model_len=128,
|
||||||
|
max_num_batched_tokens=chunk_size,
|
||||||
|
enforce_eager=True,
|
||||||
|
# `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.token_embed(
|
||||||
|
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
|
||||||
|
)
|
||||||
|
|
||||||
|
with hf_runner(
|
||||||
|
model,
|
||||||
|
auto_cls=AutoModel,
|
||||||
|
) as hf_model:
|
||||||
|
hf_outputs = []
|
||||||
|
for token_prompt in token_prompts:
|
||||||
|
inputs = hf_model.wrap_device({"input_ids": torch.tensor([token_prompt])})
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
output = hf_model.model(input_ids)
|
||||||
|
hf_outputs.append(output.last_hidden_state.cpu().float()[0])
|
||||||
|
|
||||||
|
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||||
|
check_embeddings_close(
|
||||||
|
embeddings_0_lst=hf_output,
|
||||||
|
embeddings_1_lst=vllm_output,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
tol=1e-2,
|
||||||
|
)
|
||||||
@ -20,7 +20,6 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
|
|||||||
max_model_len=128,
|
max_model_len=128,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
runner="pooling",
|
runner="pooling",
|
||||||
enable_chunked_prefill=False,
|
|
||||||
enable_prefix_caching=True,
|
enable_prefix_caching=True,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
pooling_outputs = vllm_model.llm.encode(
|
pooling_outputs = vllm_model.llm.encode(
|
||||||
|
|||||||
@ -629,8 +629,8 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
|
|||||||
(
|
(
|
||||||
"internlm/internlm2-1_8b-reward",
|
"internlm/internlm2-1_8b-reward",
|
||||||
"decoder",
|
"decoder",
|
||||||
False,
|
True,
|
||||||
"Pooling models with all pooling does not support chunked prefill.",
|
"Pooling models with causal attn and all pooling support chunked prefill.",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"BAAI/bge-base-en",
|
"BAAI/bge-base-en",
|
||||||
@ -748,8 +748,8 @@ def test_is_chunked_prefill_supported(
|
|||||||
(
|
(
|
||||||
"internlm/internlm2-1_8b-reward",
|
"internlm/internlm2-1_8b-reward",
|
||||||
"decoder",
|
"decoder",
|
||||||
False,
|
True,
|
||||||
"Pooling models with all pooling does not support prefix caching.",
|
"Pooling models with causal attn and all pooling support prefix caching.",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"BAAI/bge-base-en",
|
"BAAI/bge-base-en",
|
||||||
|
|||||||
@ -1780,20 +1780,22 @@ class ModelConfig:
|
|||||||
return False
|
return False
|
||||||
elif attn_type == "decoder":
|
elif attn_type == "decoder":
|
||||||
pooling_type = self.pooler_config.pooling_type.lower()
|
pooling_type = self.pooler_config.pooling_type.lower()
|
||||||
if pooling_type in ["all", "mean", "step", "cls"]:
|
if pooling_type in ["mean", "step", "cls"]:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Pooling models with %s pooling does not "
|
"Pooling models with %s pooling does not "
|
||||||
"support chunked prefill.",
|
"support chunked prefill.",
|
||||||
pooling_type,
|
pooling_type,
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
else:
|
elif pooling_type in ["all", "last"]:
|
||||||
# pooling_type == "last"
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Pooling models with causal attn and last pooling support "
|
"Pooling models with causal attn and %s pooling support "
|
||||||
"chunked prefill."
|
"chunked prefill.",
|
||||||
|
pooling_type,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{pooling_type=} not supported.")
|
||||||
# vllm currently does not have pooling models using hybrid,
|
# vllm currently does not have pooling models using hybrid,
|
||||||
# attention_free or encoder_decoder attn types.
|
# attention_free or encoder_decoder attn types.
|
||||||
return attn_type != "encoder_decoder"
|
return attn_type != "encoder_decoder"
|
||||||
@ -1817,20 +1819,22 @@ class ModelConfig:
|
|||||||
return False
|
return False
|
||||||
elif attn_type == "decoder":
|
elif attn_type == "decoder":
|
||||||
pooling_type = self.pooler_config.pooling_type.lower()
|
pooling_type = self.pooler_config.pooling_type.lower()
|
||||||
if pooling_type in ["all", "mean", "step", "cls"]:
|
if pooling_type in ["mean", "step", "cls"]:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Pooling models with %s pooling does not "
|
"Pooling models with %s pooling does not "
|
||||||
"support prefix caching.",
|
"support prefix caching.",
|
||||||
pooling_type,
|
pooling_type,
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
else:
|
elif pooling_type in ["all", "last"]:
|
||||||
# pooling_type == "last"
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Pooling models with causal attn and last pooling support "
|
"Pooling models with causal attn and %s pooling support "
|
||||||
"prefix caching."
|
"prefix caching.",
|
||||||
|
pooling_type,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{pooling_type=} not supported.")
|
||||||
# vllm currently does not have pooling models using hybrid,
|
# vllm currently does not have pooling models using hybrid,
|
||||||
# attention_free or encoder_decoder attn types.
|
# attention_free or encoder_decoder attn types.
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -127,14 +127,14 @@ class PoolingMethod(nn.Module, ABC):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
pooling_cursor: PoolingCursor,
|
pooling_cursor: PoolingCursor,
|
||||||
) -> list[torch.Tensor] | torch.Tensor:
|
) -> PoolerOutput:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> list[torch.Tensor] | torch.Tensor:
|
) -> PoolerOutput:
|
||||||
pooling_cursor = pooling_metadata.pooling_cursor
|
pooling_cursor = pooling_metadata.pooling_cursor
|
||||||
return self.forward_all(hidden_states, pooling_cursor)
|
return self.forward_all(hidden_states, pooling_cursor)
|
||||||
|
|
||||||
@ -147,7 +147,7 @@ class CLSPool(PoolingMethod):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
pooling_cursor: PoolingCursor,
|
pooling_cursor: PoolingCursor,
|
||||||
) -> list[torch.Tensor] | torch.Tensor:
|
) -> PoolerOutput:
|
||||||
assert not pooling_cursor.is_partial_prefill(), (
|
assert not pooling_cursor.is_partial_prefill(), (
|
||||||
"partial prefill not supported with CLS pooling"
|
"partial prefill not supported with CLS pooling"
|
||||||
)
|
)
|
||||||
@ -163,27 +163,65 @@ class LastPool(PoolingMethod):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
pooling_cursor: PoolingCursor,
|
pooling_cursor: PoolingCursor,
|
||||||
) -> list[torch.Tensor] | torch.Tensor:
|
) -> PoolerOutput:
|
||||||
return hidden_states[pooling_cursor.last_token_indices_gpu]
|
return hidden_states[pooling_cursor.last_token_indices_gpu]
|
||||||
|
|
||||||
|
|
||||||
class AllPool(PoolingMethod):
|
class AllPool(PoolingMethod):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
self.enable_chunked_prefill = (
|
||||||
|
vllm_config.scheduler_config.enable_chunked_prefill
|
||||||
|
)
|
||||||
|
|
||||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
return {"token_embed", "token_classify"}
|
return {"token_embed", "token_classify"}
|
||||||
|
|
||||||
def forward_all(
|
def forward_all(
|
||||||
self,
|
self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor
|
||||||
hidden_states: torch.Tensor,
|
) -> PoolerOutput:
|
||||||
pooling_cursor: PoolingCursor,
|
raise NotImplementedError(
|
||||||
) -> list[torch.Tensor] | torch.Tensor:
|
"forward_all is not implemented for AllPool. Use forward instead."
|
||||||
assert not pooling_cursor.is_partial_prefill(), (
|
|
||||||
"partial prefill not supported with ALL pooling"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> PoolerOutput:
|
||||||
|
pooling_cursor = pooling_metadata.pooling_cursor
|
||||||
|
is_finished = pooling_cursor.is_finished()
|
||||||
hidden_states_lst = list(
|
hidden_states_lst = list(
|
||||||
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
|
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
|
||||||
)
|
)
|
||||||
return [hidden_states_lst[i] for i in pooling_cursor.index]
|
hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index]
|
||||||
|
|
||||||
|
if not self.enable_chunked_prefill:
|
||||||
|
return hidden_states_lst
|
||||||
|
|
||||||
|
pooling_states = pooling_metadata.pooling_states
|
||||||
|
|
||||||
|
# If chunked_prefill is enabled
|
||||||
|
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
|
||||||
|
for p, hs_chunk in zip(pooling_states, hidden_states_lst):
|
||||||
|
p.hidden_states_cache.append(hs_chunk)
|
||||||
|
|
||||||
|
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
|
||||||
|
output_list: PoolerOutput = []
|
||||||
|
for p, finished in zip(pooling_states, is_finished):
|
||||||
|
if finished:
|
||||||
|
hidden_states_cache = p.hidden_states_cache
|
||||||
|
if len(hidden_states_cache) == 1:
|
||||||
|
output_list.append(hidden_states_cache[0])
|
||||||
|
else:
|
||||||
|
output_list.append(torch.concat(hidden_states_cache, dim=0))
|
||||||
|
p.clean()
|
||||||
|
else:
|
||||||
|
output_list.append(None)
|
||||||
|
|
||||||
|
return output_list
|
||||||
|
|
||||||
|
|
||||||
class MeanPool(PoolingMethod):
|
class MeanPool(PoolingMethod):
|
||||||
@ -194,7 +232,7 @@ class MeanPool(PoolingMethod):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
pooling_cursor: PoolingCursor,
|
pooling_cursor: PoolingCursor,
|
||||||
) -> list[torch.Tensor] | torch.Tensor:
|
) -> PoolerOutput:
|
||||||
assert not pooling_cursor.is_partial_prefill(), (
|
assert not pooling_cursor.is_partial_prefill(), (
|
||||||
"partial prefill not supported with MEAN pooling"
|
"partial prefill not supported with MEAN pooling"
|
||||||
)
|
)
|
||||||
@ -399,7 +437,7 @@ class PoolerHead(nn.Module):
|
|||||||
self,
|
self,
|
||||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
pooled_data: list[torch.Tensor] | torch.Tensor,
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
):
|
) -> PoolerOutput:
|
||||||
return self.activation(pooled_data)
|
return self.activation(pooled_data)
|
||||||
|
|
||||||
|
|
||||||
@ -418,7 +456,7 @@ class EmbeddingPoolerHead(PoolerHead):
|
|||||||
self,
|
self,
|
||||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
pooled_data: list[torch.Tensor] | torch.Tensor,
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
):
|
) -> PoolerOutput:
|
||||||
if isinstance(pooled_data, list):
|
if isinstance(pooled_data, list):
|
||||||
pooled_data = torch.stack(pooled_data)
|
pooled_data = torch.stack(pooled_data)
|
||||||
# pooled_data shape: [batchsize, hidden_dimension]
|
# pooled_data shape: [batchsize, hidden_dimension]
|
||||||
@ -586,8 +624,12 @@ class ClassifierPooler(Pooler):
|
|||||||
|
|
||||||
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
|
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
|
||||||
def forward(
|
def forward(
|
||||||
self, pooled_data: torch.Tensor, pooling_param: PoolingParams
|
self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams
|
||||||
) -> torch.Tensor:
|
) -> PoolerOutput:
|
||||||
|
# for unfinished chunked prefill
|
||||||
|
if pooled_data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
pooled_data = pooled_data.to(self.head_dtype)
|
pooled_data = pooled_data.to(self.head_dtype)
|
||||||
# pooled_data shape: [n_tokens, hidden_dimension]
|
# pooled_data shape: [n_tokens, hidden_dimension]
|
||||||
|
|
||||||
@ -630,9 +672,13 @@ class TokenClassifierPoolerHead(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor | None,
|
||||||
pooling_param: PoolingParams,
|
pooling_param: PoolingParams,
|
||||||
) -> torch.Tensor:
|
) -> PoolerOutput:
|
||||||
|
# for unfinished chunked prefill
|
||||||
|
if hidden_states is None:
|
||||||
|
return None
|
||||||
|
|
||||||
hidden_states = hidden_states.to(self.head_dtype)
|
hidden_states = hidden_states.to(self.head_dtype)
|
||||||
# hidden_states shape: [n_token, hidden_size]
|
# hidden_states shape: [n_token, hidden_size]
|
||||||
|
|
||||||
@ -686,17 +732,20 @@ class StepPooler(Pooler):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
) -> torch.Tensor | list[torch.Tensor]:
|
) -> PoolerOutput:
|
||||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||||
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
|
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
|
||||||
|
|
||||||
pooled_data = list[torch.Tensor]()
|
|
||||||
|
|
||||||
pooling_params = pooling_metadata.pooling_params
|
pooling_params = pooling_metadata.pooling_params
|
||||||
|
|
||||||
|
pooled_data: PoolerOutput = []
|
||||||
for data, token_id, pooling_param in zip(
|
for data, token_id, pooling_param in zip(
|
||||||
pooled_data_lst, prompt_token_ids, pooling_params
|
pooled_data_lst, prompt_token_ids, pooling_params
|
||||||
):
|
):
|
||||||
|
# for unfinished chunked prefill
|
||||||
|
if data is None:
|
||||||
|
pooled_data.append(data)
|
||||||
|
continue
|
||||||
|
|
||||||
step_tag_id = pooling_param.step_tag_id
|
step_tag_id = pooling_param.step_tag_id
|
||||||
returned_token_ids = pooling_param.returned_token_ids
|
returned_token_ids = pooling_param.returned_token_ids
|
||||||
|
|
||||||
|
|||||||
@ -64,7 +64,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
|
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
|
||||||
from .interfaces_base import default_pooling_type
|
from .interfaces_base import attn_type
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -220,7 +220,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@default_pooling_type("All")
|
@attn_type("attention_free")
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
TerratorchMultiModalProcessor,
|
TerratorchMultiModalProcessor,
|
||||||
info=TerratorchProcessingInfo,
|
info=TerratorchProcessingInfo,
|
||||||
|
|||||||
@ -89,7 +89,7 @@ class LogprobsTensors(NamedTuple):
|
|||||||
|
|
||||||
# [num_reqs, <dynamic>]
|
# [num_reqs, <dynamic>]
|
||||||
# The shape of each element depends on the pooler used
|
# The shape of each element depends on the pooler used
|
||||||
PoolerOutput = torch.Tensor | list[torch.Tensor]
|
PoolerOutput = list[torch.Tensor | None] | torch.Tensor | None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -17,6 +17,7 @@ class PoolingCursor:
|
|||||||
first_token_indices_gpu: torch.Tensor
|
first_token_indices_gpu: torch.Tensor
|
||||||
last_token_indices_gpu: torch.Tensor
|
last_token_indices_gpu: torch.Tensor
|
||||||
prompt_lens_cpu: torch.Tensor
|
prompt_lens_cpu: torch.Tensor
|
||||||
|
seq_lens_cpu: torch.Tensor
|
||||||
num_scheduled_tokens_cpu: torch.Tensor
|
num_scheduled_tokens_cpu: torch.Tensor
|
||||||
|
|
||||||
def __getitem__(self, indices: slice):
|
def __getitem__(self, indices: slice):
|
||||||
@ -25,12 +26,25 @@ class PoolingCursor:
|
|||||||
first_token_indices_gpu=self.first_token_indices_gpu[indices],
|
first_token_indices_gpu=self.first_token_indices_gpu[indices],
|
||||||
last_token_indices_gpu=self.last_token_indices_gpu[indices],
|
last_token_indices_gpu=self.last_token_indices_gpu[indices],
|
||||||
prompt_lens_cpu=self.prompt_lens_cpu[indices],
|
prompt_lens_cpu=self.prompt_lens_cpu[indices],
|
||||||
|
seq_lens_cpu=self.seq_lens_cpu[indices],
|
||||||
num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices],
|
num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices],
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_partial_prefill(self):
|
def is_partial_prefill(self):
|
||||||
return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu)
|
return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu)
|
||||||
|
|
||||||
|
def is_finished(self):
|
||||||
|
return self.prompt_lens_cpu == self.seq_lens_cpu
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingStates:
|
||||||
|
def __init__(self):
|
||||||
|
# for chunked prefill with ALL pooling
|
||||||
|
self.hidden_states_cache: list[torch.Tensor] = []
|
||||||
|
|
||||||
|
def clean(self):
|
||||||
|
self.hidden_states_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PoolingMetadata:
|
class PoolingMetadata:
|
||||||
@ -39,6 +53,7 @@ class PoolingMetadata:
|
|||||||
prompt_lens: torch.Tensor # CPU Tensor
|
prompt_lens: torch.Tensor # CPU Tensor
|
||||||
prompt_token_ids: torch.Tensor | None
|
prompt_token_ids: torch.Tensor | None
|
||||||
pooling_params: list[PoolingParams]
|
pooling_params: list[PoolingParams]
|
||||||
|
pooling_states: list[PoolingStates]
|
||||||
pooling_cursor: PoolingCursor | None = None
|
pooling_cursor: PoolingCursor | None = None
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
@ -60,6 +75,7 @@ class PoolingMetadata:
|
|||||||
if self.prompt_token_ids is None
|
if self.prompt_token_ids is None
|
||||||
else self.prompt_token_ids[indices],
|
else self.prompt_token_ids[indices],
|
||||||
pooling_params=self.pooling_params[indices],
|
pooling_params=self.pooling_params[indices],
|
||||||
|
pooling_states=self.pooling_states[indices],
|
||||||
pooling_cursor=None
|
pooling_cursor=None
|
||||||
if self.pooling_cursor is None
|
if self.pooling_cursor is None
|
||||||
else self.pooling_cursor[indices],
|
else self.pooling_cursor[indices],
|
||||||
@ -74,15 +90,21 @@ class PoolingMetadata:
|
|||||||
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
|
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
|
||||||
|
|
||||||
def build_pooling_cursor(
|
def build_pooling_cursor(
|
||||||
self, num_scheduled_tokens: list[int], device: torch.device
|
self,
|
||||||
|
num_scheduled_tokens: list[int],
|
||||||
|
seq_lens_cpu: torch.Tensor,
|
||||||
|
device: torch.device,
|
||||||
):
|
):
|
||||||
self.pooling_cursor = build_pooling_cursor(
|
self.pooling_cursor = build_pooling_cursor(
|
||||||
num_scheduled_tokens, self.prompt_lens, device
|
num_scheduled_tokens, seq_lens_cpu, self.prompt_lens, device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_pooling_cursor(
|
def build_pooling_cursor(
|
||||||
num_scheduled_tokens: list[int], prompt_lens: torch.Tensor, device: torch.device
|
num_scheduled_tokens: list[int],
|
||||||
|
seq_lens_cpu: torch.Tensor,
|
||||||
|
prompt_lens: torch.Tensor,
|
||||||
|
device: torch.device,
|
||||||
):
|
):
|
||||||
assert len(prompt_lens) == len(num_scheduled_tokens)
|
assert len(prompt_lens) == len(num_scheduled_tokens)
|
||||||
|
|
||||||
@ -99,5 +121,6 @@ def build_pooling_cursor(
|
|||||||
first_token_indices_gpu=cumsum[:n_seq],
|
first_token_indices_gpu=cumsum[:n_seq],
|
||||||
last_token_indices_gpu=cumsum[1:] - 1,
|
last_token_indices_gpu=cumsum[1:] - 1,
|
||||||
prompt_lens_cpu=prompt_lens,
|
prompt_lens_cpu=prompt_lens,
|
||||||
|
seq_lens_cpu=seq_lens_cpu,
|
||||||
num_scheduled_tokens_cpu=num_scheduled_tokens_cpu,
|
num_scheduled_tokens_cpu=num_scheduled_tokens_cpu,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from vllm.sampling_params import SamplingParams, SamplingType
|
|||||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||||
from vllm.utils.collection_utils import swap_dict_values
|
from vllm.utils.collection_utils import swap_dict_values
|
||||||
from vllm.v1.outputs import LogprobsTensors
|
from vllm.v1.outputs import LogprobsTensors
|
||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
|
||||||
from vllm.v1.sample.logits_processor import (
|
from vllm.v1.sample.logits_processor import (
|
||||||
BatchUpdateBuilder,
|
BatchUpdateBuilder,
|
||||||
LogitsProcessors,
|
LogitsProcessors,
|
||||||
@ -33,7 +33,6 @@ class CachedRequestState:
|
|||||||
prompt_token_ids: list[int] | None
|
prompt_token_ids: list[int] | None
|
||||||
mm_features: list[MultiModalFeatureSpec]
|
mm_features: list[MultiModalFeatureSpec]
|
||||||
sampling_params: SamplingParams | None
|
sampling_params: SamplingParams | None
|
||||||
pooling_params: PoolingParams | None
|
|
||||||
generator: torch.Generator | None
|
generator: torch.Generator | None
|
||||||
|
|
||||||
block_ids: tuple[list[int], ...]
|
block_ids: tuple[list[int], ...]
|
||||||
@ -51,11 +50,18 @@ class CachedRequestState:
|
|||||||
# Used when both async_scheduling and spec_decode are enabled.
|
# Used when both async_scheduling and spec_decode are enabled.
|
||||||
prev_num_draft_len: int = 0
|
prev_num_draft_len: int = 0
|
||||||
|
|
||||||
|
# for pooling models
|
||||||
|
pooling_params: PoolingParams | None = None
|
||||||
|
pooling_states: PoolingStates | None = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||||
self.prompt_token_ids, self.prompt_embeds
|
self.prompt_token_ids, self.prompt_embeds
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.pooling_params is not None:
|
||||||
|
self.pooling_states = PoolingStates()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_tokens(self) -> int:
|
def num_tokens(self) -> int:
|
||||||
return self.num_prompt_tokens + len(self.output_token_ids)
|
return self.num_prompt_tokens + len(self.output_token_ids)
|
||||||
@ -255,7 +261,9 @@ class InputBatch:
|
|||||||
# This is updated each time the batch constituents change.
|
# This is updated each time the batch constituents change.
|
||||||
self.sampling_metadata = self._make_sampling_metadata()
|
self.sampling_metadata = self._make_sampling_metadata()
|
||||||
|
|
||||||
|
# for pooling models
|
||||||
self.pooling_params: dict[str, PoolingParams] = {}
|
self.pooling_params: dict[str, PoolingParams] = {}
|
||||||
|
self.pooling_states: dict[str, PoolingStates] = {}
|
||||||
|
|
||||||
# Cached reference to the GPU tensor of previously sampled tokens
|
# Cached reference to the GPU tensor of previously sampled tokens
|
||||||
self.prev_sampled_token_ids: torch.Tensor | None = None
|
self.prev_sampled_token_ids: torch.Tensor | None = None
|
||||||
@ -413,7 +421,11 @@ class InputBatch:
|
|||||||
sampling_params.bad_words_token_ids
|
sampling_params.bad_words_token_ids
|
||||||
)
|
)
|
||||||
elif pooling_params := request.pooling_params:
|
elif pooling_params := request.pooling_params:
|
||||||
|
pooling_states = request.pooling_states
|
||||||
|
assert pooling_states is not None
|
||||||
|
|
||||||
self.pooling_params[req_id] = pooling_params
|
self.pooling_params[req_id] = pooling_params
|
||||||
|
self.pooling_states[req_id] = pooling_states
|
||||||
self.logits_processing_needs_token_ids[req_index] = (
|
self.logits_processing_needs_token_ids[req_index] = (
|
||||||
pooling_params.requires_token_ids
|
pooling_params.requires_token_ids
|
||||||
)
|
)
|
||||||
@ -469,6 +481,7 @@ class InputBatch:
|
|||||||
|
|
||||||
if self.is_pooling_model:
|
if self.is_pooling_model:
|
||||||
self.pooling_params.pop(req_id, None)
|
self.pooling_params.pop(req_id, None)
|
||||||
|
self.pooling_states.pop(req_id, None)
|
||||||
return req_index
|
return req_index
|
||||||
|
|
||||||
self.greedy_reqs.discard(req_id)
|
self.greedy_reqs.discard(req_id)
|
||||||
@ -837,13 +850,19 @@ class InputBatch:
|
|||||||
assert len(self.req_ids) == len(self.pooling_params)
|
assert len(self.req_ids) == len(self.pooling_params)
|
||||||
return [self.pooling_params[req_id] for req_id in self.req_ids]
|
return [self.pooling_params[req_id] for req_id in self.req_ids]
|
||||||
|
|
||||||
|
def get_pooling_states(self) -> list[PoolingStates]:
|
||||||
|
assert len(self.req_ids) == len(self.pooling_states)
|
||||||
|
return [self.pooling_states[req_id] for req_id in self.req_ids]
|
||||||
|
|
||||||
def get_pooling_metadata(self) -> PoolingMetadata:
|
def get_pooling_metadata(self) -> PoolingMetadata:
|
||||||
pooling_params = self.get_pooling_params()
|
pooling_params = self.get_pooling_params()
|
||||||
|
pooling_states = self.get_pooling_states()
|
||||||
|
|
||||||
return PoolingMetadata(
|
return PoolingMetadata(
|
||||||
prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
|
prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]),
|
||||||
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||||
pooling_params=pooling_params,
|
pooling_params=pooling_params,
|
||||||
|
pooling_states=pooling_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||||
|
|||||||
@ -131,7 +131,7 @@ from vllm.v1.outputs import (
|
|||||||
SamplerOutput,
|
SamplerOutput,
|
||||||
make_empty_encoder_model_runner_output,
|
make_empty_encoder_model_runner_output,
|
||||||
)
|
)
|
||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
|
||||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||||
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
|
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
@ -2291,20 +2291,6 @@ class GPUModelRunner(
|
|||||||
|
|
||||||
supported_tasks = list(model.pooler.get_supported_tasks())
|
supported_tasks = list(model.pooler.get_supported_tasks())
|
||||||
|
|
||||||
if self.scheduler_config.enable_chunked_prefill:
|
|
||||||
if "token_embed" in supported_tasks:
|
|
||||||
supported_tasks.remove("token_embed")
|
|
||||||
if "token_classify" in supported_tasks:
|
|
||||||
supported_tasks.remove("token_classify")
|
|
||||||
|
|
||||||
logger.debug_once(
|
|
||||||
"Chunked prefill is not supported with "
|
|
||||||
"token_embed and token_classify tasks "
|
|
||||||
"which using ALL pooling. "
|
|
||||||
"Please turn off chunked prefill by "
|
|
||||||
"`--no-enable-chunked-prefill` before using it."
|
|
||||||
)
|
|
||||||
|
|
||||||
if "score" in supported_tasks:
|
if "score" in supported_tasks:
|
||||||
num_labels = getattr(self.model_config.hf_config, "num_labels", 0)
|
num_labels = getattr(self.model_config.hf_config, "num_labels", 0)
|
||||||
if num_labels != 1:
|
if num_labels != 1:
|
||||||
@ -2381,11 +2367,12 @@ class GPUModelRunner(
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||||
|
seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs]
|
||||||
|
|
||||||
pooling_metadata = self.input_batch.get_pooling_metadata()
|
pooling_metadata = self.input_batch.get_pooling_metadata()
|
||||||
pooling_metadata.build_pooling_cursor(
|
pooling_metadata.build_pooling_cursor(
|
||||||
num_scheduled_tokens_np.tolist(), device=hidden_states.device
|
num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device
|
||||||
)
|
)
|
||||||
seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs]
|
|
||||||
|
|
||||||
model = cast(VllmModelForPooling, self.model)
|
model = cast(VllmModelForPooling, self.model)
|
||||||
raw_pooler_output: PoolerOutput = model.pooler(
|
raw_pooler_output: PoolerOutput = model.pooler(
|
||||||
@ -2393,7 +2380,7 @@ class GPUModelRunner(
|
|||||||
pooling_metadata=pooling_metadata,
|
pooling_metadata=pooling_metadata,
|
||||||
)
|
)
|
||||||
raw_pooler_output = json_map_leaves(
|
raw_pooler_output = json_map_leaves(
|
||||||
lambda x: x.to("cpu", non_blocking=True),
|
lambda x: x.to("cpu", non_blocking=True) if x is not None else x,
|
||||||
raw_pooler_output,
|
raw_pooler_output,
|
||||||
)
|
)
|
||||||
self._sync_device()
|
self._sync_device()
|
||||||
@ -4248,10 +4235,13 @@ class GPUModelRunner(
|
|||||||
prompt_lens=dummy_prompt_lens,
|
prompt_lens=dummy_prompt_lens,
|
||||||
prompt_token_ids=dummy_token_ids,
|
prompt_token_ids=dummy_token_ids,
|
||||||
pooling_params=[dummy_pooling_params] * num_reqs,
|
pooling_params=[dummy_pooling_params] * num_reqs,
|
||||||
|
pooling_states=[PoolingStates() for i in range(num_reqs)],
|
||||||
)
|
)
|
||||||
|
|
||||||
dummy_metadata.build_pooling_cursor(
|
dummy_metadata.build_pooling_cursor(
|
||||||
num_scheduled_tokens_list, device=hidden_states.device
|
num_scheduled_tokens_list,
|
||||||
|
seq_lens_cpu=dummy_prompt_lens,
|
||||||
|
device=hidden_states.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -4278,22 +4268,12 @@ class GPUModelRunner(
|
|||||||
supported_pooling_tasks = self.get_supported_pooling_tasks()
|
supported_pooling_tasks = self.get_supported_pooling_tasks()
|
||||||
|
|
||||||
if not supported_pooling_tasks:
|
if not supported_pooling_tasks:
|
||||||
if self.scheduler_config.enable_chunked_prefill:
|
raise RuntimeError(
|
||||||
raise RuntimeError(
|
f"Model {self.model_config.model} does not support "
|
||||||
f"Model {self.model_config.model} does not support "
|
"any pooling tasks. See "
|
||||||
"any pooling tasks with chunked prefill enabled. "
|
"https://docs.vllm.ai/en/latest/models/pooling_models.html "
|
||||||
"Please add --no-enable-chunked-prefill to your "
|
"to learn more."
|
||||||
"config or CLI args. See "
|
)
|
||||||
"https://docs.vllm.ai/en/latest/models/pooling_models.html "
|
|
||||||
"to learn more."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Model {self.model_config.model} does not support "
|
|
||||||
"any pooling tasks. See "
|
|
||||||
"https://docs.vllm.ai/en/latest/models/pooling_models.html "
|
|
||||||
"to learn more."
|
|
||||||
)
|
|
||||||
|
|
||||||
output_size = dict[PoolingTask, float]()
|
output_size = dict[PoolingTask, float]()
|
||||||
for task in supported_pooling_tasks:
|
for task in supported_pooling_tasks:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user