mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-11 20:22:18 +08:00
[Model] Define merge_by_field_config MM interface (U-Z) (#26261)
Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
4dbdf4a294
commit
5f7e8a916a
@ -69,18 +69,16 @@ class UltravoxAudioFeatureInputs(TensorSchema):
|
|||||||
type: Literal["audio_features"]
|
type: Literal["audio_features"]
|
||||||
data: Annotated[
|
data: Annotated[
|
||||||
Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]],
|
Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]],
|
||||||
TensorShape("b", "n", "nmb", "t", dynamic_dims={"n"}),
|
TensorShape("bn", "nmb", "t"),
|
||||||
]
|
]
|
||||||
lens: Annotated[
|
lens: Annotated[torch.Tensor, TensorShape("bn")]
|
||||||
Union[torch.Tensor, list[torch.Tensor]],
|
"""
|
||||||
TensorShape("b", "n", dynamic_dims={"n"}),
|
Length of the audio frames per chunk. Used for attention mask in WhisperEncoder.
|
||||||
]
|
"""
|
||||||
"""Length of the audio frames. Used for attention mask in WhisperEncoder."""
|
token_len: Annotated[torch.Tensor, TensorShape("bn")]
|
||||||
token_len: Annotated[
|
"""Length of the audio tokens per chunk. Used for flattening the audio features."""
|
||||||
Union[torch.Tensor, list[torch.Tensor]],
|
num_chunks: Annotated[torch.Tensor, TensorShape("n")]
|
||||||
TensorShape("b", "n", dynamic_dims={"n"}),
|
"""Number of chunks per audio. Used for flattening the audio features."""
|
||||||
]
|
|
||||||
"""Length of the audio tokens. Used for flattening the audio features."""
|
|
||||||
|
|
||||||
|
|
||||||
class UltravoxAudioEmbeddingInputs(TensorSchema):
|
class UltravoxAudioEmbeddingInputs(TensorSchema):
|
||||||
@ -421,6 +419,8 @@ class ModifiedWhisperEncoder(WhisperEncoder):
|
|||||||
dummy_inputs=UltravoxDummyInputsBuilder,
|
dummy_inputs=UltravoxDummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
@ -519,6 +519,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
|||||||
audio_embeds = kwargs.pop("audio_embeds", None)
|
audio_embeds = kwargs.pop("audio_embeds", None)
|
||||||
audio_lens = kwargs.pop("audio_lens", None)
|
audio_lens = kwargs.pop("audio_lens", None)
|
||||||
audio_token_len = kwargs.pop("audio_token_len", None)
|
audio_token_len = kwargs.pop("audio_token_len", None)
|
||||||
|
audio_num_chunks = kwargs.pop("audio_num_chunks", None)
|
||||||
|
|
||||||
if audio_features is None and audio_embeds is None:
|
if audio_features is None and audio_embeds is None:
|
||||||
return None
|
return None
|
||||||
@ -529,6 +530,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
|||||||
data=audio_features,
|
data=audio_features,
|
||||||
lens=audio_lens,
|
lens=audio_lens,
|
||||||
token_len=audio_token_len,
|
token_len=audio_token_len,
|
||||||
|
num_chunks=audio_num_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
if audio_embeds is not None:
|
if audio_embeds is not None:
|
||||||
@ -547,9 +549,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
|||||||
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
|
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
|
||||||
audio_features = pad_and_concat_to_dim3(audio_input["data"])
|
audio_features = pad_and_concat_to_dim3(audio_input["data"])
|
||||||
|
|
||||||
# [B1, B2] -> [B1+B2]
|
audio_lens = audio_input["lens"]
|
||||||
audio_lens = flatten_bn(audio_input["lens"], concat=True)
|
audio_token_len = audio_input["token_len"]
|
||||||
audio_token_len = flatten_bn(audio_input["token_len"], concat=True)
|
|
||||||
|
|
||||||
embeddings = self._audio_features_to_embeddings(audio_features, audio_lens)
|
embeddings = self._audio_features_to_embeddings(audio_features, audio_lens)
|
||||||
|
|
||||||
@ -568,7 +569,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
|||||||
|
|
||||||
# Return one tensor per input audio
|
# Return one tensor per input audio
|
||||||
embed_lens = [
|
embed_lens = [
|
||||||
token_len_item.sum().item() for token_len_item in audio_input["token_len"]
|
chunk_lens.sum().item()
|
||||||
|
for chunk_lens in audio_token_len.split(audio_input["num_chunks"].tolist())
|
||||||
]
|
]
|
||||||
return flattened_embeddings.split(embed_lens)
|
return flattened_embeddings.split(embed_lens)
|
||||||
|
|
||||||
@ -663,6 +665,7 @@ def pad_and_concat_to_dim3(
|
|||||||
if features.ndim > 3:
|
if features.ndim > 3:
|
||||||
# Flatten [B, N, 80, M] -> [B * N, 80, M]
|
# Flatten [B, N, 80, M] -> [B * N, 80, M]
|
||||||
features = flatten_bn(features)
|
features = flatten_bn(features)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
features = [pad_and_concat_to_dim3(f) for f in features]
|
features = [pad_and_concat_to_dim3(f) for f in features]
|
||||||
|
|||||||
@ -61,7 +61,7 @@ from vllm.transformers_utils.tokenizer import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
|
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
|
||||||
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
|
from .utils import init_vllm_registered_model, maybe_prefix
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -337,6 +337,8 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
|
|||||||
class VoxtralForConditionalGeneration(
|
class VoxtralForConditionalGeneration(
|
||||||
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
|
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
|
||||||
):
|
):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
@ -445,7 +447,6 @@ class VoxtralForConditionalGeneration(
|
|||||||
f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}"
|
f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_arrays = flatten_bn(audio_arrays)
|
|
||||||
if isinstance(audio_arrays, torch.Tensor):
|
if isinstance(audio_arrays, torch.Tensor):
|
||||||
audio_arrays = list(audio_arrays.unbind(0))
|
audio_arrays = list(audio_arrays.unbind(0))
|
||||||
return audio_arrays
|
return audio_arrays
|
||||||
|
|||||||
@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (
|
from vllm.multimodal.inputs import (
|
||||||
MultiModalDataDict,
|
MultiModalDataDict,
|
||||||
MultiModalFieldConfig,
|
MultiModalFieldConfig,
|
||||||
@ -51,6 +51,7 @@ from vllm.multimodal.processing import (
|
|||||||
)
|
)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.transformers_utils.processor import cached_get_processor
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
|
from vllm.utils.jsontree import json_map_leaves
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
|
||||||
@ -135,7 +136,10 @@ class WhisperAudioInputs(TensorSchema):
|
|||||||
- t: Time frames (M)
|
- t: Time frames (M)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input_features: Annotated[Optional[NestedTensors], TensorShape("b", "nmb", "t")]
|
input_features: Annotated[
|
||||||
|
Optional[list[torch.Tensor]],
|
||||||
|
TensorShape("b", "nmb", "t"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class WhisperEncoderAttention(MultiHeadAttention):
|
class WhisperEncoderAttention(MultiHeadAttention):
|
||||||
@ -781,6 +785,7 @@ class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo
|
|||||||
class WhisperForConditionalGeneration(
|
class WhisperForConditionalGeneration(
|
||||||
nn.Module, SupportsTranscription, SupportsMultiModal
|
nn.Module, SupportsTranscription, SupportsMultiModal
|
||||||
):
|
):
|
||||||
|
merge_by_field_config = True
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"self_attn.qkv_proj": [
|
"self_attn.qkv_proj": [
|
||||||
"self_attn.q_proj",
|
"self_attn.q_proj",
|
||||||
@ -936,12 +941,7 @@ class WhisperForConditionalGeneration(
|
|||||||
input_features = kwargs.pop("input_features", None)
|
input_features = kwargs.pop("input_features", None)
|
||||||
|
|
||||||
if input_features is not None:
|
if input_features is not None:
|
||||||
if not isinstance(input_features, (torch.Tensor, list)):
|
input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)
|
||||||
raise ValueError(
|
|
||||||
"Incorrect type of audio features. "
|
|
||||||
f"Got type: {type(input_features)}"
|
|
||||||
)
|
|
||||||
input_features = torch.cat([feat.to(self.dtype) for feat in input_features])
|
|
||||||
|
|
||||||
return WhisperAudioInputs(input_features=input_features)
|
return WhisperAudioInputs(input_features=input_features)
|
||||||
|
|
||||||
|
|||||||
@ -677,6 +677,9 @@ class MultiModalFieldConfig:
|
|||||||
self.field = field
|
self.field = field
|
||||||
self.modality = modality
|
self.modality = modality
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"MultiModalFieldConfig(field={self.field}, modality={self.modality})"
|
||||||
|
|
||||||
def build_elems(
|
def build_elems(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user