mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 08:17:03 +08:00
300 lines
9.6 KiB
Python
300 lines
9.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import copy
|
|
import functools
|
|
import math
|
|
from dataclasses import replace
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
from vllm.attention.backends.abstract import (
|
|
AttentionBackend,
|
|
AttentionMetadata,
|
|
AttentionType,
|
|
)
|
|
from vllm.attention.layer import Attention
|
|
from vllm.attention.selector import get_attn_backend
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
|
from vllm.v1.attention.backends.utils import (
|
|
CommonAttentionMetadata,
|
|
subclass_attention_backend_with_overrides,
|
|
)
|
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
|
|
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
|
|
ISO639_1_SUPPORTED_LANGS = {
|
|
"af": "Afrikaans",
|
|
"ar": "Arabic",
|
|
"hy": "Armenian",
|
|
"az": "Azerbaijani",
|
|
"be": "Belarusian",
|
|
"bs": "Bosnian",
|
|
"bg": "Bulgarian",
|
|
"ca": "Catalan",
|
|
"zh": "Chinese",
|
|
"hr": "Croatian",
|
|
"cs": "Czech",
|
|
"da": "Danish",
|
|
"nl": "Dutch",
|
|
"en": "English",
|
|
"et": "Estonian",
|
|
"fi": "Finnish",
|
|
"fr": "French",
|
|
"gl": "Galician",
|
|
"de": "German",
|
|
"el": "Greek",
|
|
"he": "Hebrew",
|
|
"hi": "Hindi",
|
|
"hu": "Hungarian",
|
|
"is": "Icelandic",
|
|
"id": "Indonesian",
|
|
"it": "Italian",
|
|
"ja": "Japanese",
|
|
"kn": "Kannada",
|
|
"kk": "Kazakh",
|
|
"ko": "Korean",
|
|
"lv": "Latvian",
|
|
"lt": "Lithuanian",
|
|
"mk": "Macedonian",
|
|
"ms": "Malay",
|
|
"mr": "Marathi",
|
|
"mi": "Maori",
|
|
"ne": "Nepali",
|
|
"no": "Norwegian",
|
|
"fa": "Persian",
|
|
"pl": "Polish",
|
|
"pt": "Portuguese",
|
|
"ro": "Romanian",
|
|
"ru": "Russian",
|
|
"sr": "Serbian",
|
|
"sk": "Slovak",
|
|
"sl": "Slovenian",
|
|
"es": "Spanish",
|
|
"sw": "Swahili",
|
|
"sv": "Swedish",
|
|
"tl": "Tagalog",
|
|
"ta": "Tamil",
|
|
"th": "Thai",
|
|
"tr": "Turkish",
|
|
"uk": "Ukrainian",
|
|
"ur": "Urdu",
|
|
"vi": "Vietnamese",
|
|
"cy": "Welsh",
|
|
}
|
|
|
|
|
|
def _pad1d(
|
|
x: torch.Tensor,
|
|
paddings: tuple[int, int],
|
|
mode: str = "constant",
|
|
value: float = 0.0,
|
|
) -> torch.Tensor:
|
|
"""Tiny wrapper around F.pad, just to allow for
|
|
reflect padding on small input.
|
|
If this is the case, we insert extra 0 padding
|
|
to the right before the reflection happen.
|
|
"""
|
|
length = x.shape[-1]
|
|
padding_left, padding_right = paddings
|
|
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
|
if mode == "reflect":
|
|
max_pad = max(padding_left, padding_right)
|
|
extra_pad = 0
|
|
if length <= max_pad:
|
|
extra_pad = max_pad - length + 1
|
|
x = F.pad(x, (0, extra_pad))
|
|
padded = F.pad(x, paddings, mode, value)
|
|
end = padded.shape[-1] - extra_pad
|
|
return padded[..., :end]
|
|
else:
|
|
return F.pad(x, paddings, mode, value)
|
|
|
|
|
|
class WhisperCausalConv1d(nn.Conv1d):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: int,
|
|
stride: int = 1,
|
|
padding: int = 0,
|
|
bias: bool = True,
|
|
) -> None:
|
|
super().__init__(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=bias,
|
|
)
|
|
self._stride = self.stride[0]
|
|
self._effective_kernel_size = (kernel_size - 1) * self.dilation[0] + 1
|
|
self._padding_total = self._effective_kernel_size - self._stride
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
n_frames = (
|
|
x.shape[-1] - self._effective_kernel_size + self._padding_total
|
|
) / self._stride + 1
|
|
target_length = (math.ceil(n_frames) - 1) * self._stride + (
|
|
self._effective_kernel_size - self._padding_total
|
|
)
|
|
extra_padding = target_length - x.shape[-1]
|
|
x = _pad1d(x, (self._padding_total, extra_padding), mode="constant")
|
|
return super().forward(x)
|
|
|
|
|
|
@functools.lru_cache
|
|
def create_whisper_attention_backend_with_block_pooling(
|
|
underlying_attn_backend: AttentionBackend, block_pool_size: int
|
|
) -> type[AttentionBackend]:
|
|
prefix = "WhisperAttentionWithBlockPooling_"
|
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
|
|
|
class WhisperAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
|
|
def __init__(
|
|
self,
|
|
kv_cache_spec: AttentionSpec,
|
|
layer_names: list[str],
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
):
|
|
assert kv_cache_spec.num_kv_heads % block_pool_size == 0
|
|
kv_cache_spec = replace(
|
|
kv_cache_spec,
|
|
block_size=kv_cache_spec.block_size * block_pool_size,
|
|
num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size,
|
|
)
|
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
|
|
|
def build(
|
|
self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
fast_build: bool = False,
|
|
) -> AttentionMetadata:
|
|
new_common_attn_metadata = copy.deepcopy(common_attn_metadata)
|
|
new_common_attn_metadata.query_start_loc *= block_pool_size
|
|
new_common_attn_metadata.query_start_loc_cpu *= block_pool_size
|
|
new_common_attn_metadata.seq_lens *= block_pool_size
|
|
new_common_attn_metadata._seq_lens_cpu *= block_pool_size
|
|
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size
|
|
new_common_attn_metadata.num_actual_tokens *= block_pool_size
|
|
new_common_attn_metadata.max_query_len *= block_pool_size
|
|
new_common_attn_metadata.max_seq_len *= block_pool_size
|
|
original_slot_mapping = common_attn_metadata.slot_mapping
|
|
common_prefix_len *= block_pool_size
|
|
new_common_attn_metadata.slot_mapping = (
|
|
(
|
|
original_slot_mapping.unsqueeze(1) * block_pool_size
|
|
+ torch.arange(block_pool_size, device=original_slot_mapping.device)
|
|
)
|
|
.flatten()
|
|
.clamp(min=-1)
|
|
)
|
|
return super().build(
|
|
common_prefix_len, new_common_attn_metadata, fast_build
|
|
)
|
|
|
|
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
|
|
raise NotImplementedError(
|
|
f"{underlying_attn_backend} is not yet supported."
|
|
"Contributions to support more backends are much "
|
|
"appreciated."
|
|
)
|
|
|
|
attn_backend = subclass_attention_backend_with_overrides(
|
|
name_prefix=prefix,
|
|
attention_backend_cls=underlying_attn_backend,
|
|
overrides={
|
|
"get_builder_cls": lambda: WhisperAttentionWithBlockPoolingBuilder,
|
|
"get_kv_cache_shape": lambda num_blocks,
|
|
block_size,
|
|
num_kv_heads,
|
|
head_size,
|
|
cache_dtype_str: (
|
|
2,
|
|
num_blocks,
|
|
# we stretch each block by `block_pool_size`
|
|
block_size * block_pool_size,
|
|
num_kv_heads // block_pool_size,
|
|
head_size,
|
|
), # TODO: generalize to other backends
|
|
},
|
|
)
|
|
|
|
return attn_backend
|
|
|
|
|
|
class WhisperAttentionWithBlockPooling(Attention):
|
|
"""Attention layer with block pooling."""
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int | None = None,
|
|
alibi_slopes: list[float] | None = None,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
logits_soft_cap: float | None = None,
|
|
per_layer_sliding_window: int | None = None,
|
|
prefix: str = "",
|
|
attn_type: str = AttentionType.DECODER,
|
|
kv_sharing_target_layer_name: str | None = None,
|
|
block_pool_size: int = 1,
|
|
attn_backend: type[AttentionBackend] | None = None,
|
|
**extra_impl_args,
|
|
) -> None:
|
|
self.block_pool_size = block_pool_size
|
|
dtype = torch.get_default_dtype()
|
|
|
|
if cache_config is not None:
|
|
kv_cache_dtype = cache_config.cache_dtype
|
|
block_size = cache_config.block_size
|
|
else:
|
|
kv_cache_dtype = "auto"
|
|
block_size = 16
|
|
|
|
underlying_attn_backend = get_attn_backend(
|
|
head_size,
|
|
dtype,
|
|
kv_cache_dtype,
|
|
block_size,
|
|
attn_type=attn_type,
|
|
)
|
|
attn_backend = create_whisper_attention_backend_with_block_pooling(
|
|
underlying_attn_backend, block_pool_size
|
|
)
|
|
|
|
super().__init__(
|
|
num_heads=num_heads,
|
|
head_size=head_size,
|
|
scale=scale,
|
|
num_kv_heads=num_kv_heads,
|
|
alibi_slopes=alibi_slopes,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
logits_soft_cap=logits_soft_cap,
|
|
per_layer_sliding_window=per_layer_sliding_window,
|
|
prefix=prefix,
|
|
attn_type=attn_type,
|
|
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
|
attn_backend=attn_backend,
|
|
**extra_impl_args,
|
|
)
|
|
|
|
def get_kv_cache_spec(self, vllm_config: VllmConfig):
|
|
kv_cache_spec = super().get_kv_cache_spec(vllm_config)
|
|
assert isinstance(kv_cache_spec, AttentionSpec)
|
|
kv_cache_spec = replace(
|
|
kv_cache_spec,
|
|
num_kv_heads=self.block_pool_size * kv_cache_spec.num_kv_heads,
|
|
)
|
|
return kv_cache_spec
|