mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 09:37:04 +08:00
Merge 762ca9e38a3a06bbc1bdb0c909ba5d0a521754f8 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
5ed23c8017
@ -134,6 +134,7 @@ class MMEncoderAttention(CustomOp):
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens=cu_seqlens,
|
||||
softmax_scale=self.scale,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
@ -172,6 +173,7 @@ class MMEncoderAttention(CustomOp):
|
||||
batch_size=bsz,
|
||||
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
|
||||
fa_version=self._fa_version,
|
||||
softmax_scale=self.scale,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
|
||||
@ -29,6 +29,7 @@ def flash_attn_maxseqlen_wrapper(
|
||||
fa_version: int | None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
softmax_scale: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
kwargs = {}
|
||||
if is_rocm_aiter:
|
||||
@ -57,6 +58,7 @@ def flash_attn_maxseqlen_wrapper(
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs,
|
||||
)
|
||||
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
|
||||
@ -67,11 +69,12 @@ def flash_attn_maxseqlen_wrapper_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
fa_version: int | None,
|
||||
fa_version: int,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q)
|
||||
|
||||
@ -92,6 +95,7 @@ def vit_flash_attn_wrapper(
|
||||
fa_version: int | None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
softmax_scale: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
||||
q,
|
||||
@ -102,16 +106,22 @@ def vit_flash_attn_wrapper(
|
||||
fa_version,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
softmax_scale,
|
||||
)
|
||||
|
||||
|
||||
def apply_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
def apply_sdpa(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
softmax_scale: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Input shape:
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
|
||||
output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
|
||||
output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, scale=softmax_scale)
|
||||
output = einops.rearrange(output, "b h s d -> b s h d ")
|
||||
return output
|
||||
|
||||
@ -123,6 +133,7 @@ def torch_sdpa_wrapper(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
softmax_scale: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Never remove the contiguous logic for ROCm
|
||||
# Without it, hallucinations occur with the backend
|
||||
@ -141,7 +152,7 @@ def torch_sdpa_wrapper(
|
||||
k_chunks = torch.split(k, lens, dim=1)
|
||||
v_chunks = torch.split(v, lens, dim=1)
|
||||
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||
output_i = apply_sdpa(q_i, k_i, v_i)
|
||||
output_i = apply_sdpa(q_i, k_i, v_i, softmax_scale)
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
return context_layer
|
||||
@ -152,6 +163,7 @@ def torch_sdpa_wrapper_fake(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q)
|
||||
|
||||
@ -168,5 +180,6 @@ def vit_torch_sdpa_wrapper(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
softmax_scale: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)
|
||||
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens, softmax_scale)
|
||||
|
||||
@ -26,7 +26,7 @@ from vllm.utils.math_utils import round_up
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
else:
|
||||
VllmConfig = object
|
||||
|
||||
@ -829,6 +829,12 @@ class CompilationConfig:
|
||||
if self.backend == "":
|
||||
self.backend = current_platform.get_compile_backend()
|
||||
|
||||
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
|
||||
if model_config.is_encoder_decoder:
|
||||
# Does not yet support encoder-decoder models.
|
||||
self.inductor_compile_config["combo_kernels"] = False
|
||||
self.inductor_compile_config["benchmark_combo_kernel"] = False
|
||||
|
||||
def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
|
||||
"""
|
||||
Initialize the backend for the compilation config from a vllm config.
|
||||
|
||||
@ -521,6 +521,7 @@ class VllmConfig:
|
||||
if self.model_config is not None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.model_config.verify_dual_chunk_attention_config(self.load_config)
|
||||
self.compilation_config.verify_with_model_config(self.model_config)
|
||||
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
@ -753,14 +754,19 @@ class VllmConfig:
|
||||
if (
|
||||
self.model_config
|
||||
and self.model_config.architecture == "WhisperForConditionalGeneration"
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
||||
):
|
||||
logger.warning(
|
||||
"Whisper is known to have issues with "
|
||||
"forked workers. If startup is hanging, "
|
||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||
"to 'spawn'."
|
||||
)
|
||||
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn":
|
||||
logger.warning(
|
||||
"Whisper is known to have issues with "
|
||||
"forked workers. If startup is hanging, "
|
||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||
"to 'spawn'."
|
||||
)
|
||||
if self.optimization_level > OptimizationLevel.O0:
|
||||
self.compilation_config.compile_mm_encoder = True
|
||||
logger.info(
|
||||
"Enabling encoder compilation for Whisper for better performance."
|
||||
)
|
||||
|
||||
if (
|
||||
self.kv_events_config is not None
|
||||
|
||||
@ -816,7 +816,7 @@ class VoxtralEncoderModel(nn.Module):
|
||||
input_embeds, chunks_per_example = self.prepare_inputs_for_conv(input_features)
|
||||
|
||||
# [total_num_chunks, ceil(chunk_size / downsample_factor), hidden_size]
|
||||
out = self.whisper_encoder([input_embeds])
|
||||
out = self.whisper_encoder(input_embeds)
|
||||
|
||||
# Re-concatenate the chunks
|
||||
chunk_idx = 0
|
||||
|
||||
@ -24,6 +24,7 @@ from vllm.attention.backends.abstract import (
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layers.cross_attention import CrossAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -81,6 +82,11 @@ class WhisperPosEmbedType(enum.Enum):
|
||||
LEARNED = "learned"
|
||||
|
||||
|
||||
def should_torch_compile_encoder(vllm_config: VllmConfig) -> bool:
|
||||
"""Callable to be passed to `@support_torch_compile`'s `enable_if` argument."""
|
||||
return vllm_config.compilation_config.compile_mm_encoder
|
||||
|
||||
|
||||
class WhisperAudioInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
@ -98,6 +104,9 @@ class WhisperAudioInputs(TensorSchema):
|
||||
class WhisperEncoderAttention(MMEncoderAttention):
|
||||
"""Multi-headed attention for Whisper encoder with 2D tensor support."""
|
||||
|
||||
def __init__(self, num_heads: int, head_size: int, scale: float, num_kv_heads: int):
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@ -347,7 +356,13 @@ class WhisperMLP(nn.Module):
|
||||
|
||||
|
||||
class WhisperEncoderLayer(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
skip_overflow_clamp: bool = False,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
is_causal = getattr(config, "is_causal", False)
|
||||
@ -355,6 +370,7 @@ class WhisperEncoderLayer(nn.Module):
|
||||
block_pool_size = getattr(config, "block_pool_size", 1)
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self._skip_overflow_clamp = skip_overflow_clamp
|
||||
|
||||
self.embed_dim = config.d_model
|
||||
self.self_attn = WhisperAttention(
|
||||
@ -390,7 +406,9 @@ class WhisperEncoderLayer(nn.Module):
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
hidden_states = cast_overflow_tensors(hidden_states)
|
||||
# Not compatible with torch.compile
|
||||
if not self._skip_overflow_clamp:
|
||||
hidden_states = cast_overflow_tensors(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -454,6 +472,9 @@ class WhisperDecoderLayer(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={"input_features": 0}, enable_if=should_torch_compile_encoder
|
||||
)
|
||||
class WhisperEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
|
||||
@ -479,7 +500,9 @@ class WhisperEncoder(nn.Module):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.encoder_layers,
|
||||
lambda prefix: WhisperEncoderLayer(
|
||||
vllm_config=vllm_config, prefix=f"{prefix}.layers"
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.layers",
|
||||
skip_overflow_clamp=should_torch_compile_encoder(vllm_config),
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
@ -511,37 +534,18 @@ class WhisperEncoder(nn.Module):
|
||||
sinusoids(*self.embed_positions.weight.shape)
|
||||
)
|
||||
|
||||
def forward_conv(
|
||||
self, input_features: torch.Tensor | list[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
hidden_states = []
|
||||
input_is_batched = False
|
||||
for features in input_features:
|
||||
embeds = nn.functional.gelu(self.conv1(features))
|
||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||
def forward_conv(self, input_features: torch.Tensor) -> torch.Tensor:
|
||||
embeds = nn.functional.gelu(self.conv1(input_features))
|
||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||
embeds = embeds.transpose(-1, -2)
|
||||
|
||||
if self.pos_embed_type in (
|
||||
WhisperPosEmbedType.SINUSOIDAL,
|
||||
WhisperPosEmbedType.LEARNED,
|
||||
):
|
||||
embeds = embeds.transpose(-1, -2)
|
||||
embeds = (
|
||||
embeds + self.embed_positions.weight[: embeds.size(-2), :]
|
||||
).to(embeds.dtype)
|
||||
elif self.pos_embed_type == WhisperPosEmbedType.NOPE:
|
||||
embeds = embeds.transpose(-1, -2).to(embeds.dtype)
|
||||
else:
|
||||
raise ValueError(f"Unknown pos_embed_type: {self.pos_embed_type}")
|
||||
if self.pos_embed_type == WhisperPosEmbedType.NOPE:
|
||||
return embeds
|
||||
|
||||
hidden_states.append(embeds)
|
||||
input_is_batched = embeds.ndim > 2
|
||||
# Input to MHA must be B x T x D
|
||||
if input_is_batched:
|
||||
# Models using WhisperEncoder may handle batching internally.
|
||||
hidden_states = torch.cat(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.stack(hidden_states, dim=0)
|
||||
|
||||
hidden_states = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
|
||||
embeds.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward_layers(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@ -551,7 +555,7 @@ class WhisperEncoder(nn.Module):
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
|
||||
def forward(self, input_features: torch.Tensor):
|
||||
hidden_states = self.forward_conv(input_features)
|
||||
return self.forward_layers(hidden_states)
|
||||
|
||||
@ -630,7 +634,7 @@ class WhisperModel(nn.Module):
|
||||
|
||||
def get_encoder_outputs(
|
||||
self,
|
||||
input_features: torch.Tensor | list[torch.Tensor] | None,
|
||||
input_features: torch.Tensor | None,
|
||||
) -> torch.Tensor | None:
|
||||
if input_features is None:
|
||||
return None
|
||||
@ -951,7 +955,6 @@ class WhisperForConditionalGeneration(
|
||||
|
||||
if input_features is not None:
|
||||
input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)
|
||||
|
||||
return WhisperAudioInputs(input_features=input_features)
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user