Support token_type_ids in V1 with less code changes (#21985)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Maximilien de Bayser 2025-08-11 02:54:59 -03:00 committed by GitHub
parent 9c97a1c349
commit 39052dbca8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 235 additions and 130 deletions

View File

@ -126,7 +126,9 @@ def test_invocations(server: RemoteOpenAIServer):
invocation_output["results"]):
assert rerank_result.keys() == invocations_result.keys()
assert rerank_result["relevance_score"] == pytest.approx(
invocations_result["relevance_score"], rel=0.01)
invocations_result["relevance_score"], rel=0.05)
# TODO: reset this tolerance to 0.01 once we find
# an alternative to flash_attn with bfloat16
@pytest.mark.asyncio

View File

@ -220,7 +220,9 @@ class TestModel:
invocation_output["data"]):
assert score_data.keys() == invocation_data.keys()
assert score_data["score"] == pytest.approx(
invocation_data["score"], rel=0.01)
invocation_data["score"], rel=0.05)
# TODO: reset this tolerance to 0.01 once we find
# an alternative to flash_attn with bfloat16
def test_activation(self, server: RemoteOpenAIServer, model: dict[str,
Any]):

View File

@ -23,6 +23,15 @@ TEXTS_2 = [
"The capital of Germany is Berlin.",
]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
DTYPE = "half"

View File

@ -28,11 +28,15 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_mistral_chat_template,
parse_chat_messages,
resolve_chat_template_content_format)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
ScoreMultiModalParam,
_cosine_similarity,
_validate_score_input_lens,
compress_token_type_ids,
get_score_prompt)
# yapf: enable
from vllm.entrypoints.utils import (_validate_truncation_size,
log_non_default_args)
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
@ -1329,6 +1333,7 @@ class LLM:
model_config = self.llm_engine.model_config
pooling_params.verify("score", model_config)
pooling_params_list = list[PoolingParams]()
tokenization_kwargs: dict[str, Any] = {}
@ -1339,38 +1344,31 @@ class LLM:
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
if model_config.is_multimodal_model:
for q, d in input_pairs:
_, engine_prompt = get_score_prompt(
model_config=model_config,
data_1=q,
data_2=d,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
)
model_config = self.llm_engine.model_config
parsed_prompts.append(engine_prompt)
else:
for q, t in input_pairs:
if model_config.use_pad_token:
# cross_encoder models defaults to using pad_token.
prompt_inputs = tokenizer(
text=q, # type: ignore[arg-type]
text_pair=t, # type: ignore[arg-type]
**tokenization_kwargs)
else:
# `llm as reranker` models defaults to not using pad_token.
prompt_inputs = tokenizer(
text=q + t, # type: ignore[operator]
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
parsed_prompts.append(engine_prompt)
for q, d in input_pairs:
_, engine_prompt = get_score_prompt(
model_config=model_config,
data_1=q,
data_2=d,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
)
if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop(
"token_type_ids", None)):
params = pooling_params.clone()
compressed = compress_token_type_ids(token_type_ids)
params.extra_kwargs = {"compressed_token_type_ids": compressed}
pooling_params_list.append(params)
else:
pooling_params_list.append(pooling_params)
parsed_prompts.append(engine_prompt)
self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
params=pooling_params_list,
use_tqdm=use_tqdm,
lora_request=lora_request,
)

View File

@ -7,6 +7,7 @@ from typing import Any, Optional, Union
from fastapi import Request
from vllm import envs
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
@ -17,11 +18,15 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
ScoreResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
ScoreMultiModalParam,
_cosine_similarity,
_validate_score_input_lens,
compress_token_type_ids,
get_score_prompt)
# yapf: enable
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
@ -158,6 +163,8 @@ class ServingScores(OpenAIServing):
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
)
self._validate_input(request, engine_prompt["prompt_token_ids"],
full_prompt)
if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
@ -188,64 +195,27 @@ class ServingScores(OpenAIServing):
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
if self.model_config.is_multimodal_model:
preprocess_async = make_async(self._preprocess_score,
executor=self._tokenizer_executor)
preprocess_async = make_async(self._preprocess_score,
executor=self._tokenizer_executor)
preprocessed_prompts = await asyncio.gather(
*(preprocess_async(request=request,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
data_1=t1,
data_2=t2) for t1, t2 in input_pairs))
preprocessed_prompts = await asyncio.gather(
*(preprocess_async(request=request,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
data_1=t1,
data_2=t2) for t1, t2 in input_pairs))
for full_prompt, engine_prompt in preprocessed_prompts:
request_prompts.append(full_prompt)
engine_prompts.append(engine_prompt)
else:
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
use_pad_token = self.model_config.use_pad_token
if use_pad_token:
# cross_encoder models defaults to using pad_token.
tokenized_prompts = await asyncio.gather(*(
tokenize_async(
text=t1, # type: ignore[arg-type]
text_pair=t2, # type: ignore[arg-type]
**tokenization_kwargs) for t1, t2 in input_pairs))
else:
# `llm as reranker` models defaults to not using pad_token.
tokenized_prompts = await asyncio.gather(*(
tokenize_async(
text=t1 + # type: ignore[operator]
t2,
**tokenization_kwargs) for t1, t2 in input_pairs))
for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs):
sep_token = tokenizer.sep_token if (tokenizer.sep_token
and use_pad_token) else ''
request_prompt = f"{t1}{sep_token}{t2}"
input_ids = prompt_inputs["input_ids"]
text_token_prompt = \
self._validate_input(request, input_ids, request_prompt)
engine_prompt = TokensPrompt(
prompt_token_ids=text_token_prompt["prompt_token_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
request_prompts.append(request_prompt)
engine_prompts.append(engine_prompt)
for full_prompt, engine_prompt in preprocessed_prompts:
request_prompts.append(full_prompt)
engine_prompts.append(engine_prompt)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params()
default_pooling_params = request.to_pooling_params()
try:
pooling_params.verify("score", self.model_config)
default_pooling_params.verify("score", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
@ -254,9 +224,19 @@ class ServingScores(OpenAIServing):
self._log_inputs(request_id_item,
request_prompts[i],
params=pooling_params,
params=default_pooling_params,
lora_request=lora_request)
if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop(
"token_type_ids", None)):
pooling_params = default_pooling_params.clone()
compressed = compress_token_type_ids(token_type_ids)
pooling_params.extra_kwargs = {
"compressed_token_type_ids": compressed
}
else:
pooling_params = (default_pooling_params)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,

View File

@ -184,15 +184,49 @@ def get_score_prompt(
model_config,
tokenizer,
)
from vllm.model_executor.model_loader import get_model_cls
full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
model = get_model_cls(model_config)
if supports_score_template(model):
full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
elif model_config.use_pad_token:
# cross_encoder models defaults to using pad_token.
prompt_inputs = tokenizer(text=prompt_1,
text_pair=prompt_2,
**tokenization_kwargs)
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
else:
# `llm as reranker` models defaults to not using pad_token.
full_prompt = prompt_1 + prompt_2
prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])
if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
engine_prompt["token_type_ids"] = token_type_ids
post_process_tokens(model_config, engine_prompt)
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
return full_prompt, engine_prompt
def compress_token_type_ids(token_type_ids: list[int]) -> int:
"""
Return position of the first 1 or the length of the list
if not found.
"""
first_one = len(token_type_ids)
err_msg = "Token type ids are expected to be a sequence"\
" of zeros followed by a sequence of ones"
for i, type_id in enumerate(token_type_ids):
if type_id == 0 and first_one < i:
raise ValueError(err_msg)
elif type_id == 1 and first_one > i:
first_one = i
elif type_id > 1:
raise ValueError(err_msg)
return first_one

View File

@ -28,7 +28,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
@ -60,21 +60,13 @@ class BertEmbedding(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
# Input embeddings.
token_type_ids = _decode_token_type_ids(input_ids)
inputs_embeds = self.word_embeddings(input_ids)
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
@ -350,25 +342,23 @@ class BertModel(nn.Module, SupportsQuant):
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
self.embeddings = embedding_class(config)
self.config = vllm_config.model_config.hf_config
self.embeddings = embedding_class(self.config)
self.encoder = BertEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder")
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
position_ids=positions)
return self.encoder(hidden_states)
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
@ -468,13 +458,11 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
position_ids=positions,
token_type_ids=token_type_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
@ -508,8 +496,53 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
})
class BertForSequenceClassification(nn.Module, SupportsV0Only,
SupportsCrossEncoding, SupportsQuant):
# Here we encode the token type ids together with the input ids.
# Since we use int 32 for the input IDs and the vocabulary size
# is way lower than 2**31, there is room to encode additional
# bits. At the same time, for cross-encoder use cases, the
# token type ids are only 0 or 1, requiring only 1 bit.
# This means that we can store the token type ids in the 31st
# bit. We void the 32nd bit because that would produce a negative
# number, which could be used to signal other things.
#
# The reason for all of this is that all the tensors that are
# passed as input to the forward function of a module marked
# with @support_torch_compile have to be persistent. So to
# avoid adding more persistent tensors in the model runner, we
# encode more information in the same persistent tensor.
#
# Since the *ForClassification module is outside of the BertModel
# which is compiled, we can do the encoding here and then separate
# the information again in the Embedding layer. Since with bit masks
# we can do this entirely with torch operations and without branching,
# it works with torch compile.
TOKEN_TYPE_SHIFT = 30
def _encode_token_type_ids(input_ids: torch.Tensor,
token_type_ids: torch.Tensor) -> None:
# input_ids can be padded to the right
input_ids[:token_type_ids.shape[0]].bitwise_or_(
token_type_ids << TOKEN_TYPE_SHIFT)
def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
ids_mask = torch.ones(input_ids.shape,
dtype=torch.int32,
device=input_ids.device) << TOKEN_TYPE_SHIFT
tokens_mask = ids_mask.bitwise_not()
token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT
input_ids.bitwise_and_(tokens_mask)
return token_type_ids
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
@ -567,8 +600,13 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if token_type_ids is not None:
assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
assert input_ids is not None
_encode_token_type_ids(input_ids, token_type_ids)
return self.bert(input_ids=input_ids,
position_ids=positions,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
token_type_ids=token_type_ids)
intermediate_tensors=intermediate_tensors)

View File

@ -14,13 +14,16 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.model_executor.models.bert import (TOKEN_TYPE_SHIFT,
BertEmbeddingModel, BertModel,
_decode_token_type_ids,
_encode_token_type_ids)
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
maybe_prefix)
from vllm.sequence import IntermediateTensors
from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .interfaces import SupportsCrossEncoding
class RobertaEmbedding(nn.Module):
@ -53,17 +56,12 @@ class RobertaEmbedding(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)
# Position embeddings.
token_type_ids = _decode_token_type_ids(input_ids)
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
@ -107,7 +105,6 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@ -119,9 +116,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
position_ids=positions,
padding_idx=self.padding_idx)
return self.model(input_ids,
positions,
token_type_ids=token_type_ids,
return self.model(input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
@ -153,8 +149,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
return loader.load_weights(weights_list, mapper=mapper)
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsV0Only):
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
@ -226,11 +221,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)
if token_type_ids is not None:
assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
assert input_ids is not None
_encode_token_type_ids(input_ids, token_type_ids)
return self.roberta(input_ids=input_ids,
position_ids=positions,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
token_type_ids=token_type_ids)
intermediate_tensors=intermediate_tensors)
# Adapted from transformers

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional
import msgspec
@ -46,6 +46,9 @@ class PoolingParams(
requires_token_ids: bool = False
"""Internal use only."""
extra_kwargs: Optional[dict[str, Any]] = None
"""Internal use only."""
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
@property
@ -167,7 +170,8 @@ class PoolingParams(
f"softmax={self.softmax}, "
f"step_tag_id={self.step_tag_id}, "
f"returned_token_ids={self.returned_token_ids}, "
f"requires_token_ids={self.requires_token_ids})")
f"requires_token_ids={self.requires_token_ids}, "
f"extra_kwargs={self.extra_kwargs})")
def __post_init__(self) -> None:
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\

View File

@ -336,6 +336,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.reorder_batch_threshold: Optional[int] = None
def _init_model_kwargs(self, num_tokens: int):
model_kwargs = dict[str, Any]()
num_reqs = self.input_batch.num_reqs
pooling_params = self.input_batch.pooling_metadata.pooling_params
num_pooling_reqs = len(pooling_params)
if num_pooling_reqs == 0:
return model_kwargs
assert num_pooling_reqs == num_reqs
token_type_id_requests = dict[int, Any]()
for i, param in enumerate(pooling_params):
if param.extra_kwargs is not None and \
(token_types := param.extra_kwargs.get(
"compressed_token_type_ids")) is not None:
token_type_id_requests[i] = token_types
if len(token_type_id_requests) == 0:
return model_kwargs
seq_lens = self.seq_lens[:num_reqs]
token_type_ids = []
for i in range(num_reqs):
pos = token_type_id_requests.get(i, seq_lens[i])
ids = (torch.arange(seq_lens[i]) >= pos).int()
token_type_ids.append(ids)
model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to(
device=self.device)
return model_kwargs
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
@ -1504,12 +1539,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
model_mm_kwargs = self._extract_mm_kwargs(scheduler_output)
model_kwargs = self._init_model_kwargs(num_scheduled_tokens)
else:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens]
model_kwargs = self._init_model_kwargs(num_input_tokens)
inputs_embeds = None
model_mm_kwargs = {}
if self.uses_mrope:
@ -1548,6 +1585,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_mm_kwargs,
device=self.device,
),
**model_kwargs,
)
if self.use_aux_hidden_state_outputs:
@ -2211,6 +2249,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model_kwargs = self._init_model_kwargs(num_tokens)
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
@ -2252,6 +2291,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_mm_kwargs,
device=self.device,
),
**model_kwargs,
)
if self.use_aux_hidden_state_outputs: