[Model, Core] Support Granite Speech & LoRA for STT (#24455)

This commit is contained in:
Alex Brooks 2025-11-05 00:33:48 -07:00 committed by GitHub
parent d43ad5a757
commit b7cbc25416
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 169 additions and 8 deletions

View File

@ -761,6 +761,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | |
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-speech-3.3-2b`, `ibm-granite/granite-speech-3.3-8b`, etc. | ✅︎ | ✅︎ |
### Pooling Models

View File

@ -65,6 +65,41 @@ async def test_basic_audio(mary_had_lamb, model_name):
assert out_usage["seconds"] == 16, out_usage["seconds"]
@pytest.mark.asyncio
async def test_basic_audio_with_lora(mary_had_lamb):
"""Ensure STT (transcribe) requests can pass LoRA through to generate."""
model_name = "ibm-granite/granite-speech-3.3-2b"
lora_model_name = "speech"
server_args = [
"--enforce-eager",
"--enable-lora",
"--max-lora-rank",
"64",
"--lora-modules",
f"{lora_model_name}={model_name}",
"--max-model-len",
"2048",
"--max-num-seqs",
"1",
]
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
model=lora_model_name,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0,
)
out = json.loads(transcription)
out_text = out["text"]
out_usage = out["usage"]
assert "mary had a little lamb" in out_text
assert out_usage["seconds"] == 16, out_usage["seconds"]
@pytest.mark.asyncio
async def test_basic_audio_gemma(foscolo):
# Gemma accuracy on some of the audio samples we use is particularly bad,

View File

@ -48,6 +48,40 @@ async def test_non_asr_model(foscolo):
assert err["message"] == "The model does not support Translations API"
@pytest.mark.asyncio
async def test_basic_audio_with_lora(mary_had_lamb):
"""Ensure STT (translate) requests can pass LoRA through to generate."""
# NOTE - careful to call this test before the module scoped server
# fixture, otherwise it'll OOMkill the CI
model_name = "ibm-granite/granite-speech-3.3-2b"
lora_model_name = "speech"
server_args = [
"--enforce-eager",
"--enable-lora",
"--max-lora-rank",
"64",
"--lora-modules",
f"{lora_model_name}={model_name}",
"--max-model-len",
"2048",
"--max-num-seqs",
"1",
]
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
translation = await client.audio.translations.create(
model=lora_model_name,
file=mary_had_lamb,
extra_body=dict(language="en", to_language="es"),
response_format="text",
temperature=0.0,
)
out = json.loads(translation)["text"].strip().lower()
assert "mary tenía un pequeño cordero" in out
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
@pytest.mark.asyncio
async def test_basic_audio(foscolo, client_and_model):

View File

@ -170,11 +170,6 @@ class OpenAISpeechToText(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
if lora_request:
return self.create_error_response(
f"Currently do not support LoRA for {self.task_type.title()}."
)
prompts, duration_s = await self._preprocess_speech_to_text(
request=request,
audio_data=audio_data,
@ -199,7 +194,7 @@ class OpenAISpeechToText(OpenAIServing):
# It will not display special tokens like <|startoftranscript|>
request.prompt,
params=sampling_params,
lora_request=None,
lora_request=lora_request,
)
list_result_generator = [
@ -207,6 +202,7 @@ class OpenAISpeechToText(OpenAIServing):
prompt,
sampling_params,
request_id,
lora_request=lora_request,
)
for prompt in prompts
]

View File

@ -26,15 +26,17 @@
import math
from collections.abc import Iterable, Mapping
from typing import Annotated
from typing import Annotated, Literal, cast
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import BatchFeature, PretrainedConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys
@ -57,6 +59,8 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import cached_get_tokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .blip2 import Blip2QFormerModel
@ -65,9 +69,22 @@ from .interfaces import (
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
SupportsTranscription,
)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
# NOTE lang support is based on what is written here:
# https://huggingface.co/ibm-granite/granite-speech-3.3-2b
# Though this may vary from model to model, and also many langs
# work pretty well with zero shot.
ISO639_1_SUPPORTED_LANGS = {
"en": "English",
"fr": "French",
"de": "German",
"pt": "Portuguese",
"es": "Spanish",
}
### Audio Input
class GraniteSpeechAudioInputs(TensorSchema):
@ -545,8 +562,10 @@ class GraniteSpeechForConditionalGeneration(
SupportsMultiModal,
SupportsPP,
SupportsLoRA,
SupportsTranscription,
):
merge_by_field_config = True
supported_languages = ISO639_1_SUPPORTED_LANGS
packed_modules_mapping = {
"qkv_proj": [
@ -816,3 +835,79 @@ class GraniteSpeechForConditionalGeneration(
connector="projector",
tower_model="encoder",
)
### Support for speech-to-text Transcription
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
"""Get the generation prompt to be used for transcription requests."""
# Audio placeholders don't use an index, so value doesn't matter
audio_tok = cls.get_placeholder_str("audio", 0)
if task_type == "translate":
full_lang_name_to = cls.supported_languages.get(to_language, to_language)
user_prompt = f"{audio_tok}translate the speech to {full_lang_name_to}" # noqa: E501
elif task_type == "transcribe":
user_prompt = (
f"{audio_tok}can you transcribe the speech into a written format?" # noqa: E501
)
else:
raise ValueError(f"Unsupported task type {task_type}")
tokenizer = cached_get_tokenizer(model_config.model)
chat = [dict(role="user", content=user_prompt)]
prompt = tokenizer.apply_chat_template(
chat,
tokenize=False,
add_generation_prompt=True,
)
prompt_token_ids = tokenizer.encode(prompt)
prompt = {
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": {"audio": audio},
}
return cast(PromptType, prompt)
# Adapted from https://github.com/huggingface/transformers/blob/v4.56.0/src/transformers/models/granite_speech/feature_extraction_granite_speech.py#L122 # noqa: E501
@classmethod
def get_num_audio_tokens(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
) -> int | None:
"""Get the number of audio tokens for an audio duration in sec."""
processor = cached_get_processor(model_config.model)
hop_length = processor.audio_processor.melspec_kwargs["hop_length"]
proj_win_size = processor.audio_processor.projector_window_size
ds_rate = processor.audio_processor.projector_downsample_rate
effective_window_size = proj_win_size // ds_rate
raw_length = audio_duration_s * stt_config.sample_rate
# mel sequence length computation
mel_length = raw_length // hop_length + 1
# encoder frame takes two mel features
encoder_length = mel_length // 2
nblocks = math.ceil(encoder_length / proj_win_size)
# projector output length
return nblocks * effective_window_size
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
"""Get the stt config for this model."""
# Default settings are reasonable for this model and we don't currently
# expose this information in the model configs, but this may change in
# the future
return SpeechToTextConfig()