[V0 Deprecation] Remove pooling model support in V0 (#23434)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Maximilien de Bayser 2025-08-29 04:04:02 -03:00 committed by GitHub
parent 934bebf192
commit 2554b27baa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 99 additions and 808 deletions

View File

@ -118,6 +118,8 @@ class PPTestSettings:
multi_node_only: bool = False, multi_node_only: bool = False,
load_format: Optional[str] = None, load_format: Optional[str] = None,
): ):
vllm_major_versions = ["1"] if runner == "pooling" else ["0"]
return PPTestSettings( return PPTestSettings(
parallel_setups=[ parallel_setups=[
ParallelSetup(tp_size=tp_base, ParallelSetup(tp_size=tp_base,
@ -126,7 +128,7 @@ class PPTestSettings:
chunked_prefill=False), chunked_prefill=False),
], ],
distributed_backends=["mp"], distributed_backends=["mp"],
vllm_major_versions=["0"], vllm_major_versions=vllm_major_versions,
runner=runner, runner=runner,
test_options=PPTestOptions(multi_node_only=multi_node_only, test_options=PPTestOptions(multi_node_only=multi_node_only,
load_format=load_format), load_format=load_format),
@ -213,7 +215,9 @@ TEXT_GENERATION_MODELS = {
EMBEDDING_MODELS = { # type: ignore[var-annotated] EMBEDDING_MODELS = { # type: ignore[var-annotated]
# [Text-only] # [Text-only]
"intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(runner="pooling"), "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(runner="pooling"),
"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"), # TODO: re-enable when https://github.com/vllm-project/vllm/issues/23883
# is fixed
#"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"),
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast( "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(
load_format="dummy", runner="pooling" load_format="dummy", runner="pooling"
), ),

View File

@ -16,14 +16,6 @@ MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
prompts = ["The chef prepared a delicious meal."] prompts = ["The chef prepared a delicious meal."]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def llm():
# pytest caches the fixture so we use weakref.proxy to # pytest caches the fixture so we use weakref.proxy to

View File

@ -27,14 +27,6 @@ TOKEN_IDS = [
] ]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def llm():
# pytest caches the fixture so we use weakref.proxy to # pytest caches the fixture so we use weakref.proxy to

View File

@ -16,14 +16,6 @@ MODEL_NAME = "internlm/internlm2-1_8b-reward"
prompts = ["The chef prepared a delicious meal."] prompts = ["The chef prepared a delicious meal."]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def llm():
# pytest caches the fixture so we use weakref.proxy to # pytest caches the fixture so we use weakref.proxy to

View File

@ -14,14 +14,6 @@ from ...models.utils import softmax
MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" MODEL_NAME = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def llm():
# pytest caches the fixture so we use weakref.proxy to # pytest caches the fixture so we use weakref.proxy to

View File

@ -32,15 +32,16 @@ MODEL_CONFIGS = [
"tensor_parallel_size": 1, "tensor_parallel_size": 1,
"tokenizer_mode": "mistral", "tokenizer_mode": "mistral",
}, },
{ # TODO: re-enable once these tests are run with V1
"model": "sentence-transformers/all-MiniLM-L12-v2", # {
"enforce_eager": True, # "model": "sentence-transformers/all-MiniLM-L12-v2",
"gpu_memory_utilization": 0.20, # "enforce_eager": True,
"max_model_len": 64, # "gpu_memory_utilization": 0.20,
"max_num_batched_tokens": 64, # "max_model_len": 64,
"max_num_seqs": 64, # "max_num_batched_tokens": 64,
"tensor_parallel_size": 1, # "max_num_seqs": 64,
}, # "tensor_parallel_size": 1,
# },
] ]

View File

@ -24,14 +24,6 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
DTYPE = "bfloat16" DTYPE = "bfloat16"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = [ args = [

View File

@ -14,14 +14,6 @@ MODEL_NAME = "BAAI/bge-reranker-base"
DTYPE = "bfloat16" DTYPE = "bfloat16"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE] args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]

View File

@ -12,15 +12,6 @@ from vllm.entrypoints.openai.protocol import ScoreResponse
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
MODELS = [ MODELS = [
{ {
"name": "BAAI/bge-reranker-v2-m3", "name": "BAAI/bge-reranker-v2-m3",

View File

@ -10,14 +10,6 @@ from vllm.platforms import current_platform
from ...utils import check_embeddings_close, check_transformers_version from ...utils import check_embeddings_close, check_transformers_version
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
@ -32,21 +24,15 @@ def v1(run_with_both_engines):
"intfloat/e5-mistral-7b-instruct", "intfloat/e5-mistral-7b-instruct",
# CPU v1 doesn't support sliding window # CPU v1 doesn't support sliding window
marks=[pytest.mark.core_model]), marks=[pytest.mark.core_model]),
# the qwen models interfere with each other (see PR
# https://github.com/vllm-project/vllm/pull/18720).
# To avoid this problem, for now we skip v0 since it will be
# deprecated anyway.
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]), marks=[pytest.mark.cpu_model]),
# [Encoder-only] # [Encoder-only]
pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]), pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2"), pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
pytest.param("intfloat/multilingual-e5-small"), pytest.param("intfloat/multilingual-e5-small"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct", pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
marks=[pytest.mark.skip_v1]),
# [Cross-Encoder] # [Cross-Encoder]
pytest.param("sentence-transformers/stsb-roberta-base-v2", pytest.param("sentence-transformers/stsb-roberta-base-v2"),
marks=[pytest.mark.skip_v1]),
], ],
) )
def test_models( def test_models(

View File

@ -13,14 +13,6 @@ from ....conftest import HfRunner
from ...utils import check_transformers_version from ...utils import check_transformers_version
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.fixture @pytest.fixture
def math_step_prompts(): def math_step_prompts():
# ruff: noqa: E501 # ruff: noqa: E501

View File

@ -23,15 +23,6 @@ TEXTS_2 = [
"The capital of Germany is Berlin.", "The capital of Germany is Berlin.",
] ]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
DTYPE = "half" DTYPE = "half"

View File

@ -323,8 +323,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
_EMBEDDING_EXAMPLE_MODELS = { _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only] # [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True), "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501 "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0",
trust_remote_code=True), trust_remote_code=True),
@ -337,9 +337,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base", "ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
trust_remote_code=True, v0_only=True), trust_remote_code=True),
"NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe", "NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe",
trust_remote_code=True, v0_only=True), # noqa: E501 trust_remote_code=True), # noqa: E501
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B", "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B",
max_transformers_version="4.53", max_transformers_version="4.53",
@ -347,9 +347,9 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B", "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B",
max_transformers_version="4.53", max_transformers_version="4.53",
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501 transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501 "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501 "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501 "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # noqa: E501
# [Multimodal] # [Multimodal]
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
@ -364,20 +364,19 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
"GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501
# [Cross-encoder] # [Cross-encoder]
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501 "GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
trust_remote_code=True, trust_remote_code=True,
hf_overrides={ hf_overrides={
"architectures": ["GteNewForSequenceClassification"]}),# noqa: E501 "architectures": ["GteNewForSequenceClassification"]}),# noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
} }
_AUTOMATIC_CONVERTED_MODELS = { _AUTOMATIC_CONVERTED_MODELS = {
# Use as_seq_cls_model for automatic conversion # Use as_seq_cls_model for automatic conversion
"GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501
v0_only=True,
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
"classifier_from_token": ["Yes"], # noqa: E501 "classifier_from_token": ["Yes"], # noqa: E501
"method": "no_post_processing"}), # noqa: E501 "method": "no_post_processing"}), # noqa: E501

View File

@ -9,10 +9,7 @@ from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.pooling_model_runner import (
ModelInputForGPUWithPoolingMetadata)
class MockAttentionBackend(AttentionBackend): class MockAttentionBackend(AttentionBackend):
@ -114,54 +111,3 @@ def test_model_runner_input():
assert (received_model_input.sampling_metadata.selected_token_indices == assert (received_model_input.sampling_metadata.selected_token_indices ==
sampling_metadata.selected_token_indices) sampling_metadata.selected_token_indices)
assert received_model_input.sampling_metadata.seq_groups is None assert received_model_input.sampling_metadata.seq_groups is None
def test_embedding_model_runner_input():
pooling_metadata = PoolingMetadata(
seq_groups=[[0]],
seq_data={},
prompt_lens=[1],
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
)
model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
pooling_metadata=pooling_metadata,
attn_metadata=attn_metadata)
assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (
ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))
# Check that received copy has correct values.
assert isinstance(received_model_input,
ModelInputForGPUWithPoolingMetadata)
assert received_model_input.input_tokens is not None
assert (
received_model_input.input_tokens == model_input.input_tokens).all()
assert received_model_input.input_positions is not None
assert (received_model_input.input_positions == model_input.input_positions
).all()
assert received_model_input.multi_modal_kwargs is None
assert (received_model_input.multi_modal_kwargs ==
model_input.multi_modal_kwargs)
assert received_model_input.lora_requests is None
assert received_model_input.lora_requests == model_input.lora_requests
assert received_model_input.lora_mapping is None
assert received_model_input.lora_mapping == model_input.lora_mapping
for field in dataclasses.fields(AttentionMetadata):
assert getattr(received_model_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# Pooling metadata is not broadcast.
assert received_model_input.pooling_metadata is None

View File

@ -1591,7 +1591,6 @@ class Scheduler:
encoder_seq_data=encoder_seq_data, encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table, cross_block_table=cross_block_table,
state=seq_group.state, state=seq_group.state,
token_type_ids=seq_group.token_type_ids,
# `multi_modal_data` will only be present for the 1st comm # `multi_modal_data` will only be present for the 1st comm
# between engine and worker. # between engine and worker.
# the subsequent comms can still use delta, but # the subsequent comms can still use delta, but

View File

@ -1566,8 +1566,7 @@ class EngineArgs:
use_spec_decode = self.speculative_config is not None use_spec_decode = self.speculative_config is not None
if (is_gpu and not use_sliding_window and not use_spec_decode if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora and not self.enable_lora):
and model_config.runner_type != "pooling"):
self.enable_chunked_prefill = True self.enable_chunked_prefill = True
logger.warning( logger.warning(
"Chunked prefill is enabled by default for models " "Chunked prefill is enabled by default for models "
@ -1585,10 +1584,6 @@ class EngineArgs:
"OOM during the initial memory profiling phase, or result " "OOM during the initial memory profiling phase, or result "
"in low performance due to small KV cache size. Consider " "in low performance due to small KV cache size. Consider "
"setting --max-model-len to a smaller value.", max_model_len) "setting --max-model-len to a smaller value.", max_model_len)
elif (self.enable_chunked_prefill
and model_config.runner_type == "pooling"):
msg = "Chunked prefill is not supported for pooling models"
raise ValueError(msg)
# if using prefix caching, we must set a hash algo # if using prefix caching, we must set a hash algo
if self.enable_prefix_caching: if self.enable_prefix_caching:

View File

@ -72,8 +72,8 @@ STOP_ITERATION = Exception() # Sentinel
class AsyncStream: class AsyncStream:
"""A stream of RequestOutputs or PoolingRequestOutputs for a request """A stream of RequestOutputs for a request that can be iterated over
that can be iterated over asynchronously via an async generator.""" asynchronously via an async generator."""
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self.request_id = request_id self.request_id = request_id
@ -81,8 +81,7 @@ class AsyncStream:
self._queue: asyncio.Queue = asyncio.Queue() self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False self._finished = False
def put(self, item: Union[RequestOutput, PoolingRequestOutput, def put(self, item: Union[RequestOutput, Exception]) -> None:
Exception]) -> None:
if not self._finished: if not self._finished:
self._queue.put_nowait(item) self._queue.put_nowait(item)
@ -99,9 +98,7 @@ class AsyncStream:
def finished(self) -> bool: def finished(self) -> bool:
return self._finished return self._finished
async def generator( async def generator(self) -> AsyncGenerator[RequestOutput, None]:
self
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
try: try:
while True: while True:
result = await self._queue.get() result = await self._queue.get()
@ -151,8 +148,7 @@ class RequestTracker:
self.abort_request(rid, exception=exc) self.abort_request(rid, exception=exc)
def process_request_output(self, def process_request_output(self,
request_output: Union[RequestOutput, request_output: RequestOutput,
PoolingRequestOutput],
*, *,
verbose: bool = False) -> None: verbose: bool = False) -> None:
"""Process a request output from the engine.""" """Process a request output from the engine."""
@ -261,9 +257,7 @@ class _AsyncLLMEngine(LLMEngine):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
async def step_async( async def step_async(self, virtual_engine: int) -> List[RequestOutput]:
self, virtual_engine: int
) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible. The workers are ran asynchronously if possible.
@ -405,7 +399,7 @@ class _AsyncLLMEngine(LLMEngine):
self, self,
request_id: str, request_id: str,
prompt: PromptType, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: SamplingParams,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
@ -779,14 +773,14 @@ class AsyncLLMEngine(EngineClient):
self, self,
request_id: str, request_id: str,
prompt: PromptType, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: SamplingParams,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
data_parallel_rank: Optional[int] = None, data_parallel_rank: Optional[int] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: ) -> AsyncGenerator[RequestOutput, None]:
if not self.is_running: if not self.is_running:
if self.start_engine_loop: if self.start_engine_loop:
self.start_background_loop() self.start_background_loop()
@ -908,7 +902,7 @@ class AsyncLLMEngine(EngineClient):
await self.abort(request_id) await self.abort(request_id)
raise raise
async def encode( def encode(
self, self,
prompt: PromptType, prompt: PromptType,
pooling_params: PoolingParams, pooling_params: PoolingParams,
@ -918,85 +912,8 @@ class AsyncLLMEngine(EngineClient):
priority: int = 0, priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model. raise NotImplementedError(
"Pooling models are not supported in vLLM V0")
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Yields:
The output `PoolingRequestOutput` objects from the LLMEngine
for the request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
[`vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`][]
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
```
# Please refer to entrypoints/api_server.py for
# the complete example.
# initialize the engine and the example input
# note that engine_args here is AsyncEngineArgs instance
engine = AsyncLLMEngine.from_engine_args(engine_args)
example_input = {
"input": "What is LLM?",
"request_id": 0,
}
# start the generation
results_generator = engine.encode(
example_input["input"],
PoolingParams(),
example_input["request_id"])
# get the results
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
# Return or raise an error
...
final_output = request_output
# Process and return the final output
...
```
"""
try:
async for output in await self.add_request(
request_id,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
tokenization_kwargs=tokenization_kwargs,
):
yield LLMEngine.validate_output(output, PoolingRequestOutput)
except asyncio.CancelledError:
await self.abort(request_id)
raise
async def abort(self, request_id: Union[str, Iterable[str]]) -> None: async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
"""Abort a request. """Abort a request.
@ -1104,8 +1021,8 @@ class AsyncLLMEngine(EngineClient):
async def is_sleeping(self) -> bool: async def is_sleeping(self) -> bool:
return self.engine.is_sleeping() return self.engine.is_sleeping()
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> bool:
self.engine.add_lora(lora_request) return self.engine.add_lora(lora_request)
async def collective_rpc(self, async def collective_rpc(self,
method: str, method: str,

View File

@ -40,12 +40,11 @@ from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.outputs import (PoolingRequestOutput, RequestOutput, from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
PoolingSequenceGroupOutput, Sequence, SequenceGroup, Sequence, SequenceGroup, SequenceGroupBase,
SequenceGroupBase, SequenceGroupMetadata, SequenceGroupMetadata, SequenceGroupOutput,
SequenceGroupOutput, SequenceStatus) SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
@ -93,8 +92,7 @@ class SchedulerContext:
def __init__(self) -> None: def __init__(self) -> None:
self.output_queue: Deque[OutputData] = deque() self.output_queue: Deque[OutputData] = deque()
self.request_outputs: List[Union[RequestOutput, self.request_outputs: List[RequestOutput] = []
PoolingRequestOutput]] = []
self.seq_group_metadata_list: Optional[ self.seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None List[SequenceGroupMetadata]] = None
self.scheduler_outputs: Optional[SchedulerOutputs] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None
@ -261,8 +259,7 @@ class LLMEngine:
self.model_executor = executor_class(vllm_config=vllm_config) self.model_executor = executor_class(vllm_config=vllm_config)
if self.model_config.runner_type != "pooling": self._initialize_kv_caches()
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info. # If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled(): if is_usage_stats_enabled():
@ -541,7 +538,7 @@ class LLMEngine:
self, self,
request_id: str, request_id: str,
processed_inputs: ProcessorInputs, processed_inputs: ProcessorInputs,
params: Union[SamplingParams, PoolingParams], params: SamplingParams,
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
@ -577,7 +574,7 @@ class LLMEngine:
encoder_seq = (None if encoder_inputs is None else Sequence( encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request)) seq_id, encoder_inputs, block_size, eos_token_id, lora_request))
# Create a SequenceGroup based on SamplingParams or PoolingParams # Create a SequenceGroup based on SamplingParams
if isinstance(params, SamplingParams): if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling( seq_group = self._create_sequence_group_with_sampling(
request_id, request_id,
@ -588,18 +585,8 @@ class LLMEngine:
trace_headers=trace_headers, trace_headers=trace_headers,
encoder_seq=encoder_seq, encoder_seq=encoder_seq,
priority=priority) priority=priority)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
encoder_seq=encoder_seq,
priority=priority)
else: else:
raise ValueError( raise ValueError("SamplingParams must be provided.")
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler with least unfinished seqs. # Add the sequence group to the scheduler with least unfinished seqs.
costs = [ costs = [
@ -618,7 +605,7 @@ class LLMEngine:
self, self,
request_id: str, request_id: str,
prompt: PromptType, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: SamplingParams,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
@ -636,9 +623,8 @@ class LLMEngine:
prompt: The prompt to the LLM. See prompt: The prompt to the LLM. See
[PromptType][vllm.inputs.PromptType] [PromptType][vllm.inputs.PromptType]
for more details about the format of each input. for more details about the format of each input.
params: Parameters for sampling or pooling. params: Parameters for sampling.
[SamplingParams][vllm.SamplingParams] for text generation. [SamplingParams][vllm.SamplingParams] for text generation.
[PoolingParams][vllm.PoolingParams] for pooling.
arrival_time: The arrival time of the request. If None, we use arrival_time: The arrival time of the request. If None, we use
the current monotonic time. the current monotonic time.
lora_request: The LoRA request to add. lora_request: The LoRA request to add.
@ -760,29 +746,6 @@ class LLMEngine:
return seq_group return seq_group
def _create_sequence_group_with_pooling(
self,
request_id: str,
seq: Sequence,
pooling_params: PoolingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
pooling_params = pooling_params.clone()
# Create the sequence group.
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params,
encoder_seq=encoder_seq,
priority=priority)
return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a request(s) with the given ID. """Aborts a request(s) with the given ID.
@ -856,18 +819,6 @@ class LLMEngine:
success = success and scheduler.reset_prefix_cache(device) success = success and scheduler.reset_prefix_cache(device)
return success return success
@staticmethod
def _process_sequence_group_outputs(
seq_group: SequenceGroup,
outputs: List[PoolingSequenceGroupOutput],
) -> None:
seq_group.pooled_data = outputs[0].data
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_STOPPED
return
def _process_model_outputs(self, def _process_model_outputs(self,
ctx: SchedulerContext, ctx: SchedulerContext,
request_id: Optional[str] = None) -> None: request_id: Optional[str] = None) -> None:
@ -962,13 +913,10 @@ class LLMEngine:
seq_group.metrics.model_execute_time = ( seq_group.metrics.model_execute_time = (
o.model_execute_time) o.model_execute_time)
if self.model_config.runner_type == "pooling": self.output_processor.process_prompt_logprob(seq_group, output)
self._process_sequence_group_outputs(seq_group, output) if seq_group_meta.do_sample:
else: self.output_processor.process_outputs(seq_group, output,
self.output_processor.process_prompt_logprob(seq_group, output) is_async)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(
seq_group, output, is_async)
if seq_group.is_finished(): if seq_group.is_finished():
finished_now.append(i) finished_now.append(i)
@ -1090,7 +1038,7 @@ class LLMEngine:
seq.append_token_id(sample.output_token, sample.logprobs, seq.append_token_id(sample.output_token, sample.logprobs,
sample.output_embed) sample.output_embed)
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
<figure markdown="span"> <figure markdown="span">

View File

@ -120,6 +120,7 @@ class RPCLoadAdapterRequest:
@dataclass @dataclass
class RPCAdapterLoadedResponse: class RPCAdapterLoadedResponse:
request_id: str request_id: str
lora_loaded: bool
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,

View File

@ -6,7 +6,7 @@ import copy
import pickle import pickle
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List, from typing import (Any, AsyncGenerator, Dict, Iterable, Iterator, List,
Mapping, Optional, Union, cast) Mapping, Optional, Union)
import cloudpickle import cloudpickle
import psutil import psutil
@ -477,10 +477,8 @@ class MQLLMEngineClient(EngineClient):
Any priority other than 0 will lead to an error if the Any priority other than 0 will lead to an error if the
scheduling policy is not "priority". scheduling policy is not "priority".
""" """
return cast( return self._process_request(prompt, sampling_params, request_id,
AsyncGenerator[RequestOutput, None], lora_request, trace_headers, priority)
self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers, priority))
def encode( def encode(
self, self,
@ -490,45 +488,20 @@ class MQLLMEngineClient(EngineClient):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model. raise NotImplementedError(
"Pooling models are not supported in vLLM V0")
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See
[`PromptType`][vllm.inputs.PromptType] for more details about
the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
Yields:
The output `PoolingRequestOutput` objects from the LLMEngine
for the request.
"""
return cast(
AsyncGenerator[PoolingRequestOutput, None],
self._process_request(prompt,
pooling_params,
request_id,
lora_request,
trace_headers,
priority=priority))
async def _process_request( async def _process_request(
self, self,
prompt: PromptType, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ ) -> AsyncGenerator[RequestOutput, None]:
PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses.""" """Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out. # If already dead, error out.
@ -547,7 +520,7 @@ class MQLLMEngineClient(EngineClient):
try: try:
# 2) Detach logits processors so that they can be pickled # 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower) # separately (may require cloudpickle which is slower)
if isinstance(params, SamplingParams) and params.logits_processors: if params.logits_processors:
# Defensive shallow copy # Defensive shallow copy
params = copy.copy(params) params = copy.copy(params)
logits_processors = params.logits_processors logits_processors = params.logits_processors
@ -646,13 +619,14 @@ class MQLLMEngineClient(EngineClient):
raise request_output raise request_output
return request_output.is_sleeping return request_output.is_sleeping
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests.""" """Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests # Uses the same I/O as generate requests
request = RPCLoadAdapterRequest(lora_request) request = RPCLoadAdapterRequest(lora_request)
# Create output queue for this request. # Create output queue for this request.
queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue() queue: asyncio.Queue[Union[
BaseException, RPCAdapterLoadedResponse]] = asyncio.Queue()
self.output_queues[request.request_id] = queue self.output_queues[request.request_id] = queue
# Send the request # Send the request
@ -666,3 +640,4 @@ class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None # Raise on error, otherwise happily return None
if isinstance(request_output, BaseException): if isinstance(request_output, BaseException):
raise request_output raise request_output
return request_output.lora_loaded

View File

@ -347,7 +347,7 @@ class MQLLMEngine:
def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest): def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
try: try:
self.engine.add_lora(request.lora_request) lora_loaded = self.engine.add_lora(request.lora_request)
except BaseException as e: except BaseException as e:
# Send back an error if the adater fails to load # Send back an error if the adater fails to load
rpc_err = RPCError(request_id=request.request_id, rpc_err = RPCError(request_id=request.request_id,
@ -357,7 +357,8 @@ class MQLLMEngine:
return return
# Otherwise, send back the successful load message # Otherwise, send back the successful load message
self._send_outputs( self._send_outputs(
RPCAdapterLoadedResponse(request_id=request.request_id)) RPCAdapterLoadedResponse(request_id=request.request_id,
lora_loaded=lora_loaded))
def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest): def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest):
is_sleeping = self.is_sleeping() is_sleeping = self.is_sleeping()

View File

@ -3,7 +3,7 @@
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import AsyncGenerator, Iterable, Mapping, Optional, Union from typing import Any, AsyncGenerator, Iterable, Mapping, Optional, Union
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig, VllmConfig from vllm.config import DecodingConfig, ModelConfig, VllmConfig
@ -224,6 +224,7 @@ class EngineClient(ABC):
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0, priority: int = 0,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model.""" """Generate outputs for a request from a pooling model."""
... ...
@ -320,7 +321,7 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests.""" """Load a new LoRA adapter into the engine for future requests."""
... ...

View File

@ -1156,8 +1156,7 @@ class LLM:
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop( if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
"token_type_ids", None)):
params = pooling_params.clone() params = pooling_params.clone()
compressed = compress_token_type_ids(token_type_ids) compressed = compress_token_type_ids(token_type_ids)
params.extra_kwargs = {"compressed_token_type_ids": compressed} params.extra_kwargs = {"compressed_token_type_ids": compressed}

View File

@ -7,7 +7,6 @@ from typing import Any, Optional, Union
from fastapi import Request from fastapi import Request
from vllm import envs
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
@ -229,8 +228,7 @@ class ServingScores(OpenAIServing):
params=default_pooling_params, params=default_pooling_params,
lora_request=lora_request) lora_request=lora_request)
if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop( if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
"token_type_ids", None)):
pooling_params = default_pooling_params.clone() pooling_params = default_pooling_params.clone()
compressed = compress_token_type_ids(token_type_ids) compressed = compress_token_type_ids(token_type_ids)
pooling_params.extra_kwargs = { pooling_params.extra_kwargs = {

View File

@ -174,9 +174,6 @@ class TokenInputs(TypedDict):
prompt_token_ids: list[int] prompt_token_ids: list[int]
"""The token IDs of the prompt.""" """The token IDs of the prompt."""
token_type_ids: NotRequired[list[int]]
"""The token type IDs of the prompt."""
prompt: NotRequired[str] prompt: NotRequired[str]
""" """
The original prompt text corresponding to the token IDs, if available. The original prompt text corresponding to the token IDs, if available.
@ -190,7 +187,6 @@ class TokenInputs(TypedDict):
def token_inputs( def token_inputs(
prompt_token_ids: list[int], prompt_token_ids: list[int],
token_type_ids: Optional[list[int]] = None,
prompt: Optional[str] = None, prompt: Optional[str] = None,
cache_salt: Optional[str] = None, cache_salt: Optional[str] = None,
) -> TokenInputs: ) -> TokenInputs:
@ -200,8 +196,6 @@ def token_inputs(
if prompt is not None: if prompt is not None:
inputs["prompt"] = prompt inputs["prompt"] = prompt
if token_type_ids is not None:
inputs["token_type_ids"] = token_type_ids
if cache_salt is not None: if cache_salt is not None:
inputs["cache_salt"] = cache_salt inputs["cache_salt"] = cache_salt

View File

@ -355,7 +355,6 @@ class InputPreprocessor:
mm_hash_overrides: Optional[dict[str, list[str]]] = None, mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"] prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
inputs: Union[TokenInputs, MultiModalInputs] inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"): if multi_modal_data := parsed_content.get("multi_modal_data"):
@ -368,10 +367,7 @@ class InputPreprocessor:
mm_hash_overrides=mm_hash_overrides, mm_hash_overrides=mm_hash_overrides,
) )
else: else:
inputs = token_inputs( inputs = token_inputs(prompt_token_ids=prompt_token_ids)
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
if cache_salt := parsed_content.get("cache_salt"): if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt inputs["cache_salt"] = cache_salt
@ -387,7 +383,6 @@ class InputPreprocessor:
mm_hash_overrides: Optional[dict[str, list[str]]] = None, mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"] prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
inputs: Union[TokenInputs, MultiModalInputs] inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"): if multi_modal_data := parsed_content.get("multi_modal_data"):
@ -400,10 +395,7 @@ class InputPreprocessor:
mm_hash_overrides=mm_hash_overrides, mm_hash_overrides=mm_hash_overrides,
) )
else: else:
inputs = token_inputs( inputs = token_inputs(prompt_token_ids=prompt_token_ids, )
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
if cache_salt := parsed_content.get("cache_salt"): if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt inputs["cache_salt"] = cache_salt

View File

@ -13,17 +13,12 @@ import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig, PoolerConfig from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import ( # noqa: E501
PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.utils import current_stream, resolve_obj_by_qualname from vllm.utils import current_stream, resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingCursor from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
PoolingFn = Callable[ PoolingFn = Callable[
[Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata], [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata],
Union[torch.Tensor, list[torch.Tensor]]] Union[torch.Tensor, list[torch.Tensor]]]
@ -127,36 +122,23 @@ def get_prompt_lens(
hidden_states: Union[torch.Tensor, list[torch.Tensor]], hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
if isinstance(pooling_metadata, V1PoolingMetadata): return pooling_metadata.prompt_lens
return pooling_metadata.prompt_lens
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states[0].device).prompt_lens
def get_prompt_token_ids( def get_prompt_token_ids(
pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: pooling_metadata: PoolingMetadata) -> list[torch.Tensor]:
if isinstance(pooling_metadata, V1PoolingMetadata): assert pooling_metadata.prompt_token_ids is not None, (
assert pooling_metadata.prompt_token_ids is not None, ( "Please set `requires_token_ids=True` in `get_pooling_updates`")
"Please set `requires_token_ids=True` in `get_pooling_updates`")
return [
pooling_metadata.prompt_token_ids[i, :num]
for i, num in enumerate(pooling_metadata.prompt_lens)
]
return [ return [
torch.tensor(seq_data_i.prompt_token_ids) pooling_metadata.prompt_token_ids[i, :num]
for seq_data_i in pooling_metadata.seq_data.values() for i, num in enumerate(pooling_metadata.prompt_lens)
] ]
def get_pooling_params( def get_pooling_params(
pooling_metadata: PoolingMetadata) -> list[PoolingParams]: pooling_metadata: PoolingMetadata) -> list[PoolingParams]:
if isinstance(pooling_metadata, V0PoolingMetadata): pooling_params = pooling_metadata.pooling_params
pooling_params = [p for _, p in pooling_metadata.seq_groups]
else:
pooling_params = pooling_metadata.pooling_params
return pooling_params return pooling_params

View File

@ -24,9 +24,9 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding, SupportsQuant from .interfaces import SupportsCrossEncoding, SupportsQuant
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type

View File

@ -15,10 +15,10 @@ from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
build_output, get_prompt_lens, build_output, get_prompt_lens,
get_prompt_token_ids) get_prompt_token_ids)
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import PoolerOutput from vllm.sequence import PoolerOutput
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type

View File

@ -22,9 +22,9 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type

View File

@ -1,90 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Optional
import torch
from vllm.pooling_params import PoolingParams
from vllm.utils import is_pin_memory_available
from vllm.v1.pool.metadata import PoolingCursor, build_pooling_cursor
class PoolingMetadata:
"""Metadata for pooling operations in the Pooler layer.
This class holds the necessary information for pooling operations,
providing context for how to perform pooling and other related operations.
Attributes:
seq_groups: List of (seq_ids, pooling_params).
seq_data: A mapping of sequence ID to additional sequence data.
prompt_lens: List of the lengths of each prompt.
"""
def __init__(
self,
seq_groups: list[tuple[list[int], PoolingParams]],
seq_data: dict[int, Any], # Specific data related to sequences
prompt_lens: list[int],
pooling_cursor: Optional[PoolingCursor] = None) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.pooling_cursor: Optional[PoolingCursor] = pooling_cursor
def __repr__(self) -> str:
return ("PoolingMetadata("
f"seq_groups={self.seq_groups}, "
f"seq_data={self.seq_data}, "
f"prompt_lens={self.prompt_lens})")
def __getitem__(self, indices: slice):
return PoolingMetadata(
seq_groups=self.seq_groups[indices],
seq_data=dict(list(self.seq_data.items())[indices]),
prompt_lens=self.prompt_lens[indices],
pooling_cursor=None
if self.pooling_cursor is None else self.pooling_cursor[indices],
)
def build_pooling_cursor(self, num_scheduled_tokens: list[int],
device: torch.device):
prompt_lens = torch.tensor(self.prompt_lens, device="cpu")
self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens,
prompt_lens,
device=device)
@dataclass
class PoolingTensors:
"""Tensors for pooling."""
prompt_lens: torch.Tensor
@classmethod
def from_pooling_metadata(
cls,
pooling_metadata: "PoolingMetadata",
device: torch.device,
) -> "PoolingTensors":
"""
Create PoolingTensors from PoolingMetadata.
Args:
pooling_metadata: PoolingMetadata instance to convert.
device: Device to store the tensors.
"""
# Convert prompt lengths to tensor
pin_memory = is_pin_memory_available()
prompt_lens_t = torch.tensor(
pooling_metadata.prompt_lens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
return cls(prompt_lens=prompt_lens_t.to(device=device,
non_blocking=True), )

View File

@ -913,9 +913,6 @@ class MultiModalInputs(TypedDict):
prompt_token_ids: list[int] prompt_token_ids: list[int]
"""The processed token IDs which includes placeholder tokens.""" """The processed token IDs which includes placeholder tokens."""
token_type_ids: NotRequired[list[int]]
"""The token type IDs of the prompt."""
mm_kwargs: MultiModalKwargsOptionalItems mm_kwargs: MultiModalKwargsOptionalItems
"""Keyword arguments to be directly passed to the model after batching.""" """Keyword arguments to be directly passed to the model after batching."""
@ -946,6 +943,3 @@ class MultiModalEncDecInputs(MultiModalInputs):
encoder_prompt_token_ids: list[int] encoder_prompt_token_ids: list[int]
"""The processed token IDs of the encoder prompt.""" """The processed token IDs of the encoder prompt."""
encoder_token_type_ids: NotRequired[list[int]]
"""The token type IDs of the encoder prompt."""

View File

@ -508,12 +508,6 @@ class Sequence:
return [0] * len(self.inputs["prompt_embeds"]) return [0] * len(self.inputs["prompt_embeds"])
return self.inputs["prompt_token_ids"] return self.inputs["prompt_token_ids"]
@property
def token_type_ids(self) -> list[int]:
if self.inputs["type"] == "embeds":
return []
return self.inputs.get("token_type_ids", [])
@property @property
def multi_modal_data(self) -> MultiModalKwargs: def multi_modal_data(self) -> MultiModalKwargs:
if self.inputs["type"] == "multimodal": if self.inputs["type"] == "multimodal":
@ -765,10 +759,6 @@ class SequenceGroup:
return (self.encoder_seq.prompt_token_ids return (self.encoder_seq.prompt_token_ids
if self.encoder_seq is not None else None) if self.encoder_seq is not None else None)
@property
def token_type_ids(self) -> Optional[list[int]]:
return self.first_seq.token_type_ids
@property @property
def multi_modal_data(self) -> MultiModalKwargs: def multi_modal_data(self) -> MultiModalKwargs:
if self.first_seq.multi_modal_data: if self.first_seq.multi_modal_data:
@ -972,7 +962,6 @@ class SequenceGroupMetadata(
computed_block_nums: Optional[list[int]] = None computed_block_nums: Optional[list[int]] = None
state: Optional[SequenceGroupState] = msgspec.field( state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState()) default_factory=lambda: SequenceGroupState())
token_type_ids: Optional[list[int]] = None
multi_modal_data: Optional[MultiModalKwargs] = None multi_modal_data: Optional[MultiModalKwargs] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
encoder_seq_data: Optional[SequenceData] = None encoder_seq_data: Optional[SequenceData] = None

View File

@ -24,8 +24,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry) MultiModalRegistry)
from vllm.platforms import _Backend from vllm.platforms import _Backend
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
from vllm.worker.model_runner import (GPUModelRunnerBase, from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder, ModelInputForGPUBuilder,
@ -161,7 +160,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[PoolerOutput]]: ) -> Optional[List[SamplerOutput]]:
if num_steps > 1: if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in " raise ValueError("num_steps > 1 is not supported in "
"EncoderDecoderModelRunner") "EncoderDecoderModelRunner")

View File

@ -86,7 +86,6 @@ class ModelInputForGPU(ModelRunnerInputBase):
input_tokens: Optional[torch.Tensor] = None input_tokens: Optional[torch.Tensor] = None
inputs_embeds: Optional[torch.Tensor] = None inputs_embeds: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None
token_types: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None
lora_mapping: Optional["LoRAMapping"] = None lora_mapping: Optional["LoRAMapping"] = None
@ -192,7 +191,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.input_tokens[0].clear() # type: ignore self.input_tokens[0].clear() # type: ignore
self.inputs_embeds = None # type: ignore self.inputs_embeds = None # type: ignore
self.input_positions[0].clear() # type: ignore self.input_positions[0].clear() # type: ignore
self.token_types[0].clear() # type: ignore
self.mrope_input_positions = None # type: ignore self.mrope_input_positions = None # type: ignore
self.seq_lens[0] = 0 # type: ignore self.seq_lens[0] = 0 # type: ignore
self.orig_seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore
@ -219,7 +217,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
input_tokens: Optional[List[List[int]]] = None, input_tokens: Optional[List[List[int]]] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
input_positions: Optional[List[List[int]]] = None, input_positions: Optional[List[List[int]]] = None,
token_types: Optional[List[List[int]]] = None,
mrope_input_positions: Optional[List[List[List[int]]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None,
# The sequence length (may be capped to the sliding window). # The sequence length (may be capped to the sliding window).
@ -284,12 +281,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for seq_id in range(len(self.seq_ids)): for seq_id in range(len(self.seq_ids)):
self.input_positions[seq_id].clear() self.input_positions[seq_id].clear()
if token_types:
self.token_types = token_types
else:
for seq_id in range(len(self.seq_ids)):
self.token_types[seq_id].clear()
self.mrope_input_positions = None self.mrope_input_positions = None
if seq_lens: if seq_lens:
@ -348,7 +339,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.input_tokens = input_tokens or [] self.input_tokens = input_tokens or []
self.inputs_embeds = inputs_embeds self.inputs_embeds = inputs_embeds
self.input_positions = input_positions or [] self.input_positions = input_positions or []
self.token_types = token_types or []
self.mrope_input_positions = mrope_input_positions or None self.mrope_input_positions = mrope_input_positions or None
self.seq_lens = seq_lens or [] self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or [] self.orig_seq_lens = orig_seq_lens or []
@ -376,7 +366,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_tokens = [[] for _ in range(self.n_seqs)]
self.input_positions = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)]
self.token_types = [[] for _ in range(self.n_seqs)]
self.mrope_input_positions = None self.mrope_input_positions = None
self.seq_lens = [0] * self.n_seqs self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs
@ -400,7 +389,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
f"inputs_embeds.shape=" f"inputs_embeds.shape="
f"{getattr(self.inputs_embeds, 'shape', None)}, " f"{getattr(self.inputs_embeds, 'shape', None)}, "
f"input_positions={self.input_positions}, " f"input_positions={self.input_positions}, "
f"token_types={self.token_types}, "
f"mrope_input_positions={self.mrope_input_positions}, " f"mrope_input_positions={self.mrope_input_positions}, "
f"seq_lens={self.seq_lens}, " f"seq_lens={self.seq_lens}, "
f"orig_seq_lens={self.orig_seq_lens}, " f"orig_seq_lens={self.orig_seq_lens}, "
@ -522,8 +510,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
prompt_embeds = seq_data.get_token_embeddings( prompt_embeds = seq_data.get_token_embeddings(
)[context_len:seq_len] )[context_len:seq_len]
token_types = seq_group_metadata.token_type_ids
inter_data.seq_lens[seq_idx] = seq_len inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len() inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len()
@ -531,8 +517,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data.input_tokens[seq_idx].extend(tokens) inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.inputs_embeds = prompt_embeds inter_data.inputs_embeds = prompt_embeds
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
inter_data.token_types[seq_idx].extend(
token_types if token_types else [])
inter_data.query_lens[seq_idx] = seq_len - context_len inter_data.query_lens[seq_idx] = seq_len - context_len
if seq_data.mrope_position_delta is not None: if seq_data.mrope_position_delta is not None:
@ -590,8 +574,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
seq_idx][uncomputed_start:] seq_idx][uncomputed_start:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[ inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][uncomputed_start:] seq_idx][uncomputed_start:]
inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
uncomputed_start:]
context_len = prefix_cache_len context_len = prefix_cache_len
inter_data.context_lens[seq_idx] = context_len inter_data.context_lens[seq_idx] = context_len
@ -606,8 +588,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
seq_idx][-1:] seq_idx][-1:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[ inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][-1:] seq_idx][-1:]
inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
-1:]
inter_data.query_lens[seq_idx] = 1 inter_data.query_lens[seq_idx] = 1
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
@ -802,12 +782,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Combine and flatten intermediate data. # Combine and flatten intermediate data.
input_tokens = list[int]() input_tokens = list[int]()
inputs_embeds_list = list[torch.Tensor]() inputs_embeds_list = list[torch.Tensor]()
token_types = list[int]()
for inter_data in self.inter_data_list: for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens: for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens) input_tokens.extend(cur_input_tokens)
for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None: if inter_data.inputs_embeds is not None:
inputs_embeds_list.append( inputs_embeds_list.append(
inter_data.inputs_embeds.to( inter_data.inputs_embeds.to(
@ -890,11 +867,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
token_types_tensor = async_tensor_h2d(token_types, torch.long,
self.runner.device,
self.runner.pin_memory) \
if token_types else None
if mrope_input_positions is not None: if mrope_input_positions is not None:
for idx in range(3): for idx in range(3):
mrope_input_positions[idx].extend( mrope_input_positions[idx].extend(
@ -951,7 +923,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
input_tokens=input_tokens_tensor, input_tokens=input_tokens_tensor,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
input_positions=input_positions_tensor, input_positions=input_positions_tensor,
token_types=token_types_tensor,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
seq_lens=seq_lens, seq_lens=seq_lens,
query_lens=query_lens, query_lens=query_lens,

View File

@ -13,10 +13,9 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.interfaces import supports_transcription from vllm.model_executor.models.interfaces import supports_transcription
from vllm.model_executor.models.interfaces_base import ( from vllm.model_executor.models.interfaces_base import is_text_generation_model
is_pooling_model, is_text_generation_model)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.tasks import GenerationTask, SupportedTask
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
@ -241,20 +240,11 @@ class ModelRunnerBase(ABC, Generic[T]):
return supported_tasks return supported_tasks
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model()
if not is_pooling_model(model):
return []
return list(model.pooler.get_supported_tasks())
def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks = list[SupportedTask]() tasks = list[SupportedTask]()
if self.model_config.runner_type == "generate": if self.model_config.runner_type == "generate":
tasks.extend(self.get_supported_generation_tasks()) tasks.extend(self.get_supported_generation_tasks())
if self.model_config.runner_type == "pooling":
tasks.extend(self.get_supported_pooling_tasks())
return tuple(tasks) return tuple(tasks)

View File

@ -1,222 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
import torch
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
ModelInputForGPUBuilder)
logger = init_logger(__name__)
@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
"""
Used by the PoolingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None
class PoolingModelRunner(
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
ModelInputForGPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
):
super().__init__(vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForGPUWithPoolingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError(
"PoolingModelRunner does not support multi-step execution.")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
virtual_engine = model_input.virtual_engine
# Pooling models are (ab-)used also to integrate non text models that
# are not autoregressive (PrithviGeosaptialMAE).
# These model might not use attention and do not really have a prefill
# and decode phase. The model input is processed in one shot and both
# decode_metadata and prefill_metadata would be None for such models.
# See the PlaceholderAttentionMetadata class.
# TODO: Figure out if cuda_graph is of any use for these models and
# explore how to leverage it.
if (prefill_meta is None and decode_meta is not None
and decode_meta.use_cuda_graph):
if model_input.inputs_embeds is None:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, False)])
else:
graph_batch_size = model_input.inputs_embeds.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, True)])
else:
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start = torch.cuda.Event(enable_timing=True)
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()
cross_enc_kwargs = {}
if model_input.token_types is not None:
cross_enc_kwargs["token_type_ids"] = model_input.token_types
with set_forward_context(model_input.attn_metadata, self.vllm_config,
virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
device=self.device,
),
**cross_enc_kwargs,
**seqlen_agnostic_kwargs,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
# Only perform pooling in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
and hidden_or_intermediate_states is not None
and isinstance(hidden_or_intermediate_states,
IntermediateTensors)
and self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
hidden_or_intermediate_states.tensors["model_forward_time"] = (
torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states
# Only perform pooling in the driver worker.
if not self.is_driver_worker:
return []
pooling_metadata = model_input.pooling_metadata
assert pooling_metadata is not None
pooling_metadata.build_pooling_cursor(
num_scheduled_tokens=pooling_metadata.prompt_lens,
device=hidden_or_intermediate_states.device)
return [
self.model.pooler(hidden_states=hidden_or_intermediate_states,
pooling_metadata=pooling_metadata)
]
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForGPUWithPoolingMetadata:
return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithPoolingMetadata:
assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Prepare PoolingMetadata.
assert model_input.seq_lens is not None
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
model_input.seq_lens)
return dataclasses.replace(model_input,
pooling_metadata=pooling_metadata)
def _prepare_pooling(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> PoolingMetadata:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups: List[Tuple[List[int], PoolingParams]] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
pooling_params = seq_group_metadata.pooling_params
assert pooling_params is not None
task = pooling_params.task
assert task is not None, "You did not set `task` in the API"
model = cast(VllmModelForPooling, self.model)
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(pooling_params)
seq_groups.append((seq_ids, pooling_params))
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
pooling_metadata = PoolingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
)
return pooling_metadata

View File

@ -30,7 +30,6 @@ from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput) WorkerInput)
@ -83,9 +82,7 @@ class Worker(LocalOrDistributedWorkerBase):
else {"return_hidden_states": True} else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_config.runner_type == "pooling": if self.model_config.is_encoder_decoder:
ModelRunnerClass = PoolingModelRunner
elif self.model_config.is_encoder_decoder:
ModelRunnerClass = EncoderDecoderModelRunner ModelRunnerClass = EncoderDecoderModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass( self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
@ -99,7 +96,6 @@ class Worker(LocalOrDistributedWorkerBase):
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
self.cache_engine: List[CacheEngine] self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as pooling models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}