mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 17:05:53 +08:00
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Ayush Singh <ayush1009208@gmail.com>
168 lines
5.3 KiB
Python
168 lines
5.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import pytest
|
|
import torch.nn as nn
|
|
|
|
from vllm.platforms import current_platform
|
|
|
|
prompt = """
|
|
Generals gathered in their masses
|
|
Just like witches at black masses
|
|
Evil minds that plot destruction
|
|
Sorcerer of death's construction
|
|
In the fields, the bodies burning
|
|
As the war machine keeps turning
|
|
Death and hatred to mankind
|
|
Poisoning their brainwashed minds
|
|
Oh, Lord, yeah
|
|
|
|
Politicians hide themselves away
|
|
They only started the war
|
|
Why should they go out to fight?
|
|
They leave that all to the poor, yeah
|
|
Time will tell on their power minds
|
|
Making war just for fun
|
|
Treating people just like pawns in chess
|
|
Wait till their judgment day comes, yeah
|
|
|
|
Now, in darkness, world stops turning
|
|
Ashes where their bodies burning
|
|
No more war pigs have the power
|
|
Hand of God has struck the hour
|
|
Day of Judgment, God is calling
|
|
On their knees, the war pigs crawling
|
|
Begging mercies for their sins
|
|
Satan, laughing, spreads his wings
|
|
Oh, Lord, yeah
|
|
"""
|
|
|
|
|
|
class WrapperPooler(nn.Module):
|
|
def __init__(self, pooler):
|
|
super().__init__()
|
|
self.pooler = pooler
|
|
self.chunks = []
|
|
|
|
def get_pooling_updates(self, task):
|
|
return self.pooler.get_pooling_updates(task)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
pooling_metadata,
|
|
):
|
|
self.chunks.append(hidden_states.shape[0])
|
|
return self.pooler(hidden_states, pooling_metadata)
|
|
|
|
|
|
def inject_pooler(self):
|
|
model = self.get_model()
|
|
wrapper = WrapperPooler(model.pooler)
|
|
model.pooler = wrapper
|
|
|
|
|
|
def retrieve_chunks(self):
|
|
model = self.get_model()
|
|
chunks = model.pooler.chunks
|
|
model.pooler.chunks = []
|
|
return chunks
|
|
|
|
|
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
|
def test_pooling_chunked_prefill(vllm_runner, monkeypatch):
|
|
"""Test chunked prefill for pooling models with LastPool."""
|
|
|
|
with monkeypatch.context() as m:
|
|
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
|
model_id = "Qwen/Qwen3-Embedding-0.6B"
|
|
|
|
chunk_size = 10
|
|
|
|
# Set chunking parameters to force chunked prefill
|
|
# Note: Chunked prefill is automatically handled by vLLM
|
|
# internally based on the model size and prompt
|
|
with vllm_runner(
|
|
model_id,
|
|
runner="pooling",
|
|
long_prefill_token_threshold=chunk_size,
|
|
tensor_parallel_size=1,
|
|
enforce_eager=True,
|
|
enable_chunked_prefill=True,
|
|
) as llm:
|
|
llm.get_llm().llm_engine.collective_rpc(inject_pooler)
|
|
|
|
tokenizer = llm.get_llm().get_tokenizer()
|
|
tokens = tokenizer(prompt)["input_ids"]
|
|
prompt_len = len(tokens)
|
|
full_chunks, last_chunk = divmod(prompt_len, chunk_size)
|
|
expected_chunks = [chunk_size] * full_chunks
|
|
if last_chunk:
|
|
expected_chunks.append(last_chunk)
|
|
llm.embed([prompt])
|
|
chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0]
|
|
|
|
# Check that PoolerWrapper was called and chunks were received
|
|
assert len(chunks) > 1
|
|
assert chunks == expected_chunks
|
|
|
|
# Disable chunked prefill
|
|
with vllm_runner(
|
|
model_id,
|
|
runner="pooling",
|
|
tensor_parallel_size=1,
|
|
enforce_eager=True,
|
|
) as llm:
|
|
llm.get_llm().llm_engine.collective_rpc(inject_pooler)
|
|
llm.embed([prompt])
|
|
chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0]
|
|
|
|
# Check that PoolerWrapper was called and no chunks were received
|
|
assert len(chunks) == 1
|
|
assert chunks[0] == prompt_len
|
|
|
|
|
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
|
def test_pooling_prefix_cache(vllm_runner, monkeypatch):
|
|
"""Test chunked prefill for pooling models with LastPool."""
|
|
|
|
verses = prompt.split("\n\n")
|
|
|
|
with monkeypatch.context() as m:
|
|
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
|
model_id = "Qwen/Qwen3-Embedding-0.6B"
|
|
|
|
with vllm_runner(
|
|
model_id,
|
|
runner="pooling",
|
|
enable_prefix_caching=True,
|
|
tensor_parallel_size=1,
|
|
enforce_eager=True,
|
|
) as llm:
|
|
llm.get_llm().llm_engine.collective_rpc(inject_pooler)
|
|
tokenizer = llm.get_llm().get_tokenizer()
|
|
|
|
prompt1 = "\n\n".join([verses[0], verses[1]])
|
|
prompt2 = "\n\n".join([verses[0], verses[2]])
|
|
tokens1 = tokenizer(prompt1)["input_ids"]
|
|
tokens2 = tokenizer(prompt2)["input_ids"]
|
|
prompt1_len = len(tokens1)
|
|
prompt2_len = len(tokens2)
|
|
|
|
llm.embed([prompt1])
|
|
chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0]
|
|
|
|
assert len(chunks) == 1
|
|
assert chunks[0] == prompt1_len
|
|
|
|
llm.embed([prompt2])
|
|
chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0]
|
|
|
|
assert len(chunks) == 1
|
|
assert chunks[0] <= prompt1_len
|
|
assert chunks[0] < prompt2_len
|
|
|
|
cache_config = llm.get_llm().llm_engine.cache_config
|
|
print(f"{cache_config=}")
|
|
# Prefixes are cached in blocks
|
|
assert (prompt2_len - chunks[0]) % cache_config.block_size == 0
|