mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 10:36:32 +08:00
[Bugfix] Relax lang pin for voxtral (#21833)
Signed-off-by: Sanchit Gandhi <sgandhi3141@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
9cb497bfa3
commit
ec02e536df
@ -86,11 +86,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
audio_data: bytes,
|
||||
) -> tuple[list[PromptType], float]:
|
||||
# Validate request
|
||||
# TODO language should be optional and can be guessed.
|
||||
# For now we default to en. See
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
||||
lang = request.language or "en"
|
||||
self.model_cls.validate_language(lang)
|
||||
language = self.model_cls.validate_language(request.language)
|
||||
|
||||
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
|
||||
raise ValueError("Maximum file size exceeded.")
|
||||
@ -112,7 +108,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
audio=chunk,
|
||||
stt_config=self.asr_config,
|
||||
model_config=self.model_config,
|
||||
language=lang,
|
||||
language=language,
|
||||
task_type=self.task_type,
|
||||
request_prompt=request.prompt)
|
||||
prompts.append(prompt)
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, MutableSequence
|
||||
from collections.abc import Iterable, Mapping, MutableSequence
|
||||
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
||||
Union, overload, runtime_checkable)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from transformers.models.whisper.tokenization_whisper import LANGUAGES
|
||||
from typing_extensions import Self, TypeIs
|
||||
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig
|
||||
@ -685,6 +686,8 @@ class SupportsQuant:
|
||||
@runtime_checkable
|
||||
class SupportsTranscription(Protocol):
|
||||
"""The interface required for all models that support transcription."""
|
||||
# Mapping from ISO639_1 language codes: language names
|
||||
supported_languages: ClassVar[Mapping[str, str]]
|
||||
|
||||
supports_transcription: ClassVar[Literal[True]] = True
|
||||
|
||||
@ -694,11 +697,22 @@ class SupportsTranscription(Protocol):
|
||||
`True`.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
# language codes in supported_languages
|
||||
# that don't exist in the full language map
|
||||
invalid = set(cls.supported_languages) - set(LANGUAGES.keys())
|
||||
if invalid:
|
||||
raise ValueError(
|
||||
f"{cls.__name__}.supported_languages contains invalid "
|
||||
f"language codes: {sorted(invalid)}\n. "
|
||||
f"Valid choices are: {sorted(LANGUAGES.keys())}")
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig, language: str,
|
||||
task_type: str,
|
||||
model_config: ModelConfig,
|
||||
language: Optional[str], task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
"""Get the prompt for the ASR model.
|
||||
The model has control over the construction, as long as it
|
||||
@ -706,9 +720,36 @@ class SupportsTranscription(Protocol):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def validate_language(cls, language: str) -> bool:
|
||||
"""Check if the model supports a specific ISO639_1 language."""
|
||||
...
|
||||
def get_other_languages(cls) -> Mapping[str, str]:
|
||||
# other possible language codes from the whisper map
|
||||
return {
|
||||
k: v
|
||||
for k, v in LANGUAGES.items() if k not in cls.supported_languages
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def validate_language(cls, language: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Ensure the language specified in the transcription request
|
||||
is a valid ISO 639-1 language code. If the request language is
|
||||
valid, but not natively supported by the model, trigger a
|
||||
warning (but not an exception).
|
||||
"""
|
||||
if language is None or language in cls.supported_languages:
|
||||
return language
|
||||
elif language in cls.get_other_languages():
|
||||
logger.warning(
|
||||
"Language %r is not natively supported by %s; "
|
||||
"results may be less accurate. Supported languages: %r",
|
||||
language,
|
||||
cls.__name__,
|
||||
list(cls.supported_languages.keys()),
|
||||
)
|
||||
return language
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported language: {language!r}. Must be one of "
|
||||
f"{list(cls.supported_languages.keys())}.")
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(
|
||||
|
||||
@ -26,8 +26,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import SupportsPP
|
||||
# yapf: disable
|
||||
from vllm.model_executor.models.whisper import (
|
||||
WhisperEncoder, WhisperForConditionalGeneration)
|
||||
from vllm.model_executor.models.whisper import WhisperEncoder
|
||||
# yapf: enable
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -50,6 +49,18 @@ from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ISO639_1_SUPPORTED_LANGS = {
|
||||
"ar": "Arabic",
|
||||
"nl": "Dutch",
|
||||
"en": "English",
|
||||
"fr": "French",
|
||||
"de": "German",
|
||||
"hi": "Hindi",
|
||||
"it": "Italian",
|
||||
"pt": "Portuguese",
|
||||
"es": "Spanish",
|
||||
}
|
||||
|
||||
|
||||
class VoxtralProcessorAdapter:
|
||||
"""
|
||||
@ -301,6 +312,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
|
||||
dummy_inputs=VoxtralDummyInputsBuilder)
|
||||
class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP, SupportsTranscription):
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@ -441,8 +453,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# for speech-to-text transcription
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
model_config: ModelConfig,
|
||||
stt_config: SpeechToTextConfig, language: str,
|
||||
task_type: str,
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: Optional[str], task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
audio = Audio(audio, int(stt_config.sample_rate),
|
||||
@ -457,11 +469,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
prompts_dict["prompt_token_ids"] = tokenized.tokens
|
||||
return cast(PromptType, prompts_dict)
|
||||
|
||||
@classmethod
|
||||
def validate_language(cls, language: str) -> bool:
|
||||
# same as whisper
|
||||
return WhisperForConditionalGeneration.validate_language(language)
|
||||
|
||||
@classmethod
|
||||
def get_num_audio_tokens(cls, audio_duration_s: float,
|
||||
stt_config: SpeechToTextConfig,
|
||||
|
||||
@ -109,51 +109,6 @@ ISO639_1_SUPPORTED_LANGS = {
|
||||
"vi": "Vietnamese",
|
||||
"cy": "Welsh"
|
||||
}
|
||||
ISO639_1_OTHER_LANGS = {
|
||||
"lo": "Lao",
|
||||
"jw": "Javanese",
|
||||
"tk": "Turkmen",
|
||||
"yi": "Yiddish",
|
||||
"so": "Somali",
|
||||
"bn": "Bengali",
|
||||
"nn": "Norwegian Nynorsk",
|
||||
"si": "Sinhala",
|
||||
"yo": "Yoruba",
|
||||
"sa": "Sanskrit",
|
||||
"mi": "Māori",
|
||||
"fo": "Faroese", # codespell:ignore
|
||||
"mt": "Maltese",
|
||||
"tg": "Tajik",
|
||||
"mg": "Malagasy",
|
||||
"haw": "Hawaiian",
|
||||
"km": "Khmer",
|
||||
"br": "Breton",
|
||||
"ps": "Pashto",
|
||||
"ln": "Lingala",
|
||||
"la": "Latin",
|
||||
"ml": "Malayalam",
|
||||
"sq": "Albanian",
|
||||
"su": "Sundanese",
|
||||
"eu": "Basque",
|
||||
"ka": "Georgian",
|
||||
"uz": "Uzbek",
|
||||
"sn": "Shona",
|
||||
"ht": "Haitian",
|
||||
"as": "Assamese",
|
||||
"mn": "Mongolian",
|
||||
"te": "Telugu",
|
||||
"pa": "Panjabi",
|
||||
"tt": "Tatar",
|
||||
"gu": "Gujarati",
|
||||
"oc": "Occitan",
|
||||
"ha": "Hausa",
|
||||
"ba": "Bashkir",
|
||||
"my": "Burmese",
|
||||
"sd": "Sindhi",
|
||||
"am": "Amharic",
|
||||
"lb": "Luxembourgish",
|
||||
"bo": "Tibetan"
|
||||
}
|
||||
|
||||
|
||||
class WhisperAudioInputs(TypedDict):
|
||||
@ -807,22 +762,20 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
|
||||
# Whisper only supports audio-conditioned generation.
|
||||
supports_transcription_only = True
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
@classmethod
|
||||
def validate_language(cls, language: str) -> bool:
|
||||
if language in ISO639_1_SUPPORTED_LANGS:
|
||||
return True
|
||||
elif language in ISO639_1_OTHER_LANGS:
|
||||
def validate_language(cls, language: Optional[str]) -> Optional[str]:
|
||||
if language is None:
|
||||
# TODO language should be optional and can be guessed.
|
||||
# For now we default to en. See
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
||||
logger.warning(
|
||||
"The selected language %s has limited accuracy with"
|
||||
" reported WER>=0.5. Results may be less accurate "
|
||||
"for this choice.", language)
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {language}."
|
||||
"Language should be one of:" +
|
||||
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
|
||||
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
||||
"Defaulting to language='en'. If you wish to transcribe "
|
||||
"audio in a different language, pass the `language` field "
|
||||
"in the TranscriptionRequest.")
|
||||
language = "en"
|
||||
return super().validate_language(language)
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(
|
||||
@ -830,9 +783,12 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
audio: np.ndarray,
|
||||
model_config: ModelConfig, # not needed here
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: str,
|
||||
language: Optional[str],
|
||||
task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
if language is None:
|
||||
raise ValueError(
|
||||
"Language must be specified when creating the Whisper prompt")
|
||||
prompt = {
|
||||
"encoder_prompt": {
|
||||
# Whisper does not support encoder prompt.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user