diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index c2227a21a4b9a..01140a4bfea7e 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -86,11 +86,7 @@ class OpenAISpeechToText(OpenAIServing): audio_data: bytes, ) -> tuple[list[PromptType], float]: # Validate request - # TODO language should be optional and can be guessed. - # For now we default to en. See - # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 - lang = request.language or "en" - self.model_cls.validate_language(lang) + language = self.model_cls.validate_language(request.language) if len(audio_data) / 1024**2 > self.max_audio_filesize_mb: raise ValueError("Maximum file size exceeded.") @@ -112,7 +108,7 @@ class OpenAISpeechToText(OpenAIServing): audio=chunk, stt_config=self.asr_config, model_config=self.model_config, - language=lang, + language=language, task_type=self.task_type, request_prompt=request.prompt) prompts.append(prompt) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 957b57276b4ca..b6d9877cd01b6 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable, MutableSequence +from collections.abc import Iterable, Mapping, MutableSequence from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, Union, overload, runtime_checkable) import numpy as np import torch from torch import Tensor +from transformers.models.whisper.tokenization_whisper import LANGUAGES from typing_extensions import Self, TypeIs from vllm.config import ModelConfig, SpeechToTextConfig @@ -685,6 +686,8 @@ class SupportsQuant: @runtime_checkable class SupportsTranscription(Protocol): """The interface required for all models that support transcription.""" + # Mapping from ISO639_1 language codes: language names + supported_languages: ClassVar[Mapping[str, str]] supports_transcription: ClassVar[Literal[True]] = True @@ -694,11 +697,22 @@ class SupportsTranscription(Protocol): `True`. """ + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # language codes in supported_languages + # that don't exist in the full language map + invalid = set(cls.supported_languages) - set(LANGUAGES.keys()) + if invalid: + raise ValueError( + f"{cls.__name__}.supported_languages contains invalid " + f"language codes: {sorted(invalid)}\n. " + f"Valid choices are: {sorted(LANGUAGES.keys())}") + @classmethod def get_generation_prompt(cls, audio: np.ndarray, stt_config: SpeechToTextConfig, - model_config: ModelConfig, language: str, - task_type: str, + model_config: ModelConfig, + language: Optional[str], task_type: str, request_prompt: str) -> PromptType: """Get the prompt for the ASR model. The model has control over the construction, as long as it @@ -706,9 +720,36 @@ class SupportsTranscription(Protocol): ... @classmethod - def validate_language(cls, language: str) -> bool: - """Check if the model supports a specific ISO639_1 language.""" - ... + def get_other_languages(cls) -> Mapping[str, str]: + # other possible language codes from the whisper map + return { + k: v + for k, v in LANGUAGES.items() if k not in cls.supported_languages + } + + @classmethod + def validate_language(cls, language: Optional[str]) -> Optional[str]: + """ + Ensure the language specified in the transcription request + is a valid ISO 639-1 language code. If the request language is + valid, but not natively supported by the model, trigger a + warning (but not an exception). + """ + if language is None or language in cls.supported_languages: + return language + elif language in cls.get_other_languages(): + logger.warning( + "Language %r is not natively supported by %s; " + "results may be less accurate. Supported languages: %r", + language, + cls.__name__, + list(cls.supported_languages.keys()), + ) + return language + else: + raise ValueError( + f"Unsupported language: {language!r}. Must be one of " + f"{list(cls.supported_languages.keys())}.") @classmethod def get_speech_to_text_config( diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 97cab628317e4..6b06c0ac6683f 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -26,8 +26,7 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import SupportsPP # yapf: disable -from vllm.model_executor.models.whisper import ( - WhisperEncoder, WhisperForConditionalGeneration) +from vllm.model_executor.models.whisper import WhisperEncoder # yapf: enable from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -50,6 +49,18 @@ from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, logger = init_logger(__name__) +ISO639_1_SUPPORTED_LANGS = { + "ar": "Arabic", + "nl": "Dutch", + "en": "English", + "fr": "French", + "de": "German", + "hi": "Hindi", + "it": "Italian", + "pt": "Portuguese", + "es": "Spanish", +} + class VoxtralProcessorAdapter: """ @@ -301,6 +312,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] dummy_inputs=VoxtralDummyInputsBuilder) class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, SupportsTranscription): + supported_languages = ISO639_1_SUPPORTED_LANGS def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -441,8 +453,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, # for speech-to-text transcription def get_generation_prompt(cls, audio: np.ndarray, model_config: ModelConfig, - stt_config: SpeechToTextConfig, language: str, - task_type: str, + stt_config: SpeechToTextConfig, + language: Optional[str], task_type: str, request_prompt: str) -> PromptType: tokenizer = cached_tokenizer_from_config(model_config) audio = Audio(audio, int(stt_config.sample_rate), @@ -457,11 +469,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, prompts_dict["prompt_token_ids"] = tokenized.tokens return cast(PromptType, prompts_dict) - @classmethod - def validate_language(cls, language: str) -> bool: - # same as whisper - return WhisperForConditionalGeneration.validate_language(language) - @classmethod def get_num_audio_tokens(cls, audio_duration_s: float, stt_config: SpeechToTextConfig, diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index d98dab5fac0e4..d7bafb9ef84d9 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -109,51 +109,6 @@ ISO639_1_SUPPORTED_LANGS = { "vi": "Vietnamese", "cy": "Welsh" } -ISO639_1_OTHER_LANGS = { - "lo": "Lao", - "jw": "Javanese", - "tk": "Turkmen", - "yi": "Yiddish", - "so": "Somali", - "bn": "Bengali", - "nn": "Norwegian Nynorsk", - "si": "Sinhala", - "yo": "Yoruba", - "sa": "Sanskrit", - "mi": "Māori", - "fo": "Faroese", # codespell:ignore - "mt": "Maltese", - "tg": "Tajik", - "mg": "Malagasy", - "haw": "Hawaiian", - "km": "Khmer", - "br": "Breton", - "ps": "Pashto", - "ln": "Lingala", - "la": "Latin", - "ml": "Malayalam", - "sq": "Albanian", - "su": "Sundanese", - "eu": "Basque", - "ka": "Georgian", - "uz": "Uzbek", - "sn": "Shona", - "ht": "Haitian", - "as": "Assamese", - "mn": "Mongolian", - "te": "Telugu", - "pa": "Panjabi", - "tt": "Tatar", - "gu": "Gujarati", - "oc": "Occitan", - "ha": "Hausa", - "ba": "Bashkir", - "my": "Burmese", - "sd": "Sindhi", - "am": "Amharic", - "lb": "Luxembourgish", - "bo": "Tibetan" -} class WhisperAudioInputs(TypedDict): @@ -807,22 +762,20 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, # Whisper only supports audio-conditioned generation. supports_transcription_only = True + supported_languages = ISO639_1_SUPPORTED_LANGS @classmethod - def validate_language(cls, language: str) -> bool: - if language in ISO639_1_SUPPORTED_LANGS: - return True - elif language in ISO639_1_OTHER_LANGS: + def validate_language(cls, language: Optional[str]) -> Optional[str]: + if language is None: + # TODO language should be optional and can be guessed. + # For now we default to en. See + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 logger.warning( - "The selected language %s has limited accuracy with" - " reported WER>=0.5. Results may be less accurate " - "for this choice.", language) - return True - else: - raise ValueError(f"Unsupported language: {language}." - "Language should be one of:" + - f" {list(ISO639_1_SUPPORTED_LANGS.values())}" + - f"or {list(ISO639_1_OTHER_LANGS.values())}") + "Defaulting to language='en'. If you wish to transcribe " + "audio in a different language, pass the `language` field " + "in the TranscriptionRequest.") + language = "en" + return super().validate_language(language) @classmethod def get_generation_prompt( @@ -830,9 +783,12 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, audio: np.ndarray, model_config: ModelConfig, # not needed here stt_config: SpeechToTextConfig, - language: str, + language: Optional[str], task_type: str, request_prompt: str) -> PromptType: + if language is None: + raise ValueError( + "Language must be specified when creating the Whisper prompt") prompt = { "encoder_prompt": { # Whisper does not support encoder prompt.