mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[Model] Use merge_by_field_config for MM models (G) (#26117)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
711f485643
commit
39b643dc1a
@ -36,7 +36,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||||
init_vllm_registered_model, maybe_prefix)
|
init_vllm_registered_model, maybe_prefix)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -289,7 +289,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
|||||||
processor=hf_processor)
|
processor=hf_processor)
|
||||||
for size in image_sizes
|
for size in image_sizes
|
||||||
]
|
]
|
||||||
processed_outputs["num_crops"] = torch.tensor(num_crops)
|
processed_outputs["num_patches"] = torch.tensor(num_crops) + 1
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
@ -298,12 +298,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
|||||||
hf_inputs: BatchFeature,
|
hf_inputs: BatchFeature,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
num_crops = hf_inputs.get("num_crops", torch.empty(0))
|
num_patches = hf_inputs.get("num_patches", torch.empty(0))
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"image", num_crops + 1),
|
"image", num_patches),
|
||||||
num_crops=MultiModalFieldConfig.batched("image"),
|
num_patches=MultiModalFieldConfig.batched("image"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
@ -460,6 +460,8 @@ class Gemma3MultiModalProjector(nn.Module):
|
|||||||
dummy_inputs=Gemma3DummyInputsBuilder)
|
dummy_inputs=Gemma3DummyInputsBuilder)
|
||||||
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||||
SupportsLoRA):
|
SupportsLoRA):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -526,29 +528,20 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
|
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
num_crops = kwargs.pop("num_crops", None)
|
num_patches = kwargs.pop("num_patches", None)
|
||||||
image_embeds = kwargs.pop("image_embeds", None)
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
|
||||||
f"Got type: {type(pixel_values)}")
|
|
||||||
|
|
||||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of num_crops. "
|
|
||||||
f"Got type: {type(num_crops)}")
|
|
||||||
|
|
||||||
image_size = self.config.vision_config.image_size
|
image_size = self.config.vision_config.image_size
|
||||||
|
|
||||||
return Gemma3ImagePixelInputs(
|
return Gemma3ImagePixelInputs(pixel_values=pixel_values,
|
||||||
pixel_values=flatten_bn(pixel_values, concat=True),
|
num_patches=num_patches,
|
||||||
num_patches=flatten_bn(num_crops, concat=True) + 1,
|
resolve_bindings={
|
||||||
resolve_bindings={
|
"h": image_size,
|
||||||
"h": image_size,
|
"w": image_size
|
||||||
"w": image_size
|
})
|
||||||
})
|
|
||||||
|
|
||||||
def _image_pixels_to_features(
|
def _image_pixels_to_features(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import Any, Literal, Optional, TypedDict, Union, cast
|
from typing import Annotated, Any, Literal, Optional, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -41,6 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||||
SupportsTranscription)
|
SupportsTranscription)
|
||||||
@ -54,17 +55,28 @@ TOKENS_PER_IMAGE = 256
|
|||||||
TOKENS_PER_AUDIO = 188
|
TOKENS_PER_AUDIO = 188
|
||||||
|
|
||||||
|
|
||||||
class Gemma3nImagePixelInputs(TypedDict):
|
class Gemma3nImagePixelInputs(TensorSchema):
|
||||||
pixel_values: torch.Tensor
|
"""
|
||||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
Dimensions:
|
||||||
|
- bn: Batch size * number of images
|
||||||
|
- c: Number of channels (3)
|
||||||
|
- h: Height of each patch
|
||||||
|
- w: Width of each patch
|
||||||
|
"""
|
||||||
|
type: Literal["pixel_values"] = "pixel_values"
|
||||||
|
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
||||||
|
|
||||||
|
|
||||||
class Gemma3nAudioInputs(TypedDict):
|
class Gemma3nAudioInputs(TensorSchema):
|
||||||
input_features: Union[torch.Tensor, list[torch.Tensor]]
|
"""
|
||||||
input_features_padded: torch.Tensor
|
Dimensions:
|
||||||
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
|
- bn: Batch size * number of audios
|
||||||
input_features_mask: torch.Tensor
|
- s: seq_length
|
||||||
"""Shape: `(batch_size * num_audio, seq_length)`"""
|
- f: num_features
|
||||||
|
"""
|
||||||
|
type: Literal["audio"] = "audio"
|
||||||
|
input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")]
|
||||||
|
input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")]
|
||||||
|
|
||||||
|
|
||||||
Gemma3nImageInputs = Gemma3nImagePixelInputs
|
Gemma3nImageInputs = Gemma3nImagePixelInputs
|
||||||
@ -212,9 +224,9 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
|||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
input_features=MultiModalFieldConfig.batched("audio"),
|
|
||||||
input_features_padded=MultiModalFieldConfig.batched("audio"),
|
input_features_padded=MultiModalFieldConfig.batched("audio"),
|
||||||
input_features_mask=MultiModalFieldConfig.batched("audio"))
|
input_features_mask=MultiModalFieldConfig.batched("audio"),
|
||||||
|
)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
self,
|
self,
|
||||||
@ -422,6 +434,7 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
|||||||
dummy_inputs=Gemma3nDummyInputsBuilder)
|
dummy_inputs=Gemma3nDummyInputsBuilder)
|
||||||
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsTranscription):
|
SupportsTranscription):
|
||||||
|
merge_by_field_config = True
|
||||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
@ -482,14 +495,6 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
device=self.language_model.model.embed_tokens.weight.device,
|
device=self.language_model.model.embed_tokens.weight.device,
|
||||||
dtype=self.language_model.model.embed_tokens.weight.dtype)
|
dtype=self.language_model.model.embed_tokens.weight.dtype)
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return next(self.parameters()).dtype
|
|
||||||
|
|
||||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
|
||||||
# TODO check if there are any
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[Gemma3nImageInputs]:
|
self, **kwargs: object) -> Optional[Gemma3nImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
@ -499,34 +504,22 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
return Gemma3nImagePixelInputs(pixel_values=pixel_values)
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
|
||||||
f"Got type: {type(pixel_values)}")
|
|
||||||
|
|
||||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
|
||||||
pixel_values = pixel_values.contiguous()
|
|
||||||
|
|
||||||
return Gemma3nImagePixelInputs(
|
|
||||||
pixel_values=self._validate_pixel_values(pixel_values), )
|
|
||||||
|
|
||||||
def _parse_and_validate_audio_input(
|
def _parse_and_validate_audio_input(
|
||||||
self, **kwargs: object) -> Optional[Gemma3nAudioInputs]:
|
self, **kwargs: object) -> Optional[Gemma3nAudioInputs]:
|
||||||
input_features = kwargs.pop("input_features", None)
|
|
||||||
if input_features is None:
|
input_features_padded = kwargs.pop("input_features_padded", None)
|
||||||
|
if input_features_padded is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
input_features_mask = kwargs.pop("input_features_mask", None)
|
input_features_mask = kwargs.pop("input_features_mask", None)
|
||||||
if input_features_mask is None:
|
if input_features_mask is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
input_features_padded = kwargs.pop("input_features_padded", None)
|
|
||||||
if input_features_padded is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return Gemma3nAudioInputs(
|
return Gemma3nAudioInputs(
|
||||||
input_features=input_features,
|
|
||||||
input_features_mask=input_features_mask,
|
|
||||||
input_features_padded=input_features_padded,
|
input_features_padded=input_features_padded,
|
||||||
|
input_features_mask=input_features_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
@ -539,7 +532,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
) and "image" not in mm_input_by_modality:
|
) and "image" not in mm_input_by_modality:
|
||||||
mm_input_by_modality[
|
mm_input_by_modality[
|
||||||
"image"] = self._parse_and_validate_image_input(**kwargs)
|
"image"] = self._parse_and_validate_image_input(**kwargs)
|
||||||
if input_key == "input_features" \
|
if input_key == "input_features_padded" \
|
||||||
and "audio" not in mm_input_by_modality:
|
and "audio" not in mm_input_by_modality:
|
||||||
mm_input_by_modality[
|
mm_input_by_modality[
|
||||||
"audio"] = self._parse_and_validate_audio_input(**kwargs)
|
"audio"] = self._parse_and_validate_audio_input(**kwargs)
|
||||||
|
|||||||
@ -1319,6 +1319,8 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
|
|||||||
)
|
)
|
||||||
class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsLoRA, SupportsPP):
|
SupportsLoRA, SupportsPP):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -1381,22 +1383,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.language_model.make_empty_intermediate_tensors)
|
self.language_model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
|
||||||
name: str) -> torch.Tensor:
|
|
||||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
|
||||||
raise ValueError(
|
|
||||||
f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
|
||||||
if isinstance(mm_input, torch.Tensor):
|
|
||||||
if mm_input.ndim == 2:
|
|
||||||
return mm_input
|
|
||||||
if mm_input.ndim != 3:
|
|
||||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
|
||||||
f"Got ndim: {mm_input.ndim} "
|
|
||||||
f"(shape={mm_input.shape})")
|
|
||||||
return mm_input.reshape(-1, mm_input.shape[-1])
|
|
||||||
else:
|
|
||||||
return torch.concat(mm_input)
|
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[Glm4vImageInputs]:
|
self, **kwargs: object) -> Optional[Glm4vImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
@ -1407,11 +1393,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
|
||||||
pixel_values, "image pixel values")
|
|
||||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
image_grid_thw, "image grid_thw")
|
|
||||||
|
|
||||||
return Glm4vImagePixelInputs(
|
return Glm4vImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
@ -1419,11 +1400,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
)
|
)
|
||||||
|
|
||||||
if image_embeds is not None:
|
if image_embeds is not None:
|
||||||
image_embeds = self._validate_and_reshape_mm_tensor(
|
|
||||||
image_embeds, "image embeds")
|
|
||||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
image_grid_thw, "image grid_thw")
|
|
||||||
|
|
||||||
return Glm4vImageEmbeddingInputs(
|
return Glm4vImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
image_embeds=image_embeds,
|
image_embeds=image_embeds,
|
||||||
@ -1440,11 +1416,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if pixel_values_videos is not None:
|
if pixel_values_videos is not None:
|
||||||
pixel_values_videos = self._validate_and_reshape_mm_tensor(
|
|
||||||
pixel_values_videos, "video pixel values")
|
|
||||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
video_grid_thw, "video grid_thw")
|
|
||||||
|
|
||||||
return Glm4vVideoPixelInputs(
|
return Glm4vVideoPixelInputs(
|
||||||
type="pixel_values_videos",
|
type="pixel_values_videos",
|
||||||
pixel_values_videos=pixel_values_videos,
|
pixel_values_videos=pixel_values_videos,
|
||||||
@ -1452,11 +1423,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
)
|
)
|
||||||
|
|
||||||
if video_embeds is not None:
|
if video_embeds is not None:
|
||||||
video_embeds = self._validate_and_reshape_mm_tensor(
|
|
||||||
video_embeds, "video embeds")
|
|
||||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
video_grid_thw, "video grid_thw")
|
|
||||||
|
|
||||||
return Glm4vVideoEmbeddingInputs(
|
return Glm4vVideoEmbeddingInputs(
|
||||||
type="video_embeds",
|
type="video_embeds",
|
||||||
video_embeds=video_embeds,
|
video_embeds=video_embeds,
|
||||||
|
|||||||
@ -43,7 +43,6 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
from .chatglm import ChatGLMBaseModel, ChatGLMModel
|
from .chatglm import ChatGLMBaseModel, ChatGLMModel
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
from .utils import flatten_bn
|
|
||||||
|
|
||||||
|
|
||||||
class GLMVImagePixelInputs(TensorSchema):
|
class GLMVImagePixelInputs(TensorSchema):
|
||||||
@ -529,8 +528,9 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
|
|||||||
@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor,
|
@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor,
|
||||||
info=GLM4VProcessingInfo,
|
info=GLM4VProcessingInfo,
|
||||||
dummy_inputs=GLM4VDummyInputsBuilder)
|
dummy_inputs=GLM4VDummyInputsBuilder)
|
||||||
class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA,
|
||||||
SupportsMultiModal):
|
SupportsPP):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"query_key_value": ["query_key_value"],
|
"query_key_value": ["query_key_value"],
|
||||||
@ -574,14 +574,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
|||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
|
||||||
f"Got type: {type(pixel_values)}")
|
|
||||||
|
|
||||||
expected_h = expected_w = self.config.vision_config["image_size"]
|
expected_h = expected_w = self.config.vision_config["image_size"]
|
||||||
return GLMVImagePixelInputs(type="pixel_values",
|
return GLMVImagePixelInputs(type="pixel_values",
|
||||||
data=flatten_bn(pixel_values,
|
data=pixel_values,
|
||||||
concat=True),
|
|
||||||
resolve_bindings={
|
resolve_bindings={
|
||||||
"h": expected_h,
|
"h": expected_h,
|
||||||
"w": expected_w
|
"w": expected_w
|
||||||
@ -598,6 +593,8 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
|||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
return self.transformer
|
return self.transformer
|
||||||
|
|
||||||
|
get_input_embeddings = SupportsMultiModal.get_input_embeddings
|
||||||
|
|
||||||
def get_multimodal_embeddings(self,
|
def get_multimodal_embeddings(self,
|
||||||
**kwargs: object) -> MultiModalEmbeddings:
|
**kwargs: object) -> MultiModalEmbeddings:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
|||||||
@ -168,10 +168,8 @@ class GraniteSpeechMultiModalProcessor(
|
|||||||
# Calculate the number of audio tokens per entry in the batch;
|
# Calculate the number of audio tokens per entry in the batch;
|
||||||
# This is used to split the batch back out after padding.
|
# This is used to split the batch back out after padding.
|
||||||
audio_token_index = self.info.get_hf_config().audio_token_index
|
audio_token_index = self.info.get_hf_config().audio_token_index
|
||||||
processed_outputs["audio_embed_sizes"] = [
|
processed_outputs["audio_embed_sizes"] = (
|
||||||
torch.sum(indices == audio_token_index).item()
|
processed_outputs["input_ids"] == audio_token_index).sum(-1)
|
||||||
for indices in processed_outputs["input_ids"]
|
|
||||||
]
|
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
@ -527,6 +525,7 @@ class GraniteSpeechForConditionalGeneration(
|
|||||||
SupportsPP,
|
SupportsPP,
|
||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
):
|
):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user