[Bugfix] Relax lang pin for voxtral (#21833)

Signed-off-by: Sanchit Gandhi <sgandhi3141@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Sanchit Gandhi 2025-07-31 04:38:52 +01:00 committed by GitHub
parent 9cb497bfa3
commit ec02e536df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 80 additions and 80 deletions

View File

@ -86,11 +86,7 @@ class OpenAISpeechToText(OpenAIServing):
audio_data: bytes,
) -> tuple[list[PromptType], float]:
# Validate request
# TODO language should be optional and can be guessed.
# For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
lang = request.language or "en"
self.model_cls.validate_language(lang)
language = self.model_cls.validate_language(request.language)
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
raise ValueError("Maximum file size exceeded.")
@ -112,7 +108,7 @@ class OpenAISpeechToText(OpenAIServing):
audio=chunk,
stt_config=self.asr_config,
model_config=self.model_config,
language=lang,
language=language,
task_type=self.task_type,
request_prompt=request.prompt)
prompts.append(prompt)

View File

@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, MutableSequence
from collections.abc import Iterable, Mapping, MutableSequence
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
Union, overload, runtime_checkable)
import numpy as np
import torch
from torch import Tensor
from transformers.models.whisper.tokenization_whisper import LANGUAGES
from typing_extensions import Self, TypeIs
from vllm.config import ModelConfig, SpeechToTextConfig
@ -685,6 +686,8 @@ class SupportsQuant:
@runtime_checkable
class SupportsTranscription(Protocol):
"""The interface required for all models that support transcription."""
# Mapping from ISO639_1 language codes: language names
supported_languages: ClassVar[Mapping[str, str]]
supports_transcription: ClassVar[Literal[True]] = True
@ -694,11 +697,22 @@ class SupportsTranscription(Protocol):
`True`.
"""
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# language codes in supported_languages
# that don't exist in the full language map
invalid = set(cls.supported_languages) - set(LANGUAGES.keys())
if invalid:
raise ValueError(
f"{cls.__name__}.supported_languages contains invalid "
f"language codes: {sorted(invalid)}\n. "
f"Valid choices are: {sorted(LANGUAGES.keys())}")
@classmethod
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig, language: str,
task_type: str,
model_config: ModelConfig,
language: Optional[str], task_type: str,
request_prompt: str) -> PromptType:
"""Get the prompt for the ASR model.
The model has control over the construction, as long as it
@ -706,9 +720,36 @@ class SupportsTranscription(Protocol):
...
@classmethod
def validate_language(cls, language: str) -> bool:
"""Check if the model supports a specific ISO639_1 language."""
...
def get_other_languages(cls) -> Mapping[str, str]:
# other possible language codes from the whisper map
return {
k: v
for k, v in LANGUAGES.items() if k not in cls.supported_languages
}
@classmethod
def validate_language(cls, language: Optional[str]) -> Optional[str]:
"""
Ensure the language specified in the transcription request
is a valid ISO 639-1 language code. If the request language is
valid, but not natively supported by the model, trigger a
warning (but not an exception).
"""
if language is None or language in cls.supported_languages:
return language
elif language in cls.get_other_languages():
logger.warning(
"Language %r is not natively supported by %s; "
"results may be less accurate. Supported languages: %r",
language,
cls.__name__,
list(cls.supported_languages.keys()),
)
return language
else:
raise ValueError(
f"Unsupported language: {language!r}. Must be one of "
f"{list(cls.supported_languages.keys())}.")
@classmethod
def get_speech_to_text_config(

View File

@ -26,8 +26,7 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsPP
# yapf: disable
from vllm.model_executor.models.whisper import (
WhisperEncoder, WhisperForConditionalGeneration)
from vllm.model_executor.models.whisper import WhisperEncoder
# yapf: enable
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -50,6 +49,18 @@ from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
logger = init_logger(__name__)
ISO639_1_SUPPORTED_LANGS = {
"ar": "Arabic",
"nl": "Dutch",
"en": "English",
"fr": "French",
"de": "German",
"hi": "Hindi",
"it": "Italian",
"pt": "Portuguese",
"es": "Spanish",
}
class VoxtralProcessorAdapter:
"""
@ -301,6 +312,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
dummy_inputs=VoxtralDummyInputsBuilder)
class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsTranscription):
supported_languages = ISO639_1_SUPPORTED_LANGS
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@ -441,8 +453,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# for speech-to-text transcription
def get_generation_prompt(cls, audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig, language: str,
task_type: str,
stt_config: SpeechToTextConfig,
language: Optional[str], task_type: str,
request_prompt: str) -> PromptType:
tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate),
@ -457,11 +469,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
prompts_dict["prompt_token_ids"] = tokenized.tokens
return cast(PromptType, prompts_dict)
@classmethod
def validate_language(cls, language: str) -> bool:
# same as whisper
return WhisperForConditionalGeneration.validate_language(language)
@classmethod
def get_num_audio_tokens(cls, audio_duration_s: float,
stt_config: SpeechToTextConfig,

View File

@ -109,51 +109,6 @@ ISO639_1_SUPPORTED_LANGS = {
"vi": "Vietnamese",
"cy": "Welsh"
}
ISO639_1_OTHER_LANGS = {
"lo": "Lao",
"jw": "Javanese",
"tk": "Turkmen",
"yi": "Yiddish",
"so": "Somali",
"bn": "Bengali",
"nn": "Norwegian Nynorsk",
"si": "Sinhala",
"yo": "Yoruba",
"sa": "Sanskrit",
"mi": "Māori",
"fo": "Faroese", # codespell:ignore
"mt": "Maltese",
"tg": "Tajik",
"mg": "Malagasy",
"haw": "Hawaiian",
"km": "Khmer",
"br": "Breton",
"ps": "Pashto",
"ln": "Lingala",
"la": "Latin",
"ml": "Malayalam",
"sq": "Albanian",
"su": "Sundanese",
"eu": "Basque",
"ka": "Georgian",
"uz": "Uzbek",
"sn": "Shona",
"ht": "Haitian",
"as": "Assamese",
"mn": "Mongolian",
"te": "Telugu",
"pa": "Panjabi",
"tt": "Tatar",
"gu": "Gujarati",
"oc": "Occitan",
"ha": "Hausa",
"ba": "Bashkir",
"my": "Burmese",
"sd": "Sindhi",
"am": "Amharic",
"lb": "Luxembourgish",
"bo": "Tibetan"
}
class WhisperAudioInputs(TypedDict):
@ -807,22 +762,20 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
# Whisper only supports audio-conditioned generation.
supports_transcription_only = True
supported_languages = ISO639_1_SUPPORTED_LANGS
@classmethod
def validate_language(cls, language: str) -> bool:
if language in ISO639_1_SUPPORTED_LANGS:
return True
elif language in ISO639_1_OTHER_LANGS:
def validate_language(cls, language: Optional[str]) -> Optional[str]:
if language is None:
# TODO language should be optional and can be guessed.
# For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
logger.warning(
"The selected language %s has limited accuracy with"
" reported WER>=0.5. Results may be less accurate "
"for this choice.", language)
return True
else:
raise ValueError(f"Unsupported language: {language}."
"Language should be one of:" +
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
f"or {list(ISO639_1_OTHER_LANGS.values())}")
"Defaulting to language='en'. If you wish to transcribe "
"audio in a different language, pass the `language` field "
"in the TranscriptionRequest.")
language = "en"
return super().validate_language(language)
@classmethod
def get_generation_prompt(
@ -830,9 +783,12 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
audio: np.ndarray,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str,
language: Optional[str],
task_type: str,
request_prompt: str) -> PromptType:
if language is None:
raise ValueError(
"Language must be specified when creating the Whisper prompt")
prompt = {
"encoder_prompt": {
# Whisper does not support encoder prompt.