[Core] Enable inputs_embeds_size separate from hidden_size (#29741)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-11-30 17:31:12 +08:00 committed by GitHub
parent 47539cfd3e
commit 64bc09ba27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 123 additions and 18 deletions

View File

@ -19,7 +19,12 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts(
}
)
MODELS = ["google/siglip-base-patch16-224", "google/siglip2-base-patch16-224"]
MODELS = [
"google/siglip-base-patch16-224",
"google/siglip2-base-patch16-224",
# Different image embedding dim than text_config.hidden_size
"google/siglip2-giant-opt-patch16-384",
]
def _run_test(

View File

@ -1202,6 +1202,16 @@ class ModelConfig:
def get_hidden_size(self) -> int:
return getattr(self.hf_text_config, "hidden_size", 0)
def get_inputs_embeds_size(self) -> int:
# The size of inputs_embeds is usually identical to the size
# of the hidden states, however there are exceptions, such as
# embedding models like CLIP and SigLIP
for target_attr in ("projection_dim", "projection_size"):
if hasattr(self.hf_text_config, target_attr):
return getattr(self.hf_text_config, target_attr)
return self.get_hidden_size()
@property
def is_deepseek_mla(self) -> bool:
if not hasattr(self.hf_text_config, "model_type"):

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import cached_property
from typing import Annotated, Literal
@ -903,6 +903,41 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
def get_language_model(self) -> torch.nn.Module:
return self.text_model
def _embed_text_input_ids(
self,
input_ids: torch.Tensor,
embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
*,
is_multimodal: torch.Tensor | None,
handle_oov_mm_token: bool,
) -> torch.Tensor:
inputs_embeds = super()._embed_text_input_ids(
input_ids,
embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
# NOTE: inputs_embeds in model runner has size text_config.projection_dim
# (instead of text_config.hidden_size) to accommodate image embeddings
inputs_embeds_size = self.projection_dim
if inputs_embeds.shape[1] < inputs_embeds_size:
inputs_embeds = torch.cat(
[
inputs_embeds,
inputs_embeds.new_empty(
inputs_embeds.shape[0],
inputs_embeds_size - inputs_embeds.shape[1],
),
],
dim=1,
)
elif inputs_embeds.shape[1] > inputs_embeds_size:
# No need to handle this case for now
raise NotImplementedError
return inputs_embeds
def embed_input_ids(
self,
input_ids: torch.Tensor,
@ -949,10 +984,16 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
if not self._is_text_input:
return inputs_embeds
# Text inputs
return self.get_text_features(
input_ids=input_ids, position_ids=positions, inputs_embeds=inputs_embeds
)
# NOTE: inputs_embeds in model runner has size text_config.projection_dim
# (instead of text_config.hidden_size) to accommodate image embeddings
hidden_size = self.text_embed_dim
if inputs_embeds.shape[1] > hidden_size:
inputs_embeds = inputs_embeds[:, :hidden_size]
elif inputs_embeds.shape[1] < hidden_size:
# No need to handle this case for now
raise NotImplementedError
return self.get_text_features(input_ids, positions, inputs_embeds)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(

View File

@ -1,10 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""
import math
from collections.abc import Iterable, Mapping
from collections.abc import Callable, Iterable, Mapping
from functools import cached_property
from typing import Annotated, Literal
@ -976,6 +974,7 @@ class SiglipTextEmbeddings(nn.Module):
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
@ -1145,6 +1144,41 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
def get_language_model(self) -> torch.nn.Module:
return self.text_model
def _embed_text_input_ids(
self,
input_ids: torch.Tensor,
embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
*,
is_multimodal: torch.Tensor | None,
handle_oov_mm_token: bool,
) -> torch.Tensor:
inputs_embeds = super()._embed_text_input_ids(
input_ids,
embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
# NOTE: inputs_embeds in model runner has size text_config.projection_size
# (instead of text_config.hidden_size) to accommodate image embeddings
inputs_embeds_size = self.text_projection_size
if inputs_embeds.shape[1] < inputs_embeds_size:
inputs_embeds = torch.cat(
[
inputs_embeds,
inputs_embeds.new_empty(
inputs_embeds.shape[0],
inputs_embeds_size - inputs_embeds.shape[1],
),
],
dim=1,
)
elif inputs_embeds.shape[1] > inputs_embeds_size:
# No need to handle this case for now
raise NotImplementedError
return inputs_embeds
def embed_input_ids(
self,
input_ids: torch.Tensor,
@ -1190,6 +1224,15 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
if not self._is_text_input:
return inputs_embeds
# NOTE: inputs_embeds in model runner has size text_config.projection_size
# (instead of text_config.hidden_size) to accommodate image embeddings
hidden_size = self.text_embed_dim
if inputs_embeds.shape[1] > hidden_size:
inputs_embeds = inputs_embeds[:, :hidden_size]
elif inputs_embeds.shape[1] < hidden_size:
# No need to handle this case for now
raise NotImplementedError
return self.get_text_features(input_ids, positions, inputs_embeds)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

View File

@ -80,6 +80,7 @@ class EagleProposer:
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
@ -151,7 +152,9 @@ class EagleProposer:
)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
(self.max_num_tokens, self.inputs_embeds_size),
dtype=self.dtype,
device=device,
)
self.backup_next_token_ids = CpuGpuBuffer(

View File

@ -17,7 +17,7 @@ class InputBuffers:
self,
max_num_reqs: int,
max_num_tokens: int,
hidden_size: int,
inputs_embeds_size: int,
vocab_size: int,
dtype: torch.dtype,
device: torch.device,

View File

@ -98,7 +98,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.hidden_size = self.model_config.get_hidden_size()
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
self.dp_size = self.parallel_config.data_parallel_size
self.dp_rank = self.parallel_config.data_parallel_rank
@ -134,7 +134,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
hidden_size=self.hidden_size,
inputs_embeds_size=self.inputs_embeds_size,
vocab_size=self.vocab_size,
dtype=self.dtype,
device=self.device,

View File

@ -44,6 +44,7 @@ class EagleSpeculator:
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
self.vocab_size = self.draft_model_config.get_vocab_size()
self.pin_memory = is_pin_memory_available()
self.dtype = vllm_config.model_config.dtype
@ -51,7 +52,7 @@ class EagleSpeculator:
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
hidden_size=self.hidden_size,
inputs_embeds_size=self.inputs_embeds_size,
vocab_size=self.vocab_size,
dtype=self.dtype,
device=device,

View File

@ -320,7 +320,7 @@ class GPUModelRunner(
# Model-related.
self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
self.hidden_size = model_config.get_hidden_size()
self.inputs_embeds_size = model_config.get_inputs_embeds_size()
self.attention_chunk_size = model_config.attention_chunk_size
# Only relevant for models using ALiBi (e.g, MPT)
self.use_alibi = model_config.uses_alibi
@ -485,7 +485,7 @@ class GPUModelRunner(
# version of this tensor, avoid a RuntimeError by not creating a
# numpy buffer.
self.inputs_embeds = self._make_buffer(
self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False
self.max_num_tokens, self.inputs_embeds_size, dtype=self.dtype, numpy=False
)
self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
self.discard_request_mask = self._make_buffer(

View File

@ -215,7 +215,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
self.inputs_embeds_size = model_config.get_inputs_embeds_size()
self.vocab_size = model_config.get_vocab_size()
# Multi-modal data support
@ -1406,7 +1406,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = torch.zeros(
(num_tokens, self.hidden_size), dtype=self.dtype, device=self.device
(num_tokens, self.inputs_embeds_size),
dtype=self.dtype,
device=self.device,
)
else:
input_ids = torch.zeros((num_tokens), dtype=torch.int32).to(self.device)