Support token_type_ids in V1 with less code changes (#21985)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Maximilien de Bayser 2025-08-11 02:54:59 -03:00 committed by GitHub
parent 9c97a1c349
commit 39052dbca8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 235 additions and 130 deletions

View File

@ -126,7 +126,9 @@ def test_invocations(server: RemoteOpenAIServer):
invocation_output["results"]): invocation_output["results"]):
assert rerank_result.keys() == invocations_result.keys() assert rerank_result.keys() == invocations_result.keys()
assert rerank_result["relevance_score"] == pytest.approx( assert rerank_result["relevance_score"] == pytest.approx(
invocations_result["relevance_score"], rel=0.01) invocations_result["relevance_score"], rel=0.05)
# TODO: reset this tolerance to 0.01 once we find
# an alternative to flash_attn with bfloat16
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -220,7 +220,9 @@ class TestModel:
invocation_output["data"]): invocation_output["data"]):
assert score_data.keys() == invocation_data.keys() assert score_data.keys() == invocation_data.keys()
assert score_data["score"] == pytest.approx( assert score_data["score"] == pytest.approx(
invocation_data["score"], rel=0.01) invocation_data["score"], rel=0.05)
# TODO: reset this tolerance to 0.01 once we find
# an alternative to flash_attn with bfloat16
def test_activation(self, server: RemoteOpenAIServer, model: dict[str, def test_activation(self, server: RemoteOpenAIServer, model: dict[str,
Any]): Any]):

View File

@ -23,6 +23,15 @@ TEXTS_2 = [
"The capital of Germany is Berlin.", "The capital of Germany is Berlin.",
] ]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
DTYPE = "half" DTYPE = "half"

View File

@ -28,11 +28,15 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_mistral_chat_template, apply_mistral_chat_template,
parse_chat_messages, parse_chat_messages,
resolve_chat_template_content_format) resolve_chat_template_content_format)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.score_utils import (ScoreContentPartParam, from vllm.entrypoints.score_utils import (ScoreContentPartParam,
ScoreMultiModalParam, ScoreMultiModalParam,
_cosine_similarity, _cosine_similarity,
_validate_score_input_lens, _validate_score_input_lens,
compress_token_type_ids,
get_score_prompt) get_score_prompt)
# yapf: enable
from vllm.entrypoints.utils import (_validate_truncation_size, from vllm.entrypoints.utils import (_validate_truncation_size,
log_non_default_args) log_non_default_args)
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
@ -1329,6 +1333,7 @@ class LLM:
model_config = self.llm_engine.model_config model_config = self.llm_engine.model_config
pooling_params.verify("score", model_config) pooling_params.verify("score", model_config)
pooling_params_list = list[PoolingParams]()
tokenization_kwargs: dict[str, Any] = {} tokenization_kwargs: dict[str, Any] = {}
@ -1339,38 +1344,31 @@ class LLM:
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
if model_config.is_multimodal_model: model_config = self.llm_engine.model_config
for q, d in input_pairs:
_, engine_prompt = get_score_prompt(
model_config=model_config,
data_1=q,
data_2=d,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
)
parsed_prompts.append(engine_prompt) for q, d in input_pairs:
else: _, engine_prompt = get_score_prompt(
for q, t in input_pairs: model_config=model_config,
if model_config.use_pad_token: data_1=q,
# cross_encoder models defaults to using pad_token. data_2=d,
prompt_inputs = tokenizer( tokenizer=tokenizer,
text=q, # type: ignore[arg-type] tokenization_kwargs=tokenization_kwargs,
text_pair=t, # type: ignore[arg-type] )
**tokenization_kwargs)
else: if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop(
# `llm as reranker` models defaults to not using pad_token. "token_type_ids", None)):
prompt_inputs = tokenizer( params = pooling_params.clone()
text=q + t, # type: ignore[operator] compressed = compress_token_type_ids(token_type_ids)
**tokenization_kwargs) params.extra_kwargs = {"compressed_token_type_ids": compressed}
engine_prompt = TokensPrompt( pooling_params_list.append(params)
prompt_token_ids=prompt_inputs["input_ids"], else:
token_type_ids=prompt_inputs.get("token_type_ids")) pooling_params_list.append(pooling_params)
parsed_prompts.append(engine_prompt)
parsed_prompts.append(engine_prompt)
self._validate_and_add_requests( self._validate_and_add_requests(
prompts=parsed_prompts, prompts=parsed_prompts,
params=pooling_params, params=pooling_params_list,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
) )

View File

@ -7,6 +7,7 @@ from typing import Any, Optional, Union
from fastapi import Request from fastapi import Request
from vllm import envs
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
@ -17,11 +18,15 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
ScoreResponseData, UsageInfo) ScoreResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.score_utils import (ScoreContentPartParam, from vllm.entrypoints.score_utils import (ScoreContentPartParam,
ScoreMultiModalParam, ScoreMultiModalParam,
_cosine_similarity, _cosine_similarity,
_validate_score_input_lens, _validate_score_input_lens,
compress_token_type_ids,
get_score_prompt) get_score_prompt)
# yapf: enable
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
@ -158,6 +163,8 @@ class ServingScores(OpenAIServing):
tokenizer=tokenizer, tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
self._validate_input(request, engine_prompt["prompt_token_ids"],
full_prompt)
if request.mm_processor_kwargs is not None: if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
@ -188,64 +195,27 @@ class ServingScores(OpenAIServing):
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
if self.model_config.is_multimodal_model: preprocess_async = make_async(self._preprocess_score,
executor=self._tokenizer_executor)
preprocess_async = make_async(self._preprocess_score, preprocessed_prompts = await asyncio.gather(
executor=self._tokenizer_executor) *(preprocess_async(request=request,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
data_1=t1,
data_2=t2) for t1, t2 in input_pairs))
preprocessed_prompts = await asyncio.gather( for full_prompt, engine_prompt in preprocessed_prompts:
*(preprocess_async(request=request, request_prompts.append(full_prompt)
tokenizer=tokenizer, engine_prompts.append(engine_prompt)
tokenization_kwargs=tokenization_kwargs,
data_1=t1,
data_2=t2) for t1, t2 in input_pairs))
for full_prompt, engine_prompt in preprocessed_prompts:
request_prompts.append(full_prompt)
engine_prompts.append(engine_prompt)
else:
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
use_pad_token = self.model_config.use_pad_token
if use_pad_token:
# cross_encoder models defaults to using pad_token.
tokenized_prompts = await asyncio.gather(*(
tokenize_async(
text=t1, # type: ignore[arg-type]
text_pair=t2, # type: ignore[arg-type]
**tokenization_kwargs) for t1, t2 in input_pairs))
else:
# `llm as reranker` models defaults to not using pad_token.
tokenized_prompts = await asyncio.gather(*(
tokenize_async(
text=t1 + # type: ignore[operator]
t2,
**tokenization_kwargs) for t1, t2 in input_pairs))
for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs):
sep_token = tokenizer.sep_token if (tokenizer.sep_token
and use_pad_token) else ''
request_prompt = f"{t1}{sep_token}{t2}"
input_ids = prompt_inputs["input_ids"]
text_token_prompt = \
self._validate_input(request, input_ids, request_prompt)
engine_prompt = TokensPrompt(
prompt_token_ids=text_token_prompt["prompt_token_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
request_prompts.append(request_prompt)
engine_prompts.append(engine_prompt)
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params() default_pooling_params = request.to_pooling_params()
try: try:
pooling_params.verify("score", self.model_config) default_pooling_params.verify("score", self.model_config)
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
@ -254,9 +224,19 @@ class ServingScores(OpenAIServing):
self._log_inputs(request_id_item, self._log_inputs(request_id_item,
request_prompts[i], request_prompts[i],
params=pooling_params, params=default_pooling_params,
lora_request=lora_request) lora_request=lora_request)
if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop(
"token_type_ids", None)):
pooling_params = default_pooling_params.clone()
compressed = compress_token_type_ids(token_type_ids)
pooling_params.extra_kwargs = {
"compressed_token_type_ids": compressed
}
else:
pooling_params = (default_pooling_params)
generator = self.engine_client.encode( generator = self.engine_client.encode(
engine_prompt, engine_prompt,
pooling_params, pooling_params,

View File

@ -184,15 +184,49 @@ def get_score_prompt(
model_config, model_config,
tokenizer, tokenizer,
) )
from vllm.model_executor.model_loader import get_model_cls
full_prompt = apply_score_template(model_config, prompt_1, prompt_2) model = get_model_cls(model_config)
if supports_score_template(model):
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs) full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
elif model_config.use_pad_token:
# cross_encoder models defaults to using pad_token.
prompt_inputs = tokenizer(text=prompt_1,
text_pair=prompt_2,
**tokenization_kwargs)
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
else:
# `llm as reranker` models defaults to not using pad_token.
full_prompt = prompt_1 + prompt_2
prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"]) engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])
if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
engine_prompt["token_type_ids"] = token_type_ids
post_process_tokens(model_config, engine_prompt) post_process_tokens(model_config, engine_prompt)
if mm_data is not None: if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data engine_prompt["multi_modal_data"] = mm_data
return full_prompt, engine_prompt return full_prompt, engine_prompt
def compress_token_type_ids(token_type_ids: list[int]) -> int:
"""
Return position of the first 1 or the length of the list
if not found.
"""
first_one = len(token_type_ids)
err_msg = "Token type ids are expected to be a sequence"\
" of zeros followed by a sequence of ones"
for i, type_id in enumerate(token_type_ids):
if type_id == 0 and first_one < i:
raise ValueError(err_msg)
elif type_id == 1 and first_one > i:
first_one = i
elif type_id > 1:
raise ValueError(err_msg)
return first_one

View File

@ -28,7 +28,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only from .interfaces import SupportsCrossEncoding, SupportsQuant
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
@ -60,21 +60,13 @@ class BertEmbedding(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
input_shape = input_ids.size()
# Input embeddings. token_type_ids = _decode_token_type_ids(input_ids)
inputs_embeds = self.word_embeddings(input_ids) inputs_embeds = self.word_embeddings(input_ids)
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings embeddings = inputs_embeds + token_type_embeddings + position_embeddings
@ -350,25 +342,23 @@ class BertModel(nn.Module, SupportsQuant):
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
self.embeddings = embedding_class(config) self.embeddings = embedding_class(self.config)
self.encoder = BertEncoder(vllm_config=vllm_config, self.encoder = BertEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder") prefix=f"{prefix}.encoder")
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.embeddings(input_ids=input_ids, hidden_states = self.embeddings(input_ids=input_ids,
position_ids=position_ids, position_ids=positions)
token_type_ids=token_type_ids)
return self.encoder(hidden_states) return self.encoder(hidden_states)
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
@ -468,13 +458,11 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.model(input_ids=input_ids, return self.model(input_ids=input_ids,
position_ids=positions, positions=positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors) intermediate_tensors=intermediate_tensors)
@ -508,8 +496,53 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
}) })
class BertForSequenceClassification(nn.Module, SupportsV0Only, # Here we encode the token type ids together with the input ids.
SupportsCrossEncoding, SupportsQuant): # Since we use int 32 for the input IDs and the vocabulary size
# is way lower than 2**31, there is room to encode additional
# bits. At the same time, for cross-encoder use cases, the
# token type ids are only 0 or 1, requiring only 1 bit.
# This means that we can store the token type ids in the 31st
# bit. We void the 32nd bit because that would produce a negative
# number, which could be used to signal other things.
#
# The reason for all of this is that all the tensors that are
# passed as input to the forward function of a module marked
# with @support_torch_compile have to be persistent. So to
# avoid adding more persistent tensors in the model runner, we
# encode more information in the same persistent tensor.
#
# Since the *ForClassification module is outside of the BertModel
# which is compiled, we can do the encoding here and then separate
# the information again in the Embedding layer. Since with bit masks
# we can do this entirely with torch operations and without branching,
# it works with torch compile.
TOKEN_TYPE_SHIFT = 30
def _encode_token_type_ids(input_ids: torch.Tensor,
token_type_ids: torch.Tensor) -> None:
# input_ids can be padded to the right
input_ids[:token_type_ids.shape[0]].bitwise_or_(
token_type_ids << TOKEN_TYPE_SHIFT)
def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
ids_mask = torch.ones(input_ids.shape,
dtype=torch.int32,
device=input_ids.device) << TOKEN_TYPE_SHIFT
tokens_mask = ids_mask.bitwise_not()
token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT
input_ids.bitwise_and_(tokens_mask)
return token_type_ids
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsQuant):
"""A model that uses Bert to provide embedding functionalities. """A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for This class encapsulates the BertModel and provides an interface for
@ -567,8 +600,13 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if token_type_ids is not None:
assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
assert input_ids is not None
_encode_token_type_ids(input_ids, token_type_ids)
return self.bert(input_ids=input_ids, return self.bert(input_ids=input_ids,
position_ids=positions, positions=positions,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors)
token_type_ids=token_type_ids)

View File

@ -14,13 +14,16 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler) DispatchPooler, Pooler)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel from vllm.model_executor.models.bert import (TOKEN_TYPE_SHIFT,
BertEmbeddingModel, BertModel,
_decode_token_type_ids,
_encode_token_type_ids)
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
maybe_prefix) maybe_prefix)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .bert_with_rope import BertWithRope, JinaRobertaModel from .bert_with_rope import BertWithRope, JinaRobertaModel
from .interfaces import SupportsCrossEncoding, SupportsV0Only from .interfaces import SupportsCrossEncoding
class RobertaEmbedding(nn.Module): class RobertaEmbedding(nn.Module):
@ -53,17 +56,12 @@ class RobertaEmbedding(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)
# Position embeddings. token_type_ids = _decode_token_type_ids(input_ids)
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings embeddings = inputs_embeds + token_type_embeddings + position_embeddings
@ -107,7 +105,6 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -119,9 +116,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
position_ids=positions, position_ids=positions,
padding_idx=self.padding_idx) padding_idx=self.padding_idx)
return self.model(input_ids, return self.model(input_ids=input_ids,
positions, positions=positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors) intermediate_tensors=intermediate_tensors)
@ -153,8 +149,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
return loader.load_weights(weights_list, mapper=mapper) return loader.load_weights(weights_list, mapper=mapper)
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
SupportsV0Only):
"""A model that uses Roberta to provide embedding functionalities. """A model that uses Roberta to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for This class encapsulates the BertModel and provides an interface for
@ -226,11 +221,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
replace_roberta_positions(input_ids=input_ids, replace_roberta_positions(input_ids=input_ids,
position_ids=positions, position_ids=positions,
padding_idx=self.padding_idx) padding_idx=self.padding_idx)
if token_type_ids is not None:
assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
assert input_ids is not None
_encode_token_type_ids(input_ids, token_type_ids)
return self.roberta(input_ids=input_ids, return self.roberta(input_ids=input_ids,
position_ids=positions, positions=positions,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors)
token_type_ids=token_type_ids)
# Adapted from transformers # Adapted from transformers

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Any, Optional
import msgspec import msgspec
@ -46,6 +46,9 @@ class PoolingParams(
requires_token_ids: bool = False requires_token_ids: bool = False
"""Internal use only.""" """Internal use only."""
extra_kwargs: Optional[dict[str, Any]] = None
"""Internal use only."""
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
@property @property
@ -167,7 +170,8 @@ class PoolingParams(
f"softmax={self.softmax}, " f"softmax={self.softmax}, "
f"step_tag_id={self.step_tag_id}, " f"step_tag_id={self.step_tag_id}, "
f"returned_token_ids={self.returned_token_ids}, " f"returned_token_ids={self.returned_token_ids}, "
f"requires_token_ids={self.requires_token_ids})") f"requires_token_ids={self.requires_token_ids}, "
f"extra_kwargs={self.extra_kwargs})")
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\ assert self.output_kind == RequestOutputKind.FINAL_ONLY,\

View File

@ -336,6 +336,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.reorder_batch_threshold: Optional[int] = None self.reorder_batch_threshold: Optional[int] = None
def _init_model_kwargs(self, num_tokens: int):
model_kwargs = dict[str, Any]()
num_reqs = self.input_batch.num_reqs
pooling_params = self.input_batch.pooling_metadata.pooling_params
num_pooling_reqs = len(pooling_params)
if num_pooling_reqs == 0:
return model_kwargs
assert num_pooling_reqs == num_reqs
token_type_id_requests = dict[int, Any]()
for i, param in enumerate(pooling_params):
if param.extra_kwargs is not None and \
(token_types := param.extra_kwargs.get(
"compressed_token_type_ids")) is not None:
token_type_id_requests[i] = token_types
if len(token_type_id_requests) == 0:
return model_kwargs
seq_lens = self.seq_lens[:num_reqs]
token_type_ids = []
for i in range(num_reqs):
pos = token_type_id_requests.get(i, seq_lens[i])
ids = (torch.arange(seq_lens[i]) >= pos).int()
token_type_ids.append(ids)
model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to(
device=self.device)
return model_kwargs
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
""" """
Update the order of requests in the batch based on the attention Update the order of requests in the batch based on the attention
@ -1504,12 +1539,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens] inputs_embeds = self.inputs_embeds[:num_input_tokens]
model_mm_kwargs = self._extract_mm_kwargs(scheduler_output) model_mm_kwargs = self._extract_mm_kwargs(scheduler_output)
model_kwargs = self._init_model_kwargs(num_scheduled_tokens)
else: else:
# For text-only models, we use token ids as input. # For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the # While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since # multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph. # then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens] input_ids = self.input_ids[:num_input_tokens]
model_kwargs = self._init_model_kwargs(num_input_tokens)
inputs_embeds = None inputs_embeds = None
model_mm_kwargs = {} model_mm_kwargs = {}
if self.uses_mrope: if self.uses_mrope:
@ -1548,6 +1585,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_mm_kwargs, model_mm_kwargs,
device=self.device, device=self.device,
), ),
**model_kwargs,
) )
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
@ -2211,6 +2249,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with self.maybe_dummy_run_with_lora(self.lora_config, with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens): num_scheduled_tokens):
model_kwargs = self._init_model_kwargs(num_tokens)
if self.supports_mm_inputs: if self.supports_mm_inputs:
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens] inputs_embeds = self.inputs_embeds[:num_tokens]
@ -2252,6 +2291,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_mm_kwargs, model_mm_kwargs,
device=self.device, device=self.device,
), ),
**model_kwargs,
) )
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs: