mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-10 01:07:04 +08:00
Rename clashing method names for vLLM model protocol (#27583)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
3226283461
commit
97d1c99302
@ -56,13 +56,13 @@ The initialization code should look like this:
|
||||
|
||||
### Computation Code
|
||||
|
||||
- Add a `get_input_embeddings` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model.
|
||||
- Add a `embed_input_ids` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model.
|
||||
|
||||
```python
|
||||
class MyModel(nn.Module):
|
||||
...
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ Further update the model as follows:
|
||||
|
||||
More conveniently, you can simply pass `**kwargs` to the [forward][torch.nn.Module.forward] method and retrieve the keyword parameters for multimodal inputs from it.
|
||||
|
||||
- Implement [get_multimodal_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
|
||||
- Implement [embed_multimodal][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_multimodal] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
|
||||
|
||||
??? code
|
||||
|
||||
@ -49,7 +49,7 @@ Further update the model as follows:
|
||||
image_features = self.vision_encoder(image_input)
|
||||
return self.multi_modal_projector(image_features)
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
def embed_multimodal(
|
||||
self,
|
||||
**kwargs: object,
|
||||
) -> MultiModalEmbeddings | None:
|
||||
@ -69,7 +69,7 @@ Further update the model as follows:
|
||||
!!! note
|
||||
By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in
|
||||
[PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing.
|
||||
This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings].
|
||||
This logic can be found at [embed_input_ids][vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids].
|
||||
|
||||
You may override this method if additional logic is required for your model when merging embeddings.
|
||||
|
||||
|
||||
@ -382,7 +382,7 @@ class ApertusModel(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -396,7 +396,7 @@ class ApertusModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -557,8 +557,8 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
vllm_config=vllm_config, prefix=prefix, layer_type=layer_type
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -239,7 +239,7 @@ class ArceeModel(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -254,7 +254,7 @@ class ArceeModel(nn.Module):
|
||||
hidden_states = (
|
||||
inputs_embeds
|
||||
if inputs_embeds is not None
|
||||
else self.get_input_embeddings(input_ids)
|
||||
else self.embed_input_ids(input_ids)
|
||||
)
|
||||
residual = None
|
||||
else:
|
||||
@ -423,8 +423,8 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
"""Load weights into the model (delegates to inner model and handles
|
||||
|
||||
@ -442,7 +442,7 @@ class ArcticModel(nn.Module):
|
||||
["hidden_states"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -456,7 +456,7 @@ class ArcticModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
@ -496,8 +496,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -613,7 +613,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
@ -629,8 +629,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if inputs_embeds is None:
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
multimodal_embeddings = self.embed_multimodal(**kwargs)
|
||||
inputs_embeds = self.embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
|
||||
@ -417,7 +417,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
@ -309,7 +309,7 @@ class BaiChuanModel(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -323,7 +323,7 @@ class BaiChuanModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -426,8 +426,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -438,7 +438,7 @@ class BailingMoeModel(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.word_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -452,7 +452,7 @@ class BailingMoeModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -608,8 +608,8 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -314,7 +314,7 @@ class BambaModel(nn.Module):
|
||||
|
||||
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -328,7 +328,7 @@ class BambaModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -493,8 +493,8 @@ class BambaForCausalLM(
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -375,7 +375,7 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
self.embeddings = embedding_class(self.config)
|
||||
self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder")
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings.word_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -486,8 +486,8 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
)
|
||||
self.pooler = self._build_pooler(pooler_config)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -835,8 +835,8 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
|
||||
}
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.bert.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.bert.embed_input_ids(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
@ -893,8 +893,8 @@ class BertForTokenClassification(nn.Module):
|
||||
}
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.bert.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.bert.embed_input_ids(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
|
||||
@ -463,7 +463,7 @@ class BertWithRope(nn.Module, SupportsQuant):
|
||||
)
|
||||
self.pooler = BertPooler(self.config) if add_pooling_layer else None
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -714,8 +714,8 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
loaded_params = loader.load_weights(weights)
|
||||
return loaded_params
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.new.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.new.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -630,7 +630,7 @@ class Blip2ForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
@ -271,7 +271,7 @@ class BloomModel(nn.Module):
|
||||
["hidden_states"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.word_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -285,7 +285,7 @@ class BloomModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
hidden_states = self.word_embeddings_layernorm(hidden_states)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -353,8 +353,8 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant):
|
||||
self.transformer.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -886,7 +886,7 @@ class ChameleonModel(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
@ -912,7 +912,7 @@ class ChameleonModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -998,7 +998,7 @@ class ChameleonForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
@ -1006,7 +1006,7 @@ class ChameleonForConditionalGeneration(
|
||||
image_tokens = self.model.get_image_tokens(
|
||||
image_input["data"].to(self.config.dtype)
|
||||
)
|
||||
vision_embeddings = self.model.get_input_embeddings(image_tokens)
|
||||
vision_embeddings = self.model.embed_input_ids(image_tokens)
|
||||
return vision_embeddings
|
||||
|
||||
def forward(
|
||||
|
||||
@ -353,7 +353,7 @@ class ChatGLMModel(nn.Module, SupportsQuant):
|
||||
self.encoder.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embedding(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -368,7 +368,7 @@ class ChatGLMModel(nn.Module, SupportsQuant):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
@ -451,8 +451,8 @@ class ChatGLMBaseModel(nn.Module):
|
||||
self.transformer.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.embed_input_ids(input_ids)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
|
||||
@ -561,7 +561,7 @@ class CLIPTextTransformer(nn.Module):
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings.token_embedding(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -842,7 +842,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
}
|
||||
)
|
||||
|
||||
# Assumes that self.forward is called after self.get_input_embeddings
|
||||
# Assumes that self.forward is called after self.embed_input_ids
|
||||
self._is_text_input = True
|
||||
|
||||
def get_text_features(
|
||||
@ -903,7 +903,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.text_model
|
||||
|
||||
def get_input_embeddings(
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
@ -917,16 +917,16 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
return super().embed_input_ids(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
return super().embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
@ -439,7 +439,7 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
@ -311,7 +311,7 @@ class CohereModel(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -325,7 +325,7 @@ class CohereModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -436,8 +436,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
|
||||
@ -354,7 +354,7 @@ class DbrxModel(nn.Module):
|
||||
["hidden_states"], config.d_model
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.wte(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -368,7 +368,7 @@ class DbrxModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
@ -455,8 +455,8 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
|
||||
self.transformer.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -73,7 +73,7 @@ class DeepseekV2Model(nn.Module):
|
||||
self.hnorm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
||||
self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -222,8 +222,8 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
|
||||
self.num_moe_layers = self.config.num_hidden_layers
|
||||
self.set_moe_parameters()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -142,7 +142,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -206,8 +206,8 @@ class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
self.extract_moe_parameters(example_moe)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -557,9 +557,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object
|
||||
) -> MultiModalEmbeddings | None:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
@ -1236,7 +1236,7 @@ class DeepseekV2Model(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -1250,7 +1250,7 @@ class DeepseekV2Model(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -1389,8 +1389,8 @@ class DeepseekV2ForCausalLM(
|
||||
|
||||
self.extract_moe_parameters(example_moe)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -619,7 +619,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
@ -398,7 +398,7 @@ class Dots1Model(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -412,7 +412,7 @@ class Dots1Model(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -541,8 +541,8 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -840,7 +840,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
@ -858,8 +858,8 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
vision_embeddings = self.embed_multimodal(**kwargs)
|
||||
inputs_embeds = self.embed_input_ids(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_id,
|
||||
|
||||
@ -465,7 +465,7 @@ class Ernie4_5_MoeModel(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -479,7 +479,7 @@ class Ernie4_5_MoeModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -726,8 +726,8 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -1656,9 +1656,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(
|
||||
|
||||
return modalities
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object
|
||||
) -> MultiModalEmbeddings | None:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return None
|
||||
@ -1681,7 +1679,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
@ -1694,9 +1692,9 @@ class Ernie4_5_VLMoeForConditionalGeneration(
|
||||
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
return super().embed_input_ids(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
return super().embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
|
||||
@ -561,7 +561,7 @@ class Ernie4_5_VLMoeModel(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -577,7 +577,7 @@ class Ernie4_5_VLMoeModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -642,8 +642,8 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -112,7 +112,7 @@ class ErnieMultiTokenPredictor(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -160,8 +160,8 @@ class ErnieMTP(nn.Module, SupportsPP):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -357,7 +357,7 @@ class ExaoneModel(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.wte(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -371,7 +371,7 @@ class ExaoneModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -512,8 +512,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.transformer.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -344,7 +344,7 @@ class Exaone4Model(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -358,7 +358,7 @@ class Exaone4Model(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -498,8 +498,8 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -399,7 +399,7 @@ class FalconModel(nn.Module):
|
||||
["hidden_states"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.word_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -413,7 +413,7 @@ class FalconModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for layer in islice(self.h, self.start_layer, self.end_layer):
|
||||
@ -515,8 +515,8 @@ class FalconForCausalLM(nn.Module, SupportsPP):
|
||||
self.transformer.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -461,7 +461,7 @@ class FalconH1Model(nn.Module):
|
||||
else:
|
||||
self.final_layernorm = PPMissingLayer()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -476,7 +476,7 @@ class FalconH1Model(nn.Module):
|
||||
hidden_states = inputs_embeds * self.embedding_multiplier
|
||||
else:
|
||||
hidden_states = (
|
||||
self.get_input_embeddings(input_ids) * self.embedding_multiplier
|
||||
self.embed_input_ids(input_ids) * self.embedding_multiplier
|
||||
)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -601,8 +601,8 @@ class FalconH1ForCausalLM(
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -333,7 +333,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
@ -293,7 +293,7 @@ class GemmaModel(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -307,7 +307,7 @@ class GemmaModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
hidden_states *= self.normalizer
|
||||
residual = None
|
||||
else:
|
||||
@ -396,8 +396,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -290,7 +290,7 @@ class Gemma2Model(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -304,7 +304,7 @@ class Gemma2Model(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
hidden_states *= self.normalizer
|
||||
residual = None
|
||||
else:
|
||||
@ -409,8 +409,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -393,7 +393,7 @@ class Gemma3Model(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE(woosuk): Only apply the normalizer to the output of
|
||||
# vocab embedding. Don't apply it to the vision embedding.
|
||||
return self.embed_tokens(input_ids) * self.normalizer
|
||||
@ -410,7 +410,7 @@ class Gemma3Model(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -540,8 +540,8 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -596,7 +596,7 @@ class Gemma3ForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
@ -685,7 +685,7 @@ class Gemma3nSelfDecoder(nn.Module):
|
||||
per_layer_inputs = per_layer_projection
|
||||
return per_layer_inputs
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor:
|
||||
@ -712,7 +712,7 @@ class Gemma3nSelfDecoder(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states_0 = inputs_embeds
|
||||
else:
|
||||
hidden_states_0 = self.get_input_embeddings(input_ids)
|
||||
hidden_states_0 = self.embed_input_ids(input_ids)
|
||||
|
||||
adjusted_per_layer_inputs = self.get_per_layer_inputs(
|
||||
hidden_states_0, per_layer_inputs
|
||||
@ -881,8 +881,8 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||
def get_per_layer_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.self_decoder.get_per_layer_input_embeddings(input_ids)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.self_decoder.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.self_decoder.embed_input_ids(input_ids)
|
||||
|
||||
def fast_prefill_forward(
|
||||
self,
|
||||
@ -1125,8 +1125,8 @@ class Gemma3nForCausalLM(nn.Module):
|
||||
config.vocab_size, soft_cap=config.final_logit_softcapping
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -645,7 +645,7 @@ class Gemma3nForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if mm_input_by_modality is None:
|
||||
return []
|
||||
@ -664,7 +664,7 @@ class Gemma3nForConditionalGeneration(
|
||||
multimodal_embeddings.extend(audio_embeddings)
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
@ -689,9 +689,9 @@ class Gemma3nForConditionalGeneration(
|
||||
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
return super().embed_input_ids(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
return super().embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
@ -709,10 +709,10 @@ class Gemma3nForConditionalGeneration(
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE (NickLucche) During profiling, `get_input_embeddings` is not
|
||||
# NOTE (NickLucche) During profiling, `embed_input_ids` is not
|
||||
# called, hence we don't have input_ids to compute PLEs. We simply
|
||||
# select a chunk of pre-allocated PLEs. During normal execution,
|
||||
# `get_input_embeddings` is called before forward, hence this slice
|
||||
# `embed_input_ids` is called before forward, hence this slice
|
||||
# will contain PLEs computed from the actual input_ids.
|
||||
per_layer_inputs = self.per_layer_embeddings[: inputs_embeds.shape[0]]
|
||||
|
||||
|
||||
@ -275,8 +275,8 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -1594,9 +1594,7 @@ class Glm4vForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object
|
||||
) -> MultiModalEmbeddings | None:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
|
||||
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not mm_input_by_modality:
|
||||
return None
|
||||
|
||||
@ -455,7 +455,7 @@ class Glm4MoeModel(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -469,7 +469,7 @@ class Glm4MoeModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -704,8 +704,8 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, Glm4MixtureOfExper
|
||||
|
||||
self.extract_moe_parameters(example_moe)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -149,7 +149,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -211,8 +211,8 @@ class Glm4MoeMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts):
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
self.extract_moe_parameters(example_moe)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -756,9 +756,9 @@ class GLM4VForCausalLM(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.transformer
|
||||
|
||||
get_input_embeddings = SupportsMultiModal.get_input_embeddings
|
||||
embed_input_ids = SupportsMultiModal.embed_input_ids
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
@ -213,7 +213,7 @@ class GPT2Model(nn.Module):
|
||||
["hidden_states"], config.n_embd
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.wte(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -225,7 +225,7 @@ class GPT2Model(nn.Module):
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
inputs_embeds = self.embed_input_ids(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
else:
|
||||
@ -293,8 +293,8 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
|
||||
self.transformer.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -365,8 +365,8 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
}
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.embed_input_ids(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
|
||||
@ -230,7 +230,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
["hidden_states"], config.n_embd
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.wte(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -242,7 +242,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
inputs_embeds = self.embed_input_ids(input_ids)
|
||||
hidden_states = inputs_embeds + self.wpe(position_ids)
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
@ -306,8 +306,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.transformer.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -215,7 +215,7 @@ class GPTJModel(nn.Module):
|
||||
["hidden_states"], config.n_embd
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.wte(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -229,7 +229,7 @@ class GPTJModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for layer in islice(self.h, self.start_layer, self.end_layer):
|
||||
@ -319,8 +319,8 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
|
||||
self.transformer.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -229,7 +229,7 @@ class GPTNeoXModel(nn.Module):
|
||||
["hidden_states"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_in(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -243,7 +243,7 @@ class GPTNeoXModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
@ -317,8 +317,8 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
|
||||
self.gpt_neox.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.gpt_neox.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.gpt_neox.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -269,7 +269,7 @@ class GptOssModel(nn.Module):
|
||||
)
|
||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embedding(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -283,7 +283,7 @@ class GptOssModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
x = inputs_embeds
|
||||
else:
|
||||
x = self.get_input_embeddings(input_ids)
|
||||
x = self.embed_input_ids(input_ids)
|
||||
|
||||
residual = None
|
||||
else:
|
||||
@ -703,8 +703,8 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -318,7 +318,7 @@ class GraniteModel(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -332,7 +332,7 @@ class GraniteModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
|
||||
hidden_states *= self.config.embedding_multiplier
|
||||
else:
|
||||
@ -473,8 +473,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -767,7 +767,7 @@ class GraniteSpeechForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
def embed_multimodal(
|
||||
self,
|
||||
**kwargs: object,
|
||||
) -> MultiModalEmbeddings:
|
||||
@ -779,7 +779,7 @@ class GraniteSpeechForConditionalGeneration(
|
||||
audio_features = self._process_audio_input(audio_input)
|
||||
return audio_features
|
||||
|
||||
def get_input_embeddings(
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
@ -790,9 +790,9 @@ class GraniteSpeechForConditionalGeneration(
|
||||
) -> torch.Tensor:
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
return super().embed_input_ids(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
return super().embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
|
||||
@ -315,7 +315,7 @@ class GraniteMoeModel(nn.Module):
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -329,7 +329,7 @@ class GraniteMoeModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
hidden_states *= self.embedding_multiplier
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -531,8 +531,8 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
scale=1 / self.config.logits_scaling,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -366,7 +366,7 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -380,7 +380,7 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
hidden_states = hidden_states * self.embedding_multiplier
|
||||
residual = None
|
||||
else:
|
||||
@ -680,8 +680,8 @@ class GraniteMoeHybridForCausalLM(
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -182,7 +182,7 @@ class GraniteMoeSharedModel(nn.Module):
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -196,7 +196,7 @@ class GraniteMoeSharedModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
hidden_states *= self.embedding_multiplier
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -295,8 +295,8 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
scale=1 / self.config.logits_scaling,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -334,7 +334,7 @@ class Grok1Model(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
hidden_states = hidden_states * self.embedding_multiplier_scale
|
||||
return hidden_states
|
||||
@ -350,7 +350,7 @@ class Grok1Model(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -522,8 +522,8 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -643,7 +643,7 @@ class HunYuanModel(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -657,7 +657,7 @@ class HunYuanModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -987,8 +987,8 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP):
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
|
||||
class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts):
|
||||
|
||||
@ -732,7 +732,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
def embed_multimodal(
|
||||
self,
|
||||
**kwargs: object,
|
||||
) -> MultiModalEmbeddings:
|
||||
|
||||
@ -550,8 +550,8 @@ class Idefics3Model(nn.Module):
|
||||
|
||||
return image_hidden_states
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.text_model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.text_model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -674,7 +674,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
@ -94,7 +94,7 @@ class SupportsMultiModal(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
"""
|
||||
Returns multimodal embeddings generated from multimodal kwargs
|
||||
to be merged with text embeddings.
|
||||
@ -104,7 +104,13 @@ class SupportsMultiModal(Protocol):
|
||||
the appearances of their corresponding multimodal data item in the
|
||||
input prompt.
|
||||
"""
|
||||
...
|
||||
if hasattr(self, "get_multimodal_embeddings"):
|
||||
logger.warning_once(
|
||||
"`get_multimodal_embeddings` for vLLM models is deprecated and will be "
|
||||
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
|
||||
"this method to `embed_multimodal`."
|
||||
)
|
||||
return self.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
def get_language_model(self) -> VllmModel:
|
||||
"""
|
||||
@ -119,10 +125,10 @@ class SupportsMultiModal(Protocol):
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_input_embeddings(self, input_ids: Tensor) -> Tensor: ...
|
||||
def embed_input_ids(self, input_ids: Tensor) -> Tensor: ...
|
||||
|
||||
@overload
|
||||
def get_input_embeddings(
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings,
|
||||
@ -131,17 +137,17 @@ class SupportsMultiModal(Protocol):
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> Tensor: ...
|
||||
|
||||
def _get_text_embeddings(
|
||||
def _embed_text_input_ids(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
get_input_embeddings: Callable[[Tensor], Tensor],
|
||||
embed_input_ids: Callable[[Tensor], Tensor],
|
||||
*,
|
||||
is_multimodal: Tensor | None,
|
||||
handle_oov_mm_token: bool,
|
||||
) -> Tensor:
|
||||
if handle_oov_mm_token and is_multimodal is not None:
|
||||
is_text = ~is_multimodal
|
||||
text_embeds = get_input_embeddings(input_ids[is_text])
|
||||
text_embeds = embed_input_ids(input_ids[is_text])
|
||||
|
||||
return torch.empty(
|
||||
(input_ids.shape[0], text_embeds.shape[1]),
|
||||
@ -149,9 +155,9 @@ class SupportsMultiModal(Protocol):
|
||||
device=text_embeds.device,
|
||||
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
|
||||
|
||||
return get_input_embeddings(input_ids)
|
||||
return embed_input_ids(input_ids)
|
||||
|
||||
def get_input_embeddings(
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
@ -167,15 +173,15 @@ class SupportsMultiModal(Protocol):
|
||||
|
||||
In case the multi-modal token IDs exceed the vocabulary size of
|
||||
the language model, you can set `handle_oov_mm_token=False`
|
||||
to avoid calling the language model's `get_input_embeddings` method
|
||||
to avoid calling the language model's `embed_input_ids` method
|
||||
on those tokens. Note however that doing so increases memory usage
|
||||
as an additional buffer is needed to hold the input embeddings.
|
||||
"""
|
||||
from .utils import _merge_multimodal_embeddings
|
||||
|
||||
inputs_embeds = self._get_text_embeddings(
|
||||
inputs_embeds = self._embed_text_input_ids(
|
||||
input_ids,
|
||||
self.get_language_model().get_input_embeddings,
|
||||
self.get_language_model().embed_input_ids,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
@ -185,7 +191,7 @@ class SupportsMultiModal(Protocol):
|
||||
|
||||
if is_multimodal is None:
|
||||
raise ValueError(
|
||||
"`get_input_embeddings` now requires `is_multimodal` arg, "
|
||||
"`embed_input_ids` now requires `is_multimodal` arg, "
|
||||
"please update your model runner according to "
|
||||
"https://github.com/vllm-project/vllm/pull/16229."
|
||||
)
|
||||
|
||||
@ -41,24 +41,19 @@ T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
|
||||
class VllmModel(Protocol[T_co]):
|
||||
"""The interface required for all models in vLLM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None: ...
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: ...
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply token embeddings to `input_ids`."""
|
||||
...
|
||||
if hasattr(self, "get_input_embeddings"):
|
||||
logger.warning_once(
|
||||
"`get_input_embeddings` for vLLM models is deprecated and will be "
|
||||
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
|
||||
"this method to `embed_input_ids`."
|
||||
)
|
||||
return self.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> T_co: ...
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor) -> T_co: ...
|
||||
|
||||
|
||||
def _check_vllm_model_init(model: type[object] | object) -> bool:
|
||||
@ -66,11 +61,19 @@ def _check_vllm_model_init(model: type[object] | object) -> bool:
|
||||
return supports_kw(model_init, "vllm_config")
|
||||
|
||||
|
||||
def _check_vllm_model_get_input_embeddings(model: type[object] | object) -> bool:
|
||||
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
|
||||
if not callable(model_get_input_embeddings):
|
||||
def _check_vllm_model_embed_input_ids(model: type[object] | object) -> bool:
|
||||
model_embed_input_ids = getattr(model, "embed_input_ids", None)
|
||||
if not callable(model_embed_input_ids):
|
||||
model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
|
||||
if callable(model_get_input_embeddings):
|
||||
logger.warning(
|
||||
"`get_input_embeddings` for vLLM models is deprecated and will be "
|
||||
"removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
|
||||
"this method to `embed_input_ids`."
|
||||
)
|
||||
model.embed_input_ids = model_get_input_embeddings
|
||||
logger.warning(
|
||||
"The model (%s) is missing the `get_input_embeddings` method.",
|
||||
"The model (%s) is missing the `embed_input_ids` method.",
|
||||
model,
|
||||
)
|
||||
return False
|
||||
@ -110,7 +113,7 @@ def is_vllm_model(
|
||||
) -> TypeIs[type[VllmModel]] | TypeIs[VllmModel]:
|
||||
return (
|
||||
_check_vllm_model_init(model)
|
||||
and _check_vllm_model_get_input_embeddings(model)
|
||||
and _check_vllm_model_embed_input_ids(model)
|
||||
and _check_vllm_model_forward(model)
|
||||
)
|
||||
|
||||
|
||||
@ -284,7 +284,7 @@ class InternLM2Model(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.tok_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -298,7 +298,7 @@ class InternLM2Model(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -350,8 +350,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -742,7 +742,7 @@ class InternS1ForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return []
|
||||
@ -765,7 +765,7 @@ class InternS1ForConditionalGeneration(
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
@ -778,9 +778,9 @@ class InternS1ForConditionalGeneration(
|
||||
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
return super().embed_input_ids(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
return super().embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
|
||||
@ -1344,7 +1344,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return []
|
||||
@ -1367,7 +1367,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
@ -1380,9 +1380,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
|
||||
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
return super().embed_input_ids(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
return super().embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
|
||||
@ -275,7 +275,7 @@ class JAISModel(nn.Module):
|
||||
["hidden_states"], config.n_embd
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.wte(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -287,7 +287,7 @@ class JAISModel(nn.Module):
|
||||
) -> IntermediateTensors | torch.Tensor:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
inputs_embeds = self.embed_input_ids(input_ids)
|
||||
if self.wpe is not None:
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
@ -339,8 +339,8 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
|
||||
self.transformer.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -340,7 +340,7 @@ class JambaModel(nn.Module):
|
||||
|
||||
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -354,7 +354,7 @@ class JambaModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -508,8 +508,8 @@ class JambaForCausalLM(
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -1484,9 +1484,7 @@ class BaseKeyeModule(nn.Module):
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object
|
||||
) -> MultiModalEmbeddings | None:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return None
|
||||
|
||||
@ -439,7 +439,7 @@ class KimiLinearModel(nn.Module):
|
||||
"num_attention_heads must be divisible by world_size"
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -454,7 +454,7 @@ class KimiLinearModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -504,8 +504,8 @@ class KimiLinearForCausalLM(
|
||||
self.config.vocab_size, scale=logit_scale
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -404,7 +404,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> NestedTensors | None:
|
||||
def embed_multimodal(self, **kwargs: object) -> NestedTensors | None:
|
||||
# Validate the multimodal input keyword arguments
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
|
||||
@ -351,7 +351,7 @@ class Lfm2Model(nn.Module):
|
||||
else:
|
||||
self.embedding_norm = PPMissingLayer()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -365,7 +365,7 @@ class Lfm2Model(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -504,8 +504,8 @@ class Lfm2ForCausalLM(
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -466,7 +466,7 @@ class Lfm2MoeModel(nn.Module):
|
||||
else:
|
||||
self.embedding_norm = PPMissingLayer()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -480,7 +480,7 @@ class Lfm2MoeModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -714,8 +714,8 @@ class Lfm2MoeForCausalLM(
|
||||
self.num_routed_experts = example_layer.n_routed_experts
|
||||
self.num_redundant_experts = example_layer.n_redundant_experts
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def update_physical_experts_metadata(
|
||||
self,
|
||||
|
||||
@ -424,7 +424,7 @@ class LlamaModel(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -438,7 +438,7 @@ class LlamaModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -640,8 +640,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
):
|
||||
return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -82,7 +82,7 @@ class LlamaModel(nn.Module):
|
||||
)
|
||||
self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -93,7 +93,7 @@ class LlamaModel(nn.Module):
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
inputs_embeds = self.embed_input_ids(input_ids)
|
||||
hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1))
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
@ -195,7 +195,7 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
get_input_embeddings = SupportsMultiModal.get_input_embeddings # type: ignore
|
||||
embed_input_ids = SupportsMultiModal.embed_input_ids # type: ignore
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -84,7 +84,7 @@ class LlamaModel(nn.Module):
|
||||
self.config.hidden_size * 2, self.config.hidden_size, bias=False
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -158,8 +158,8 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
self.config.vocab_size, scale=logit_scale
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -172,7 +172,7 @@ class LlamaModel(nn.Module):
|
||||
eps=self.config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -183,7 +183,7 @@ class LlamaModel(nn.Module):
|
||||
input_embeds: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if input_embeds is None:
|
||||
input_embeds = self.get_input_embeddings(input_ids)
|
||||
input_embeds = self.embed_input_ids(input_ids)
|
||||
assert hidden_states.shape[-1] == input_embeds.shape[-1]
|
||||
|
||||
residual = None
|
||||
@ -261,13 +261,13 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: NestedTensors | None = None,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -661,7 +661,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
@ -483,14 +483,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
@ -501,9 +501,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
|
||||
) -> torch.Tensor:
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
return super().embed_input_ids(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
return super().embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
|
||||
@ -422,7 +422,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||
if video_input is None:
|
||||
return []
|
||||
|
||||
@ -866,7 +866,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not mm_input_by_modality:
|
||||
return []
|
||||
|
||||
@ -498,7 +498,7 @@ class FlashModel(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -512,7 +512,7 @@ class FlashModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -583,8 +583,8 @@ class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -135,7 +135,7 @@ class MambaModel(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -149,7 +149,7 @@ class MambaModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -218,8 +218,8 @@ class MambaForCausalLM(
|
||||
self.backbone.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.backbone.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.backbone.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -131,7 +131,7 @@ class Mamba2Model(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -145,7 +145,7 @@ class Mamba2Model(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -257,8 +257,8 @@ class Mamba2ForCausalLM(
|
||||
self.backbone.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.backbone.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.backbone.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -791,7 +791,7 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.decoder
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
|
||||
if audio_input is None:
|
||||
|
||||
@ -70,7 +70,7 @@ class MiMoModel(Qwen2Model):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
|
||||
@ -120,7 +120,7 @@ class MiMoMultiTokenPredictor(nn.Module):
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -164,8 +164,8 @@ class MiMoMTP(nn.Module):
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -440,7 +440,7 @@ class MiniCPMModel(nn.Module):
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
embedding = self.embed_tokens(input_ids)
|
||||
return embedding * self.config.scale_emb
|
||||
|
||||
@ -455,7 +455,7 @@ class MiniCPMModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
@ -615,8 +615,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
return MiniCPMModel(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
@ -193,7 +193,7 @@ class EagleMiniCPMModel(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
embedding = self.embed_tokens(input_ids)
|
||||
return embedding * self.config.scale_emb
|
||||
|
||||
@ -203,7 +203,7 @@ class EagleMiniCPMModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
input_embeds = self.get_input_embeddings(input_ids)
|
||||
input_embeds = self.embed_input_ids(input_ids)
|
||||
input_embeds = self.input_norm1(input_embeds)
|
||||
hidden_states = self.input_norm2(hidden_states)
|
||||
|
||||
@ -354,8 +354,8 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
vllm_config=vllm_config, prefix=prefix, start_layer=start_layer
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -1139,7 +1139,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.llm
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return []
|
||||
|
||||
@ -360,7 +360,7 @@ class MiniMaxM2Model(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -374,7 +374,7 @@ class MiniMaxM2Model(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -510,8 +510,8 @@ class MiniMaxM2ForCausalLM(nn.Module, SupportsPP):
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -620,7 +620,7 @@ class MiniMaxText01Model(nn.Module):
|
||||
)
|
||||
minimax_cache_tensors[:, slots_tensor, ...] = 0
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -709,8 +709,8 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -353,7 +353,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
@ -371,8 +371,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
vision_embeddings = self.embed_multimodal(**kwargs)
|
||||
inputs_embeds = self.embed_input_ids(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_index,
|
||||
|
||||
@ -549,7 +549,7 @@ class Mistral3ForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
@ -345,7 +345,7 @@ class MixtralModel(nn.Module):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -359,7 +359,7 @@ class MixtralModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
@ -591,8 +591,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -865,7 +865,7 @@ class Llama4ForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
@ -46,7 +46,7 @@ class ModernBertEmbeddings(nn.Module):
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=eps, bias=config.norm_bias)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.tok_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -225,8 +225,8 @@ class ModernBertModel(nn.Module):
|
||||
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings.embed_input_ids(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
weights = self.hf_to_vllm_mapper.apply(weights)
|
||||
@ -337,8 +337,8 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
}
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
self_weights = []
|
||||
@ -424,8 +424,8 @@ class ModernBertForTokenClassification(nn.Module):
|
||||
}
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["drop"])
|
||||
|
||||
@ -832,7 +832,7 @@ class MolmoModel(nn.Module, SupportsQuant):
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -1491,7 +1491,7 @@ class MolmoForCausalLM(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
@ -248,7 +248,7 @@ class MPTModel(nn.Module):
|
||||
["hidden_states"], config.d_model
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.wte(input_ids)
|
||||
|
||||
def forward(
|
||||
@ -262,7 +262,7 @@ class MPTModel(nn.Module):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
@ -308,8 +308,8 @@ class MPTForCausalLM(nn.Module, SupportsPP):
|
||||
self.transformer.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids)
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -655,7 +655,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
The replacement returned is not actually used to replace the placeholder
|
||||
tokens - it's just used to make sure we allocate the correct number
|
||||
of tokens.
|
||||
Actual replacement is done in get_multimodal_embeddings of
|
||||
Actual replacement is done in embed_multimodal of
|
||||
NemotronH_Nano_VL_V2
|
||||
(specifically in _process_video_input -> _create_final_video_embeddings).
|
||||
There, we create the final embeddings with text embeddings for indicator tokens
|
||||
@ -1401,7 +1401,7 @@ class NemotronH_Nano_VL_V2(
|
||||
|
||||
# Create final video embeddings, merging text embeddings for indicator
|
||||
# tokens with video embeddings
|
||||
text_embeddings = self.get_language_model().get_input_embeddings(repl_token_ids)
|
||||
text_embeddings = self.get_language_model().embed_input_ids(repl_token_ids)
|
||||
final_video_embeddings = _merge_multimodal_embeddings(
|
||||
inputs_embeds=text_embeddings,
|
||||
multimodal_embeddings=video_embeddings,
|
||||
@ -1465,7 +1465,7 @@ class NemotronH_Nano_VL_V2(
|
||||
|
||||
return modalities
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
# Validate the multimodal input keyword arguments
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if modalities is None:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user