Support embedding models in V1 (#16188)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
Maximilien de Bayser 2025-06-19 01:36:33 -03:00 committed by GitHub
parent 4959915089
commit 799397ee4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
56 changed files with 889 additions and 281 deletions

View File

@ -12,7 +12,10 @@ def parse_args():
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="intfloat/e5-mistral-7b-instruct", task="embed", enforce_eager=True
model="intfloat/e5-mistral-7b-instruct",
task="embed",
enforce_eager=True,
max_model_len=1024,
)
return parser.parse_args()

View File

@ -94,6 +94,7 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
engine_args = EngineArgs(
model="TIGER-Lab/VLM2Vec-Full",
task="embed",
max_model_len=4096,
trust_remote_code=True,
mm_processor_kwargs={"num_crops": 4},
limit_mm_per_prompt={"image": 1},

View File

@ -31,7 +31,7 @@ class TestSetting:
# basic llama model
TestSetting(
model="meta-llama/Llama-3.2-1B-Instruct",
model_args=[],
model_args=["--max-model-len", "2048"],
pp_size=2,
tp_size=2,
attn_backend="FLASHINFER",
@ -41,7 +41,7 @@ class TestSetting:
# llama model with quantization
TestSetting(
model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
model_args=["--quantization", "gptq"],
model_args=["--quantization", "gptq", "--max-model-len", "2048"],
pp_size=1,
tp_size=1,
attn_backend="FLASH_ATTN",
@ -51,7 +51,7 @@ class TestSetting:
# MoE model
TestSetting(
model="ibm/PowerMoE-3b",
model_args=[],
model_args=["--max-model-len", "2048"],
pp_size=1,
tp_size=2,
attn_backend="FLASH_ATTN",
@ -61,23 +61,27 @@ class TestSetting:
# embedding model
TestSetting(
model="BAAI/bge-multilingual-gemma2",
model_args=["--task", "embed", "--dtype", "bfloat16"],
model_args=[
"--task", "embed", "--dtype", "bfloat16", "--max-model-len",
"2048"
],
pp_size=1,
tp_size=1,
attn_backend="FLASH_ATTN",
method="encode",
fullgraph=True,
),
# encoder-based embedding model (BERT)
TestSetting(
model="BAAI/bge-base-en-v1.5",
model_args=["--task", "embed"],
pp_size=1,
tp_size=1,
attn_backend="XFORMERS",
method="encode",
fullgraph=True,
),
# TODO: bert models are not supported in V1 yet
# # encoder-based embedding model (BERT)
# TestSetting(
# model="BAAI/bge-base-en-v1.5",
# model_args=["--task", "embed"],
# pp_size=1,
# tp_size=1,
# attn_backend="XFORMERS",
# method="encode",
# fullgraph=True,
# ),
# vision language model
TestSetting(
model="microsoft/Phi-3.5-vision-instruct",

View File

@ -145,6 +145,7 @@ def run_with_both_engines(request, monkeypatch):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
skip_v0 = request.node.get_closest_marker("skip_v0")
skip_v1 = request.node.get_closest_marker("skip_v1")
if use_v1:
@ -152,6 +153,8 @@ def run_with_both_engines(request, monkeypatch):
pytest.skip("Skipping test on vllm V1")
monkeypatch.setenv('VLLM_USE_V1', '1')
else:
if skip_v0:
pytest.skip("Skipping test on vllm V0")
monkeypatch.setenv('VLLM_USE_V1', '0')
yield

View File

@ -8,6 +8,8 @@ import pytest
from vllm import LLM, PoolingParams, PoolingRequestOutput
from vllm.distributed import cleanup_dist_env_and_memory
from ...models.utils import check_embeddings_close
MODEL_NAME = "intfloat/multilingual-e5-small"
PROMPTS = [
@ -27,6 +29,14 @@ TOKEN_IDS = [
]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
@ -46,9 +56,15 @@ def llm():
cleanup_dist_env_and_memory()
def assert_outputs_equal(o1: list[PoolingRequestOutput],
def assert_outputs_match(o1: list[PoolingRequestOutput],
o2: list[PoolingRequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2]
check_embeddings_close(
embeddings_0_lst=[o.outputs.data for o in o1],
embeddings_1_lst=[o.outputs.data for o in o2],
name_0="hf",
name_1="vllm",
tol=1e-2,
)
@pytest.mark.skip_global_cleanup
@ -63,7 +79,7 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
v2_output = llm.encode({"prompt_token_ids": prompt_token_ids},
pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
assert_outputs_match(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
@ -80,7 +96,7 @@ def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
} for p in TOKEN_IDS],
pooling_params=pooling_params,
)
assert_outputs_equal(v1_output, v2_output)
assert_outputs_match(v1_output, v2_output)
@pytest.mark.skip_global_cleanup

View File

@ -21,6 +21,14 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
DTYPE = "bfloat16"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module")
def server():
args = [

View File

@ -7,6 +7,7 @@ import numpy as np
import pytest
import requests
from tests.models.utils import check_embeddings_close
from vllm.entrypoints.openai.protocol import PoolingResponse
from vllm.transformers_utils.tokenizer import get_tokenizer
@ -223,8 +224,11 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
np.frombuffer(base64.b64decode(data.data),
dtype="float32").tolist())
assert responses_float.data[0].data == decoded_responses_base64_data[0]
assert responses_float.data[1].data == decoded_responses_base64_data[1]
check_embeddings_close(
embeddings_0_lst=[d.data for d in responses_float.data],
embeddings_1_lst=decoded_responses_base64_data,
name_0="float32",
name_1="base64")
# Default response is float32 decoded from base64 by OpenAI Client
default_response = requests.post(
@ -237,5 +241,8 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
default_response.raise_for_status()
responses_default = PoolingResponse.model_validate(default_response.json())
assert responses_float.data[0].data == responses_default.data[0].data
assert responses_float.data[1].data == responses_default.data[1].data
check_embeddings_close(
embeddings_0_lst=[d.data for d in responses_default.data],
embeddings_1_lst=[d.data for d in responses_default.data],
name_0="float32",
name_1="base64")

View File

@ -12,6 +12,14 @@ MODEL_NAME = "BAAI/bge-reranker-base"
DTYPE = "bfloat16"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module")
def server():
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]

View File

@ -11,6 +11,15 @@ from vllm.entrypoints.openai.protocol import ScoreResponse
from ...utils import RemoteOpenAIServer
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
MODELS = [
{
"name": "BAAI/bge-reranker-v2-m3",

View File

@ -6,6 +6,14 @@ from transformers import AutoModelForSequenceClassification
from vllm.platforms import current_platform
# TODO: enable when float32 is supported by V1
# @pytest.fixture(autouse=True)
# def v1(run_with_both_engines):
# # Simple autouse wrapper to run both engines for each test
# # This can be promoted up to conftest.py to run for every
# # test in a package
# pass
@pytest.mark.parametrize(
"model",
@ -29,7 +37,7 @@ def test_models(
# switch to use ROCm CK FA backend
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)
with hf_runner(model,

View File

@ -8,6 +8,14 @@ from vllm.platforms import current_platform
from ...utils import check_embeddings_close
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.mark.parametrize(
"model",
[
@ -20,15 +28,27 @@ from ...utils import check_embeddings_close
marks=[pytest.mark.core_model]),
pytest.param("intfloat/e5-mistral-7b-instruct",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
# the qwen models interfere with each other (see PR
# https://github.com/vllm-project/vllm/pull/18720).
# To avoid this problem, for now we skip v0 since it will be
# deprecated anyway.
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
marks=[pytest.mark.skip_v0]),
# [Encoder-only]
pytest.param("BAAI/bge-base-en-v1.5",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
pytest.param("intfloat/multilingual-e5-small"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
marks=[
pytest.mark.core_model, pytest.mark.cpu_model,
pytest.mark.skip_v1
]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2",
marks=[pytest.mark.skip_v1]),
pytest.param("intfloat/multilingual-e5-small",
marks=[pytest.mark.skip_v1]),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
marks=[pytest.mark.skip_v1]),
# [Cross-Encoder]
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
pytest.param("sentence-transformers/stsb-roberta-base-v2",
marks=[pytest.mark.skip_v1]),
],
)
def test_models(
@ -62,7 +82,7 @@ def test_models(
with vllm_runner(model,
task="embed",
max_model_len=None,
max_model_len=512,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)

View File

@ -265,8 +265,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
_EMBEDDING_EXAMPLE_MODELS = {
# [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True),
@ -279,16 +279,16 @@ _EMBEDDING_EXAMPLE_MODELS = {
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
trust_remote_code=True),
trust_remote_code=True, v0_only=True),
"NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe",
trust_remote_code=True),
trust_remote_code=True, v0_only=True), # noqa: E501
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"),
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501
# [Multimodal]
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
@ -300,10 +300,10 @@ _EMBEDDING_EXAMPLE_MODELS = {
_CROSS_ENCODER_EXAMPLE_MODELS = {
# [Text-only]
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
}
_MULTIMODAL_EXAMPLE_MODELS = {

View File

@ -68,6 +68,7 @@ def _run_incremental_decode(tokenizer,
None,
params,
None,
None,
0.0,
None,
cache_salt=None,

View File

@ -43,6 +43,7 @@ def make_request(request_id,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
pooling_params=None,
eos_token_id=100,
lora_request=None,
cache_salt=cache_salt,

View File

@ -39,6 +39,7 @@ def make_request(request_id,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17,
prompt_logprobs=prompt_logprobs),
pooling_params=None,
eos_token_id=100,
lora_request=None,
cache_salt=cache_salt,

View File

@ -135,6 +135,7 @@ def create_requests(num_requests: int,
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
@ -283,6 +284,7 @@ def test_schedule_partial_requests():
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
@ -333,6 +335,7 @@ def test_no_mm_input_chunking():
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
@ -396,6 +399,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
@ -420,6 +424,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output1, model_runner_output)
output2 = scheduler.schedule()
@ -473,7 +478,8 @@ def test_stop_via_update_from_output():
11]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
@ -523,7 +529,8 @@ def test_stop_via_update_from_output():
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
@ -572,7 +579,8 @@ def test_stop_via_update_from_output():
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
@ -614,7 +622,8 @@ def test_stop_via_update_from_output():
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
@ -663,6 +672,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output0, model_runner_output)
@ -680,6 +690,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output1, model_runner_output)
@ -730,6 +741,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
spec_token_ids=spec_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
@ -769,6 +781,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
@ -896,6 +909,7 @@ def test_kv_connector_basic():
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Ensure ScheduleOutput is correct.
@ -941,6 +955,7 @@ def test_kv_connector_basic():
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# We should get a local cache hit of NUM_TOKENS_PREFIX and
@ -1007,6 +1022,7 @@ def test_kv_connector_unable_to_allocate():
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Just one request should be running.
@ -1087,6 +1103,7 @@ def test_kv_connector_handles_preemption():
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# All can be scheduled - 1st token.
@ -1181,6 +1198,7 @@ def make_output(scheduler: Scheduler):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)

View File

@ -39,6 +39,7 @@ def make_request() -> EngineCoreRequest:
mm_hashes=None,
mm_placeholders=None,
sampling_params=SamplingParams(),
pooling_params=None,
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,

View File

@ -53,6 +53,7 @@ def make_request(
mm_hashes=None,
mm_placeholders=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,

View File

@ -33,6 +33,7 @@ def test_fast_inc_detok_invalid_utf8_err_case():
None,
params,
None,
None,
0.0,
None,
cache_salt=None,

View File

@ -66,7 +66,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
output_kind=request_output_kind,
stop=[],
include_stop_str_in_output=False,
))
),
pooling_params=None)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
@ -416,7 +417,8 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
include_stop_str_in_output=False,
logprobs=num_sample_logprobs,
prompt_logprobs=num_prompt_logprobs,
))
),
pooling_params=None)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
@ -582,7 +584,8 @@ def test_stop_token(include_stop_str_in_output: bool,
logprobs=num_sample_logprobs,
prompt_logprobs=None,
ignore_eos=ignore_eos,
))
),
pooling_params=None)
# Add request to the detokenizer.
output_processor.add_request(request, prompt_string)
@ -678,7 +681,8 @@ def test_stop_string(include_stop_str_in_output: bool,
include_stop_str_in_output=include_stop_str_in_output,
logprobs=num_sample_logprobs,
prompt_logprobs=None,
))
),
pooling_params=None)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
@ -786,6 +790,7 @@ def test_iteration_stats(dummy_test_vectors):
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(),
pooling_params=None,
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]

View File

@ -150,6 +150,7 @@ def create_request(
request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
@ -183,6 +184,7 @@ def create_model_runner_output(
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=None,
finished_sending=finished_sending,
finished_recving=finished_recving,
)

View File

@ -10,6 +10,7 @@ import torch
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@ -46,7 +47,7 @@ def _compare_objs(obj1, obj2):
for a_i, b_i in zip(a.block_tables, b.block_tables):
_compare_objs(a_i, b_i)
is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata)):
elif isinstance(a, (BlockTable, SamplingMetadata, PoolingMetadata)):
_compare_objs(a, b)
is_same = True # if we make it here must be same
elif a == b:
@ -201,6 +202,7 @@ def _construct_cached_request_state(req_id_suffix: int):
req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
sampling_params=_create_sampling_params(),
pooling_params=None,
mm_inputs=[],
mm_positions=[],
block_ids=([], ),

View File

@ -122,6 +122,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
pooling_params=None,
block_ids=([0], ),
num_computed_tokens=0,
lora_request=None,

View File

@ -4496,11 +4496,31 @@ class VllmConfig:
if self.compilation_config.full_cuda_graph and \
not self.model_config.disable_cascade_attn:
logger.warning_once(
"full_cuda_graph is not supported with "
"cascade attention. Disabling cascade attention.")
logger.info("full_cuda_graph is not supported with "
"cascade attention. Disabling cascade attention.")
self.model_config.disable_cascade_attn = True
disable_chunked_prefill_reasons: list[str] = []
if self.model_config and self.model_config.pooler_config:
pooling_type = self.model_config.pooler_config.pooling_type
if pooling_type is None or pooling_type.lower() != "last":
disable_chunked_prefill_reasons.append(
"Only \"last\" pooling supports chunked "
"prefill and prefix caching; disabling both.")
if disable_chunked_prefill_reasons:
for reason in disable_chunked_prefill_reasons:
logger.info(reason)
self.scheduler_config.chunked_prefill_enabled = False
self.scheduler_config.long_prefill_token_threshold = 0
self.scheduler_config.max_num_batched_tokens = max(
self.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)
if self.cache_config is not None:
self.cache_config.enable_prefix_caching = False
if (self.kv_events_config is not None
and self.kv_events_config.enable_kv_cache_events
and not self.cache_config.enable_prefix_caching):

View File

@ -1041,7 +1041,7 @@ class EngineArgs:
# Set default arguments for V0 or V1 Engine.
if use_v1:
self._set_default_args_v1(usage_context)
self._set_default_args_v1(usage_context, model_config)
else:
self._set_default_args_v0(model_config)
@ -1349,13 +1349,7 @@ class EngineArgs:
recommend_to_remove=False)
return False
# No Embedding Models so far.
if model_config.task not in ["generate"]:
_raise_or_fallback(feature_name=f"--task {model_config.task}",
recommend_to_remove=False)
return False
# No Encoder-Decoder, not all Mamba so far.
# No Mamba or Encoder-Decoder so far.
if not model_config.is_v1_compatible:
_raise_or_fallback(feature_name=model_config.architectures,
recommend_to_remove=False)
@ -1523,15 +1517,38 @@ class EngineArgs:
if self.max_num_seqs is None:
self.max_num_seqs = 256
def _set_default_args_v1(self, usage_context: UsageContext) -> None:
def _set_default_args_v1(self, usage_context: UsageContext,
model_config: ModelConfig) -> None:
"""Set Default Arguments for V1 Engine."""
# V1 always uses chunked prefills.
self.enable_chunked_prefill = True
# V1 always uses chunked prefills and prefix caching
# for non-pooling tasks.
# For pooling tasks the default is False
if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True
if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
else:
# V1 enables prefix caching by default.
if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
pooling_type = model_config.pooler_config.pooling_type
# TODO: when encoder models are supported we'll have to
# check for causal attention here.
incremental_prefill_supported = (pooling_type is not None and
pooling_type.lower() == "last")
action = "Enabling" if \
incremental_prefill_supported else "Disabling"
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = incremental_prefill_supported
logger.info("(%s) chunked prefill by default", action)
if self.enable_prefix_caching is None:
self.enable_prefix_caching = incremental_prefill_supported
logger.info("(%s) prefix caching by default", action)
if not self.enable_chunked_prefill:
self.max_num_batched_tokens = model_config.max_model_len
# V1 should use the new scheduler by default.
# Swap it only if this arg is set to the original V0 default

View File

@ -1266,7 +1266,7 @@ class LLM:
# the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# lists of tokens to the `text` and `text_pair` kwargs
tokenizer = self.llm_engine.get_tokenizer()
tokenizer = self.get_tokenizer()
def ensure_str(prompt: SingletonPrompt):
if isinstance(prompt, dict):

View File

@ -9,6 +9,7 @@ from typing import Final, Literal, Optional, Union, cast
import jinja2
import numpy as np
import torch
from fastapi import Request
from typing_extensions import assert_never
@ -39,7 +40,8 @@ def _get_data(
elif encoding_format == "base64":
# Force to use float32 for base64 encoding
# to match the OpenAI python client behavior
pooling_bytes = np.array(output.data, dtype="float32").tobytes()
pt_float32 = output.data.to(dtype=torch.float32)
pooling_bytes = np.array(pt_float32, dtype="float32").tobytes()
return base64.b64encode(pooling_bytes).decode("utf-8")
assert_never(encoding_format)

View File

@ -10,11 +10,15 @@ import torch.nn.functional as F
from typing_extensions import assert_never
from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors)
from vllm.model_executor.pooling_metadata import ( # noqa: E501
PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
class PoolingType(IntEnum):
@ -75,15 +79,18 @@ class SimplePooler(nn.Module):
def get_prompt_lens(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
if isinstance(pooling_metadata, V1PoolingMetadata):
return pooling_metadata.prompt_lens
assert isinstance(hidden_states, torch.Tensor)
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
def extract_states(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
raise NotImplementedError
@ -93,7 +100,7 @@ class SimplePooler(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
@ -106,11 +113,19 @@ class CLSPool(SimplePooler):
def extract_states(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
if isinstance(hidden_states, list):
result = []
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with CLS pooling"
result.append(req_state[0])
return result
first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
return hidden_states[first_token_flat_indices]
@ -120,9 +135,12 @@ class LastPool(SimplePooler):
def extract_states(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
if isinstance(hidden_states, list):
return [h[-1] for h in hidden_states]
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
@ -133,11 +151,17 @@ class AllPool(SimplePooler):
def extract_states(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
if isinstance(hidden_states, list):
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with ALL pooling"
return hidden_states
offset = 0
pooled_data = list[torch.Tensor]()
for prompt_len in prompt_lens:
@ -151,11 +175,20 @@ class MeanPool(SimplePooler):
def extract_states(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
if isinstance(hidden_states, list):
result = []
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with mean pooling"
result.append(torch.mean(req_state, dim=0,
dtype=torch.float32))
return result
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
@ -184,30 +217,53 @@ class StepPool(SimplePooler):
self.step_tag_id = step_tag_id
self.returned_token_ids = returned_token_ids
def get_prompt_token_ids(
self,
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor]:
if isinstance(pooling_metadata, V1PoolingMetadata):
return [
pooling_metadata.prompt_token_ids[i, :num]
for i, num in enumerate(pooling_metadata.prompt_lens)
]
return [
torch.tensor(seq_data_i.prompt_token_ids)
for seq_data_i in pooling_metadata.seq_data.values()
]
def extract_states(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
prompt_token_ids = self.get_prompt_token_ids(pooling_metadata)
pooled_data: list[torch.Tensor] = []
if isinstance(hidden_states, list):
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with mean pooling"
pooled_data = hidden_states
else:
offset = 0
for prompt_len in prompt_lens:
pooled_data_i = hidden_states[offset:offset + prompt_len]
offset += prompt_len
pooled_data.append(pooled_data_i)
pooled_data = []
returned_token_ids = self.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
hidden_states = hidden_states[:, returned_token_ids]
step_tag_id = self.step_tag_id
offset = 0
pooled_data = list[torch.Tensor]()
for prompt_len, seq_data_i in zip(prompt_lens,
pooling_metadata.seq_data.values()):
pooled_data_i = hidden_states[offset:offset + prompt_len]
if step_tag_id is not None:
token_ids = torch.tensor(seq_data_i.prompt_token_ids)
pooled_data_i = pooled_data_i[token_ids == step_tag_id]
for data, token_id in zip(pooled_data, prompt_token_ids):
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
offset += prompt_len
pooled_data.append(pooled_data_i)
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data
@ -230,10 +286,17 @@ class PoolerHead(nn.Module):
else:
pooled_data = pooled_data.to(torch.float32)
dimensions_list = [
pooling_param.dimensions
for _, pooling_param in pooling_metadata.seq_groups
]
if isinstance(pooling_metadata, V0PoolingMetadata):
dimensions_list = [
pooling_param.dimensions
for _, pooling_param in pooling_metadata.seq_groups
]
else:
assert isinstance(pooled_data, list)
dimensions_list = [
pooling_param.dimensions
for pooling_param in pooling_metadata.pooling_params
]
if any(d is not None for d in dimensions_list):
# change the output dimension
assert len(pooled_data) == len(dimensions_list)
@ -325,20 +388,41 @@ class ClassifierPooler(nn.Module):
raise NotImplementedError(f"task={config.task!r} is not supported"
" with the classification pooler")
def get_prompt_lens(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
if isinstance(pooling_metadata, V1PoolingMetadata):
return pooling_metadata.prompt_lens
assert isinstance(hidden_states, torch.Tensor)
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
def forward(
self,
hidden_states: torch.Tensor,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
"""Pools sentence pair scores from the hidden_states."""
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
prompt_lens = PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
pooled_data = list[torch.Tensor]()
if isinstance(hidden_states, list):
for req_state, prompt_len in zip(hidden_states, prompt_lens):
assert prompt_len == req_state.shape[0], \
"partial prefill not supported with classifier"
pooled_data = hidden_states
else:
offset = 0
for prompt_len in prompt_lens:
pooled_data_i = hidden_states[offset:offset + prompt_len]
offset += prompt_len
pooled_data.append(pooled_data_i)
offset = 0
pooled_data_lst = []
for prompt_len in prompt_lens:
pooled_data_i = hidden_states[offset:offset + prompt_len]
for pooled_data_i in pooled_data:
if self.pooler is not None:
final_shape_tensor = self.pooler(pooled_data_i)
@ -346,7 +430,6 @@ class ClassifierPooler(nn.Module):
final_shape_tensor = self.classifier(pooled_data_i)
pooled_data_lst.append(final_shape_tensor)
offset += prompt_len
pooled_output = torch.stack(pooled_data_lst)

View File

@ -446,8 +446,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
softmax=False)
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsQuant):
class BertForSequenceClassification(nn.Module, SupportsV0Only,
SupportsCrossEncoding, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for

View File

@ -21,7 +21,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsCrossEncoding
from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix
@ -270,7 +270,8 @@ class ModernBertPooler(nn.Module):
return pooled_output
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
SupportsCrossEncoding):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -375,7 +375,12 @@ class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
) -> Optional[PoolerOutput]:
hidden_states = self._pooler.extract_states(hidden_states,
pooling_metadata)
logits, _ = self.score(hidden_states)
if isinstance(hidden_states, list):
logits = [self.score(state)[0] for state in hidden_states]
else:
logits, _ = self.score(hidden_states)
pooled_data = self._pooler.head(logits, pooling_metadata)
pooled_outputs = [
self._pooler.build_output(data.squeeze(-1)) for data in pooled_data

View File

@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Any, Optional
import msgspec
from vllm.sampling_params import RequestOutputKind
if TYPE_CHECKING:
from vllm.config import ModelConfig
@ -23,6 +25,7 @@ class PoolingParams(
dimensions: Optional[int] = None
additional_data: Optional[Any] = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance."""
@ -52,3 +55,7 @@ class PoolingParams(
return (f"PoolingParams("
f"dimensions={self.dimensions}, "
f"additional_metadata={self.additional_data})")
def __post_init__(self) -> None:
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\
"For pooling output_kind has to be FINAL_ONLY"

View File

@ -146,7 +146,8 @@ class KVCacheManager:
# Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching.
if (not self.enable_caching
or request.sampling_params.prompt_logprobs is not None):
or (request.sampling_params is not None
and request.sampling_params.prompt_logprobs is not None)):
return self.create_empty_block_list(), 0
# The block hashes for the request may already be computed

View File

@ -14,6 +14,7 @@ if TYPE_CHECKING:
KVConnectorMetadata)
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
@ -26,7 +27,8 @@ class NewRequestData:
mm_inputs: list[MultiModalKwargs]
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
@ -44,6 +46,7 @@ class NewRequestData:
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,

View File

@ -402,6 +402,15 @@ class Scheduler(SchedulerInterface):
< num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget:
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
@ -707,6 +716,7 @@ class Scheduler(SchedulerInterface):
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
pooler_outputs = model_runner_output.pooler_output
new_running: list[Request] = []
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
@ -724,7 +734,8 @@ class Scheduler(SchedulerInterface):
continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index]
generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else []
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
@ -776,8 +787,17 @@ class Scheduler(SchedulerInterface):
del new_token_ids[num_new:] # Trim new tokens if needed.
break
pooler_output = None
if pooler_outputs:
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, self.max_model_len,
pooler_output)
if stopped:
kv_transfer_params = self._free_request(request)
# Extract sample logprobs if needed.
if request.sampling_params.logprobs is not None and logprobs:
if request.sampling_params is not None \
and request.sampling_params.logprobs is not None and logprobs:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
@ -802,7 +822,8 @@ class Scheduler(SchedulerInterface):
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or kv_transfer_params:
if new_token_ids or pooler_output is not None \
or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
@ -812,6 +833,7 @@ class Scheduler(SchedulerInterface):
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,

View File

@ -1,15 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.v1.request import Request, RequestStatus
def check_stop(request: Request, max_model_len: int) -> bool:
def check_stop(request: Request,
max_model_len: int,
pooler_output: Optional[torch.Tensor] = None) -> bool:
if (request.num_tokens >= max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
if request.pooling_params:
if pooler_output is not None:
request.status = RequestStatus.FINISHED_STOPPED
return True
return False
sampling_params = request.sampling_params
assert sampling_params is not None
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):

View File

@ -7,10 +7,12 @@ from collections.abc import Sequence
from typing import Any, Optional, Union
import msgspec
import torch
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
@ -50,7 +52,8 @@ class EngineCoreRequest(
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: SamplingParams
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
eos_token_id: Optional[int]
arrival_time: float
lora_request: Optional[LoRARequest]
@ -104,6 +107,8 @@ class EngineCoreOutput(
new_logprobs: Optional[LogprobsLists] = None
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
pooling_output: Optional[torch.Tensor] = None
finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None
events: Optional[list[EngineCoreEvent]] = None

View File

@ -17,7 +17,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
@ -228,8 +228,7 @@ class AsyncLLM(EngineClient):
if self.errored:
raise EngineDeadError()
assert isinstance(params, SamplingParams), \
"Pooling is not supported in V1"
is_pooling = isinstance(params, PoolingParams)
# Create a new output collector for the request.
queue = RequestOutputCollector(output_kind=params.output_kind)
@ -240,7 +239,7 @@ class AsyncLLM(EngineClient):
tokenization_kwargs, trace_headers, prompt_adapter_request,
priority, data_parallel_rank)
if params.n == 1:
if is_pooling or params.n == 1:
await self._add_request(request, prompt_str, None, 0, queue)
return queue
@ -443,7 +442,7 @@ class AsyncLLM(EngineClient):
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
def encode(
async def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
@ -451,8 +450,75 @@ class AsyncLLM(EngineClient):
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
):
raise ValueError("Not Supported on V1 yet.")
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
* 2) Processing the Input.
* 3) Adding the Request to the EngineCore (separate process).
A separate output_handler loop runs in a background AsyncIO task,
pulling outputs from EngineCore and putting them into the
per-request AsyncStream.
The caller of generate() iterates the returned AsyncGenerator,
returning the RequestOutput back to the caller.
"""
try:
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
q = await self.add_request(
request_id,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
)
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
finished = False
while not finished:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() or await q.get()
assert isinstance(out, PoolingRequestOutput)
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
yield out
# If the request is disconnected by the client, generate()
# is cancelled. So, we abort the request if we end up here.
except asyncio.CancelledError:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise
# Engine is dead. Do not abort since we shut down.
except EngineDeadError:
if self.log_requests:
logger.info("Request %s failed (engine dead).", request_id)
raise
# Request validation error.
except ValueError:
if self.log_requests:
logger.info("Request %s failed (bad request).", request_id)
raise
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
async def get_vllm_config(self) -> VllmConfig:
return self.vllm_config

View File

@ -60,7 +60,6 @@ class EngineCore:
executor_class: type[Executor],
log_stats: bool,
executor_fail_callback: Optional[Callable] = None):
assert vllm_config.model_config.runner_type != "pooling"
# plugins need to be loaded at the engine/scheduler level too
from vllm.plugins import load_general_plugins

View File

@ -50,6 +50,8 @@ class IncrementalDetokenizer:
request: EngineCoreRequest,
) -> "IncrementalDetokenizer":
assert request.sampling_params is not None
if tokenizer is None:
# No tokenizer => skipping detokenization.
return IncrementalDetokenizer()
@ -70,6 +72,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
# Stop strings
params = request.sampling_params
assert params is not None
self.stop = stop = params.stop
self.include_stop_str_in_output = params.include_stop_str_in_output
@ -164,6 +167,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
super().__init__(request)
sampling_params = request.sampling_params
assert sampling_params is not None
self.request_id = request.request_id
self.skip_special_tokens = sampling_params.skip_special_tokens
@ -245,20 +249,20 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
super().__init__(request)
self.tokenizer = tokenizer
params = request.sampling_params
assert params is not None
# Metadata for incremental detokenization.
self.tokens, self.prefix_offset, self.read_offset = (
convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
skip_special_tokens=request.sampling_params.
skip_special_tokens,
skip_special_tokens=params.skip_special_tokens,
))
self.token_ids.extend(request.prompt_token_ids)
self.prompt_len = len(request.prompt_token_ids)
params = request.sampling_params
self.skip_special_tokens = params.skip_special_tokens
self.spaces_between_special_tokens = (
params.spaces_between_special_tokens)

View File

@ -15,7 +15,7 @@ from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import RequestOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
@ -221,7 +221,7 @@ class LLMEngine:
# Add the request to EngineCore.
self.engine_core.add_request(child_request)
def step(self) -> list[RequestOutput]:
def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]:
if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False

View File

@ -38,6 +38,7 @@ class LogprobsProcessor:
tokenizer: Optional[AnyTokenizer],
request: EngineCoreRequest,
) -> "LogprobsProcessor":
assert request.sampling_params is not None
num_logprobs = request.sampling_params.logprobs
num_prompt_logprobs = request.sampling_params.prompt_logprobs
return cls(

View File

@ -4,9 +4,12 @@
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast
from vllm.outputs import CompletionOutput, RequestOutput
import torch
from vllm.outputs import (CompletionOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@ -29,20 +32,22 @@ class RequestOutputCollector:
def __init__(self, output_kind: RequestOutputKind):
self.aggregate = output_kind == RequestOutputKind.DELTA
self.output: Optional[Union[RequestOutput, Exception]] = None
self.output: Optional[Union[RequestOutput, PoolingRequestOutput,
Exception]] = None
self.ready = asyncio.Event()
def put(self, output: Union[RequestOutput, Exception]) -> None:
def put(self, output: Union[RequestOutput, PoolingRequestOutput,
Exception]) -> None:
"""Non-blocking put operation."""
if self.output is None or isinstance(output, Exception):
self.output = output
self.ready.set()
elif isinstance(self.output, RequestOutput):
elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
# This ensures that request outputs with different request indexes
# (if n > 1) do not override each other.
self.output.add(output, aggregate=self.aggregate)
async def get(self) -> RequestOutput:
async def get(self) -> Union[RequestOutput, PoolingRequestOutput]:
"""Get operation blocks on put event."""
while (output := self.output) is None:
await self.ready.wait()
@ -52,7 +57,8 @@ class RequestOutputCollector:
raise output
return output
def get_nowait(self) -> Optional[RequestOutput]:
def get_nowait(
self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
"""Non-blocking get operation."""
output = self.output
if output is not None:
@ -66,7 +72,7 @@ class RequestOutputCollector:
@dataclass
class OutputProcessorOutput:
request_outputs: list[RequestOutput]
request_outputs: list[Union[RequestOutput, PoolingRequestOutput]]
reqs_to_abort: list[str]
@ -81,8 +87,8 @@ class RequestState:
output_kind: RequestOutputKind,
prompt: Optional[str],
prompt_token_ids: list[int],
logprobs_processor: LogprobsProcessor,
detokenizer: IncrementalDetokenizer,
logprobs_processor: Optional[LogprobsProcessor],
detokenizer: Optional[IncrementalDetokenizer],
max_tokens_param: Optional[int],
arrival_time: float,
queue: Optional[RequestOutputCollector],
@ -116,27 +122,39 @@ class RequestState:
queue: Optional[RequestOutputCollector],
log_stats: bool,
) -> "RequestState":
if not request.sampling_params.detokenize:
tokenizer = None
if sampling_params := request.sampling_params:
if not sampling_params.detokenize:
tokenizer = None
output_kind = sampling_params.output_kind
logprobs_processor = LogprobsProcessor.from_new_request(
tokenizer=tokenizer,
request=request,
)
detokenizer = IncrementalDetokenizer.from_new_request(
tokenizer=tokenizer,
request=request,
)
max_tokens_param = sampling_params.max_tokens
else:
logprobs_processor = None
detokenizer = None
max_tokens_param = None
assert request.pooling_params is not None
output_kind = request.pooling_params.output_kind
return cls(
request_id=request.request_id,
parent_req=parent_req,
request_index=request_index,
lora_name=(request.lora_request.name
if request.lora_request is not None else None),
output_kind=request.sampling_params.output_kind,
output_kind=output_kind,
prompt=prompt,
prompt_token_ids=request.prompt_token_ids,
logprobs_processor=LogprobsProcessor.from_new_request(
tokenizer=tokenizer,
request=request,
),
detokenizer=IncrementalDetokenizer.from_new_request(
tokenizer=tokenizer,
request=request,
),
max_tokens_param=(request.sampling_params.max_tokens if
request.sampling_params is not None else None),
logprobs_processor=logprobs_processor,
detokenizer=detokenizer,
max_tokens_param=max_tokens_param,
arrival_time=request.arrival_time,
queue=queue,
log_stats=log_stats,
@ -145,11 +163,12 @@ class RequestState:
def make_request_output(
self,
new_token_ids: list[int],
pooling_output: Optional[torch.Tensor],
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> Optional[RequestOutput]:
) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
finished = finish_reason is not None
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
@ -158,15 +177,20 @@ class RequestState:
# Only the final output is required in FINAL_ONLY mode.
return None
completion_output = self._new_completion_output(
new_token_ids, finish_reason, stop_reason)
request_id = self.request_id
if pooling_output is not None:
return self._new_request_output(
request_id, [self._new_pooling_output(pooling_output)],
finished)
output = self._new_completion_output(new_token_ids, finish_reason,
stop_reason)
if self.parent_req is None:
outputs = [completion_output]
outputs = [output]
else:
request_id, outputs, finished = self.parent_req.get_outputs(
request_id, completion_output)
request_id, output)
if not outputs:
return None
@ -176,12 +200,21 @@ class RequestState:
def _new_request_output(
self,
request_id: str,
outputs: list[CompletionOutput],
outputs: Union[list[CompletionOutput], list[PoolingOutput]],
finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> RequestOutput:
) -> Union[RequestOutput, PoolingRequestOutput]:
if isinstance(outputs[0], PoolingOutput):
assert len(outputs) == 1
return PoolingRequestOutput(
request_id=request_id,
outputs=outputs[0],
prompt_token_ids=self.prompt_token_ids,
finished=finished,
)
assert self.logprobs_processor is not None
if self.output_kind == RequestOutputKind.DELTA:
# Side effect: logprobs processor forgets prompt logprobs
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
@ -193,7 +226,7 @@ class RequestState:
prompt=self.prompt,
prompt_token_ids=self.prompt_token_ids,
prompt_logprobs=prompt_logprobs,
outputs=outputs,
outputs=cast(list[CompletionOutput], outputs),
finished=finished,
kv_transfer_params=kv_transfer_params,
num_cached_tokens=num_cached_tokens,
@ -206,6 +239,8 @@ class RequestState:
stop_reason: Union[int, str, None],
) -> CompletionOutput:
assert self.detokenizer is not None
assert self.logprobs_processor is not None
finished = finish_reason is not None
delta = self.output_kind == RequestOutputKind.DELTA
@ -228,6 +263,13 @@ class RequestState:
finish_reason=str(finish_reason) if finished else None,
stop_reason=stop_reason if finished else None)
def _new_pooling_output(
self,
pooling_output: torch.Tensor,
) -> PoolingOutput:
return PoolingOutput(data=pooling_output)
class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs."""
@ -326,7 +368,8 @@ class OutputProcessor:
within the loop below.
"""
request_outputs: list[RequestOutput] = []
request_outputs: Union[list[RequestOutput],
list[PoolingRequestOutput]] = []
reqs_to_abort: list[str] = []
for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id
@ -341,25 +384,31 @@ class OutputProcessor:
iteration_stats)
new_token_ids = engine_core_output.new_token_ids
pooling_output = engine_core_output.pooling_output
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params
num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False
# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP)
if stop_string:
finish_reason = FinishReason.STOP
stop_reason = stop_string
if pooling_output is None:
assert req_state.detokenizer is not None
assert req_state.logprobs_processor is not None
# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
new_token_ids, finish_reason == FinishReason.STOP)
if stop_string:
finish_reason = FinishReason.STOP
stop_reason = stop_string
# 3) Compute sample and prompt logprobs for request, if required.
req_state.logprobs_processor.update_from_output(engine_core_output)
# 3) Compute sample and prompt logprobs for request,
# if required.
req_state.logprobs_processor.update_from_output(
engine_core_output)
# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids, finish_reason, stop_reason,
new_token_ids, pooling_output, finish_reason, stop_reason,
kv_transfer_params, num_cached_tokens):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().

View File

@ -136,8 +136,8 @@ class Processor:
Should raise ValueError if unsupported for API Server.
"""
if not isinstance(params, SamplingParams):
raise ValueError("V1 does not yet support Pooling models.")
if isinstance(params, PoolingParams):
return
self._validate_logprobs(params)
self._validate_sampling_params(params, lora_request)
@ -263,18 +263,22 @@ class Processor:
if encoder_inputs is not None:
raise NotImplementedError
assert isinstance(params, SamplingParams)
# TODO: can we avoid cloning here in multiproc case?
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
sampling_params.max_tokens = (
self.model_config.max_model_len -
len(decoder_inputs["prompt_token_ids"]))
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))
sampling_params = None
pooling_params = None
if isinstance(params, SamplingParams):
# TODO: can we avoid cloning here in multiproc case?
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
sampling_params.max_tokens = (
self.model_config.max_model_len -
len(decoder_inputs["prompt_token_ids"]))
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))
else:
pooling_params = params.clone()
# Multimodal related.
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
@ -331,6 +335,7 @@ class Processor:
mm_hashes=sorted_mm_hashes,
mm_placeholders=sorted_mm_positions,
sampling_params=sampling_params,
pooling_params=pooling_params,
eos_token_id=eos_token_id,
arrival_time=arrival_time,
lora_request=lora_request,

View File

@ -481,8 +481,9 @@ class PrometheusStatLogger(StatLoggerBase):
finished_request.num_prompt_tokens)
self.histogram_num_generation_tokens_request.observe(
finished_request.num_generation_tokens)
self.histogram_max_tokens_request.observe(
finished_request.max_tokens_param)
if finished_request.max_tokens_param:
self.histogram_max_tokens_request.observe(
finished_request.max_tokens_param)
if self.gauge_lora_info is not None:
running_lora_adapters = \

View File

@ -106,7 +106,6 @@ class IterationStats:
self.num_generation_tokens += num_new_generation_tokens
if is_prefilling:
assert num_new_generation_tokens > 0
self.num_prompt_tokens += prompt_len
first_token_latency = self._time_since(req_stats.arrival_time)

View File

@ -101,6 +101,9 @@ class ModelRunnerOutput:
# [prompt_len]
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
# [num_reqs, hidden_size]
pooler_output: list[Optional[torch.Tensor]]
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
@ -112,5 +115,6 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
finished_sending=None,
finished_recving=None)

0
vllm/v1/pool/__init__.py Normal file
View File

16
vllm/v1/pool/metadata.py Normal file
View File

@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.pooling_params import PoolingParams
@dataclass
class PoolingMetadata:
"""Tensors for pooling."""
prompt_lens: torch.Tensor
prompt_token_ids: Optional[torch.Tensor]
pooling_params: list[PoolingParams]

View File

@ -5,6 +5,7 @@ import enum
from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
@ -25,7 +26,8 @@ class Request:
multi_modal_inputs: Optional[list[MultiModalKwargs]],
multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: SamplingParams,
sampling_params: Optional[SamplingParams],
pooling_params: Optional[PoolingParams],
eos_token_id: Optional[int],
client_index: int = 0,
lora_request: Optional["LoRARequest"] = None,
@ -35,18 +37,35 @@ class Request:
self.request_id = request_id
self.client_index = client_index
self.sampling_params = sampling_params
self.pooling_params = pooling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.structured_output_request = structured_output_request
self.status = (RequestStatus.WAITING_FOR_FSM
if sampling_params.guided_decoding is not None else
RequestStatus.WAITING)
self.status = RequestStatus.WAITING
if sampling_params and sampling_params.guided_decoding is not None:
self.status = RequestStatus.WAITING_FOR_FSM
self.events: list[EngineCoreEvent] = []
self.stop_reason: Union[int, str, None] = None
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
# P/D: Connector-specific KV transfer parameters.
self.kv_transfer_params: Optional[dict[str, Any]] = None
if pooling_params is not None:
self.max_tokens = 1
elif sampling_params is not None:
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
if sampling_params.guided_decoding is not None:
self.status = RequestStatus.WAITING_FOR_FSM
if sampling_params.extra_args is not None:
self.kv_transfer_params = \
sampling_params.extra_args.get("kv_transfer_params")
else:
raise ValueError(
"sampling_params and pooling_params can't both be unset")
self.prompt_token_ids = prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids)
@ -63,11 +82,6 @@ class Request:
self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# P/D: Connector-specific KV transfer parameters.
kv_params = (None if sampling_params.extra_args is None else
sampling_params.extra_args.get("kv_transfer_params"))
self.kv_transfer_params: Optional[dict[str, Any]] = kv_params
# Sanity check
assert len(self.mm_inputs) == len(self.mm_positions)
if self.mm_hashes:
@ -98,10 +112,12 @@ class Request:
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
eos_token_id=request.eos_token_id,
lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params),
sampling_params=request.sampling_params) \
if request.sampling_params else None,
cache_salt=request.cache_salt,
)
@ -141,7 +157,8 @@ class Request:
@property
def use_structured_output(self) -> bool:
return self.sampling_params.guided_decoding is not None
return self.sampling_params is not None and \
self.sampling_params.guided_decoding is not None
def record_event(
self,

View File

@ -62,13 +62,15 @@ class StructuredOutputManager:
return
if TYPE_CHECKING:
assert request.sampling_params.guided_decoding is not None
assert request.sampling_params is not None and \
request.sampling_params.guided_decoding is not None
# Initialize the backend the first time it is needed.
#
# NOTE: We only support a single backend. We do NOT support different
# backends on a per-request basis in V1 (for now, anyway...).
if self.backend is None:
assert request.sampling_params is not None
backend = request.sampling_params.guided_decoding.backend
vocab_size = self.vllm_config.model_config.get_vocab_size()
if backend == "xgrammar":

View File

@ -10,9 +10,11 @@ import torch
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import MultiGroupBlockTable
@ -27,7 +29,8 @@ class CachedRequestState:
prompt_token_ids: list[int]
mm_inputs: list[MultiModalKwargs]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator]
block_ids: tuple[list[int], ...]
@ -226,6 +229,8 @@ class InputBatch:
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
self.pooling_params: dict[str, PoolingParams] = {}
@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
@ -269,77 +274,83 @@ class InputBatch:
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
sampling_params = request.sampling_params
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
if sampling_params := request.sampling_params:
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.min_p_cpu[req_index] = sampling_params.min_p
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.min_p > _SAMPLING_EPS:
self.min_p_reqs.add(req_id)
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
if sampling_params.min_tokens:
self.min_tokens[req_index] = (sampling_params.min_tokens,
sampling_params.all_stop_token_ids)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.min_p_cpu[req_index] = sampling_params.min_p
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.min_p > _SAMPLING_EPS:
self.min_p_reqs.add(req_id)
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
if sampling_params.min_tokens:
self.min_tokens[req_index] = (
sampling_params.min_tokens,
sampling_params.all_stop_token_ids)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[
req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
else:
assert request.pooling_params is not None
self.pooling_params[req_id] = request.pooling_params
# Add request lora ID
if request.lora_request:
@ -392,6 +403,7 @@ class InputBatch:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
self.pooling_params.pop(req_id, None)
return req_index
def swap_states(self, i1: int, i2: int) -> None:
@ -602,6 +614,25 @@ class InputBatch:
bad_words_token_ids=self.bad_words_token_ids,
)
@property
def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0:
pooling_params = []
else:
# Note, for now this assumes that all request in the batch
# are either sampling or pooling requests
assert len(self.req_ids) == len(self.pooling_params)
pooling_params = [
self.pooling_params[req_id] for req_id in self.req_ids
]
return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(

View File

@ -36,6 +36,7 @@ from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
@ -51,6 +52,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
@ -119,6 +121,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cache_config.cache_dtype]
self.is_multimodal_model = model_config.is_multimodal_model
self.is_pooling_model = model_config.pooler_config is not None
self.max_model_len = model_config.max_model_len
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
@ -394,7 +397,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
pooling_params = new_req_data.pooling_params
if sampling_params and \
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
@ -406,6 +411,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
pooling_params=pooling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
@ -563,7 +569,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
) -> tuple[dict[str, Any], bool, torch.Tensor,
Optional[SpecDecodeMetadata]]:
Optional[SpecDecodeMetadata], np.ndarray]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
@ -750,7 +756,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.set_active_loras(self.input_batch, num_scheduled_tokens)
return (attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata)
spec_decode_metadata, num_scheduled_tokens)
def _compute_cascade_attn_prefix_len(
self,
@ -1197,6 +1203,51 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int32)
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
def _pool(
self,
hidden_states: torch.Tensor,
num_scheduled_tokens: int,
num_scheduled_tokens_np: np.ndarray,
finished_sending: Optional[set[str]],
finished_recving: Optional[set[str]],
) -> ModelRunnerOutput:
assert self.input_batch.num_reqs ==\
len(self.input_batch.pooling_params), \
"Either all or none of the requests in" \
" a batch must be pooling request"
extracted_hidden_states = list(
torch.split(hidden_states[:num_scheduled_tokens],
num_scheduled_tokens_np.tolist()))
pooling_metadata = self.input_batch.pooling_metadata
raw_pooler_output = self.model.pooler(
hidden_states=extracted_hidden_states,
pooling_metadata=pooling_metadata)
pooler_output: list[Optional[torch.Tensor]] = []
seq_lens = self.seq_lens[:self.input_batch.num_reqs]
for raw_output, seq_len, prompt_len in zip(
raw_pooler_output, seq_lens, pooling_metadata.prompt_lens):
if seq_len == prompt_len:
pooler_output.append(raw_output.data.cpu())
else:
pooler_output.append(None)
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
finished_sending=finished_sending,
finished_recving=finished_recving,
)
@torch.inference_mode()
def execute_model(
self,
@ -1214,7 +1265,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Prepare the decoder inputs.
(attn_metadata, attention_cuda_graphs, logits_indices,
spec_decode_metadata) = (self._prepare_inputs(scheduler_output))
spec_decode_metadata,
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
@ -1284,7 +1336,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
# Run the decoder.
# Run the model.
# Use persistent buffers for CUDA graphs.
with set_forward_context(
attn_metadata,
@ -1326,6 +1378,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
all_gather_group=get_tp_group())
logits = None
else:
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, finished_sending,
finished_recving)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
@ -1541,6 +1598,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
)
@ -1802,7 +1860,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self,
num_tokens: int,
capture_attn_cudagraph: bool = False,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
@ -1899,7 +1957,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.drafter.dummy_run(num_tokens)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
return hidden_states[logit_indices]
return hidden_states, hidden_states[logit_indices]
@torch.inference_mode()
def _dummy_sampler_run(
@ -1978,6 +2036,48 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return sampler_output
@torch.inference_mode()
def _dummy_pooler_run(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
hidden_states_list = list(
torch.split(hidden_states, num_scheduled_tokens_list))
req_num_tokens = num_tokens // num_reqs
dummy_metadata = PoolingMetadata(
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
device=self.device),
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
dtype=torch.int32,
device=self.device),
pooling_params=[PoolingParams()] * num_reqs)
try:
pooler_output = self.model.pooler(hidden_states=hidden_states_list,
pooling_metadata=dummy_metadata)
except RuntimeError as e:
if 'out of memory' in str(e):
raise RuntimeError(
"CUDA out of memory occurred when warming up pooler with "
f"{num_reqs} dummy requests. Please try lowering "
"`max_num_seqs` or `gpu_memory_utilization` when "
"initializing the engine.") from e
else:
raise e
return pooler_output
def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them.
@ -2048,13 +2148,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
hidden_states = self._dummy_run(self.max_num_tokens)
hidden_states, last_hidden_states \
= self._dummy_run(self.max_num_tokens)
if get_pp_group().is_last_rank:
sampler_output = self._dummy_sampler_run(hidden_states)
if self.is_pooling_model:
output = self._dummy_pooler_run(hidden_states)
else:
output = self._dummy_sampler_run(last_hidden_states)
else:
sampler_output = None
output = None
self._sync_device()
del hidden_states, sampler_output
del hidden_states, output
self.encoder_cache.clear()
gc.collect()

View File

@ -273,9 +273,14 @@ class Worker(WorkerBase):
if get_pp_group().is_last_rank:
max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens)
self.model_runner._dummy_sampler_run(
hidden_states=self.model_runner._dummy_run(
num_tokens=max_num_reqs))
hidden_states, last_hidden_states = \
self.model_runner._dummy_run(num_tokens=max_num_reqs)
if self.model_runner.is_pooling_model:
self.model_runner._dummy_pooler_run(hidden_states)
else:
self.model_runner._dummy_sampler_run(
hidden_states=last_hidden_states)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.

View File

@ -231,6 +231,7 @@ class InputBatch:
self.block_table.add_row(request.block_ids, req_index)
sampling_params = request.sampling_params
assert sampling_params is not None, "pooling requests not supported yet"
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0

View File

@ -386,6 +386,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add: list[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
assert new_req_data.sampling_params is not None,\
"Pooling is not supported in TPU yet"
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
@ -395,6 +397,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
mm_inputs=new_req_data.mm_inputs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
pooling_params=None,
generator=None,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
@ -956,6 +959,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids=None,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
)
# Check there are no new graphs compiled - all the graphs should be