mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
Support token_type_ids in V1 with less code changes (#21985)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
parent
9c97a1c349
commit
39052dbca8
@ -126,7 +126,9 @@ def test_invocations(server: RemoteOpenAIServer):
|
|||||||
invocation_output["results"]):
|
invocation_output["results"]):
|
||||||
assert rerank_result.keys() == invocations_result.keys()
|
assert rerank_result.keys() == invocations_result.keys()
|
||||||
assert rerank_result["relevance_score"] == pytest.approx(
|
assert rerank_result["relevance_score"] == pytest.approx(
|
||||||
invocations_result["relevance_score"], rel=0.01)
|
invocations_result["relevance_score"], rel=0.05)
|
||||||
|
# TODO: reset this tolerance to 0.01 once we find
|
||||||
|
# an alternative to flash_attn with bfloat16
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -220,7 +220,9 @@ class TestModel:
|
|||||||
invocation_output["data"]):
|
invocation_output["data"]):
|
||||||
assert score_data.keys() == invocation_data.keys()
|
assert score_data.keys() == invocation_data.keys()
|
||||||
assert score_data["score"] == pytest.approx(
|
assert score_data["score"] == pytest.approx(
|
||||||
invocation_data["score"], rel=0.01)
|
invocation_data["score"], rel=0.05)
|
||||||
|
# TODO: reset this tolerance to 0.01 once we find
|
||||||
|
# an alternative to flash_attn with bfloat16
|
||||||
|
|
||||||
def test_activation(self, server: RemoteOpenAIServer, model: dict[str,
|
def test_activation(self, server: RemoteOpenAIServer, model: dict[str,
|
||||||
Any]):
|
Any]):
|
||||||
|
|||||||
@ -23,6 +23,15 @@ TEXTS_2 = [
|
|||||||
"The capital of Germany is Berlin.",
|
"The capital of Germany is Berlin.",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def v1(run_with_both_engines):
|
||||||
|
# Simple autouse wrapper to run both engines for each test
|
||||||
|
# This can be promoted up to conftest.py to run for every
|
||||||
|
# test in a package
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
DTYPE = "half"
|
DTYPE = "half"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -28,11 +28,15 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
|||||||
apply_mistral_chat_template,
|
apply_mistral_chat_template,
|
||||||
parse_chat_messages,
|
parse_chat_messages,
|
||||||
resolve_chat_template_content_format)
|
resolve_chat_template_content_format)
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
|
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
|
||||||
ScoreMultiModalParam,
|
ScoreMultiModalParam,
|
||||||
_cosine_similarity,
|
_cosine_similarity,
|
||||||
_validate_score_input_lens,
|
_validate_score_input_lens,
|
||||||
|
compress_token_type_ids,
|
||||||
get_score_prompt)
|
get_score_prompt)
|
||||||
|
# yapf: enable
|
||||||
from vllm.entrypoints.utils import (_validate_truncation_size,
|
from vllm.entrypoints.utils import (_validate_truncation_size,
|
||||||
log_non_default_args)
|
log_non_default_args)
|
||||||
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||||
@ -1329,6 +1333,7 @@ class LLM:
|
|||||||
|
|
||||||
model_config = self.llm_engine.model_config
|
model_config = self.llm_engine.model_config
|
||||||
pooling_params.verify("score", model_config)
|
pooling_params.verify("score", model_config)
|
||||||
|
pooling_params_list = list[PoolingParams]()
|
||||||
|
|
||||||
tokenization_kwargs: dict[str, Any] = {}
|
tokenization_kwargs: dict[str, Any] = {}
|
||||||
|
|
||||||
@ -1339,7 +1344,8 @@ class LLM:
|
|||||||
|
|
||||||
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
||||||
|
|
||||||
if model_config.is_multimodal_model:
|
model_config = self.llm_engine.model_config
|
||||||
|
|
||||||
for q, d in input_pairs:
|
for q, d in input_pairs:
|
||||||
_, engine_prompt = get_score_prompt(
|
_, engine_prompt = get_score_prompt(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@ -1349,28 +1355,20 @@ class LLM:
|
|||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
parsed_prompts.append(engine_prompt)
|
if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop(
|
||||||
|
"token_type_ids", None)):
|
||||||
|
params = pooling_params.clone()
|
||||||
|
compressed = compress_token_type_ids(token_type_ids)
|
||||||
|
params.extra_kwargs = {"compressed_token_type_ids": compressed}
|
||||||
|
pooling_params_list.append(params)
|
||||||
else:
|
else:
|
||||||
for q, t in input_pairs:
|
pooling_params_list.append(pooling_params)
|
||||||
if model_config.use_pad_token:
|
|
||||||
# cross_encoder models defaults to using pad_token.
|
|
||||||
prompt_inputs = tokenizer(
|
|
||||||
text=q, # type: ignore[arg-type]
|
|
||||||
text_pair=t, # type: ignore[arg-type]
|
|
||||||
**tokenization_kwargs)
|
|
||||||
else:
|
|
||||||
# `llm as reranker` models defaults to not using pad_token.
|
|
||||||
prompt_inputs = tokenizer(
|
|
||||||
text=q + t, # type: ignore[operator]
|
|
||||||
**tokenization_kwargs)
|
|
||||||
engine_prompt = TokensPrompt(
|
|
||||||
prompt_token_ids=prompt_inputs["input_ids"],
|
|
||||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
|
||||||
parsed_prompts.append(engine_prompt)
|
parsed_prompts.append(engine_prompt)
|
||||||
|
|
||||||
self._validate_and_add_requests(
|
self._validate_and_add_requests(
|
||||||
prompts=parsed_prompts,
|
prompts=parsed_prompts,
|
||||||
params=pooling_params,
|
params=pooling_params_list,
|
||||||
use_tqdm=use_tqdm,
|
use_tqdm=use_tqdm,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from typing import Any, Optional, Union
|
|||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
@ -17,11 +18,15 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
|
|||||||
ScoreResponseData, UsageInfo)
|
ScoreResponseData, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
|
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
|
||||||
ScoreMultiModalParam,
|
ScoreMultiModalParam,
|
||||||
_cosine_similarity,
|
_cosine_similarity,
|
||||||
_validate_score_input_lens,
|
_validate_score_input_lens,
|
||||||
|
compress_token_type_ids,
|
||||||
get_score_prompt)
|
get_score_prompt)
|
||||||
|
# yapf: enable
|
||||||
from vllm.entrypoints.utils import _validate_truncation_size
|
from vllm.entrypoints.utils import _validate_truncation_size
|
||||||
from vllm.inputs.data import TokensPrompt
|
from vllm.inputs.data import TokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -158,6 +163,8 @@ class ServingScores(OpenAIServing):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
self._validate_input(request, engine_prompt["prompt_token_ids"],
|
||||||
|
full_prompt)
|
||||||
if request.mm_processor_kwargs is not None:
|
if request.mm_processor_kwargs is not None:
|
||||||
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||||
|
|
||||||
@ -188,8 +195,6 @@ class ServingScores(OpenAIServing):
|
|||||||
|
|
||||||
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
||||||
|
|
||||||
if self.model_config.is_multimodal_model:
|
|
||||||
|
|
||||||
preprocess_async = make_async(self._preprocess_score,
|
preprocess_async = make_async(self._preprocess_score,
|
||||||
executor=self._tokenizer_executor)
|
executor=self._tokenizer_executor)
|
||||||
|
|
||||||
@ -204,48 +209,13 @@ class ServingScores(OpenAIServing):
|
|||||||
request_prompts.append(full_prompt)
|
request_prompts.append(full_prompt)
|
||||||
engine_prompts.append(engine_prompt)
|
engine_prompts.append(engine_prompt)
|
||||||
|
|
||||||
else:
|
|
||||||
tokenize_async = make_async(tokenizer.__call__,
|
|
||||||
executor=self._tokenizer_executor)
|
|
||||||
use_pad_token = self.model_config.use_pad_token
|
|
||||||
|
|
||||||
if use_pad_token:
|
|
||||||
# cross_encoder models defaults to using pad_token.
|
|
||||||
tokenized_prompts = await asyncio.gather(*(
|
|
||||||
tokenize_async(
|
|
||||||
text=t1, # type: ignore[arg-type]
|
|
||||||
text_pair=t2, # type: ignore[arg-type]
|
|
||||||
**tokenization_kwargs) for t1, t2 in input_pairs))
|
|
||||||
else:
|
|
||||||
# `llm as reranker` models defaults to not using pad_token.
|
|
||||||
tokenized_prompts = await asyncio.gather(*(
|
|
||||||
tokenize_async(
|
|
||||||
text=t1 + # type: ignore[operator]
|
|
||||||
t2,
|
|
||||||
**tokenization_kwargs) for t1, t2 in input_pairs))
|
|
||||||
|
|
||||||
for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs):
|
|
||||||
sep_token = tokenizer.sep_token if (tokenizer.sep_token
|
|
||||||
and use_pad_token) else ''
|
|
||||||
request_prompt = f"{t1}{sep_token}{t2}"
|
|
||||||
|
|
||||||
input_ids = prompt_inputs["input_ids"]
|
|
||||||
text_token_prompt = \
|
|
||||||
self._validate_input(request, input_ids, request_prompt)
|
|
||||||
engine_prompt = TokensPrompt(
|
|
||||||
prompt_token_ids=text_token_prompt["prompt_token_ids"],
|
|
||||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
|
||||||
|
|
||||||
request_prompts.append(request_prompt)
|
|
||||||
engine_prompts.append(engine_prompt)
|
|
||||||
|
|
||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||||
|
|
||||||
pooling_params = request.to_pooling_params()
|
default_pooling_params = request.to_pooling_params()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pooling_params.verify("score", self.model_config)
|
default_pooling_params.verify("score", self.model_config)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
@ -254,9 +224,19 @@ class ServingScores(OpenAIServing):
|
|||||||
|
|
||||||
self._log_inputs(request_id_item,
|
self._log_inputs(request_id_item,
|
||||||
request_prompts[i],
|
request_prompts[i],
|
||||||
params=pooling_params,
|
params=default_pooling_params,
|
||||||
lora_request=lora_request)
|
lora_request=lora_request)
|
||||||
|
|
||||||
|
if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop(
|
||||||
|
"token_type_ids", None)):
|
||||||
|
pooling_params = default_pooling_params.clone()
|
||||||
|
compressed = compress_token_type_ids(token_type_ids)
|
||||||
|
pooling_params.extra_kwargs = {
|
||||||
|
"compressed_token_type_ids": compressed
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
pooling_params = (default_pooling_params)
|
||||||
|
|
||||||
generator = self.engine_client.encode(
|
generator = self.engine_client.encode(
|
||||||
engine_prompt,
|
engine_prompt,
|
||||||
pooling_params,
|
pooling_params,
|
||||||
|
|||||||
@ -184,15 +184,49 @@ def get_score_prompt(
|
|||||||
model_config,
|
model_config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.model_loader import get_model_cls
|
||||||
|
|
||||||
|
model = get_model_cls(model_config)
|
||||||
|
if supports_score_template(model):
|
||||||
full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
|
full_prompt = apply_score_template(model_config, prompt_1, prompt_2)
|
||||||
|
|
||||||
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
|
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
|
||||||
|
elif model_config.use_pad_token:
|
||||||
|
# cross_encoder models defaults to using pad_token.
|
||||||
|
prompt_inputs = tokenizer(text=prompt_1,
|
||||||
|
text_pair=prompt_2,
|
||||||
|
**tokenization_kwargs)
|
||||||
|
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
|
||||||
|
else:
|
||||||
|
# `llm as reranker` models defaults to not using pad_token.
|
||||||
|
full_prompt = prompt_1 + prompt_2
|
||||||
|
prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
|
||||||
|
|
||||||
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])
|
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])
|
||||||
|
|
||||||
|
if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
|
||||||
|
engine_prompt["token_type_ids"] = token_type_ids
|
||||||
|
|
||||||
post_process_tokens(model_config, engine_prompt)
|
post_process_tokens(model_config, engine_prompt)
|
||||||
|
|
||||||
if mm_data is not None:
|
if mm_data is not None:
|
||||||
engine_prompt["multi_modal_data"] = mm_data
|
engine_prompt["multi_modal_data"] = mm_data
|
||||||
return full_prompt, engine_prompt
|
return full_prompt, engine_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def compress_token_type_ids(token_type_ids: list[int]) -> int:
|
||||||
|
"""
|
||||||
|
Return position of the first 1 or the length of the list
|
||||||
|
if not found.
|
||||||
|
"""
|
||||||
|
first_one = len(token_type_ids)
|
||||||
|
err_msg = "Token type ids are expected to be a sequence"\
|
||||||
|
" of zeros followed by a sequence of ones"
|
||||||
|
for i, type_id in enumerate(token_type_ids):
|
||||||
|
if type_id == 0 and first_one < i:
|
||||||
|
raise ValueError(err_msg)
|
||||||
|
elif type_id == 1 and first_one > i:
|
||||||
|
first_one = i
|
||||||
|
elif type_id > 1:
|
||||||
|
raise ValueError(err_msg)
|
||||||
|
|
||||||
|
return first_one
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.tasks import PoolingTask
|
from vllm.tasks import PoolingTask
|
||||||
|
|
||||||
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
|
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
@ -60,21 +60,13 @@ class BertEmbedding(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
input_shape = input_ids.size()
|
|
||||||
|
|
||||||
# Input embeddings.
|
token_type_ids = _decode_token_type_ids(input_ids)
|
||||||
|
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
# Position embeddings.
|
|
||||||
position_embeddings = self.position_embeddings(position_ids)
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
|
|
||||||
if token_type_ids is None:
|
|
||||||
token_type_ids = torch.zeros(input_shape,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=inputs_embeds.device)
|
|
||||||
|
|
||||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||||
|
|
||||||
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||||
@ -350,25 +342,23 @@ class BertModel(nn.Module, SupportsQuant):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
self.config = vllm_config.model_config.hf_config
|
||||||
self.embeddings = embedding_class(config)
|
self.embeddings = embedding_class(self.config)
|
||||||
self.encoder = BertEncoder(vllm_config=vllm_config,
|
self.encoder = BertEncoder(vllm_config=vllm_config,
|
||||||
prefix=f"{prefix}.encoder")
|
prefix=f"{prefix}.encoder")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
else:
|
else:
|
||||||
hidden_states = self.embeddings(input_ids=input_ids,
|
hidden_states = self.embeddings(input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=positions)
|
||||||
token_type_ids=token_type_ids)
|
|
||||||
return self.encoder(hidden_states)
|
return self.encoder(hidden_states)
|
||||||
|
|
||||||
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
@ -468,13 +458,11 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.model(input_ids=input_ids,
|
return self.model(input_ids=input_ids,
|
||||||
position_ids=positions,
|
positions=positions,
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
intermediate_tensors=intermediate_tensors)
|
intermediate_tensors=intermediate_tensors)
|
||||||
|
|
||||||
@ -508,8 +496,53 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
# Here we encode the token type ids together with the input ids.
|
||||||
SupportsCrossEncoding, SupportsQuant):
|
# Since we use int 32 for the input IDs and the vocabulary size
|
||||||
|
# is way lower than 2**31, there is room to encode additional
|
||||||
|
# bits. At the same time, for cross-encoder use cases, the
|
||||||
|
# token type ids are only 0 or 1, requiring only 1 bit.
|
||||||
|
# This means that we can store the token type ids in the 31st
|
||||||
|
# bit. We void the 32nd bit because that would produce a negative
|
||||||
|
# number, which could be used to signal other things.
|
||||||
|
#
|
||||||
|
# The reason for all of this is that all the tensors that are
|
||||||
|
# passed as input to the forward function of a module marked
|
||||||
|
# with @support_torch_compile have to be persistent. So to
|
||||||
|
# avoid adding more persistent tensors in the model runner, we
|
||||||
|
# encode more information in the same persistent tensor.
|
||||||
|
#
|
||||||
|
# Since the *ForClassification module is outside of the BertModel
|
||||||
|
# which is compiled, we can do the encoding here and then separate
|
||||||
|
# the information again in the Embedding layer. Since with bit masks
|
||||||
|
# we can do this entirely with torch operations and without branching,
|
||||||
|
# it works with torch compile.
|
||||||
|
|
||||||
|
TOKEN_TYPE_SHIFT = 30
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_token_type_ids(input_ids: torch.Tensor,
|
||||||
|
token_type_ids: torch.Tensor) -> None:
|
||||||
|
# input_ids can be padded to the right
|
||||||
|
input_ids[:token_type_ids.shape[0]].bitwise_or_(
|
||||||
|
token_type_ids << TOKEN_TYPE_SHIFT)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
ids_mask = torch.ones(input_ids.shape,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=input_ids.device) << TOKEN_TYPE_SHIFT
|
||||||
|
tokens_mask = ids_mask.bitwise_not()
|
||||||
|
|
||||||
|
token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT
|
||||||
|
|
||||||
|
input_ids.bitwise_and_(tokens_mask)
|
||||||
|
|
||||||
|
return token_type_ids
|
||||||
|
|
||||||
|
|
||||||
|
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||||
|
SupportsQuant):
|
||||||
"""A model that uses Bert to provide embedding functionalities.
|
"""A model that uses Bert to provide embedding functionalities.
|
||||||
|
|
||||||
This class encapsulates the BertModel and provides an interface for
|
This class encapsulates the BertModel and provides an interface for
|
||||||
@ -567,8 +600,13 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
|||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if token_type_ids is not None:
|
||||||
|
assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
|
||||||
|
assert input_ids is not None
|
||||||
|
_encode_token_type_ids(input_ids, token_type_ids)
|
||||||
|
|
||||||
return self.bert(input_ids=input_ids,
|
return self.bert(input_ids=input_ids,
|
||||||
position_ids=positions,
|
positions=positions,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors)
|
||||||
token_type_ids=token_type_ids)
|
|
||||||
|
|||||||
@ -14,13 +14,16 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
|
|||||||
DispatchPooler, Pooler)
|
DispatchPooler, Pooler)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
|
from vllm.model_executor.models.bert import (TOKEN_TYPE_SHIFT,
|
||||||
|
BertEmbeddingModel, BertModel,
|
||||||
|
_decode_token_type_ids,
|
||||||
|
_encode_token_type_ids)
|
||||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .bert_with_rope import BertWithRope, JinaRobertaModel
|
from .bert_with_rope import BertWithRope, JinaRobertaModel
|
||||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
from .interfaces import SupportsCrossEncoding
|
||||||
|
|
||||||
|
|
||||||
class RobertaEmbedding(nn.Module):
|
class RobertaEmbedding(nn.Module):
|
||||||
@ -53,17 +56,12 @@ class RobertaEmbedding(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
input_shape = input_ids.size()
|
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
|
||||||
|
|
||||||
# Position embeddings.
|
token_type_ids = _decode_token_type_ids(input_ids)
|
||||||
|
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
position_embeddings = self.position_embeddings(position_ids)
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
if token_type_ids is None:
|
|
||||||
token_type_ids = torch.zeros(input_shape,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=inputs_embeds.device)
|
|
||||||
|
|
||||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||||
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||||
@ -107,7 +105,6 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
token_type_ids: Optional[torch.Tensor] = None,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -119,9 +116,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
|||||||
position_ids=positions,
|
position_ids=positions,
|
||||||
padding_idx=self.padding_idx)
|
padding_idx=self.padding_idx)
|
||||||
|
|
||||||
return self.model(input_ids,
|
return self.model(input_ids=input_ids,
|
||||||
positions,
|
positions=positions,
|
||||||
token_type_ids=token_type_ids,
|
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
intermediate_tensors=intermediate_tensors)
|
intermediate_tensors=intermediate_tensors)
|
||||||
|
|
||||||
@ -153,8 +149,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
|||||||
return loader.load_weights(weights_list, mapper=mapper)
|
return loader.load_weights(weights_list, mapper=mapper)
|
||||||
|
|
||||||
|
|
||||||
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||||
SupportsV0Only):
|
|
||||||
"""A model that uses Roberta to provide embedding functionalities.
|
"""A model that uses Roberta to provide embedding functionalities.
|
||||||
|
|
||||||
This class encapsulates the BertModel and provides an interface for
|
This class encapsulates the BertModel and provides an interface for
|
||||||
@ -226,11 +221,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
|||||||
replace_roberta_positions(input_ids=input_ids,
|
replace_roberta_positions(input_ids=input_ids,
|
||||||
position_ids=positions,
|
position_ids=positions,
|
||||||
padding_idx=self.padding_idx)
|
padding_idx=self.padding_idx)
|
||||||
|
if token_type_ids is not None:
|
||||||
|
assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
|
||||||
|
assert input_ids is not None
|
||||||
|
_encode_token_type_ids(input_ids, token_type_ids)
|
||||||
return self.roberta(input_ids=input_ids,
|
return self.roberta(input_ids=input_ids,
|
||||||
position_ids=positions,
|
positions=positions,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors)
|
||||||
token_type_ids=token_type_ids)
|
|
||||||
|
|
||||||
|
|
||||||
# Adapted from transformers
|
# Adapted from transformers
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
@ -46,6 +46,9 @@ class PoolingParams(
|
|||||||
requires_token_ids: bool = False
|
requires_token_ids: bool = False
|
||||||
"""Internal use only."""
|
"""Internal use only."""
|
||||||
|
|
||||||
|
extra_kwargs: Optional[dict[str, Any]] = None
|
||||||
|
"""Internal use only."""
|
||||||
|
|
||||||
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -167,7 +170,8 @@ class PoolingParams(
|
|||||||
f"softmax={self.softmax}, "
|
f"softmax={self.softmax}, "
|
||||||
f"step_tag_id={self.step_tag_id}, "
|
f"step_tag_id={self.step_tag_id}, "
|
||||||
f"returned_token_ids={self.returned_token_ids}, "
|
f"returned_token_ids={self.returned_token_ids}, "
|
||||||
f"requires_token_ids={self.requires_token_ids})")
|
f"requires_token_ids={self.requires_token_ids}, "
|
||||||
|
f"extra_kwargs={self.extra_kwargs})")
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\
|
assert self.output_kind == RequestOutputKind.FINAL_ONLY,\
|
||||||
|
|||||||
@ -336,6 +336,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
self.reorder_batch_threshold: Optional[int] = None
|
self.reorder_batch_threshold: Optional[int] = None
|
||||||
|
|
||||||
|
def _init_model_kwargs(self, num_tokens: int):
|
||||||
|
model_kwargs = dict[str, Any]()
|
||||||
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
|
||||||
|
pooling_params = self.input_batch.pooling_metadata.pooling_params
|
||||||
|
|
||||||
|
num_pooling_reqs = len(pooling_params)
|
||||||
|
|
||||||
|
if num_pooling_reqs == 0:
|
||||||
|
return model_kwargs
|
||||||
|
|
||||||
|
assert num_pooling_reqs == num_reqs
|
||||||
|
|
||||||
|
token_type_id_requests = dict[int, Any]()
|
||||||
|
for i, param in enumerate(pooling_params):
|
||||||
|
if param.extra_kwargs is not None and \
|
||||||
|
(token_types := param.extra_kwargs.get(
|
||||||
|
"compressed_token_type_ids")) is not None:
|
||||||
|
token_type_id_requests[i] = token_types
|
||||||
|
|
||||||
|
if len(token_type_id_requests) == 0:
|
||||||
|
return model_kwargs
|
||||||
|
|
||||||
|
seq_lens = self.seq_lens[:num_reqs]
|
||||||
|
token_type_ids = []
|
||||||
|
|
||||||
|
for i in range(num_reqs):
|
||||||
|
pos = token_type_id_requests.get(i, seq_lens[i])
|
||||||
|
ids = (torch.arange(seq_lens[i]) >= pos).int()
|
||||||
|
token_type_ids.append(ids)
|
||||||
|
|
||||||
|
model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to(
|
||||||
|
device=self.device)
|
||||||
|
return model_kwargs
|
||||||
|
|
||||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
"""
|
"""
|
||||||
Update the order of requests in the batch based on the attention
|
Update the order of requests in the batch based on the attention
|
||||||
@ -1504,12 +1539,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||||
model_mm_kwargs = self._extract_mm_kwargs(scheduler_output)
|
model_mm_kwargs = self._extract_mm_kwargs(scheduler_output)
|
||||||
|
model_kwargs = self._init_model_kwargs(num_scheduled_tokens)
|
||||||
else:
|
else:
|
||||||
# For text-only models, we use token ids as input.
|
# For text-only models, we use token ids as input.
|
||||||
# While it is possible to use embeddings as input just like the
|
# While it is possible to use embeddings as input just like the
|
||||||
# multimodal models, it is not desirable for performance since
|
# multimodal models, it is not desirable for performance since
|
||||||
# then the embedding layer is not included in the CUDA graph.
|
# then the embedding layer is not included in the CUDA graph.
|
||||||
input_ids = self.input_ids[:num_input_tokens]
|
input_ids = self.input_ids[:num_input_tokens]
|
||||||
|
model_kwargs = self._init_model_kwargs(num_input_tokens)
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
model_mm_kwargs = {}
|
model_mm_kwargs = {}
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
@ -1548,6 +1585,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
model_mm_kwargs,
|
model_mm_kwargs,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
@ -2211,6 +2249,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||||
num_scheduled_tokens):
|
num_scheduled_tokens):
|
||||||
|
model_kwargs = self._init_model_kwargs(num_tokens)
|
||||||
if self.supports_mm_inputs:
|
if self.supports_mm_inputs:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||||
@ -2252,6 +2291,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
model_mm_kwargs,
|
model_mm_kwargs,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user