mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 16:35:17 +08:00
adapt voxtral (#31095)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
b10d47e0e0
commit
3faa8bee57
@ -111,4 +111,5 @@ async def test_online_serving(client, audio_assets: AudioTestAssets):
|
||||
|
||||
assert len(chat_completion.choices) == 1
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.message.content == "In the first audio clip, you hear a brief"
|
||||
assert choice.finish_reason == "length"
|
||||
|
||||
@ -860,6 +860,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
# disable this temporarily until we support HF format
|
||||
is_available_online=False,
|
||||
),
|
||||
"VoxtralStreamingGeneration": _HfExamplesInfo(
|
||||
"<place-holder>",
|
||||
# disable this temporarily until we support HF format
|
||||
is_available_online=False,
|
||||
),
|
||||
# [Encoder-decoder]
|
||||
"WhisperForConditionalGeneration": _HfExamplesInfo(
|
||||
"openai/whisper-large-v3-turbo",
|
||||
|
||||
@ -1542,6 +1542,10 @@ class ModelConfig:
|
||||
def is_multimodal_raw_input_only_model(self) -> bool:
|
||||
return self._model_info.supports_multimodal_raw_input_only
|
||||
|
||||
@property
|
||||
def requires_raw_input_tokens(self) -> bool:
|
||||
return self._model_info.requires_raw_input_tokens
|
||||
|
||||
@property
|
||||
def is_cross_encoder(self) -> bool:
|
||||
return (
|
||||
|
||||
@ -94,6 +94,12 @@ class SupportsMultiModal(Protocol):
|
||||
`multimodal_config.mm_encoder_tp_mode="data"`.
|
||||
"""
|
||||
|
||||
requires_raw_input_tokens: ClassVar[bool] = False
|
||||
"""
|
||||
A flag that indicates this model processes input id tokens
|
||||
in their raw form and not input embeddings.
|
||||
"""
|
||||
|
||||
merge_by_field_config: ClassVar[bool | None] = None
|
||||
"""
|
||||
[DEPRECATED] A flag that indicates which implementation of
|
||||
@ -306,6 +312,10 @@ def supports_multimodal_raw_input_only(model: type[object] | object) -> bool:
|
||||
return getattr(model, "supports_multimodal_raw_input_only", False)
|
||||
|
||||
|
||||
def requires_raw_input_tokens(model: type[object] | object) -> bool:
|
||||
return getattr(model, "requires_raw_input_tokens", False)
|
||||
|
||||
|
||||
def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool:
|
||||
return getattr(model, "supports_encoder_tp_data", False)
|
||||
|
||||
|
||||
@ -46,6 +46,7 @@ from .interfaces import (
|
||||
has_noops,
|
||||
is_attention_free,
|
||||
is_hybrid,
|
||||
requires_raw_input_tokens,
|
||||
supports_cross_encoding,
|
||||
supports_mamba_prefix_caching,
|
||||
supports_multimodal,
|
||||
@ -422,6 +423,7 @@ _MULTIMODAL_MODELS = {
|
||||
),
|
||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
|
||||
"VoxtralStreamingGeneration": ("voxtral_streaming", "VoxtralStreamingGeneration"), # noqa: E501
|
||||
# [Encoder-decoder]
|
||||
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
|
||||
}
|
||||
@ -539,6 +541,7 @@ class _ModelInfo:
|
||||
supports_cross_encoding: bool
|
||||
supports_multimodal: bool
|
||||
supports_multimodal_raw_input_only: bool
|
||||
requires_raw_input_tokens: bool
|
||||
supports_multimodal_encoder_tp_data: bool
|
||||
supports_pp: bool
|
||||
has_inner_state: bool
|
||||
@ -562,6 +565,7 @@ class _ModelInfo:
|
||||
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
|
||||
model
|
||||
),
|
||||
requires_raw_input_tokens=requires_raw_input_tokens(model),
|
||||
supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
|
||||
model
|
||||
),
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
@ -116,10 +117,7 @@ class VoxtralProcessorAdapter:
|
||||
self,
|
||||
audio_length: int,
|
||||
) -> int:
|
||||
pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames(
|
||||
audio_length, self.sampling_rate
|
||||
)
|
||||
return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate))
|
||||
return ceil(audio_length / (self.sampling_rate // self.frame_rate))
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -158,7 +156,14 @@ class VoxtralProcessorAdapter:
|
||||
assert audio.ndim == 1
|
||||
|
||||
# pad if necessary
|
||||
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
||||
# TODO(Patrick) - remove once mistral-common is bumped
|
||||
sig = inspect.signature(self._audio_processor.pad)
|
||||
if "is_online_streaming" in sig.parameters:
|
||||
audio = self._audio_processor.pad(
|
||||
audio, self.sampling_rate, is_online_streaming=False
|
||||
)
|
||||
else:
|
||||
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
||||
|
||||
audio_tokens = [self.begin_audio_token_id] + [
|
||||
self.audio_token_id
|
||||
@ -510,6 +515,7 @@ class VoxtralForConditionalGeneration(
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
remapping_rules = [
|
||||
(r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
|
||||
(r"mm_whisper_embeddings\.(.*)", r"\1"),
|
||||
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
|
||||
(
|
||||
@ -535,13 +541,16 @@ class VoxtralForConditionalGeneration(
|
||||
def llm_weights_generator():
|
||||
nonlocal loaded_weights
|
||||
for name, w in weights:
|
||||
is_encoder = (
|
||||
name.startswith("mm_whisper_embeddings")
|
||||
and not name.startswith("mm_whisper_embeddings.tok_embeddings")
|
||||
and not name.startswith(
|
||||
"mm_whisper_embeddings.audio_language_projection"
|
||||
is_encoder = False
|
||||
for k in [
|
||||
"mm_whisper_embeddings",
|
||||
"mm_streams_embeddings.embedding_module",
|
||||
]:
|
||||
is_encoder |= (
|
||||
name.startswith(k)
|
||||
and not name.startswith(f"{k}.tok_embeddings")
|
||||
and not name.startswith(f"{k}.audio_language_projection")
|
||||
)
|
||||
)
|
||||
|
||||
for pattern, repl in remapping_rules:
|
||||
if re.fullmatch(pattern, name):
|
||||
@ -676,6 +685,7 @@ class VoxtralEncoderModel(nn.Module):
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
|
||||
mistral_remapping = [
|
||||
(r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
|
||||
(
|
||||
r"whisper_encoder\.conv_layers\.0\.(weight|bias)",
|
||||
r"whisper_encoder.conv1.\1",
|
||||
@ -684,6 +694,14 @@ class VoxtralEncoderModel(nn.Module):
|
||||
r"whisper_encoder\.conv_layers\.1\.(weight|bias)",
|
||||
r"whisper_encoder.conv2.\1",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.conv_layers\.0\.conv\.(weight|bias)",
|
||||
r"whisper_encoder.conv1.\1",
|
||||
), # noqa: E501
|
||||
(
|
||||
r"whisper_encoder\.conv_layers\.1\.conv\.(weight|bias)",
|
||||
r"whisper_encoder.conv2.\1",
|
||||
), # noqa: E501
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501
|
||||
r"whisper_encoder.layers.\1.self_attn.\2_proj.\3",
|
||||
|
||||
243
vllm/model_executor/models/voxtral_streaming.py
Normal file
243
vllm/model_executor/models/voxtral_streaming.py
Normal file
@ -0,0 +1,243 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from collections.abc import Mapping
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.voxtral import (
|
||||
VoxtralDummyInputsBuilder,
|
||||
VoxtralForConditionalGeneration,
|
||||
VoxtralMultiModalProcessor,
|
||||
VoxtralProcessingInfo,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import _I, BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalKwargsOptionalItems,
|
||||
)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (
|
||||
MultiModalPromptUpdates,
|
||||
PlaceholderFeaturesInfo,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import (
|
||||
_flatten_embeddings,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class VoxtralStreamingMultiModalProcessor(VoxtralMultiModalProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
info: _I,
|
||||
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
) -> None:
|
||||
# streaming can't make use of a cache yet
|
||||
super().__init__(info, dummy_inputs, cache=None)
|
||||
|
||||
def _maybe_apply_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
prompt_ids: list[int],
|
||||
mm_kwargs: MultiModalKwargsOptionalItems,
|
||||
mm_prompt_updates: MultiModalPromptUpdates,
|
||||
is_update_applied: bool,
|
||||
) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
||||
# there are no placeholder audio tokens for streaming
|
||||
# so we need to build the place placeholder positions manually
|
||||
|
||||
# in streaming there is always only one audio input
|
||||
audios = mm_kwargs.get("audio", [])
|
||||
assert len(audios) == 1, (
|
||||
f"Expected only one audio input for streaming, got {mm_kwargs=}"
|
||||
)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
audio_config = tokenizer.instruct.audio_encoder.audio_config
|
||||
|
||||
num_audio_samples = audios[0]["audio_arrays"].data.shape[0]
|
||||
length = audio_config.num_audio_tokens(num_audio_samples)
|
||||
|
||||
features_info = PlaceholderFeaturesInfo(
|
||||
modality="audio",
|
||||
item_idx=0,
|
||||
start_idx=0,
|
||||
tokens=length
|
||||
* [0], # only used for length computation, so we can take dummy inputs
|
||||
is_embed=None,
|
||||
)
|
||||
return prompt_ids, {"audio": [features_info]}
|
||||
|
||||
|
||||
class TimeEmbedding(torch.nn.Module):
|
||||
"""Sinusoidal Embedding for encoding time"""
|
||||
|
||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
inv_freq = torch.exp(
|
||||
-math.log(self.theta)
|
||||
* torch.arange(self.dim // 2).float()
|
||||
/ (self.dim // 2)
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
||||
t = t[..., None] # (B,) -> (B, 1) or (B, T) -> (B, T, 1)
|
||||
inv_freq = self.inv_freq.to(device=t.device, dtype=t.dtype)
|
||||
emb = (
|
||||
t * inv_freq
|
||||
) # (B, 1) x (D/2,) -> (B, D/2) or (B, T, 1) x (D/2,) -> (B, T, D/2)
|
||||
return torch.cat((emb.cos(), emb.sin()), dim=-1) # (B, D) or (B, T, D)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
VoxtralStreamingMultiModalProcessor,
|
||||
info=VoxtralProcessingInfo,
|
||||
dummy_inputs=VoxtralDummyInputsBuilder,
|
||||
)
|
||||
class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
|
||||
requires_raw_input_tokens = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.time_embedding: TimeEmbedding = TimeEmbedding(
|
||||
dim=self.config.text_config.hidden_size
|
||||
)
|
||||
|
||||
audio_config = self.tokenizer.instruct.audio_encoder.audio_config
|
||||
_n_delay_tokens = (
|
||||
audio_config.frame_rate * audio_config.transcription_delay_ms / 1000
|
||||
)
|
||||
assert _n_delay_tokens.is_integer(), (
|
||||
f"n_delay_tokens must be integer, got {_n_delay_tokens}"
|
||||
)
|
||||
|
||||
self.n_delay_tokens = int(_n_delay_tokens)
|
||||
|
||||
@property
|
||||
def audio_config(self):
|
||||
return self.tokenizer.instruct.audio_encoder.audio_config
|
||||
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
# Multi-modal token ID may exceed vocab size
|
||||
handle_oov_mm_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Pass post-conv embeddings directly as input"""
|
||||
# for streaming we simply flatten the multimodal embeddings
|
||||
# to be in tensor format, we treat the input ids later
|
||||
assert multimodal_embeddings is not None
|
||||
assert len(multimodal_embeddings) > 0, (
|
||||
"For streaming you must provide a multimodal_embedding at every step."
|
||||
)
|
||||
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
|
||||
return mm_embeds_flat
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
assert inputs_embeds is not None
|
||||
assert input_ids is not None
|
||||
|
||||
pool_size = self.config.audio_config.block_pool_size
|
||||
inputs_embeds = inputs_embeds.view(
|
||||
inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size
|
||||
)
|
||||
|
||||
audio_hidden_states = self.whisper_encoder.whisper_encoder.forward_layers(
|
||||
inputs_embeds
|
||||
)
|
||||
|
||||
num_tokens, audio_hidden_size = audio_hidden_states.shape
|
||||
assert num_tokens % self.downsample_factor == 0
|
||||
audio_hidden_states = audio_hidden_states.reshape(
|
||||
num_tokens // self.downsample_factor,
|
||||
audio_hidden_size * self.downsample_factor,
|
||||
)
|
||||
audio_text_embeds = self.audio_language_adapter(audio_hidden_states)
|
||||
|
||||
text_embeds = self.language_model.embed_input_ids(input_ids)
|
||||
|
||||
# sum pool text and audio embeddings
|
||||
inputs_embeds = audio_text_embeds + text_embeds
|
||||
|
||||
time_tensor = torch.tensor(
|
||||
[self.n_delay_tokens],
|
||||
device=inputs_embeds.device,
|
||||
dtype=inputs_embeds.dtype,
|
||||
)
|
||||
inputs_embeds = inputs_embeds + self.time_embedding(time_tensor)
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def embed_multimodal(
|
||||
self, **kwargs
|
||||
) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None:
|
||||
"""Transform audio waveforms -> initial whisper post-conv embeddings"""
|
||||
audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
|
||||
|
||||
assert audio_inputs is not None, (
|
||||
"For streaming you must provide an audio input at every step."
|
||||
)
|
||||
|
||||
multiple_of = self.audio_config.raw_audio_length_per_tok
|
||||
assert all(
|
||||
(this_audio := audio.shape[0]) % multiple_of == 0 for audio in audio_inputs
|
||||
), (
|
||||
f"Every input audio waveform has to be a multiple of {multiple_of}, but"
|
||||
f" one is {this_audio} with {(this_audio / multiple_of)=}."
|
||||
)
|
||||
|
||||
mel_features = [
|
||||
self.whisper_encoder.compute_whisper_melspec(audio).to(
|
||||
self.whisper_encoder.dtype
|
||||
)
|
||||
for audio in audio_inputs
|
||||
]
|
||||
seq_lens = [mel.shape[1] for mel in mel_features]
|
||||
# [total_num_20ms_frames, hidden_size]
|
||||
audio_embeddings = self.whisper_encoder.whisper_encoder.forward_conv(
|
||||
mel_features
|
||||
)[0]
|
||||
conv_stride = self.whisper_encoder.whisper_encoder.total_stride
|
||||
audio_embeddings_per_sample = audio_embeddings.split(
|
||||
[s // conv_stride for s in seq_lens], dim=0
|
||||
)
|
||||
|
||||
# audio_embeddings per sample need to be divisible by 4
|
||||
pool_size = self.config.audio_config.block_pool_size
|
||||
assert all(
|
||||
(this_shape := sample.shape[0]) % pool_size == 0
|
||||
for sample in audio_embeddings_per_sample
|
||||
), f"Every audio embedding has to be a multiple of 4, but one is {this_shape}."
|
||||
|
||||
audio_embeddings_per_sample = [
|
||||
e.view(e.shape[0] // pool_size, e.shape[1] * pool_size)
|
||||
for e in audio_embeddings_per_sample
|
||||
]
|
||||
return audio_embeddings_per_sample
|
||||
@ -1,9 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Annotated, Literal, cast
|
||||
|
||||
import numpy as np
|
||||
@ -16,7 +18,10 @@ from transformers import (
|
||||
)
|
||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||
|
||||
from vllm.attention.layer import Attention, AttentionType
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layers.cross_attention import CrossAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
@ -34,6 +39,11 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.whisper_utils import (
|
||||
ISO639_1_SUPPORTED_LANGS,
|
||||
WhisperAttentionWithBlockPooling,
|
||||
WhisperCausalConv1d,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
@ -64,67 +74,11 @@ from .utils import (
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
|
||||
|
||||
ISO639_1_SUPPORTED_LANGS = {
|
||||
"af": "Afrikaans",
|
||||
"ar": "Arabic",
|
||||
"hy": "Armenian",
|
||||
"az": "Azerbaijani",
|
||||
"be": "Belarusian",
|
||||
"bs": "Bosnian",
|
||||
"bg": "Bulgarian",
|
||||
"ca": "Catalan",
|
||||
"zh": "Chinese",
|
||||
"hr": "Croatian",
|
||||
"cs": "Czech",
|
||||
"da": "Danish",
|
||||
"nl": "Dutch",
|
||||
"en": "English",
|
||||
"et": "Estonian",
|
||||
"fi": "Finnish",
|
||||
"fr": "French",
|
||||
"gl": "Galician",
|
||||
"de": "German",
|
||||
"el": "Greek",
|
||||
"he": "Hebrew",
|
||||
"hi": "Hindi",
|
||||
"hu": "Hungarian",
|
||||
"is": "Icelandic",
|
||||
"id": "Indonesian",
|
||||
"it": "Italian",
|
||||
"ja": "Japanese",
|
||||
"kn": "Kannada",
|
||||
"kk": "Kazakh",
|
||||
"ko": "Korean",
|
||||
"lv": "Latvian",
|
||||
"lt": "Lithuanian",
|
||||
"mk": "Macedonian",
|
||||
"ms": "Malay",
|
||||
"mr": "Marathi",
|
||||
"mi": "Maori",
|
||||
"ne": "Nepali",
|
||||
"no": "Norwegian",
|
||||
"fa": "Persian",
|
||||
"pl": "Polish",
|
||||
"pt": "Portuguese",
|
||||
"ro": "Romanian",
|
||||
"ru": "Russian",
|
||||
"sr": "Serbian",
|
||||
"sk": "Slovak",
|
||||
"sl": "Slovenian",
|
||||
"es": "Spanish",
|
||||
"sw": "Swahili",
|
||||
"sv": "Swedish",
|
||||
"tl": "Tagalog",
|
||||
"ta": "Tamil",
|
||||
"th": "Thai",
|
||||
"tr": "Turkish",
|
||||
"uk": "Ukrainian",
|
||||
"ur": "Urdu",
|
||||
"vi": "Vietnamese",
|
||||
"cy": "Welsh",
|
||||
}
|
||||
class WhisperPosEmbedType(enum.Enum):
|
||||
SINUSOIDAL = "sinusoidal"
|
||||
NOPE = "nope"
|
||||
LEARNED = "learned"
|
||||
|
||||
|
||||
class WhisperAudioInputs(TensorSchema):
|
||||
@ -184,6 +138,8 @@ class WhisperAttention(nn.Module):
|
||||
num_heads: int,
|
||||
bias: bool = True,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
per_layer_sliding_window: int | None = None,
|
||||
block_pool_size: int = 1,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
@ -242,7 +198,14 @@ class WhisperAttention(nn.Module):
|
||||
attn_type=self.attn_type,
|
||||
)
|
||||
else: # AttentionType.DECODER (regular decoder self-attention)
|
||||
self.attn = Attention(
|
||||
if block_pool_size > 1:
|
||||
attn_cls = partial(
|
||||
WhisperAttentionWithBlockPooling, block_pool_size=block_pool_size
|
||||
)
|
||||
else:
|
||||
attn_cls = Attention
|
||||
|
||||
self.attn = attn_cls(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
@ -251,6 +214,7 @@ class WhisperAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=self.attn_type,
|
||||
per_layer_sliding_window=per_layer_sliding_window,
|
||||
)
|
||||
|
||||
def _init_qkv(
|
||||
@ -386,6 +350,9 @@ class WhisperEncoderLayer(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
is_causal = getattr(config, "is_causal", False)
|
||||
sliding_window = getattr(config, "sliding_window", None)
|
||||
block_pool_size = getattr(config, "block_pool_size", 1)
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
@ -393,7 +360,9 @@ class WhisperEncoderLayer(nn.Module):
|
||||
self.self_attn = WhisperAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
attn_type=AttentionType.ENCODER,
|
||||
attn_type=AttentionType.DECODER if is_causal else AttentionType.ENCODER,
|
||||
block_pool_size=block_pool_size,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
@ -492,12 +461,21 @@ class WhisperEncoder(nn.Module):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
embed_dim = config.d_model
|
||||
|
||||
self.pos_embed_type = WhisperPosEmbedType(
|
||||
getattr(config, "pos_embed", "sinusoidal")
|
||||
)
|
||||
self.num_mel_bins = config.num_mel_bins
|
||||
self.max_source_positions = config.max_source_positions
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
|
||||
is_causal = getattr(config, "is_causal", False)
|
||||
Conv1d = WhisperCausalConv1d if is_causal else partial(nn.Conv1d, padding=1)
|
||||
|
||||
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
|
||||
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
|
||||
|
||||
self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.encoder_layers,
|
||||
lambda prefix: WhisperEncoderLayer(
|
||||
@ -507,29 +485,54 @@ class WhisperEncoder(nn.Module):
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
maybe_fp32_init_ctx = (
|
||||
set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
|
||||
)
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
maybe_fp32_init_ctx,
|
||||
if is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
|
||||
raise ValueError(
|
||||
"Only NOPE position embeddings are supported "
|
||||
f"for causal models, but got {self.pos_embed_type}"
|
||||
)
|
||||
elif self.pos_embed_type in (
|
||||
WhisperPosEmbedType.SINUSOIDAL,
|
||||
WhisperPosEmbedType.LEARNED,
|
||||
):
|
||||
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
|
||||
self.embed_positions.weight.copy_(
|
||||
sinusoids(*self.embed_positions.weight.shape)
|
||||
maybe_fp32_init_ctx = (
|
||||
set_default_torch_dtype(torch.float32)
|
||||
if init_in_fp32
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
|
||||
with (
|
||||
torch.no_grad(),
|
||||
maybe_fp32_init_ctx,
|
||||
):
|
||||
self.embed_positions = nn.Embedding(
|
||||
self.max_source_positions, embed_dim
|
||||
)
|
||||
self.embed_positions.weight.copy_(
|
||||
sinusoids(*self.embed_positions.weight.shape)
|
||||
)
|
||||
|
||||
def forward_conv(
|
||||
self, input_features: torch.Tensor | list[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
hidden_states = []
|
||||
input_is_batched = False
|
||||
for features in input_features:
|
||||
embeds = nn.functional.gelu(self.conv1(features))
|
||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||
embeds = embeds.transpose(-1, -2)
|
||||
embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
|
||||
embeds.dtype
|
||||
)
|
||||
|
||||
if self.pos_embed_type in (
|
||||
WhisperPosEmbedType.SINUSOIDAL,
|
||||
WhisperPosEmbedType.LEARNED,
|
||||
):
|
||||
embeds = embeds.transpose(-1, -2)
|
||||
embeds = (
|
||||
embeds + self.embed_positions.weight[: embeds.size(-2), :]
|
||||
).to(embeds.dtype)
|
||||
elif self.pos_embed_type == WhisperPosEmbedType.NOPE:
|
||||
embeds = embeds.transpose(-1, -2).to(embeds.dtype)
|
||||
else:
|
||||
raise ValueError(f"Unknown pos_embed_type: {self.pos_embed_type}")
|
||||
|
||||
hidden_states.append(embeds)
|
||||
input_is_batched = embeds.ndim > 2
|
||||
# Input to MHA must be B x T x D
|
||||
@ -539,12 +542,19 @@ class WhisperEncoder(nn.Module):
|
||||
else:
|
||||
hidden_states = torch.stack(hidden_states, dim=0)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward_layers(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
|
||||
hidden_states = self.forward_conv(input_features)
|
||||
return self.forward_layers(hidden_states)
|
||||
|
||||
|
||||
class WhisperDecoder(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
299
vllm/model_executor/models/whisper_utils.py
Normal file
299
vllm/model_executor/models/whisper_utils.py
Normal file
@ -0,0 +1,299 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import functools
|
||||
import math
|
||||
from dataclasses import replace
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend_with_overrides,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
|
||||
ISO639_1_SUPPORTED_LANGS = {
|
||||
"af": "Afrikaans",
|
||||
"ar": "Arabic",
|
||||
"hy": "Armenian",
|
||||
"az": "Azerbaijani",
|
||||
"be": "Belarusian",
|
||||
"bs": "Bosnian",
|
||||
"bg": "Bulgarian",
|
||||
"ca": "Catalan",
|
||||
"zh": "Chinese",
|
||||
"hr": "Croatian",
|
||||
"cs": "Czech",
|
||||
"da": "Danish",
|
||||
"nl": "Dutch",
|
||||
"en": "English",
|
||||
"et": "Estonian",
|
||||
"fi": "Finnish",
|
||||
"fr": "French",
|
||||
"gl": "Galician",
|
||||
"de": "German",
|
||||
"el": "Greek",
|
||||
"he": "Hebrew",
|
||||
"hi": "Hindi",
|
||||
"hu": "Hungarian",
|
||||
"is": "Icelandic",
|
||||
"id": "Indonesian",
|
||||
"it": "Italian",
|
||||
"ja": "Japanese",
|
||||
"kn": "Kannada",
|
||||
"kk": "Kazakh",
|
||||
"ko": "Korean",
|
||||
"lv": "Latvian",
|
||||
"lt": "Lithuanian",
|
||||
"mk": "Macedonian",
|
||||
"ms": "Malay",
|
||||
"mr": "Marathi",
|
||||
"mi": "Maori",
|
||||
"ne": "Nepali",
|
||||
"no": "Norwegian",
|
||||
"fa": "Persian",
|
||||
"pl": "Polish",
|
||||
"pt": "Portuguese",
|
||||
"ro": "Romanian",
|
||||
"ru": "Russian",
|
||||
"sr": "Serbian",
|
||||
"sk": "Slovak",
|
||||
"sl": "Slovenian",
|
||||
"es": "Spanish",
|
||||
"sw": "Swahili",
|
||||
"sv": "Swedish",
|
||||
"tl": "Tagalog",
|
||||
"ta": "Tamil",
|
||||
"th": "Thai",
|
||||
"tr": "Turkish",
|
||||
"uk": "Ukrainian",
|
||||
"ur": "Urdu",
|
||||
"vi": "Vietnamese",
|
||||
"cy": "Welsh",
|
||||
}
|
||||
|
||||
|
||||
def _pad1d(
|
||||
x: torch.Tensor,
|
||||
paddings: tuple[int, int],
|
||||
mode: str = "constant",
|
||||
value: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
"""Tiny wrapper around F.pad, just to allow for
|
||||
reflect padding on small input.
|
||||
If this is the case, we insert extra 0 padding
|
||||
to the right before the reflection happen.
|
||||
"""
|
||||
length = x.shape[-1]
|
||||
padding_left, padding_right = paddings
|
||||
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
||||
if mode == "reflect":
|
||||
max_pad = max(padding_left, padding_right)
|
||||
extra_pad = 0
|
||||
if length <= max_pad:
|
||||
extra_pad = max_pad - length + 1
|
||||
x = F.pad(x, (0, extra_pad))
|
||||
padded = F.pad(x, paddings, mode, value)
|
||||
end = padded.shape[-1] - extra_pad
|
||||
return padded[..., :end]
|
||||
else:
|
||||
return F.pad(x, paddings, mode, value)
|
||||
|
||||
|
||||
class WhisperCausalConv1d(nn.Conv1d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
self._stride = self.stride[0]
|
||||
self._effective_kernel_size = (kernel_size - 1) * self.dilation[0] + 1
|
||||
self._padding_total = self._effective_kernel_size - self._stride
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
n_frames = (
|
||||
x.shape[-1] - self._effective_kernel_size + self._padding_total
|
||||
) / self._stride + 1
|
||||
target_length = (math.ceil(n_frames) - 1) * self._stride + (
|
||||
self._effective_kernel_size - self._padding_total
|
||||
)
|
||||
extra_padding = target_length - x.shape[-1]
|
||||
x = _pad1d(x, (self._padding_total, extra_padding), mode="constant")
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_whisper_attention_backend_with_block_pooling(
|
||||
underlying_attn_backend: AttentionBackend, block_pool_size: int
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "WhisperAttentionWithBlockPooling_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class WhisperAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
assert kv_cache_spec.num_kv_heads % block_pool_size == 0
|
||||
kv_cache_spec = replace(
|
||||
kv_cache_spec,
|
||||
block_size=kv_cache_spec.block_size * block_pool_size,
|
||||
num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size,
|
||||
)
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
new_common_attn_metadata = copy.deepcopy(common_attn_metadata)
|
||||
new_common_attn_metadata.query_start_loc *= block_pool_size
|
||||
new_common_attn_metadata.query_start_loc_cpu *= block_pool_size
|
||||
new_common_attn_metadata.seq_lens *= block_pool_size
|
||||
new_common_attn_metadata._seq_lens_cpu *= block_pool_size
|
||||
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size
|
||||
new_common_attn_metadata.num_actual_tokens *= block_pool_size
|
||||
new_common_attn_metadata.max_query_len *= block_pool_size
|
||||
new_common_attn_metadata.max_seq_len *= block_pool_size
|
||||
original_slot_mapping = common_attn_metadata.slot_mapping
|
||||
common_prefix_len *= block_pool_size
|
||||
new_common_attn_metadata.slot_mapping = (
|
||||
(
|
||||
original_slot_mapping.unsqueeze(1) * block_pool_size
|
||||
+ torch.arange(block_pool_size, device=original_slot_mapping.device)
|
||||
)
|
||||
.flatten()
|
||||
.clamp(min=-1)
|
||||
)
|
||||
return super().build(
|
||||
common_prefix_len, new_common_attn_metadata, fast_build
|
||||
)
|
||||
|
||||
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
|
||||
raise NotImplementedError(
|
||||
f"{underlying_attn_backend} is not yet supported."
|
||||
"Contributions to support more backends are much "
|
||||
"appreciated."
|
||||
)
|
||||
|
||||
attn_backend = subclass_attention_backend_with_overrides(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
overrides={
|
||||
"get_builder_cls": lambda: WhisperAttentionWithBlockPoolingBuilder,
|
||||
"get_kv_cache_shape": lambda num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
cache_dtype_str: (
|
||||
2,
|
||||
num_blocks,
|
||||
# we stretch each block by `block_pool_size`
|
||||
block_size * block_pool_size,
|
||||
num_kv_heads // block_pool_size,
|
||||
head_size,
|
||||
), # TODO: generalize to other backends
|
||||
},
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class WhisperAttentionWithBlockPooling(Attention):
|
||||
"""Attention layer with block pooling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int | None = None,
|
||||
alibi_slopes: list[float] | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
logits_soft_cap: float | None = None,
|
||||
per_layer_sliding_window: int | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
block_pool_size: int = 1,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
self.block_pool_size = block_pool_size
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
attn_backend = create_whisper_attention_backend_with_block_pooling(
|
||||
underlying_attn_backend, block_pool_size
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
per_layer_sliding_window=per_layer_sliding_window,
|
||||
prefix=prefix,
|
||||
attn_type=attn_type,
|
||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||
attn_backend=attn_backend,
|
||||
**extra_impl_args,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig):
|
||||
kv_cache_spec = super().get_kv_cache_spec(vllm_config)
|
||||
assert isinstance(kv_cache_spec, AttentionSpec)
|
||||
kv_cache_spec = replace(
|
||||
kv_cache_spec,
|
||||
num_kv_heads=self.block_pool_size * kv_cache_spec.num_kv_heads,
|
||||
)
|
||||
return kv_cache_spec
|
||||
@ -184,18 +184,42 @@ def _remap_mistral_audio_args(config: dict) -> dict:
|
||||
whisper_args = config["multimodal"].pop("whisper_model_args")
|
||||
encoder_args = whisper_args["encoder_args"]
|
||||
downsample_args = whisper_args["downsample_args"]
|
||||
downsample_factor = downsample_args["downsample_factor"]
|
||||
|
||||
# make sure that k/v blocks can be allocated with
|
||||
# unified k/v cache class and pool whisper k/v cache blocks
|
||||
# with downsample_factor:1 ratio
|
||||
if encoder_args.get("causal"):
|
||||
block_pool_size = downsample_factor
|
||||
config["projection_size"] = downsample_factor * encoder_args["dim"]
|
||||
else:
|
||||
block_pool_size = 1
|
||||
|
||||
_maybe_sliding_window = encoder_args.get("ragged_attention", None)
|
||||
if _maybe_sliding_window is None:
|
||||
sliding_window = None
|
||||
elif _maybe_sliding_window.isdigit():
|
||||
sliding_window = int(_maybe_sliding_window)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported: {_maybe_sliding_window=}")
|
||||
|
||||
architecture = (
|
||||
"VoxtralStreamingGeneration"
|
||||
if encoder_args.get("causal")
|
||||
else "VoxtralForConditionalGeneration"
|
||||
)
|
||||
|
||||
quant_config = config.get("quantization_config")
|
||||
config = {
|
||||
"model_type": "whixtral",
|
||||
"architectures": ["VoxtralForConditionalGeneration"],
|
||||
"model_type": "voxtral",
|
||||
"architectures": [architecture],
|
||||
"text_config": PretrainedConfig.from_dict(config),
|
||||
"audio_config": WhisperConfig(
|
||||
num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
|
||||
window_size=encoder_args["audio_encoding_args"]["window_size"],
|
||||
sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
|
||||
hop_length=encoder_args["audio_encoding_args"]["hop_length"],
|
||||
downsample_factor=downsample_args["downsample_factor"],
|
||||
downsample_factor=downsample_factor,
|
||||
d_model=encoder_args["dim"],
|
||||
encoder_layers=encoder_args["n_layers"],
|
||||
encoder_ffn_dim=encoder_args["hidden_dim"],
|
||||
@ -203,6 +227,10 @@ def _remap_mistral_audio_args(config: dict) -> dict:
|
||||
vocab_size=encoder_args["vocab_size"],
|
||||
max_source_positions=encoder_args["max_source_positions"],
|
||||
is_encoder_decoder=False, # Override WhisperConfig default
|
||||
is_causal=encoder_args.get("causal", False),
|
||||
sliding_window=sliding_window,
|
||||
block_pool_size=block_pool_size,
|
||||
pos_embed=encoder_args.get("pos_embed", "sinusoidal"),
|
||||
),
|
||||
}
|
||||
if quant_config:
|
||||
|
||||
@ -835,6 +835,15 @@ def subclass_attention_backend(
|
||||
)
|
||||
|
||||
|
||||
def subclass_attention_backend_with_overrides(
|
||||
name_prefix: str,
|
||||
attention_backend_cls: type[AttentionBackend],
|
||||
overrides: dict[str, Any],
|
||||
) -> type[AttentionBackend]:
|
||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||
return type(name, (attention_backend_cls,), overrides)
|
||||
|
||||
|
||||
def split_decodes_prefills_and_extends(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
|
||||
@ -2457,6 +2457,17 @@ class GPUModelRunner(
|
||||
return round_up(num_scheduled_tokens, tp_size)
|
||||
return num_scheduled_tokens
|
||||
|
||||
def _prepare_mm_inputs(
|
||||
self, num_tokens: int
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
||||
if self.model.requires_raw_input_tokens:
|
||||
input_ids = self.input_ids.gpu[:num_tokens]
|
||||
else:
|
||||
input_ids = None
|
||||
|
||||
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
|
||||
return input_ids, inputs_embeds
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@ -2499,8 +2510,7 @@ class GPUModelRunner(
|
||||
# TODO(woosuk): Avoid the copy. Optimize.
|
||||
self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled)
|
||||
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
|
||||
input_ids, inputs_embeds = self._prepare_mm_inputs(num_input_tokens)
|
||||
model_kwargs = {
|
||||
**self._init_model_kwargs(num_scheduled_tokens),
|
||||
**self._extract_mm_kwargs(scheduler_output),
|
||||
@ -4220,8 +4230,8 @@ class GPUModelRunner(
|
||||
assert num_tokens_padded <= self.max_num_tokens
|
||||
model_kwargs = self._init_model_kwargs(num_tokens_padded)
|
||||
if self.supports_mm_inputs and not self.model_config.is_encoder_decoder:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
|
||||
input_ids, inputs_embeds = self._prepare_mm_inputs(num_tokens_padded)
|
||||
|
||||
model_kwargs = {
|
||||
**model_kwargs,
|
||||
**self._dummy_mm_kwargs(num_reqs),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user