mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 17:55:50 +08:00
[Bugfix] Merge MM embeddings by index instead of token IDs (#16229)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: NickLucche <nlucches@redhat.com> Co-authored-by: Roger Wang <hey@rogerw.io> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
6970fa9937
commit
0b8166aa8f
@ -66,35 +66,12 @@ Further update the model as follows:
|
||||
!!! important
|
||||
The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
|
||||
|
||||
- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.
|
||||
!!! note
|
||||
By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in
|
||||
[PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing.
|
||||
This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings].
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
from .utils import merge_multimodal_embeddings
|
||||
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
...
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# `get_input_embeddings` should already be implemented for the language
|
||||
# model as one of the requirements of basic vLLM model implementation.
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=self.config.image_token_index)
|
||||
|
||||
return inputs_embeds
|
||||
```
|
||||
You may override this method if additional logic is required for your model when merging embeddings.
|
||||
|
||||
- Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model.
|
||||
|
||||
|
||||
@ -509,9 +509,14 @@ class ModelConfig:
|
||||
else: # task == "auto"
|
||||
pass
|
||||
else:
|
||||
debug_info = {
|
||||
"architectures": architectures,
|
||||
"is_generative_model": is_generative_model,
|
||||
"is_pooling_model": is_pooling_model,
|
||||
}
|
||||
raise AssertionError("The model should be a generative or "
|
||||
"pooling model when task is set to "
|
||||
f"{self.task!r}.")
|
||||
f"{self.task!r}. Found: {debug_info}")
|
||||
|
||||
self.runner = runner
|
||||
self.convert = convert
|
||||
|
||||
@ -38,8 +38,7 @@ from .idefics2_vision_model import (
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
|
||||
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
is_pp_missing_parameter, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
is_pp_missing_parameter, maybe_prefix)
|
||||
|
||||
|
||||
class AriaImagePixelInputs(TensorSchema):
|
||||
@ -605,19 +604,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
multimodal_embeddings = self._process_image_input(image_input)
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.config.image_token_index)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -628,10 +614,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if inputs_embeds is None:
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
multimodal_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
|
||||
@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
|
||||
class AyaVisionImagePixelInputs(TensorSchema):
|
||||
@ -417,23 +416,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return self._process_image_input(image_input, **kwargs)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=self.config.image_token_index,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -449,8 +431,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
|
||||
@ -348,6 +348,9 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
self.encoder = BertEncoder(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -457,6 +460,9 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.pooler = self._build_pooler(pooler_config)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -588,6 +594,9 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
),
|
||||
})
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.bert.get_input_embeddings(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_params = loader.load_weights(weights)
|
||||
@ -637,6 +646,9 @@ class BertForTokenClassification(nn.Module):
|
||||
Pooler.for_encode(pooler_config),
|
||||
})
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.bert.get_input_embeddings(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_params = loader.load_weights(weights)
|
||||
|
||||
@ -426,6 +426,9 @@ class BertWithRope(nn.Module, SupportsQuant):
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.pooler = BertPooler(self.config) if add_pooling_layer else None
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -673,6 +676,9 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
loaded_params = loader.load_weights(weights)
|
||||
return loaded_params
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.new.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
|
||||
@ -27,7 +27,7 @@ from .blip import BlipVisionModel
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
||||
SupportsQuant)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
maybe_prefix)
|
||||
|
||||
# We use this internally as placeholders since there is no image token
|
||||
# defined on the HuggingFace repo
|
||||
@ -631,19 +631,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
_IMAGE_TOKEN_ID)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -689,8 +676,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
|
||||
@ -44,7 +44,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
||||
SupportsQuant)
|
||||
from .utils import (flatten_bn, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -1002,20 +1002,6 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
vision_embeddings = self.model.get_input_embeddings(image_tokens)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.model.vocabulary_mapping.image_token_id)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -1032,8 +1018,12 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
image_token_id = self.model.vocabulary_mapping.image_token_id
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.model(input_ids,
|
||||
|
||||
@ -433,6 +433,9 @@ class ChatGLMBaseModel(nn.Module):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -37,8 +37,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
|
||||
class Cohere2VisionImagePixelInputs(TensorSchema):
|
||||
@ -430,23 +429,6 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return self._process_image_input(image_input, **kwargs)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=self.config.image_token_id,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -462,8 +444,11 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
|
||||
@ -66,6 +66,9 @@ class DeepseekV2Model(nn.Module):
|
||||
self.norm = RMSNorm(self.config.hidden_size,
|
||||
eps=self.config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -205,6 +208,9 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
||||
scale=logit_scale)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -101,6 +101,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -142,6 +145,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -41,8 +41,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
# The image token id may be various
|
||||
_IMAGE_TOKEN = "<image>"
|
||||
@ -346,7 +345,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
|
||||
self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN]
|
||||
|
||||
self.vision = self._init_vision_module(self.vision_config,
|
||||
quant_config,
|
||||
@ -605,19 +604,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.image_token_id)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@ -632,8 +618,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# condition is for v0 compatibility
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(input_ids,
|
||||
|
||||
@ -34,8 +34,7 @@ from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder,
|
||||
Qwen2VLProcessingInfo)
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
maybe_prefix)
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalDataDict
|
||||
@ -796,33 +795,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_id,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
@ -830,17 +813,14 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None and kwargs.get("pixel_values") is not None:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
assert input_ids is not None
|
||||
inputs_embeds = self.get_multimodal_embeddings(
|
||||
input_ids,
|
||||
image_input=image_input,
|
||||
)
|
||||
input_ids = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
input_ids=input_ids,
|
||||
|
||||
@ -60,8 +60,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
from .vision import get_vit_attn_backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -1467,18 +1466,24 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if multimodal_embeddings is not None and len(
|
||||
multimodal_embeddings) > 0:
|
||||
self._set_visual_token_mask(input_ids)
|
||||
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is None:
|
||||
return inputs_embeds
|
||||
|
||||
self._set_visual_token_mask(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(input_ids, inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
[self.config.im_patch_id])
|
||||
return inputs_embeds
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -116,6 +116,9 @@ class ErnieMultiTokenPredictor(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -160,6 +163,9 @@ class ErnieMTP(nn.Module, SupportsPP):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -42,8 +42,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
|
||||
|
||||
# Cannot find the following 2 numbers from hf config.
|
||||
_IMAGE_TOKEN_ID = 71011
|
||||
@ -342,22 +341,6 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
_IMAGE_TOKEN_ID,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -373,8 +356,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == _IMAGE_TOKEN_ID,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
|
||||
@ -37,8 +37,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -588,22 +587,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@ -618,8 +601,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
if (vision_embeddings is not None) and len(vision_embeddings) != 0:
|
||||
kwargs = self.prepare_attn_masks(
|
||||
input_ids,
|
||||
|
||||
@ -632,8 +632,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
|
||||
# them here, as the model forward has only access to the input_embeds.
|
||||
if input_ids is not None:
|
||||
@ -645,15 +647,16 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_(
|
||||
per_layer_inputs)
|
||||
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
# NOTE: this order of processing mm items is important
|
||||
[self.config.image_token_id, self.config.audio_token_id])
|
||||
return inputs_embeds
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -1552,23 +1552,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
multimodal_embeddings += video_embeddings
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if (multimodal_embeddings is not None
|
||||
and len(multimodal_embeddings) != 0
|
||||
and all(embed.numel() > 0 for embed in multimodal_embeddings)):
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
[self.config.image_token_id, self.config.video_token_id],
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -132,6 +132,9 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -173,6 +176,9 @@ class Glm4MoeMTP(nn.Module, SupportsPP):
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .chatglm import ChatGLMBaseModel, ChatGLMModel
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import flatten_bn, merge_multimodal_embeddings
|
||||
from .utils import flatten_bn, isin_list
|
||||
|
||||
|
||||
class GLMVImagePixelInputs(TensorSchema):
|
||||
@ -607,28 +607,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=[
|
||||
self.config.boi_token_id,
|
||||
self.config.pad_token_id,
|
||||
self.config.eoi_token_id,
|
||||
],
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -644,8 +622,15 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=isin_list(input_ids, [
|
||||
self.config.boi_token_id,
|
||||
self.config.pad_token_id,
|
||||
self.config.eoi_token_id,
|
||||
]),
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
|
||||
@ -52,8 +52,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .blip2 import Blip2QFormerModel
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, embed_multimodal,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
|
||||
|
||||
### Audio Input
|
||||
@ -720,6 +719,9 @@ class GraniteSpeechForConditionalGeneration(
|
||||
# Split variable length features into a tuple
|
||||
return torch.split(masked_embeds, audio_input["audio_embed_sizes"])
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self,
|
||||
**kwargs: object,
|
||||
@ -728,7 +730,7 @@ class GraniteSpeechForConditionalGeneration(
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
if audio_input is None:
|
||||
return []
|
||||
return None
|
||||
|
||||
audio_features = self._process_audio_input(audio_input)
|
||||
return audio_features
|
||||
|
||||
@ -736,19 +738,21 @@ class GraniteSpeechForConditionalGeneration(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
# Multi-modal token ID may exceed vocab size
|
||||
handle_oov_mm_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Compute the merged LLM / audio embeddings."""
|
||||
if multimodal_embeddings is None \
|
||||
or len(multimodal_embeddings) == 0:
|
||||
return self.language_model.get_input_embeddings(input_ids)
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
inputs_embeds = embed_multimodal(
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
self.config.audio_token_index,
|
||||
self.language_model.model.get_input_embeddings,
|
||||
multimodal_embeddings,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -765,7 +769,11 @@ class GraniteSpeechForConditionalGeneration(
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
audio_embeds = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
audio_embeds,
|
||||
is_multimodal=input_ids == self.config.audio_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
model_output = self.language_model(input_ids, positions,
|
||||
|
||||
@ -989,6 +989,9 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -45,8 +45,8 @@ from vllm.sequence import IntermediateTensors
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model, isin_list,
|
||||
maybe_prefix)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
EOT = "<|endofturn|>"
|
||||
@ -691,7 +691,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def get_multimodal_embeddings(
|
||||
self,
|
||||
**kwargs: Unpack[HCXVisionMultimodalInputs],
|
||||
) -> Optional[MultiModalEmbeddings]:
|
||||
) -> MultiModalEmbeddings:
|
||||
|
||||
multimodal_embeddings = list()
|
||||
if kwargs.get("pixel_values_images") is not None:
|
||||
@ -736,26 +736,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
multimodal_embeddings.append(_multimodal_embeddings_videos)
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
placeholder_token_id=[
|
||||
self.config.image_token_id,
|
||||
self.config.video_token_id,
|
||||
],
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -771,8 +751,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
multimodal_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings,
|
||||
is_multimodal=isin_list(
|
||||
input_ids,
|
||||
[self.config.image_token_id, self.config.video_token_id]),
|
||||
)
|
||||
input_ids = None
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
|
||||
@ -52,8 +52,7 @@ from .idefics2_vision_model import (
|
||||
# yapf: enable
|
||||
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
||||
from .llama import LlamaModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
||||
|
||||
|
||||
class Idefics3ImagePixelInputs(TensorSchema):
|
||||
@ -539,10 +538,7 @@ class Idefics3Model(nn.Module):
|
||||
|
||||
return image_hidden_states
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.text_model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -695,22 +691,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -726,8 +706,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.model.text_model(input_ids,
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, Mapping, MutableSequence
|
||||
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
||||
Union, overload, runtime_checkable)
|
||||
from typing import (TYPE_CHECKING, Callable, ClassVar, Literal, Optional,
|
||||
Protocol, Union, overload, runtime_checkable)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.utils import supports_kw
|
||||
|
||||
from .interfaces_base import is_pooling_model
|
||||
from .interfaces_base import VllmModel, is_pooling_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
@ -90,7 +90,7 @@ class SupportsMultiModal(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
def get_language_model(self) -> VllmModel:
|
||||
"""
|
||||
Returns the underlying language model used for text generation.
|
||||
|
||||
@ -102,17 +102,84 @@ class SupportsMultiModal(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_input_embeddings(self, input_ids: Tensor) -> Tensor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings,
|
||||
*,
|
||||
is_multimodal: torch.Tensor,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> Tensor:
|
||||
...
|
||||
|
||||
def _get_text_embeddings(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
get_input_embeddings: Callable[[Tensor], Tensor],
|
||||
*,
|
||||
is_multimodal: Optional[Tensor],
|
||||
handle_oov_mm_token: bool,
|
||||
) -> Tensor:
|
||||
if handle_oov_mm_token and is_multimodal is not None:
|
||||
is_text = ~is_multimodal
|
||||
text_embeds = get_input_embeddings(input_ids[is_text])
|
||||
|
||||
return torch.empty(
|
||||
(input_ids.shape[0], text_embeds.shape[1]),
|
||||
dtype=text_embeds.dtype,
|
||||
device=text_embeds.device,
|
||||
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
|
||||
|
||||
return get_input_embeddings(input_ids)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Returns the input embeddings merged from the text embeddings from
|
||||
input_ids and the multimodal embeddings generated from multimodal
|
||||
kwargs.
|
||||
Apply token embeddings to `input_ids`.
|
||||
|
||||
If `multimodal_embeddings` is passed, scatter them into
|
||||
`input_ids` according to the mask `is_multimodal`.
|
||||
|
||||
In case the multi-modal token IDs exceed the vocabulary size of
|
||||
the language model, you can set `handle_oov_mm_token=False`
|
||||
to avoid calling the language model's `get_input_embeddings` method
|
||||
on those tokens. Note however that doing so increases memory usage
|
||||
as an additional buffer is needed to hold the input embeddings.
|
||||
"""
|
||||
...
|
||||
from .utils import _merge_multimodal_embeddings
|
||||
|
||||
inputs_embeds = self._get_text_embeddings(
|
||||
input_ids,
|
||||
self.get_language_model().get_input_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
if is_multimodal is None:
|
||||
raise ValueError(
|
||||
"`get_input_embeddings` now requires `is_multimodal` arg, "
|
||||
"please update your model runner according to "
|
||||
"https://github.com/vllm-project/vllm/pull/16229.")
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
||||
@ -41,6 +41,13 @@ class VllmModel(Protocol[T_co]):
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Apply token embeddings to `input_ids`."""
|
||||
...
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -54,6 +61,19 @@ def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
|
||||
return supports_kw(model_init, "vllm_config")
|
||||
|
||||
|
||||
def _check_vllm_model_get_input_embeddings(
|
||||
model: Union[type[object], object]) -> bool:
|
||||
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
|
||||
if not callable(model_get_input_embeddings):
|
||||
logger.warning(
|
||||
"The model (%s) is missing the `get_input_embeddings` method.",
|
||||
model,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
|
||||
model_forward = getattr(model, "forward", None)
|
||||
if not callable(model_forward):
|
||||
@ -88,7 +108,9 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]:
|
||||
def is_vllm_model(
|
||||
model: Union[type[object], object],
|
||||
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
|
||||
return _check_vllm_model_init(model) and _check_vllm_model_forward(model)
|
||||
return (_check_vllm_model_init(model)
|
||||
and _check_vllm_model_get_input_embeddings(model)
|
||||
and _check_vllm_model_forward(model))
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
||||
@ -40,8 +40,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, isin_list, maybe_prefix)
|
||||
|
||||
|
||||
class InternS1MultiModalProjector(nn.Module):
|
||||
@ -767,24 +766,24 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
context_token_ids = [
|
||||
token_id for token_id in (self.img_context_token_id,
|
||||
self.video_context_token_id)
|
||||
if token_id is not None
|
||||
]
|
||||
assert len(context_token_ids) >= 1
|
||||
if multimodal_embeddings is not None and len(
|
||||
multimodal_embeddings) > 0:
|
||||
self._set_visual_token_mask(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
context_token_ids,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -802,9 +801,17 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
context_token_ids = [
|
||||
token_id for token_id in (self.img_context_token_id,
|
||||
self.video_context_token_id)
|
||||
if token_id is not None
|
||||
]
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=isin_list(input_ids, context_token_ids),
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
forward_kwargs = {
|
||||
|
||||
@ -43,7 +43,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
isin_list, maybe_prefix)
|
||||
|
||||
IMG_START = '<img>'
|
||||
IMG_END = '</img>'
|
||||
@ -1339,24 +1339,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
context_token_ids = [
|
||||
token_id for token_id in (self.img_context_token_id,
|
||||
self.video_context_token_id)
|
||||
if token_id is not None
|
||||
]
|
||||
assert len(context_token_ids) >= 1
|
||||
if multimodal_embeddings is not None and len(
|
||||
multimodal_embeddings) > 0:
|
||||
self._set_visual_token_mask(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
context_token_ids,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1374,9 +1374,17 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
context_token_ids = [
|
||||
token_id for token_id in (self.img_context_token_id,
|
||||
self.video_context_token_id)
|
||||
if token_id is not None
|
||||
]
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=isin_list(input_ids, context_token_ids),
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
forward_kwargs = {
|
||||
|
||||
@ -1450,24 +1450,6 @@ class BaseKeyeModule(nn.Module):
|
||||
multimodal_embeddings += video_embeddings
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
[
|
||||
self.config.image_token_id,
|
||||
self.config.video_token_id,
|
||||
],
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -66,7 +66,6 @@ from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
|
||||
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
|
||||
SupportsPP)
|
||||
from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
|
||||
from vllm.model_executor.models.utils import merge_multimodal_embeddings
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems, NestedTensors)
|
||||
@ -424,26 +423,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# `get_input_embeddings` should already be implemented for the language
|
||||
# model as one of the requirements of basic vLLM model implementation.
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None and len(
|
||||
multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=self.config.media_placeholder_token_id)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -462,14 +441,12 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if image_input is None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
image_embeds,
|
||||
placeholder_token_id=self.config.
|
||||
media_placeholder_token_id,
|
||||
is_multimodal=input_ids ==
|
||||
self.config.media_placeholder_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
|
||||
@ -522,6 +522,9 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -37,9 +37,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
|
||||
Llama4ForCausalLM)
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
|
||||
from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -79,10 +79,7 @@ class LlamaModel(nn.Module):
|
||||
self.norm = RMSNorm(self.config.hidden_size,
|
||||
eps=self.config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -194,6 +191,11 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
||||
scale=logit_scale)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
get_input_embeddings = SupportsMultiModal.get_input_embeddings # type: ignore
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -220,20 +222,3 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
||||
skip_prefixes=(["lm_head."]),
|
||||
)
|
||||
loader.load_weights(map(transform, weights))
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
@ -73,6 +73,9 @@ class LlamaModel(nn.Module):
|
||||
self.config.hidden_size,
|
||||
bias=False)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -149,6 +152,9 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
||||
scale=logit_scale)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -18,7 +18,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
|
||||
LlamaForCausalLM)
|
||||
|
||||
@ -144,10 +143,7 @@ class LlamaModel(nn.Module):
|
||||
eps=self.config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -239,6 +235,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -302,11 +301,3 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
skip_substrs=skip_substrs,
|
||||
)
|
||||
loader.load_weights(model_weights.items())
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
return inputs_embeds
|
||||
|
||||
@ -41,8 +41,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
|
||||
@ -676,22 +675,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -744,8 +727,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
|
||||
@ -25,8 +25,8 @@ from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
|
||||
LlavaDummyInputsBuilder, LlavaLikeConfig,
|
||||
LlavaMultiModalProjector, init_vision_tower_for_llava)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal,
|
||||
flatten_bn, init_vllm_registered_model, maybe_prefix)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
|
||||
class LlavaNextImagePixelInputs(TensorSchema):
|
||||
@ -474,19 +474,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
# Multi-modal token ID may exceed vocab size
|
||||
handle_oov_mm_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is None \
|
||||
or len(multimodal_embeddings) == 0:
|
||||
return self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
inputs_embeds = embed_multimodal(
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
self.config.image_token_index,
|
||||
self.language_model.model.get_input_embeddings,
|
||||
multimodal_embeddings,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -549,8 +551,11 @@ model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens].
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
|
||||
@ -30,8 +30,7 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .llava import init_vision_tower_for_llava
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
|
||||
@ -415,19 +414,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
vision_embeddings = self._process_video_pixels(video_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.config.video_token_index)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -449,8 +435,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.video_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
|
||||
@ -850,19 +850,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
[self.config.image_token_index, self.config.video_token_index])
|
||||
return inputs_embeds
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -54,8 +54,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.midashenglm import DashengConfig
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
|
||||
_Tuple2 = Union[int, tuple[int, int], Sequence[int]]
|
||||
|
||||
@ -744,21 +743,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return []
|
||||
return self._process_audio_input(audio_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.decoder.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings and len(multimodal_embeddings) > 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.audio_token_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -771,8 +755,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None:
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
multimodal_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings,
|
||||
is_multimodal=input_ids == self.config.audio_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
return self.decoder.model(input_ids,
|
||||
|
||||
@ -117,6 +117,9 @@ class MiMoMultiTokenPredictor(nn.Module):
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -158,6 +161,9 @@ class MiMoMTP(nn.Module):
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -71,8 +71,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .utils import AutoWeightsLoader, flatten_bn, isin_list, maybe_prefix
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 16
|
||||
@ -1144,23 +1143,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return self._process_multimodal_inputs(modalities)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.llm.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
assert len(self.mm_token_ids) > 0
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
list(self.mm_token_ids),
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -1178,8 +1160,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=isin_list(input_ids, list(self.mm_token_ids)),
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.llm.model(
|
||||
|
||||
@ -592,10 +592,7 @@ class MiniMaxText01Model(nn.Module):
|
||||
dtype=torch.long)
|
||||
minimax_cache_tensors[:, slots_tensor, ...] = 0
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(self,
|
||||
@ -687,10 +684,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(
|
||||
batch_size)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(self,
|
||||
|
||||
@ -28,7 +28,7 @@ from .llava_next import LlavaNextProcessingInfo
|
||||
from .pixtral import PixtralHFVisionModel
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class MiniMaxVL01ImagePixelInputs(TensorSchema):
|
||||
@ -218,22 +218,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
@ -403,8 +387,11 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
|
||||
@ -38,8 +38,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
|
||||
@ -524,22 +523,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
|
||||
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -592,8 +575,11 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
|
||||
@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.utils import initialize_model
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems, NestedTensors)
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
@ -56,8 +56,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .llama4 import Llama4ForCausalLM
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
|
||||
|
||||
@ -813,24 +812,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None and len(
|
||||
multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -846,8 +827,11 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# this condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
return self.language_model(input_ids, positions, intermediate_tensors,
|
||||
|
||||
@ -43,6 +43,9 @@ class ModernBertEmbeddings(nn.Module):
|
||||
eps=config.layer_norm_eps,
|
||||
bias=config.norm_bias)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.tok_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -220,6 +223,9 @@ class ModernBertModel(nn.Module):
|
||||
eps=config.norm_eps,
|
||||
bias=config.norm_bias)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings.get_input_embeddings(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
weights = self.hf_to_vllm_mapper.apply(weights)
|
||||
@ -333,6 +339,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
),
|
||||
})
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
self_weights = []
|
||||
|
||||
@ -58,7 +58,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
maybe_prefix)
|
||||
|
||||
# TODO: hard-coded for now. Consider making it configurable.
|
||||
VIT_LAYERS = [-2, -9]
|
||||
@ -819,10 +819,7 @@ class MolmoModel(nn.Module, SupportsQuant):
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -1481,24 +1478,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
assert self.img_patch_id is not None
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.img_patch_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@ -1515,8 +1494,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.img_patch_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.model(input_ids,
|
||||
|
||||
@ -35,8 +35,7 @@ from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
|
||||
from vllm.model_executor.models.radio import RadioModel
|
||||
from vllm.model_executor.models.utils import (flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
isin_list, maybe_prefix)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargs, MultiModalKwargsItems,
|
||||
@ -1096,8 +1095,8 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
|
||||
return modalities
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
# Validate the multimodal input keyword arguments
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if modalities is None:
|
||||
@ -1121,30 +1120,6 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
if (multimodal_embeddings is not None
|
||||
and len(multimodal_embeddings) != 0):
|
||||
context_token_ids = [
|
||||
token_id for token_id in (self.img_context_token_id,
|
||||
self.video_context_token_id)
|
||||
if token_id is not None
|
||||
]
|
||||
assert len(context_token_ids) >= 1
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
context_token_ids,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
@ -1163,9 +1138,17 @@ class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid,
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
context_token_ids = [
|
||||
token_id for token_id in (self.img_context_token_id,
|
||||
self.video_context_token_id)
|
||||
if token_id is not None
|
||||
]
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=isin_list(input_ids, context_token_ids),
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(
|
||||
|
||||
@ -38,7 +38,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
maybe_prefix)
|
||||
|
||||
IMG_START = '<img>'
|
||||
IMG_END = '</img>'
|
||||
@ -576,20 +576,24 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
context_token_ids = [self.img_context_token_id]
|
||||
assert len(context_token_ids) >= 1
|
||||
if multimodal_embeddings is not None and len(
|
||||
multimodal_embeddings) > 0:
|
||||
self._set_visual_token_mask(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
context_token_ids,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -608,8 +612,11 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.img_context_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
forward_kwargs = {
|
||||
|
||||
@ -295,6 +295,9 @@ class Olmo2Model(nn.Module):
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
self.config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -408,6 +411,9 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -48,7 +48,6 @@ from vllm.transformers_utils.processors.ovis import OvisProcessor
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import merge_multimodal_embeddings
|
||||
|
||||
# Cannot find the following number from hf config.
|
||||
IMAGE_TOKEN = "<image>"
|
||||
@ -501,19 +500,6 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return image_features
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.llm.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.image_pad_token_id)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -529,8 +515,11 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.image_pad_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
# up until here we have an inputs_embeds 100% numerical identity
|
||||
|
||||
@ -585,17 +585,6 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.llm.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
tmp = torch.concat(multimodal_embeddings, dim=0)
|
||||
inputs_embeds[input_ids == self.image_pad_token_id] = tmp
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -612,8 +601,11 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.image_pad_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
# up until here we have a inputs_embeds 100% numerical identity
|
||||
|
||||
@ -26,8 +26,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -362,19 +361,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.config.image_token_index)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@ -388,8 +374,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
|
||||
@ -51,9 +51,9 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
||||
SupportsQuant)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
_merge_multimodal_embeddings, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -643,14 +643,31 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.image_token_id)
|
||||
return inputs_embeds
|
||||
inputs_embeds = self._get_text_embeddings(
|
||||
input_ids,
|
||||
self.embed_tokens,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
if is_multimodal is None:
|
||||
raise ValueError(
|
||||
"`get_input_embeddings` now requires `is_multimodal` arg, "
|
||||
"please update your model runner according to "
|
||||
"https://github.com/vllm-project/vllm/pull/16229.")
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -666,8 +683,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
# condition is for v0 compatibility
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=self.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
|
||||
@ -1342,12 +1342,12 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
image_attention_mask)
|
||||
return image_embeds
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return None
|
||||
return []
|
||||
|
||||
# The result multimodal_embeddings is tuple of tensors, with each
|
||||
# tensor corresponding to a multimodal data item (image or video).
|
||||
@ -1371,18 +1371,6 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
[_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID])
|
||||
return inputs_embeds
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -1151,7 +1151,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return []
|
||||
return None
|
||||
|
||||
# The result multimodal_embeddings is tuple of tensors, with each
|
||||
# tensor corresponding to a multimodal data item (image or video).
|
||||
@ -1175,19 +1174,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
if multimodal_embeddings is not None and len(
|
||||
multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
[_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID])
|
||||
return inputs_embeds
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -50,8 +50,7 @@ from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
|
||||
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
||||
|
||||
try:
|
||||
@ -433,22 +432,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.vision_args.image_token_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -465,8 +448,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.vision_args.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
|
||||
@ -865,24 +865,26 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
multimodal_embeddings += audio_embeddings
|
||||
return multimodal_embeddings
|
||||
|
||||
# TODO (ywang96): support overlapping modality embeddings so that
|
||||
# `use_audio_in_video` will work on V1.
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
# TODO (ywang96): support overlapping modality embeddings so that
|
||||
# `use_audio_in_video` will work on V1.
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings, [
|
||||
self.config.image_token_index,
|
||||
self.config.video_token_index,
|
||||
self.config.audio_token_index
|
||||
])
|
||||
return inputs_embeds
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def get_multimodal_embeddings_v0(
|
||||
self, **kwargs: object) -> Optional[NestedTensors]:
|
||||
|
||||
@ -1365,19 +1365,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
multimodal_embeddings += video_embeddings
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
[self.config.image_token_id, self.config.video_token_id])
|
||||
return inputs_embeds
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -49,8 +49,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
|
||||
|
||||
# # === Audio Inputs === #
|
||||
@ -438,19 +437,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
masked_audio_features = self._process_audio_input(audio_input)
|
||||
return masked_audio_features
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.config.audio_token_index)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -467,8 +453,11 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
multimodal_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings,
|
||||
is_multimodal=input_ids == self.config.audio_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
|
||||
@ -1459,19 +1459,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
[self.config.image_token_id, self.config.video_token_id])
|
||||
return inputs_embeds
|
||||
|
||||
def get_input_embeddings_v0(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -79,7 +79,8 @@ from .qwen2_5_vl import (Qwen2_5_VisionAttention,
|
||||
from .qwen2_vl import Qwen2VLProcessingInfo
|
||||
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
_merge_multimodal_embeddings, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -1324,17 +1325,22 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return multimodal_embeddings
|
||||
|
||||
def _compute_deepstack_embeds(
|
||||
self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings) -> torch.Tensor:
|
||||
visual_lens = [
|
||||
x.shape[0] if isinstance(x, torch.Tensor) else len(x)
|
||||
for x in multimodal_embeddings
|
||||
]
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings,
|
||||
is_multimodal: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, MultiModalEmbeddings]:
|
||||
visual_lens = [len(x) for x in multimodal_embeddings]
|
||||
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)
|
||||
|
||||
multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501
|
||||
multimodal_embeddings_cat, [self.visual_dim, self.multiscale_dim],
|
||||
dim=-1)
|
||||
(
|
||||
multimodal_embeddings_main,
|
||||
multimodal_embeddings_multiscale,
|
||||
) = torch.split(
|
||||
multimodal_embeddings_cat,
|
||||
[self.visual_dim, self.multiscale_dim],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
multimodal_embeddings = torch.split(multimodal_embeddings_main,
|
||||
visual_lens,
|
||||
@ -1346,39 +1352,62 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds.size(0),
|
||||
self.deepstack_num_level * inputs_embeds.size(1))
|
||||
|
||||
deepstack_input_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
deepstack_input_embeds,
|
||||
multimodal_embeddings_multiscale,
|
||||
placeholder_token_id=[
|
||||
self.config.image_token_id, self.config.video_token_id
|
||||
],
|
||||
deepstack_input_embeds = _merge_multimodal_embeddings(
|
||||
inputs_embeds=deepstack_input_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings_multiscale,
|
||||
is_multimodal=is_multimodal,
|
||||
)
|
||||
deepstack_input_embeds = deepstack_input_embeds.view(
|
||||
inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim)
|
||||
deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)
|
||||
|
||||
return deepstack_input_embeds, multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
deepstack_input_embeds = None
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
if self.use_deepstack:
|
||||
deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501
|
||||
input_ids, inputs_embeds, multimodal_embeddings)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
[self.config.image_token_id, self.config.video_token_id])
|
||||
inputs_embeds = self._get_text_embeddings(
|
||||
input_ids,
|
||||
self.language_model.get_input_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
if is_multimodal is None:
|
||||
raise ValueError(
|
||||
"`get_input_embeddings` now requires `is_multimodal` arg, "
|
||||
"please update your model runner according to "
|
||||
"https://github.com/vllm-project/vllm/pull/16229.")
|
||||
|
||||
if self.use_deepstack:
|
||||
if deepstack_input_embeds is None:
|
||||
deepstack_input_embeds = torch.zeros_like(
|
||||
inputs_embeds).unsqueeze(0).repeat(
|
||||
self.deepstack_num_level, 1, 1).contiguous()
|
||||
(
|
||||
deepstack_input_embeds,
|
||||
multimodal_embeddings,
|
||||
) = self._compute_deepstack_embeds(
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
)
|
||||
else:
|
||||
deepstack_input_embeds = None
|
||||
|
||||
inputs_embeds = _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
)
|
||||
|
||||
if deepstack_input_embeds is not None:
|
||||
deepstack_input_embeds = torch.zeros_like(inputs_embeds).unsqueeze(
|
||||
0).repeat(self.deepstack_num_level, 1, 1).contiguous()
|
||||
self._set_deepstack_input_embeds(deepstack_input_embeds)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
@ -45,7 +45,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .qwen import QWenBaseModel, QWenModel
|
||||
from .utils import flatten_bn, merge_multimodal_embeddings
|
||||
from .utils import flatten_bn
|
||||
|
||||
|
||||
class QwenImagePixelInputs(TensorSchema):
|
||||
@ -756,21 +756,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.transformer.visual.image_pad_id)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -786,8 +771,12 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids ==
|
||||
self.transformer.visual.image_pad_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.transformer(input_ids, positions,
|
||||
|
||||
@ -218,6 +218,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.roberta.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
|
||||
@ -38,7 +38,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
maybe_prefix)
|
||||
|
||||
IMG_START = '<img>'
|
||||
IMG_END = '</img>'
|
||||
@ -842,19 +842,24 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
assert self.img_context_token_id is not None
|
||||
if multimodal_embeddings is not None and len(
|
||||
multimodal_embeddings) > 0:
|
||||
self._set_visual_token_mask(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.img_context_token_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -873,8 +878,11 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.img_context_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
forward_kwargs = {
|
||||
|
||||
@ -483,6 +483,9 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@ -395,6 +395,9 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
|
||||
@ -25,7 +25,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems, NestedTensors)
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
@ -37,8 +37,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
|
||||
|
||||
@ -996,10 +995,13 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
1 else cur_feature[0])
|
||||
return merged_image_features
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
return []
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
@ -1007,24 +1009,21 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
# Multi-modal token ID may exceed vocab size
|
||||
handle_oov_mm_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if multimodal_embeddings is None:
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
else:
|
||||
is_text = input_ids != self.config.image_token_id
|
||||
text_ids = input_ids[is_text]
|
||||
text_embeds = self.language_model.model.get_input_embeddings(
|
||||
text_ids)
|
||||
inputs_embeds = torch.empty(input_ids.shape[0],
|
||||
text_embeds.shape[-1],
|
||||
dtype=text_embeds.dtype,
|
||||
device=text_embeds.device)
|
||||
inputs_embeds[is_text] = text_embeds
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.config.image_token_id)
|
||||
return inputs_embeds
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1038,10 +1037,11 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(input_ids,
|
||||
|
||||
@ -40,7 +40,7 @@ from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
maybe_prefix)
|
||||
from .vision import VisionEncoderInfo, get_vision_encoder_info
|
||||
|
||||
|
||||
@ -589,22 +589,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return []
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -617,8 +601,11 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids=input_ids,
|
||||
|
||||
@ -233,6 +233,9 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# We do not really use any input tokens and therefore no embeddings
|
||||
# to be calculated. However, due to the mandatory token ids in
|
||||
|
||||
@ -52,8 +52,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
|
||||
SupportsQuant)
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
flatten_bn, make_empty_intermediate_tensors_factory,
|
||||
maybe_prefix)
|
||||
@ -797,6 +797,9 @@ class TransformersForCausalLM(TransformersBase):
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings()(input_ids)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -873,13 +876,19 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
|
||||
multimodal_embeds = self.get_multimodal_embeddings(**kwargs)
|
||||
if multimodal_embeds is not None:
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids, multimodal_embeds)
|
||||
input_ids,
|
||||
multimodal_embeds,
|
||||
is_multimodal=input_ids == self.config.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
model_output = super().forward(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
return model_output
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs):
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
pixel_values = pixel_values if pixel_values is not None else kwargs.pop(
|
||||
@ -934,15 +943,42 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings=None,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
||||
if (multimodal_embeddings is not None
|
||||
and len(multimodal_embeddings) != 0):
|
||||
mask = (input_ids == self.config.image_token_id)
|
||||
mask = mask.unsqueeze(-1).expand_as(inputs_embeds)
|
||||
multimodal_embeddings = torch.cat(multimodal_embeddings)
|
||||
"""
|
||||
Apply token embeddings to `input_ids`.
|
||||
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
mask, multimodal_embeddings)
|
||||
return inputs_embeds
|
||||
If `multimodal_embeddings` is passed, scatter them into
|
||||
`input_ids` according to the mask `is_multimodal`.
|
||||
|
||||
In case the multi-modal token IDs exceed the vocabulary size of
|
||||
the language model, you can set `handle_oov_mm_token=False`
|
||||
to avoid calling the language model's `get_input_embeddings` method
|
||||
on those tokens.
|
||||
"""
|
||||
from .utils import _merge_multimodal_embeddings
|
||||
|
||||
inputs_embeds = self._get_text_embeddings(
|
||||
input_ids,
|
||||
self.model.get_input_embeddings(),
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
if is_multimodal is None:
|
||||
raise ValueError(
|
||||
"`get_input_embeddings` now requires `is_multimodal` arg, "
|
||||
"please update your model runner according to "
|
||||
"https://github.com/vllm-project/vllm/pull/16229.")
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
)
|
||||
|
||||
@ -33,8 +33,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
|
||||
_MAX_ENCODER_BATCH_SIZE = 16
|
||||
@ -555,19 +554,21 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
# Multi-modal token ID may exceed vocab size
|
||||
handle_oov_mm_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# The audio token index is not included in the embedding table
|
||||
# We need to remove it before embedding lookup
|
||||
safe_input_ids = input_ids.clone()
|
||||
safe_input_ids[safe_input_ids == self.config.audio_token_index] = 0
|
||||
inputs_embeds = self.language_model.get_input_embeddings(
|
||||
safe_input_ids)
|
||||
if multimodal_embeddings is not None and len(
|
||||
multimodal_embeddings) > 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.config.audio_token_index)
|
||||
return inputs_embeds
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -601,8 +602,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
elif inputs_embeds is None:
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
multimodal_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings,
|
||||
is_multimodal=input_ids == self.config.audio_token_index,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
language_model = self.language_model
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import itertools
|
||||
from collections.abc import Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Literal, Optional, Protocol, Union, overload
|
||||
from typing import Any, Literal, Optional, Protocol, Union, overload
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -391,8 +391,8 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
|
||||
|
||||
def _merge_multimodal_embeddings(
|
||||
inputs_embeds: torch.Tensor,
|
||||
is_multimodal: torch.Tensor,
|
||||
multimodal_embeddings: NestedTensors,
|
||||
is_multimodal: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
||||
@ -402,63 +402,37 @@ def _merge_multimodal_embeddings(
|
||||
Note:
|
||||
This updates ``inputs_embeds`` in place.
|
||||
"""
|
||||
flattened = _flatten_embeddings(multimodal_embeddings)
|
||||
try:
|
||||
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
|
||||
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
|
||||
flattened.to(dtype=inputs_embeds.dtype))
|
||||
except RuntimeError as e:
|
||||
num_expected_tokens = is_multimodal.sum().item()
|
||||
assert isinstance(num_expected_tokens, int)
|
||||
if len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
if flattened.shape[0] != num_expected_tokens:
|
||||
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
|
||||
input_dtype = inputs_embeds.dtype
|
||||
|
||||
try:
|
||||
# For debugging
|
||||
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
|
||||
|
||||
# NOTE: This can avoid D2H sync (#22105), but fails to
|
||||
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
|
||||
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
|
||||
mm_embeds_flat.to(dtype=input_dtype))
|
||||
except RuntimeError as e:
|
||||
num_actual_tokens = len(mm_embeds_flat)
|
||||
num_expected_tokens = is_multimodal.sum().item()
|
||||
|
||||
if num_actual_tokens != num_expected_tokens:
|
||||
expr = _embedding_count_expression(multimodal_embeddings)
|
||||
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {flattened.shape[0]} "
|
||||
f"Attempted to assign {expr} = {num_actual_tokens} "
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders"
|
||||
) from e
|
||||
else:
|
||||
raise ValueError("Error during masked scatter operation") from e
|
||||
|
||||
raise ValueError("Error during masked scatter operation") from e
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
def embed_multimodal(
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_token_id: int,
|
||||
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
|
||||
multimodal_embeds: NestedTensors,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Embed token IDs and multimodal inputs and combine their embeddings.
|
||||
|
||||
``multimodal_token_id`` is used to determine whether a token ID should
|
||||
be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
|
||||
|
||||
Compared to ``merge_multimodal_embeddings`, this avoids running
|
||||
``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
|
||||
which causes issues when the placeholder token ID exceeds the
|
||||
vocabulary size of the language model.
|
||||
"""
|
||||
is_multimodal = input_ids == multimodal_token_id
|
||||
is_text = ~is_multimodal
|
||||
|
||||
text_embeds = get_text_embeds(input_ids[is_text])
|
||||
merged_embeds = torch.empty(
|
||||
(input_ids.shape[0], text_embeds.shape[1]),
|
||||
dtype=text_embeds.dtype,
|
||||
device=text_embeds.device,
|
||||
)
|
||||
|
||||
merged_embeds[is_text] = text_embeds
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
merged_embeds,
|
||||
is_multimodal,
|
||||
multimodal_embeds,
|
||||
)
|
||||
|
||||
|
||||
def merge_multimodal_embeddings(
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
@ -491,23 +465,29 @@ def merge_multimodal_embeddings(
|
||||
This updates ``inputs_embeds`` in place.
|
||||
"""
|
||||
if isinstance(placeholder_token_id, list):
|
||||
placeholder_token_id = torch.tensor(
|
||||
placeholder_token_id,
|
||||
pin_memory=is_pin_memory_available()).to(device=input_ids.device,
|
||||
non_blocking=True)
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds,
|
||||
torch.isin(input_ids, placeholder_token_id),
|
||||
multimodal_embeddings,
|
||||
)
|
||||
is_multimodal = isin_list(input_ids, placeholder_token_id)
|
||||
else:
|
||||
is_multimodal = (input_ids == placeholder_token_id)
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds,
|
||||
(input_ids == placeholder_token_id),
|
||||
multimodal_embeddings,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
)
|
||||
|
||||
|
||||
def isin_list(
|
||||
elements: torch.Tensor,
|
||||
test_elements_list: list[int],
|
||||
) -> torch.Tensor:
|
||||
test_elements = torch.tensor(
|
||||
test_elements_list,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
).to(device=elements.device, non_blocking=True)
|
||||
|
||||
return torch.isin(elements, test_elements)
|
||||
|
||||
|
||||
class LayerFn(Protocol):
|
||||
|
||||
def __call__(self, prefix: str) -> torch.nn.Module:
|
||||
|
||||
@ -45,10 +45,8 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
cached_tokenizer_from_config)
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsTranscription)
|
||||
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
|
||||
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -376,9 +374,14 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
audio_encoder = self.tokenizer.instruct.audio_encoder
|
||||
audio_tok_id = audio_encoder.audio_token
|
||||
audio_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
audio_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
audio_embeddings,
|
||||
is_multimodal=input_ids == audio_tok_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
@ -421,20 +424,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return audio_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
audio_encoder = self.tokenizer.instruct.audio_encoder
|
||||
audio_tok_id = audio_encoder.audio_token
|
||||
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings, audio_tok_id)
|
||||
return inputs_embeds
|
||||
|
||||
def _parse_and_validate_audio_arrays(
|
||||
self, **kwargs: object) -> Union[list[torch.Tensor], None]:
|
||||
audio_arrays = kwargs.pop("audio_arrays", None)
|
||||
|
||||
@ -579,10 +579,7 @@ class WhisperDecoder(nn.Module):
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
|
||||
@ -916,7 +913,10 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# This method just returns the decoder sequence embeddings since
|
||||
# Whisper does not have encoder text tokens.
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
@ -64,8 +65,10 @@ class EagleProposer:
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||
|
||||
self.is_multimodal_model = vllm_config.model_config \
|
||||
.is_multimodal_model
|
||||
# Multi-modal data support
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||
vllm_config.model_config)
|
||||
|
||||
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
||||
|
||||
@ -175,7 +178,8 @@ class EagleProposer:
|
||||
last_token_indices: Optional[torch.Tensor],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
mm_embeds: Optional[list[torch.Tensor]] = None,
|
||||
mm_embed_inputs: Optional[tuple[list[torch.Tensor],
|
||||
torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
@ -219,18 +223,21 @@ class EagleProposer:
|
||||
# copy inputs to buffer for cudagraph
|
||||
self._set_positions(num_tokens, target_positions)
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
if self.is_multimodal_model:
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=mm_embeds or None,
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
|
||||
|
||||
self.inputs_embeds[:num_tokens] = self.model.get_input_embeddings(
|
||||
self.input_ids[:num_tokens],
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
self.inputs_embeds[:num_tokens] = inputs_embeds
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
else:
|
||||
inputs_embeds = None
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
@ -372,14 +379,15 @@ class EagleProposer:
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self._set_positions(batch_size, clamped_positions)
|
||||
self.hidden_states[:batch_size] = hidden_states
|
||||
if self.is_multimodal_model:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
self.inputs_embeds[:batch_size] = inputs_embeds
|
||||
inputs_embeds = self.inputs_embeds[:input_batch_size]
|
||||
if self.supports_mm_inputs:
|
||||
self.inputs_embeds[:batch_size] = \
|
||||
self.model.get_input_embeddings(input_ids)
|
||||
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:input_batch_size]
|
||||
else:
|
||||
inputs_embeds = None
|
||||
input_ids = self.input_ids[:input_batch_size]
|
||||
inputs_embeds = None
|
||||
|
||||
# Run the model.
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
@ -849,7 +857,7 @@ class EagleProposer:
|
||||
|
||||
self.attn_layer_names = list(draft_attn_layer_names)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
if self.supports_mm_inputs:
|
||||
# Even if the target model is multimodal, we can also use
|
||||
# text-only draft models
|
||||
try:
|
||||
@ -861,7 +869,7 @@ class EagleProposer:
|
||||
logger.warning(
|
||||
"Draft model does not support multimodal inputs, "
|
||||
"falling back to text-only mode")
|
||||
self.is_multimodal_model = False
|
||||
self.supports_mm_inputs = False
|
||||
|
||||
if supports_multimodal(target_model):
|
||||
# handle multimodality
|
||||
@ -933,7 +941,7 @@ class EagleProposer:
|
||||
) -> None:
|
||||
with set_forward_context(None, self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
if self.is_multimodal_model:
|
||||
if self.supports_mm_inputs:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||
else:
|
||||
|
||||
@ -368,6 +368,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int64)
|
||||
|
||||
# Only relevant for multimodal models
|
||||
if self.supports_mm_inputs:
|
||||
self.is_mm_embed = self._make_buffer(self.max_num_tokens,
|
||||
dtype=torch.bool)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
# NOTE: `mrope_positions` is implemented with one additional dummy
|
||||
@ -1627,9 +1632,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
shift_computed_tokens: int = 0,
|
||||
) -> list[torch.Tensor]:
|
||||
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
|
||||
mm_embeds = list[torch.Tensor]()
|
||||
is_mm_embed = self.is_mm_embed.cpu
|
||||
is_mm_embed[:total_num_scheduled_tokens] = False
|
||||
|
||||
req_start_idx = 0
|
||||
should_sync_mrope_positions = False
|
||||
mm_embeds: list[torch.Tensor] = []
|
||||
|
||||
for req_id in self.input_batch.req_ids:
|
||||
mm_embeds_req: list[torch.Tensor] = []
|
||||
|
||||
@ -1638,6 +1650,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = \
|
||||
req_state.num_computed_tokens + shift_computed_tokens
|
||||
|
||||
for mm_feature in req_state.mm_features:
|
||||
pos_info = mm_feature.mm_position
|
||||
start_pos = pos_info.offset
|
||||
@ -1670,6 +1683,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if (is_embed := pos_info.is_embed) is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
|
||||
req_start_pos = req_start_idx + start_pos - num_computed_tokens
|
||||
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
|
||||
= True if is_embed is None else is_embed
|
||||
|
||||
mm_embeds_item = gather_mm_placeholders(
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
@ -1677,6 +1694,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
mm_embeds_req.append(mm_embeds_item)
|
||||
|
||||
if self.is_multimodal_pruning_enabled and self.uses_mrope:
|
||||
assert req_state.mrope_positions is not None
|
||||
should_sync_mrope_positions = True
|
||||
mm_embeds_req, new_mrope_positions, new_delta = (
|
||||
self.model.recompute_mrope_positions(
|
||||
@ -1685,18 +1703,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
mrope_positions=req_state.mrope_positions,
|
||||
num_computed_tokens=req_state.num_computed_tokens,
|
||||
))
|
||||
assert req_state.mrope_positions is not None
|
||||
req_state.mrope_positions.copy_(new_mrope_positions)
|
||||
req_state.mrope_position_delta = new_delta
|
||||
|
||||
mm_embeds.extend(mm_embeds_req)
|
||||
req_start_idx += num_scheduled_tokens
|
||||
|
||||
is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
|
||||
|
||||
if should_sync_mrope_positions:
|
||||
self._calc_mrope_positions(scheduler_output)
|
||||
self.mrope_positions.copy_to_gpu(
|
||||
scheduler_output.total_num_scheduled_tokens)
|
||||
self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
|
||||
|
||||
return mm_embeds
|
||||
return mm_embeds, is_mm_embed
|
||||
|
||||
def _extract_encoder_inputs(
|
||||
self,
|
||||
@ -1990,14 +2009,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
and not self.model_config.is_encoder_decoder):
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
|
||||
scheduler_output)
|
||||
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
# as input to the multimodal model, even when the input is text.
|
||||
inputs_embeds_scheduled = self.model.get_input_embeddings(
|
||||
input_ids=self.input_ids.gpu[:num_scheduled_tokens],
|
||||
multimodal_embeddings=mm_embeds or None,
|
||||
self.input_ids.gpu[:num_scheduled_tokens],
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
|
||||
# TODO(woosuk): Avoid the copy. Optimize.
|
||||
@ -2586,10 +2607,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
mm_embeds = None
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output,
|
||||
shift_computed_tokens=1)
|
||||
mm_embed_inputs = self._gather_mm_embeddings(
|
||||
scheduler_output,
|
||||
shift_computed_tokens=1,
|
||||
)
|
||||
else:
|
||||
mm_embed_inputs = None
|
||||
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
@ -2599,8 +2624,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
last_token_indices=token_indices_to_sample,
|
||||
sampling_metadata=sampling_metadata,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
mm_embeds=mm_embeds,
|
||||
mm_embed_inputs=mm_embed_inputs,
|
||||
)
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
|
||||
@ -263,6 +263,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
|
||||
# Only relevant for multimodal models
|
||||
if self.supports_mm_inputs:
|
||||
self.is_mm_embed_cpu = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.bool,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
# Range tensor with values [0 .. self.max_num_tokens - 1].
|
||||
# Used to initialize positions / context_lens / seq_lens
|
||||
# Keep in int64 to avoid overflow with long context
|
||||
@ -879,13 +886,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def _gather_mm_embeddings(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> list[torch.Tensor]:
|
||||
mm_embeds: list[torch.Tensor] = []
|
||||
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
padded_total_num_scheduled_tokens = _get_padded_token_len(
|
||||
self.num_tokens_paddings, total_num_scheduled_tokens)
|
||||
|
||||
is_mm_embed = self.is_mm_embed_cpu
|
||||
is_mm_embed[:padded_total_num_scheduled_tokens] = False
|
||||
mm_embeds = list[torch.Tensor]()
|
||||
req_start_idx = 0
|
||||
|
||||
for req_id in self.input_batch.req_ids:
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
|
||||
# TODO unroll loop and assume/enforce --disable_chunked_mm_input
|
||||
# NOTE (NickLucche) here we diverge from logic in other runners, as
|
||||
# we assume to only have whole mm items to process. Hence we avoid
|
||||
@ -906,26 +922,53 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
continue
|
||||
|
||||
start_idx = max(num_computed_tokens - start_pos, 0)
|
||||
end_idx = min(
|
||||
num_computed_tokens - start_pos + num_scheduled_tokens,
|
||||
num_encoder_tokens,
|
||||
)
|
||||
assert start_idx < end_idx
|
||||
|
||||
mm_hash = mm_feature.identifier
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
assert encoder_output is not None,\
|
||||
f"Encoder cache miss for {mm_hash}."
|
||||
|
||||
assert pos_info.is_embed is None, "Expected all positions to"\
|
||||
" be contiguous and embeddings."
|
||||
encoder_output = self.encoder_cache[mm_hash]
|
||||
mm_embeds.append(encoder_output)
|
||||
return mm_embeds
|
||||
|
||||
def _get_model_inputs(self, input_ids: torch.Tensor,
|
||||
mm_embeds: list[torch.Tensor]):
|
||||
req_start_pos = req_start_idx + start_pos - num_computed_tokens
|
||||
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
|
||||
= True
|
||||
|
||||
# Only whole mm items are processed
|
||||
mm_embeds.append(encoder_output)
|
||||
|
||||
req_start_idx += num_scheduled_tokens
|
||||
|
||||
is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens] \
|
||||
.to(self.device)
|
||||
|
||||
return mm_embeds, is_mm_embed
|
||||
|
||||
def _get_model_inputs(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]],
|
||||
):
|
||||
if self.supports_mm_inputs:
|
||||
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
|
||||
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
# as input to the multimodal model, even when the input is text.
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids=input_ids,
|
||||
input_ids,
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
|
||||
return None, inputs_embeds
|
||||
else:
|
||||
# For text-only models, we use token ids as input.
|
||||
@ -953,9 +996,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if self.supports_mm_inputs:
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||
mm_embed_inputs = self._gather_mm_embeddings(scheduler_output)
|
||||
else:
|
||||
mm_embeds = []
|
||||
mm_embed_inputs = None
|
||||
|
||||
torch_xla.sync(wait=False)
|
||||
# Prepare inputs, the requests might be split into multiple
|
||||
# executions, combine the result of each execution.
|
||||
@ -972,7 +1016,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
attn_metadata, logits_indices, padded_num_reqs, num_reqs,\
|
||||
end_index = self._prepare_inputs(scheduler_output, start_index)
|
||||
input_ids, inputs_embeds = self._get_model_inputs(
|
||||
self.input_ids, mm_embeds)
|
||||
self.input_ids, mm_embed_inputs)
|
||||
torch_xla.sync(wait=False)
|
||||
# Run the decoder
|
||||
with set_forward_context(
|
||||
@ -1325,9 +1369,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
hf_config.image_token_index
|
||||
|
||||
placeholders_ids = placeholders_ids.to(self.device)
|
||||
|
||||
mm_mask = torch.tensor([False] * num_tokens)
|
||||
mm_mask[:items_size] = True
|
||||
mm_mask = mm_mask.to(self.device)
|
||||
# Assign outputs or the graph will be cut short.
|
||||
a, b = self._get_model_inputs(placeholders_ids,
|
||||
[mm_embeds])
|
||||
a, b = self._get_model_inputs(
|
||||
placeholders_ids,
|
||||
mm_embed_inputs=([mm_embeds], mm_mask),
|
||||
)
|
||||
assert a is None
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
@ -1338,7 +1388,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
placeholders_ids = placeholders_ids.to(self.device)
|
||||
a, b = self._get_model_inputs(placeholders_ids, [])
|
||||
a, b = self._get_model_inputs(
|
||||
placeholders_ids,
|
||||
mm_embed_inputs=None,
|
||||
)
|
||||
assert a is None
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user