mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 19:37:14 +08:00
[Model, Core] Support Granite Speech & LoRA for STT (#24455)
This commit is contained in:
parent
d43ad5a757
commit
b7cbc25416
@ -761,6 +761,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
|
||||
| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | |
|
||||
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ |
|
||||
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
|
||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-speech-3.3-2b`, `ibm-granite/granite-speech-3.3-8b`, etc. | ✅︎ | ✅︎ |
|
||||
|
||||
### Pooling Models
|
||||
|
||||
|
||||
@ -65,6 +65,41 @@ async def test_basic_audio(mary_had_lamb, model_name):
|
||||
assert out_usage["seconds"] == 16, out_usage["seconds"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_audio_with_lora(mary_had_lamb):
|
||||
"""Ensure STT (transcribe) requests can pass LoRA through to generate."""
|
||||
model_name = "ibm-granite/granite-speech-3.3-2b"
|
||||
lora_model_name = "speech"
|
||||
server_args = [
|
||||
"--enforce-eager",
|
||||
"--enable-lora",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--lora-modules",
|
||||
f"{lora_model_name}={model_name}",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"1",
|
||||
]
|
||||
|
||||
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
transcription = await client.audio.transcriptions.create(
|
||||
model=lora_model_name,
|
||||
file=mary_had_lamb,
|
||||
language="en",
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
)
|
||||
out = json.loads(transcription)
|
||||
out_text = out["text"]
|
||||
out_usage = out["usage"]
|
||||
assert "mary had a little lamb" in out_text
|
||||
assert out_usage["seconds"] == 16, out_usage["seconds"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_audio_gemma(foscolo):
|
||||
# Gemma accuracy on some of the audio samples we use is particularly bad,
|
||||
|
||||
@ -48,6 +48,40 @@ async def test_non_asr_model(foscolo):
|
||||
assert err["message"] == "The model does not support Translations API"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_audio_with_lora(mary_had_lamb):
|
||||
"""Ensure STT (translate) requests can pass LoRA through to generate."""
|
||||
# NOTE - careful to call this test before the module scoped server
|
||||
# fixture, otherwise it'll OOMkill the CI
|
||||
model_name = "ibm-granite/granite-speech-3.3-2b"
|
||||
lora_model_name = "speech"
|
||||
server_args = [
|
||||
"--enforce-eager",
|
||||
"--enable-lora",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--lora-modules",
|
||||
f"{lora_model_name}={model_name}",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"1",
|
||||
]
|
||||
|
||||
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
translation = await client.audio.translations.create(
|
||||
model=lora_model_name,
|
||||
file=mary_had_lamb,
|
||||
extra_body=dict(language="en", to_language="es"),
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
)
|
||||
out = json.loads(translation)["text"].strip().lower()
|
||||
assert "mary tenía un pequeño cordero" in out
|
||||
|
||||
|
||||
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_audio(foscolo, client_and_model):
|
||||
|
||||
@ -170,11 +170,6 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if lora_request:
|
||||
return self.create_error_response(
|
||||
f"Currently do not support LoRA for {self.task_type.title()}."
|
||||
)
|
||||
|
||||
prompts, duration_s = await self._preprocess_speech_to_text(
|
||||
request=request,
|
||||
audio_data=audio_data,
|
||||
@ -199,7 +194,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
# It will not display special tokens like <|startoftranscript|>
|
||||
request.prompt,
|
||||
params=sampling_params,
|
||||
lora_request=None,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
list_result_generator = [
|
||||
@ -207,6 +202,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
@ -26,15 +26,17 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Literal, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@ -57,6 +59,8 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.transformers_utils.tokenizer import cached_get_tokenizer
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .blip2 import Blip2QFormerModel
|
||||
@ -65,9 +69,22 @@ from .interfaces import (
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsTranscription,
|
||||
)
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
|
||||
# NOTE lang support is based on what is written here:
|
||||
# https://huggingface.co/ibm-granite/granite-speech-3.3-2b
|
||||
# Though this may vary from model to model, and also many langs
|
||||
# work pretty well with zero shot.
|
||||
ISO639_1_SUPPORTED_LANGS = {
|
||||
"en": "English",
|
||||
"fr": "French",
|
||||
"de": "German",
|
||||
"pt": "Portuguese",
|
||||
"es": "Spanish",
|
||||
}
|
||||
|
||||
|
||||
### Audio Input
|
||||
class GraniteSpeechAudioInputs(TensorSchema):
|
||||
@ -545,8 +562,10 @@ class GraniteSpeechForConditionalGeneration(
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsLoRA,
|
||||
SupportsTranscription,
|
||||
):
|
||||
merge_by_field_config = True
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
@ -816,3 +835,79 @@ class GraniteSpeechForConditionalGeneration(
|
||||
connector="projector",
|
||||
tower_model="encoder",
|
||||
)
|
||||
|
||||
### Support for speech-to-text Transcription
|
||||
@classmethod
|
||||
def get_generation_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
model_config: ModelConfig,
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: str | None,
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: str | None,
|
||||
) -> PromptType:
|
||||
"""Get the generation prompt to be used for transcription requests."""
|
||||
# Audio placeholders don't use an index, so value doesn't matter
|
||||
audio_tok = cls.get_placeholder_str("audio", 0)
|
||||
|
||||
if task_type == "translate":
|
||||
full_lang_name_to = cls.supported_languages.get(to_language, to_language)
|
||||
user_prompt = f"{audio_tok}translate the speech to {full_lang_name_to}" # noqa: E501
|
||||
elif task_type == "transcribe":
|
||||
user_prompt = (
|
||||
f"{audio_tok}can you transcribe the speech into a written format?" # noqa: E501
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported task type {task_type}")
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.model)
|
||||
chat = [dict(role="user", content=user_prompt)]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
chat,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
prompt = {
|
||||
"prompt_token_ids": prompt_token_ids,
|
||||
"multi_modal_data": {"audio": audio},
|
||||
}
|
||||
return cast(PromptType, prompt)
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.56.0/src/transformers/models/granite_speech/feature_extraction_granite_speech.py#L122 # noqa: E501
|
||||
@classmethod
|
||||
def get_num_audio_tokens(
|
||||
cls,
|
||||
audio_duration_s: float,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
) -> int | None:
|
||||
"""Get the number of audio tokens for an audio duration in sec."""
|
||||
processor = cached_get_processor(model_config.model)
|
||||
hop_length = processor.audio_processor.melspec_kwargs["hop_length"]
|
||||
proj_win_size = processor.audio_processor.projector_window_size
|
||||
ds_rate = processor.audio_processor.projector_downsample_rate
|
||||
effective_window_size = proj_win_size // ds_rate
|
||||
|
||||
raw_length = audio_duration_s * stt_config.sample_rate
|
||||
|
||||
# mel sequence length computation
|
||||
mel_length = raw_length // hop_length + 1
|
||||
# encoder frame takes two mel features
|
||||
encoder_length = mel_length // 2
|
||||
nblocks = math.ceil(encoder_length / proj_win_size)
|
||||
# projector output length
|
||||
return nblocks * effective_window_size
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(
|
||||
cls, model_config: ModelConfig, task_type: str
|
||||
) -> SpeechToTextConfig:
|
||||
"""Get the stt config for this model."""
|
||||
# Default settings are reasonable for this model and we don't currently
|
||||
# expose this information in the model configs, but this may change in
|
||||
# the future
|
||||
return SpeechToTextConfig()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user