Migrate GraniteSpeechAudioInputs to TensorSchema (#21682)

Signed-off-by: Benji Beck <benjibeck@meta.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Benji Beck 2025-07-27 22:37:20 -07:00 committed by GitHub
parent 304dcdf575
commit 75856bc2cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,7 +25,7 @@
"""Inference-only IBM Granite speech model.""" """Inference-only IBM Granite speech model."""
import math import math
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from typing import Optional, TypedDict, Union from typing import Annotated, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -48,6 +48,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .blip2 import Blip2QFormerModel from .blip2 import Blip2QFormerModel
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@ -57,16 +58,24 @@ from .utils import (AutoWeightsLoader, embed_multimodal,
### Audio Input ### Audio Input
class GraniteSpeechAudioInputs(TypedDict): class GraniteSpeechAudioInputs(TensorSchema):
"""
Audio input features for Granite Speech model.
Dimensions:
- b: Batch size
- nf: Number of audio features (variable length)
- 160: Fixed feature dimension for Mel spectrogram features
"""
input_features: torch.Tensor input_features: Annotated[torch.Tensor, TensorShape("b", "nf", 160)]
"""Shape: `(bsz, num_features, 160)`""" """Audio input features."""
input_features_mask: torch.Tensor input_features_mask: Annotated[torch.Tensor, TensorShape("b", "nf")]
"""Shape: `(bsz, num_features)`""" """Mask for variable length audio features."""
audio_embed_sizes: list[int] audio_embed_sizes: Annotated[list[int], TensorShape("b")]
"""List of length `bsz`""" """List of audio embedding sizes for each item in batch."""
class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo): class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo):
@ -581,6 +590,7 @@ class GraniteSpeechForConditionalGeneration(
input_features = kwargs.pop("input_features", None) input_features = kwargs.pop("input_features", None)
input_features_mask = kwargs.pop("input_features_mask", None) input_features_mask = kwargs.pop("input_features_mask", None)
audio_embed_sizes = kwargs.pop("audio_embed_sizes", None) audio_embed_sizes = kwargs.pop("audio_embed_sizes", None)
if input_features is None: if input_features is None:
return None return None