diff --git a/tests/v1/e2e/test_pooling_chunked_prefill.py b/tests/v1/e2e/test_pooling_chunked_prefill.py new file mode 100644 index 0000000000000..a196e359920de --- /dev/null +++ b/tests/v1/e2e/test_pooling_chunked_prefill.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch.nn as nn + +from vllm.platforms import current_platform + +prompt = """ +Generals gathered in their masses +Just like witches at black masses +Evil minds that plot destruction +Sorcerer of death's construction +In the fields, the bodies burning +As the war machine keeps turning +Death and hatred to mankind +Poisoning their brainwashed minds +Oh, Lord, yeah + +Politicians hide themselves away +They only started the war +Why should they go out to fight? +They leave that all to the poor, yeah +Time will tell on their power minds +Making war just for fun +Treating people just like pawns in chess +Wait till their judgment day comes, yeah + +Now, in darkness, world stops turning +Ashes where their bodies burning +No more war pigs have the power +Hand of God has struck the hour +Day of Judgment, God is calling +On their knees, the war pigs crawling +Begging mercies for their sins +Satan, laughing, spreads his wings +Oh, Lord, yeah +""" + + +class WrapperPooler(nn.Module): + def __init__(self, pooler): + super().__init__() + self.pooler = pooler + self.chunks = [] + + def get_pooling_updates(self, task): + return self.pooler.get_pooling_updates(task) + + def forward( + self, + hidden_states, + pooling_metadata, + ): + self.chunks.append(hidden_states.shape[0]) + return self.pooler(hidden_states, pooling_metadata) + + +def inject_pooler(self): + model = self.get_model() + wrapper = WrapperPooler(model.pooler) + model.pooler = wrapper + + +def retrieve_chunks(self): + model = self.get_model() + chunks = model.pooler.chunks + model.pooler.chunks = [] + return chunks + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") +def test_pooling_chunked_prefill(vllm_runner, monkeypatch): + """Test chunked prefill for pooling models with LastPool.""" + + with monkeypatch.context() as m: + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + model_id = "Qwen/Qwen3-Embedding-0.6B" + + chunk_size = 10 + + # Set chunking parameters to force chunked prefill + # Note: Chunked prefill is automatically handled by vLLM + # internally based on the model size and prompt + with vllm_runner( + model_id, + runner="pooling", + long_prefill_token_threshold=chunk_size, + tensor_parallel_size=1, + enforce_eager=True, + enable_chunked_prefill=True, + ) as llm: + llm.get_llm().llm_engine.collective_rpc(inject_pooler) + + tokenizer = llm.get_llm().get_tokenizer() + tokens = tokenizer(prompt)["input_ids"] + prompt_len = len(tokens) + full_chunks, last_chunk = divmod(prompt_len, chunk_size) + expected_chunks = [chunk_size] * full_chunks + if last_chunk: + expected_chunks.append(last_chunk) + llm.embed([prompt]) + chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0] + + # Check that PoolerWrapper was called and chunks were received + assert len(chunks) > 1 + assert chunks == expected_chunks + + # Disable chunked prefill + with vllm_runner( + model_id, + runner="pooling", + tensor_parallel_size=1, + enforce_eager=True, + ) as llm: + llm.get_llm().llm_engine.collective_rpc(inject_pooler) + llm.embed([prompt]) + chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0] + + # Check that PoolerWrapper was called and no chunks were received + assert len(chunks) == 1 + assert chunks[0] == prompt_len + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") +def test_pooling_prefix_cache(vllm_runner, monkeypatch): + """Test chunked prefill for pooling models with LastPool.""" + + verses = prompt.split("\n\n") + + with monkeypatch.context() as m: + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + model_id = "Qwen/Qwen3-Embedding-0.6B" + + with vllm_runner( + model_id, + runner="pooling", + enable_prefix_caching=True, + tensor_parallel_size=1, + enforce_eager=True, + ) as llm: + llm.get_llm().llm_engine.collective_rpc(inject_pooler) + tokenizer = llm.get_llm().get_tokenizer() + + prompt1 = "\n\n".join([verses[0], verses[1]]) + prompt2 = "\n\n".join([verses[0], verses[2]]) + tokens1 = tokenizer(prompt1)["input_ids"] + tokens2 = tokenizer(prompt2)["input_ids"] + prompt1_len = len(tokens1) + prompt2_len = len(tokens2) + + llm.embed([prompt1]) + chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0] + + assert len(chunks) == 1 + assert chunks[0] == prompt1_len + + llm.embed([prompt2]) + chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0] + + assert len(chunks) == 1 + assert chunks[0] <= prompt1_len + assert chunks[0] < prompt2_len + + cache_config = llm.get_llm().llm_engine.cache_config + print(f"{cache_config=}") + # Prefixes are cached in blocks + assert (prompt2_len - chunks[0]) % cache_config.block_size == 0