diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index f121693e329f..73364294cbcd 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -126,7 +126,9 @@ def test_invocations(server: RemoteOpenAIServer): invocation_output["results"]): assert rerank_result.keys() == invocations_result.keys() assert rerank_result["relevance_score"] == pytest.approx( - invocations_result["relevance_score"], rel=0.01) + invocations_result["relevance_score"], rel=0.05) + # TODO: reset this tolerance to 0.01 once we find + # an alternative to flash_attn with bfloat16 @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index 1a5df1d2dbd2..cb6ec795ae96 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -220,7 +220,9 @@ class TestModel: invocation_output["data"]): assert score_data.keys() == invocation_data.keys() assert score_data["score"] == pytest.approx( - invocation_data["score"], rel=0.01) + invocation_data["score"], rel=0.05) + # TODO: reset this tolerance to 0.01 once we find + # an alternative to flash_attn with bfloat16 def test_activation(self, server: RemoteOpenAIServer, model: dict[str, Any]): diff --git a/tests/models/language/pooling/test_scoring.py b/tests/models/language/pooling/test_scoring.py index ef9d5530cde1..6b5ff7068145 100644 --- a/tests/models/language/pooling/test_scoring.py +++ b/tests/models/language/pooling/test_scoring.py @@ -23,6 +23,15 @@ TEXTS_2 = [ "The capital of Germany is Berlin.", ] + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + DTYPE = "half" diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ca24b0c32b73..4014a961c6c2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -28,11 +28,15 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, apply_mistral_chat_template, parse_chat_messages, resolve_chat_template_content_format) +# yapf conflicts with isort for this block +# yapf: disable from vllm.entrypoints.score_utils import (ScoreContentPartParam, ScoreMultiModalParam, _cosine_similarity, _validate_score_input_lens, + compress_token_type_ids, get_score_prompt) +# yapf: enable from vllm.entrypoints.utils import (_validate_truncation_size, log_non_default_args) from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt @@ -1329,6 +1333,7 @@ class LLM: model_config = self.llm_engine.model_config pooling_params.verify("score", model_config) + pooling_params_list = list[PoolingParams]() tokenization_kwargs: dict[str, Any] = {} @@ -1339,38 +1344,31 @@ class LLM: input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] - if model_config.is_multimodal_model: - for q, d in input_pairs: - _, engine_prompt = get_score_prompt( - model_config=model_config, - data_1=q, - data_2=d, - tokenizer=tokenizer, - tokenization_kwargs=tokenization_kwargs, - ) + model_config = self.llm_engine.model_config - parsed_prompts.append(engine_prompt) - else: - for q, t in input_pairs: - if model_config.use_pad_token: - # cross_encoder models defaults to using pad_token. - prompt_inputs = tokenizer( - text=q, # type: ignore[arg-type] - text_pair=t, # type: ignore[arg-type] - **tokenization_kwargs) - else: - # `llm as reranker` models defaults to not using pad_token. - prompt_inputs = tokenizer( - text=q + t, # type: ignore[operator] - **tokenization_kwargs) - engine_prompt = TokensPrompt( - prompt_token_ids=prompt_inputs["input_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) - parsed_prompts.append(engine_prompt) + for q, d in input_pairs: + _, engine_prompt = get_score_prompt( + model_config=model_config, + data_1=q, + data_2=d, + tokenizer=tokenizer, + tokenization_kwargs=tokenization_kwargs, + ) + + if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop( + "token_type_ids", None)): + params = pooling_params.clone() + compressed = compress_token_type_ids(token_type_ids) + params.extra_kwargs = {"compressed_token_type_ids": compressed} + pooling_params_list.append(params) + else: + pooling_params_list.append(pooling_params) + + parsed_prompts.append(engine_prompt) self._validate_and_add_requests( prompts=parsed_prompts, - params=pooling_params, + params=pooling_params_list, use_tqdm=use_tqdm, lora_request=lora_request, ) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 4da2094147ce..c246274514db 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -7,6 +7,7 @@ from typing import Any, Optional, Union from fastapi import Request +from vllm import envs from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger @@ -17,11 +18,15 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument, ScoreResponseData, UsageInfo) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels +# yapf conflicts with isort for this block +# yapf: disable from vllm.entrypoints.score_utils import (ScoreContentPartParam, ScoreMultiModalParam, _cosine_similarity, _validate_score_input_lens, + compress_token_type_ids, get_score_prompt) +# yapf: enable from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger @@ -158,6 +163,8 @@ class ServingScores(OpenAIServing): tokenizer=tokenizer, tokenization_kwargs=tokenization_kwargs, ) + self._validate_input(request, engine_prompt["prompt_token_ids"], + full_prompt) if request.mm_processor_kwargs is not None: engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs @@ -188,64 +195,27 @@ class ServingScores(OpenAIServing): input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] - if self.model_config.is_multimodal_model: + preprocess_async = make_async(self._preprocess_score, + executor=self._tokenizer_executor) - preprocess_async = make_async(self._preprocess_score, - executor=self._tokenizer_executor) + preprocessed_prompts = await asyncio.gather( + *(preprocess_async(request=request, + tokenizer=tokenizer, + tokenization_kwargs=tokenization_kwargs, + data_1=t1, + data_2=t2) for t1, t2 in input_pairs)) - preprocessed_prompts = await asyncio.gather( - *(preprocess_async(request=request, - tokenizer=tokenizer, - tokenization_kwargs=tokenization_kwargs, - data_1=t1, - data_2=t2) for t1, t2 in input_pairs)) - - for full_prompt, engine_prompt in preprocessed_prompts: - request_prompts.append(full_prompt) - engine_prompts.append(engine_prompt) - - else: - tokenize_async = make_async(tokenizer.__call__, - executor=self._tokenizer_executor) - use_pad_token = self.model_config.use_pad_token - - if use_pad_token: - # cross_encoder models defaults to using pad_token. - tokenized_prompts = await asyncio.gather(*( - tokenize_async( - text=t1, # type: ignore[arg-type] - text_pair=t2, # type: ignore[arg-type] - **tokenization_kwargs) for t1, t2 in input_pairs)) - else: - # `llm as reranker` models defaults to not using pad_token. - tokenized_prompts = await asyncio.gather(*( - tokenize_async( - text=t1 + # type: ignore[operator] - t2, - **tokenization_kwargs) for t1, t2 in input_pairs)) - - for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs): - sep_token = tokenizer.sep_token if (tokenizer.sep_token - and use_pad_token) else '' - request_prompt = f"{t1}{sep_token}{t2}" - - input_ids = prompt_inputs["input_ids"] - text_token_prompt = \ - self._validate_input(request, input_ids, request_prompt) - engine_prompt = TokensPrompt( - prompt_token_ids=text_token_prompt["prompt_token_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) - - request_prompts.append(request_prompt) - engine_prompts.append(engine_prompt) + for full_prompt, engine_prompt in preprocessed_prompts: + request_prompts.append(full_prompt) + engine_prompts.append(engine_prompt) # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - pooling_params = request.to_pooling_params() + default_pooling_params = request.to_pooling_params() try: - pooling_params.verify("score", self.model_config) + default_pooling_params.verify("score", self.model_config) except ValueError as e: return self.create_error_response(str(e)) @@ -254,9 +224,19 @@ class ServingScores(OpenAIServing): self._log_inputs(request_id_item, request_prompts[i], - params=pooling_params, + params=default_pooling_params, lora_request=lora_request) + if envs.VLLM_USE_V1 and (token_type_ids := engine_prompt.pop( + "token_type_ids", None)): + pooling_params = default_pooling_params.clone() + compressed = compress_token_type_ids(token_type_ids) + pooling_params.extra_kwargs = { + "compressed_token_type_ids": compressed + } + else: + pooling_params = (default_pooling_params) + generator = self.engine_client.encode( engine_prompt, pooling_params, diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index f3f042355c9e..642d6389539b 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -184,15 +184,49 @@ def get_score_prompt( model_config, tokenizer, ) + from vllm.model_executor.model_loader import get_model_cls - full_prompt = apply_score_template(model_config, prompt_1, prompt_2) - - prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs) + model = get_model_cls(model_config) + if supports_score_template(model): + full_prompt = apply_score_template(model_config, prompt_1, prompt_2) + prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs) + elif model_config.use_pad_token: + # cross_encoder models defaults to using pad_token. + prompt_inputs = tokenizer(text=prompt_1, + text_pair=prompt_2, + **tokenization_kwargs) + full_prompt = tokenizer.decode(prompt_inputs["input_ids"]) + else: + # `llm as reranker` models defaults to not using pad_token. + full_prompt = prompt_1 + prompt_2 + prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs) engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"]) + if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None: + engine_prompt["token_type_ids"] = token_type_ids + post_process_tokens(model_config, engine_prompt) if mm_data is not None: engine_prompt["multi_modal_data"] = mm_data return full_prompt, engine_prompt + + +def compress_token_type_ids(token_type_ids: list[int]) -> int: + """ + Return position of the first 1 or the length of the list + if not found. + """ + first_one = len(token_type_ids) + err_msg = "Token type ids are expected to be a sequence"\ + " of zeros followed by a sequence of ones" + for i, type_id in enumerate(token_type_ids): + if type_id == 0 and first_one < i: + raise ValueError(err_msg) + elif type_id == 1 and first_one > i: + first_one = i + elif type_id > 1: + raise ValueError(err_msg) + + return first_one diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 8f988903f78c..3d5d5d505b35 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -28,7 +28,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask -from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only +from .interfaces import SupportsCrossEncoding, SupportsQuant from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -60,21 +60,13 @@ class BertEmbedding(nn.Module): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - input_shape = input_ids.size() - # Input embeddings. + token_type_ids = _decode_token_type_ids(input_ids) + inputs_embeds = self.word_embeddings(input_ids) - - # Position embeddings. position_embeddings = self.position_embeddings(position_ids) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device) - token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings + position_embeddings @@ -350,25 +342,23 @@ class BertModel(nn.Module, SupportsQuant): ) -> None: super().__init__() - config = vllm_config.model_config.hf_config - self.embeddings = embedding_class(config) + self.config = vllm_config.model_config.hf_config + self.embeddings = embedding_class(self.config) self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder") def forward( self, input_ids: torch.Tensor, - position_ids: torch.Tensor, + positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embeddings(input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids) + position_ids=positions) return self.encoder(hidden_states) def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): @@ -468,13 +458,11 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): self, input_ids: torch.Tensor, positions: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.model(input_ids=input_ids, - position_ids=positions, - token_type_ids=token_type_ids, + positions=positions, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) @@ -508,8 +496,53 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): }) -class BertForSequenceClassification(nn.Module, SupportsV0Only, - SupportsCrossEncoding, SupportsQuant): +# Here we encode the token type ids together with the input ids. +# Since we use int 32 for the input IDs and the vocabulary size +# is way lower than 2**31, there is room to encode additional +# bits. At the same time, for cross-encoder use cases, the +# token type ids are only 0 or 1, requiring only 1 bit. +# This means that we can store the token type ids in the 31st +# bit. We void the 32nd bit because that would produce a negative +# number, which could be used to signal other things. +# +# The reason for all of this is that all the tensors that are +# passed as input to the forward function of a module marked +# with @support_torch_compile have to be persistent. So to +# avoid adding more persistent tensors in the model runner, we +# encode more information in the same persistent tensor. +# +# Since the *ForClassification module is outside of the BertModel +# which is compiled, we can do the encoding here and then separate +# the information again in the Embedding layer. Since with bit masks +# we can do this entirely with torch operations and without branching, +# it works with torch compile. + +TOKEN_TYPE_SHIFT = 30 + + +def _encode_token_type_ids(input_ids: torch.Tensor, + token_type_ids: torch.Tensor) -> None: + # input_ids can be padded to the right + input_ids[:token_type_ids.shape[0]].bitwise_or_( + token_type_ids << TOKEN_TYPE_SHIFT) + + +def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: + + ids_mask = torch.ones(input_ids.shape, + dtype=torch.int32, + device=input_ids.device) << TOKEN_TYPE_SHIFT + tokens_mask = ids_mask.bitwise_not() + + token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT + + input_ids.bitwise_and_(tokens_mask) + + return token_type_ids + + +class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, + SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -567,8 +600,13 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only, inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: + + if token_type_ids is not None: + assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) + assert input_ids is not None + _encode_token_type_ids(input_ids, token_type_ids) + return self.bert(input_ids=input_ids, - position_ids=positions, + positions=positions, inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors, - token_type_ids=token_type_ids) + intermediate_tensors=intermediate_tensors) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 61c8faed4065..005b9179827e 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -14,13 +14,16 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, DispatchPooler, Pooler) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel +from vllm.model_executor.models.bert import (TOKEN_TYPE_SHIFT, + BertEmbeddingModel, BertModel, + _decode_token_type_ids, + _encode_token_type_ids) from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix) from vllm.sequence import IntermediateTensors from .bert_with_rope import BertWithRope, JinaRobertaModel -from .interfaces import SupportsCrossEncoding, SupportsV0Only +from .interfaces import SupportsCrossEncoding class RobertaEmbedding(nn.Module): @@ -53,17 +56,12 @@ class RobertaEmbedding(nn.Module): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - input_shape = input_ids.size() - inputs_embeds = self.word_embeddings(input_ids) - # Position embeddings. + token_type_ids = _decode_token_type_ids(input_ids) + + inputs_embeds = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings + position_embeddings @@ -107,7 +105,6 @@ class RobertaEmbeddingModel(BertEmbeddingModel): self, input_ids: torch.Tensor, positions: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -119,9 +116,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel): position_ids=positions, padding_idx=self.padding_idx) - return self.model(input_ids, - positions, - token_type_ids=token_type_ids, + return self.model(input_ids=input_ids, + positions=positions, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) @@ -153,8 +149,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel): return loader.load_weights(weights_list, mapper=mapper) -class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, - SupportsV0Only): +class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): """A model that uses Roberta to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -226,11 +221,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, replace_roberta_positions(input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx) + if token_type_ids is not None: + assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) + assert input_ids is not None + _encode_token_type_ids(input_ids, token_type_ids) return self.roberta(input_ids=input_ids, - position_ids=positions, + positions=positions, inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors, - token_type_ids=token_type_ids) + intermediate_tensors=intermediate_tensors) # Adapted from transformers diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 7077f68353fc..29f037b4372c 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional import msgspec @@ -46,6 +46,9 @@ class PoolingParams( requires_token_ids: bool = False """Internal use only.""" + extra_kwargs: Optional[dict[str, Any]] = None + """Internal use only.""" + output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY @property @@ -167,7 +170,8 @@ class PoolingParams( f"softmax={self.softmax}, " f"step_tag_id={self.step_tag_id}, " f"returned_token_ids={self.returned_token_ids}, " - f"requires_token_ids={self.requires_token_ids})") + f"requires_token_ids={self.requires_token_ids}, " + f"extra_kwargs={self.extra_kwargs})") def __post_init__(self) -> None: assert self.output_kind == RequestOutputKind.FINAL_ONLY,\ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 48ff50fd6bd8..3cde7c6e9620 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -336,6 +336,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.reorder_batch_threshold: Optional[int] = None + def _init_model_kwargs(self, num_tokens: int): + model_kwargs = dict[str, Any]() + num_reqs = self.input_batch.num_reqs + + pooling_params = self.input_batch.pooling_metadata.pooling_params + + num_pooling_reqs = len(pooling_params) + + if num_pooling_reqs == 0: + return model_kwargs + + assert num_pooling_reqs == num_reqs + + token_type_id_requests = dict[int, Any]() + for i, param in enumerate(pooling_params): + if param.extra_kwargs is not None and \ + (token_types := param.extra_kwargs.get( + "compressed_token_type_ids")) is not None: + token_type_id_requests[i] = token_types + + if len(token_type_id_requests) == 0: + return model_kwargs + + seq_lens = self.seq_lens[:num_reqs] + token_type_ids = [] + + for i in range(num_reqs): + pos = token_type_id_requests.get(i, seq_lens[i]) + ids = (torch.arange(seq_lens[i]) >= pos).int() + token_type_ids.append(ids) + + model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( + device=self.device) + return model_kwargs + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ Update the order of requests in the batch based on the attention @@ -1504,12 +1539,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_ids = None inputs_embeds = self.inputs_embeds[:num_input_tokens] model_mm_kwargs = self._extract_mm_kwargs(scheduler_output) + model_kwargs = self._init_model_kwargs(num_scheduled_tokens) else: # For text-only models, we use token ids as input. # While it is possible to use embeddings as input just like the # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] + model_kwargs = self._init_model_kwargs(num_input_tokens) inputs_embeds = None model_mm_kwargs = {} if self.uses_mrope: @@ -1548,6 +1585,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model_mm_kwargs, device=self.device, ), + **model_kwargs, ) if self.use_aux_hidden_state_outputs: @@ -2211,6 +2249,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): + model_kwargs = self._init_model_kwargs(num_tokens) if self.supports_mm_inputs: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -2252,6 +2291,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model_mm_kwargs, device=self.device, ), + **model_kwargs, ) if self.use_aux_hidden_state_outputs: