mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 15:29:09 +08:00
[v1] Add Whisper model support (encoder-decoder) (#21088)
Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
4db4426404
commit
37e8182bfe
@ -321,7 +321,6 @@ steps:
|
|||||||
- python3 offline_inference/vision_language_pooling.py --seed 0
|
- python3 offline_inference/vision_language_pooling.py --seed 0
|
||||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||||
- VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
- VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||||
- python3 offline_inference/encoder_decoder.py
|
|
||||||
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
|
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
|
||||||
- python3 offline_inference/basic/classify.py
|
- python3 offline_inference/basic/classify.py
|
||||||
- python3 offline_inference/basic/embed.py
|
- python3 offline_inference/basic/embed.py
|
||||||
@ -644,7 +643,7 @@ steps:
|
|||||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||||
- pip freeze | grep -E 'torch'
|
- pip freeze | grep -E 'torch'
|
||||||
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
|
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
|
||||||
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
- cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
||||||
|
|
||||||
- label: Multi-Modal Models Test (Extended) 1
|
- label: Multi-Modal Models Test (Extended) 1
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental]
|
||||||
@ -818,7 +817,8 @@ steps:
|
|||||||
# Avoid importing model tests that cause CUDA reinitialization error
|
# Avoid importing model tests that cause CUDA reinitialization error
|
||||||
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
|
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
|
||||||
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
|
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
|
||||||
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)'
|
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
|
||||||
|
- VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)'
|
||||||
# test sequence parallel
|
# test sequence parallel
|
||||||
- pytest -v -s distributed/test_sequence_parallel.py
|
- pytest -v -s distributed/test_sequence_parallel.py
|
||||||
# this test fails consistently.
|
# this test fails consistently.
|
||||||
|
|||||||
@ -5,6 +5,8 @@ Demonstrate prompting of text-to-text
|
|||||||
encoder/decoder models, specifically BART and mBART.
|
encoder/decoder models, specifically BART and mBART.
|
||||||
|
|
||||||
This script is refactored to allow model selection via command-line arguments.
|
This script is refactored to allow model selection via command-line arguments.
|
||||||
|
|
||||||
|
NOTE: This example is not yet supported in V1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|||||||
@ -5,6 +5,7 @@ This example shows how to use vLLM for running offline inference with
|
|||||||
the explicit/implicit prompt format on enc-dec LMMs for text generation.
|
the explicit/implicit prompt format on enc-dec LMMs for text generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
@ -130,6 +131,8 @@ def run_mllama():
|
|||||||
|
|
||||||
|
|
||||||
def run_whisper():
|
def run_whisper():
|
||||||
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
|
|
||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model="openai/whisper-large-v3-turbo",
|
model="openai/whisper-large-v3-turbo",
|
||||||
max_model_len=448,
|
max_model_len=448,
|
||||||
|
|||||||
@ -63,6 +63,7 @@ def clear_cache():
|
|||||||
current_platform.is_cpu(),
|
current_platform.is_cpu(),
|
||||||
reason="CPU backend is not currently supported with encoder/decoder models"
|
reason="CPU backend is not currently supported with encoder/decoder models"
|
||||||
)
|
)
|
||||||
|
@pytest.mark.skip(reason="bart not supported in V1")
|
||||||
def test_encoder_decoder_e2e(
|
def test_encoder_decoder_e2e(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
|
|||||||
@ -30,6 +30,7 @@ async def client(server):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.skip(reason="bart is not yet supported in V1")
|
||||||
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
|
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
|
||||||
completion = await client.completions.create(model=model_name,
|
completion = await client.completions.create(model=model_name,
|
||||||
prompt="Hello, my name is",
|
prompt="Hello, my name is",
|
||||||
|
|||||||
@ -178,6 +178,7 @@ def run_test(
|
|||||||
@pytest.mark.parametrize("max_tokens", [64])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
|
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
|
||||||
|
@pytest.mark.skip(reason="bart not supported in V1")
|
||||||
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
|
def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
|
||||||
dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
|
dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
|
||||||
|
|
||||||
@ -201,6 +202,7 @@ def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
|
|||||||
@pytest.mark.parametrize("max_tokens", [64])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM])
|
@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM])
|
||||||
|
@pytest.mark.skip(reason="bart not supported in V1")
|
||||||
def test_models_distributed(hf_runner, vllm_runner,
|
def test_models_distributed(hf_runner, vllm_runner,
|
||||||
example_encoder_decoder_prompts,
|
example_encoder_decoder_prompts,
|
||||||
distributed_executor_backend, model, dtype,
|
distributed_executor_backend, model, dtype,
|
||||||
|
|||||||
@ -122,8 +122,7 @@ def run_test(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.core_model
|
@pytest.mark.core_model
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
||||||
"model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"])
|
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_models(vllm_runner, model) -> None:
|
def test_models(vllm_runner, model) -> None:
|
||||||
run_test(
|
run_test(
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from ...utils import dummy_hf_overrides
|
|||||||
|
|
||||||
ARCH_TO_SKIP = {
|
ARCH_TO_SKIP = {
|
||||||
"MolmoForCausalLM": "incompatible requirements",
|
"MolmoForCausalLM": "incompatible requirements",
|
||||||
|
"Florence2ForConditionalGeneration": "not supported in V1",
|
||||||
}
|
}
|
||||||
ARCH_NEEDS_EXTRAS = [
|
ARCH_NEEDS_EXTRAS = [
|
||||||
"InternVLChatModel",
|
"InternVLChatModel",
|
||||||
|
|||||||
@ -68,6 +68,12 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
|||||||
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
|
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
|
||||||
# L4 supports FA3.
|
# L4 supports FA3.
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
|
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
|
||||||
|
if model_arch == "Florence2ForConditionalGeneration":
|
||||||
|
# An encoder-decoder model that's V0-only. Just skip it
|
||||||
|
# since V0 is about to be removed.
|
||||||
|
pytest.skip("Skipping Florence2ForConditionalGeneration")
|
||||||
|
if model_arch == "WhisperForConditionalGeneration":
|
||||||
|
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||||
LLM(
|
LLM(
|
||||||
model_info.default,
|
model_info.default,
|
||||||
tokenizer=model_info.tokenizer,
|
tokenizer=model_info.tokenizer,
|
||||||
|
|||||||
@ -10,7 +10,6 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
|||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
|
||||||
UNSUPPORTED_MODELS_V1 = [
|
UNSUPPORTED_MODELS_V1 = [
|
||||||
"openai/whisper-large-v3", # transcription
|
|
||||||
"facebook/bart-large-cnn", # encoder decoder
|
"facebook/bart-large-cnn", # encoder decoder
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
160
vllm/attention/layers/cross_attention.py
Normal file
160
vllm/attention/layers/cross_attention.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import functools
|
||||||
|
from copy import copy
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import CacheConfig
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
|
AttentionMetadata, AttentionType)
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||||
|
subclass_attention_backend)
|
||||||
|
from vllm.v1.kv_cache_interface import CrossAttentionSpec
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_max_encoder_len(vllm_config: VllmConfig) -> int:
|
||||||
|
return MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(
|
||||||
|
vllm_config.model_config)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray,
|
||||||
|
block_table_tensor: torch.Tensor,
|
||||||
|
kv_cache_spec: CrossAttentionSpec,
|
||||||
|
device: torch.device) -> torch.Tensor:
|
||||||
|
"""Get cross-attention slot mappings."""
|
||||||
|
|
||||||
|
block_size = kv_cache_spec.block_size
|
||||||
|
slot_mappings = []
|
||||||
|
|
||||||
|
# Find indices with non-zero encoder sequence lengths
|
||||||
|
# The majority of parallel requests will be running the
|
||||||
|
# decoder, so this list should be relatively small.
|
||||||
|
active_indices = np.nonzero(encoder_seq_lens)[0]
|
||||||
|
|
||||||
|
for req_index in active_indices:
|
||||||
|
encoder_seq_len = encoder_seq_lens[req_index].item()
|
||||||
|
|
||||||
|
# Calculate the number of blocks needed for this request
|
||||||
|
num_blocks_needed = cdiv(encoder_seq_len, block_size)
|
||||||
|
|
||||||
|
# Get the block IDs for this request from the tensor
|
||||||
|
req_block_ids = block_table_tensor[req_index]
|
||||||
|
|
||||||
|
# Get only the blocks we need (first num_blocks_needed blocks)
|
||||||
|
needed_block_ids = req_block_ids[:num_blocks_needed]
|
||||||
|
|
||||||
|
# All needed blocks are allocated
|
||||||
|
i_values = torch.arange(encoder_seq_len,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device)
|
||||||
|
block_indices = i_values // block_size
|
||||||
|
block_offsets = i_values % block_size
|
||||||
|
block_numbers = needed_block_ids[block_indices]
|
||||||
|
slot_mapping = block_numbers * block_size + block_offsets
|
||||||
|
|
||||||
|
slot_mappings.append(slot_mapping)
|
||||||
|
|
||||||
|
if slot_mappings:
|
||||||
|
return torch.cat(slot_mappings)
|
||||||
|
else:
|
||||||
|
return torch.empty(0, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def create_cross_attention_backend(
|
||||||
|
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
|
||||||
|
prefix = "CrossAttention_"
|
||||||
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
|
|
||||||
|
class CrossAttentionBuilder(underlying_builder): # type: ignore
|
||||||
|
|
||||||
|
def build(self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False) -> AttentionMetadata:
|
||||||
|
new_metadata = copy(common_attn_metadata)
|
||||||
|
new_metadata.causal = False
|
||||||
|
max_encoder_len = _get_max_encoder_len(self.vllm_config)
|
||||||
|
new_metadata.max_seq_len = max_encoder_len
|
||||||
|
|
||||||
|
new_metadata.seq_lens = torch.full(
|
||||||
|
(new_metadata.num_reqs, ),
|
||||||
|
max_encoder_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
new_metadata.seq_lens_cpu = torch.full(
|
||||||
|
(new_metadata.num_reqs, ),
|
||||||
|
max_encoder_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
new_metadata.slot_mapping = _get_cross_slot_mapping(
|
||||||
|
new_metadata.encoder_seq_lens, new_metadata.block_table_tensor,
|
||||||
|
self.kv_cache_spec, self.device)
|
||||||
|
return super().build(common_prefix_len, new_metadata, fast_build)
|
||||||
|
|
||||||
|
attn_backend = subclass_attention_backend(
|
||||||
|
name_prefix=prefix,
|
||||||
|
attention_backend_cls=underlying_attn_backend,
|
||||||
|
builder_cls=CrossAttentionBuilder)
|
||||||
|
|
||||||
|
return attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttention(Attention):
|
||||||
|
"""
|
||||||
|
Cross-attention for encoder-decoder models.
|
||||||
|
Handles attention between decoder queries and encoder keys/values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
attn_type: Optional[str] = None,
|
||||||
|
**kwargs):
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
if cache_config is not None:
|
||||||
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
|
block_size = cache_config.block_size
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
block_size = 16
|
||||||
|
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size)
|
||||||
|
|
||||||
|
attn_backend = create_cross_attention_backend(
|
||||||
|
underlying_attn_backend)
|
||||||
|
else:
|
||||||
|
# in v0 cross attention is handled inside the backends
|
||||||
|
attn_backend = None
|
||||||
|
|
||||||
|
if attn_type is not None:
|
||||||
|
assert attn_type == AttentionType.ENCODER_DECODER, (
|
||||||
|
"CrossAttention only supports AttentionType.ENCODER_DECODER")
|
||||||
|
|
||||||
|
super().__init__(num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
scale=scale,
|
||||||
|
cache_config=cache_config,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
attn_type=AttentionType.ENCODER_DECODER,
|
||||||
|
**kwargs)
|
||||||
@ -8,6 +8,7 @@ import enum
|
|||||||
import hashlib
|
import hashlib
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import textwrap
|
import textwrap
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
@ -41,6 +42,7 @@ from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
|||||||
from vllm.config.utils import ConfigType, config
|
from vllm.config.utils import ConfigType, config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||||
@ -3509,16 +3511,33 @@ class VllmConfig:
|
|||||||
|
|
||||||
disable_chunked_prefill_reasons: list[str] = []
|
disable_chunked_prefill_reasons: list[str] = []
|
||||||
|
|
||||||
if self.model_config and self.model_config.pooler_config:
|
if self.model_config:
|
||||||
pooling_type = self.model_config.pooler_config.pooling_type
|
if self.model_config.pooler_config:
|
||||||
if pooling_type is None or pooling_type.lower() != "last":
|
pooling_type = self.model_config.pooler_config.pooling_type
|
||||||
|
if pooling_type is None or pooling_type.lower() != "last":
|
||||||
|
disable_chunked_prefill_reasons.append(
|
||||||
|
"Only \"last\" pooling supports chunked "
|
||||||
|
"prefill and prefix caching; disabling both.")
|
||||||
|
elif self.model_config.is_encoder_decoder:
|
||||||
|
self.scheduler_config.max_num_encoder_input_tokens = \
|
||||||
|
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
|
||||||
|
logger.debug(
|
||||||
|
"Encoder-decoder model detected: setting "
|
||||||
|
"`max_num_encoder_input_tokens` to encoder length (%s)",
|
||||||
|
self.scheduler_config.max_num_encoder_input_tokens)
|
||||||
|
self.scheduler_config.disable_chunked_mm_input = True
|
||||||
disable_chunked_prefill_reasons.append(
|
disable_chunked_prefill_reasons.append(
|
||||||
"Only \"last\" pooling supports chunked "
|
"Encoder-decoder models do not support chunked prefill nor"
|
||||||
"prefill and prefix caching; disabling both.")
|
" prefix caching; disabling both.")
|
||||||
elif not getattr(self.model_config.hf_config, "is_causal", True):
|
if (self.model_config.architecture
|
||||||
disable_chunked_prefill_reasons.append(
|
== "WhisperForConditionalGeneration"
|
||||||
"Only models using causal attention supports chunked "
|
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD")
|
||||||
"prefill and prefix caching; disabling both.")
|
!= "spawn"):
|
||||||
|
logger.warning(
|
||||||
|
"Whisper is known to have issues with "
|
||||||
|
"forked workers. If startup is hanging, "
|
||||||
|
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||||
|
"to 'spawn'.")
|
||||||
|
|
||||||
if disable_chunked_prefill_reasons:
|
if disable_chunked_prefill_reasons:
|
||||||
for reason in disable_chunked_prefill_reasons:
|
for reason in disable_chunked_prefill_reasons:
|
||||||
|
|||||||
@ -600,7 +600,6 @@ class VoxtralEncoderModel(nn.Module):
|
|||||||
self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config,
|
self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(
|
prefix=maybe_prefix(
|
||||||
prefix, "whisper_encoder"),
|
prefix, "whisper_encoder"),
|
||||||
is_standalone_encoder=True,
|
|
||||||
init_in_fp32=True)
|
init_in_fp32=True)
|
||||||
mel_filters = mel_filter_bank(
|
mel_filters = mel_filter_bank(
|
||||||
num_frequency_bins=1 + self.config.window_size // 2,
|
num_frequency_bins=1 + self.config.window_size // 2,
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from transformers.models.whisper.modeling_whisper import sinusoids
|
|||||||
|
|
||||||
from vllm.attention import Attention, AttentionType
|
from vllm.attention import Attention, AttentionType
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
|
from vllm.attention.layers.cross_attention import CrossAttention
|
||||||
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
|
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
|
||||||
VllmConfig)
|
VllmConfig)
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@ -43,7 +44,7 @@ from vllm.transformers_utils.processor import cached_get_processor
|
|||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||||
SupportsTranscription, SupportsV0Only)
|
SupportsTranscription)
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
|
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
|
||||||
make_layers)
|
make_layers)
|
||||||
|
|
||||||
@ -124,6 +125,34 @@ class WhisperAudioInputs(TensorSchema):
|
|||||||
TensorShape("b", "nmb", "t")]
|
TensorShape("b", "nmb", "t")]
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperEncoderAttention(MultiHeadAttention):
|
||||||
|
"""Multi-headed attention for Whisper encoder with 2D tensor support."""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Input shape: batch_size x seq_len x hidden_size
|
||||||
|
or seq_len x hidden_size
|
||||||
|
"""
|
||||||
|
is_2d = query.dim() == 2
|
||||||
|
if is_2d:
|
||||||
|
query = query.unsqueeze(0)
|
||||||
|
key = key.unsqueeze(0)
|
||||||
|
value = value.unsqueeze(0)
|
||||||
|
|
||||||
|
# Call the parent forward method
|
||||||
|
out = super().forward(query, key, value)
|
||||||
|
|
||||||
|
if is_2d:
|
||||||
|
out = out.squeeze(0)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class WhisperPositionalEmbedding(nn.Embedding):
|
class WhisperPositionalEmbedding(nn.Embedding):
|
||||||
|
|
||||||
def __init__(self, num_positions: int, embedding_dim: int):
|
def __init__(self, num_positions: int, embedding_dim: int):
|
||||||
@ -144,7 +173,6 @@ class WhisperAttention(nn.Module):
|
|||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
standalone_encoder: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -180,14 +208,25 @@ class WhisperAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.out_proj",
|
prefix=f"{prefix}.out_proj",
|
||||||
)
|
)
|
||||||
if standalone_encoder:
|
if attn_type == AttentionType.ENCODER:
|
||||||
self.attn = MultiHeadAttention(
|
self.attn = WhisperEncoderAttention(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
)
|
)
|
||||||
else:
|
elif self.attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
self.attn = CrossAttention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
attn_type=self.attn_type,
|
||||||
|
)
|
||||||
|
else: # AttentionType.DECODER (regular decoder self-attention)
|
||||||
self.attn = Attention(
|
self.attn = Attention(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -332,11 +371,7 @@ class WhisperMLP(nn.Module):
|
|||||||
|
|
||||||
class WhisperEncoderLayer(nn.Module):
|
class WhisperEncoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
*,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
prefix: str = "",
|
|
||||||
is_standalone_encoder: bool = False):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
@ -350,7 +385,6 @@ class WhisperEncoderLayer(nn.Module):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
standalone_encoder=is_standalone_encoder,
|
|
||||||
)
|
)
|
||||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
self.mlp = WhisperMLP(
|
self.mlp = WhisperMLP(
|
||||||
@ -446,12 +480,10 @@ class WhisperEncoder(nn.Module):
|
|||||||
*,
|
*,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
is_standalone_encoder: bool = False,
|
|
||||||
init_in_fp32: bool = False):
|
init_in_fp32: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
embed_dim = config.d_model
|
embed_dim = config.d_model
|
||||||
self.is_standalone_encoder = is_standalone_encoder
|
|
||||||
self.num_mel_bins = config.num_mel_bins
|
self.num_mel_bins = config.num_mel_bins
|
||||||
self.max_source_positions = config.max_source_positions
|
self.max_source_positions = config.max_source_positions
|
||||||
self.embed_scale = (math.sqrt(embed_dim)
|
self.embed_scale = (math.sqrt(embed_dim)
|
||||||
@ -469,9 +501,7 @@ class WhisperEncoder(nn.Module):
|
|||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.encoder_layers,
|
config.encoder_layers,
|
||||||
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
|
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers"),
|
||||||
is_standalone_encoder=
|
|
||||||
is_standalone_encoder),
|
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||||
@ -752,7 +782,7 @@ class WhisperMultiModalProcessor(
|
|||||||
info=WhisperProcessingInfo,
|
info=WhisperProcessingInfo,
|
||||||
dummy_inputs=WhisperDummyInputsBuilder)
|
dummy_inputs=WhisperDummyInputsBuilder)
|
||||||
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||||
SupportsMultiModal, SupportsV0Only):
|
SupportsMultiModal):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"self_attn.qkv_proj": [
|
"self_attn.qkv_proj": [
|
||||||
"self_attn.q_proj",
|
"self_attn.q_proj",
|
||||||
@ -880,19 +910,17 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
|||||||
|
|
||||||
def get_multimodal_embeddings(self,
|
def get_multimodal_embeddings(self,
|
||||||
**kwargs: object) -> MultiModalEmbeddings:
|
**kwargs: object) -> MultiModalEmbeddings:
|
||||||
# TODO: This method does not obey the interface for SupportsMultiModal.
|
# Required as part of SupportsMultiModal interface.
|
||||||
# Refactor this once encoder/decoder support is implemented in V1.
|
|
||||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||||
return self.model.get_encoder_outputs(audio_input["input_features"])
|
return [self.model.get_encoder_outputs(audio_input["input_features"])]
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# TODO: This method just returns the decoder sequence embeddings since
|
# This method just returns the decoder sequence embeddings since
|
||||||
# Whisper does not have encoder text tokens. Refactor this once
|
# Whisper does not have encoder text tokens.
|
||||||
# encoder/decoder support is implemented in V1.
|
|
||||||
return self.model.decoder.get_input_embeddings(input_ids)
|
return self.model.decoder.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def _parse_and_validate_audio_input(
|
def _parse_and_validate_audio_input(
|
||||||
|
|||||||
@ -157,6 +157,7 @@ def _remap_mistral_audio_args(config: dict) -> dict:
|
|||||||
encoder_attention_heads=encoder_args["n_heads"],
|
encoder_attention_heads=encoder_args["n_heads"],
|
||||||
vocab_size=encoder_args["vocab_size"],
|
vocab_size=encoder_args["vocab_size"],
|
||||||
max_source_positions=encoder_args["max_source_positions"],
|
max_source_positions=encoder_args["max_source_positions"],
|
||||||
|
is_encoder_decoder=False, # Override WhisperConfig default
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if quant_config:
|
if quant_config:
|
||||||
|
|||||||
@ -317,8 +317,8 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
|||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
vllm_config: VllmConfig, device: torch.device) -> None:
|
vllm_config: VllmConfig, device: torch.device) -> None:
|
||||||
self.kv_cache_spec = kv_cache_spec
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.scheduler_config = vllm_config.scheduler_config
|
self.scheduler_config = vllm_config.scheduler_config
|
||||||
|
|
||||||
# For reorder
|
# For reorder
|
||||||
|
|||||||
@ -177,12 +177,11 @@ class FlashAttentionMetadataBuilder(
|
|||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
vllm_config: VllmConfig, device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
self.vllm_config = vllm_config
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
self.parallel_config = vllm_config.parallel_config
|
self.parallel_config = vllm_config.parallel_config
|
||||||
self.cache_config = vllm_config.cache_config
|
self.cache_config = vllm_config.cache_config
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
self.device = device
|
|
||||||
|
|
||||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||||
self.parallel_config)
|
self.parallel_config)
|
||||||
|
|||||||
@ -163,11 +163,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
vllm_config: VllmConfig, device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
self.device = device
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.cache_config = vllm_config.cache_config
|
self.cache_config = vllm_config.cache_config
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
self.kv_cache_spec = kv_cache_spec
|
|
||||||
self._workspace_buffer = None
|
self._workspace_buffer = None
|
||||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||||
self._decode_wrapper = None # Wrapper for decode (general shape)
|
self._decode_wrapper = None # Wrapper for decode (general shape)
|
||||||
|
|||||||
@ -516,10 +516,11 @@ class FlexAttentionMetadataBuilder(
|
|||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
vllm_config: VllmConfig, device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
|
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
self.parallel_config = vllm_config.parallel_config
|
self.parallel_config = vllm_config.parallel_config
|
||||||
self.cache_config = vllm_config.cache_config
|
self.cache_config = vllm_config.cache_config
|
||||||
self.device = device
|
|
||||||
|
|
||||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||||
self.parallel_config)
|
self.parallel_config)
|
||||||
|
|||||||
@ -39,8 +39,8 @@ class LinearAttentionMetadataBuilder(
|
|||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
vllm_config: VllmConfig, device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
assert isinstance(kv_cache_spec, MambaSpec)
|
assert isinstance(kv_cache_spec, MambaSpec)
|
||||||
self.kv_cache_spec = kv_cache_spec
|
|
||||||
|
|
||||||
def build(self,
|
def build(self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
|
|||||||
@ -22,12 +22,9 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
|||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
vllm_config: VllmConfig, device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
assert isinstance(kv_cache_spec, MambaSpec)
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
self.kv_cache_spec = kv_cache_spec
|
|
||||||
self.device = device
|
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.layer_names = layer_names
|
|
||||||
|
|
||||||
|
assert isinstance(kv_cache_spec, MambaSpec)
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
self.decode_cudagraph_max_bs = min(
|
self.decode_cudagraph_max_bs = min(
|
||||||
self.vllm_config.scheduler_config.max_num_seqs,
|
self.vllm_config.scheduler_config.max_num_seqs,
|
||||||
|
|||||||
@ -236,11 +236,11 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
vllm_config: VllmConfig, device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
self.vllm_config = vllm_config
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
|
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
self.parallel_config = vllm_config.parallel_config
|
self.parallel_config = vllm_config.parallel_config
|
||||||
self.cache_config = vllm_config.cache_config
|
self.cache_config = vllm_config.cache_config
|
||||||
self.device = device
|
|
||||||
|
|
||||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||||
self.parallel_config)
|
self.parallel_config)
|
||||||
@ -248,7 +248,6 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
self.parallel_config)
|
self.parallel_config)
|
||||||
self.headdim = self.model_config.get_head_size()
|
self.headdim = self.model_config.get_head_size()
|
||||||
self.block_size = kv_cache_spec.block_size
|
self.block_size = kv_cache_spec.block_size
|
||||||
self.kv_cache_spec = kv_cache_spec
|
|
||||||
# Sliding window size to be used with the AOT scheduler will be
|
# Sliding window size to be used with the AOT scheduler will be
|
||||||
# populated on first build() call.
|
# populated on first build() call.
|
||||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||||
|
|||||||
@ -45,8 +45,8 @@ class ShortConvAttentionMetadataBuilder(
|
|||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
vllm_config: VllmConfig, device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
assert isinstance(kv_cache_spec, MambaSpec)
|
assert isinstance(kv_cache_spec, MambaSpec)
|
||||||
self.kv_cache_spec = kv_cache_spec
|
|
||||||
|
|
||||||
def build(self,
|
def build(self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
|
|||||||
@ -165,7 +165,8 @@ class TreeAttentionMetadataBuilder(
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
self.kv_cache_spec = kv_cache_spec
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
|
|
||||||
self.block_size = kv_cache_spec.block_size
|
self.block_size = kv_cache_spec.block_size
|
||||||
|
|
||||||
spec_config = vllm_config.speculative_config
|
spec_config = vllm_config.speculative_config
|
||||||
|
|||||||
@ -66,9 +66,9 @@ class TritonAttentionMetadataBuilder(
|
|||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
vllm_config: VllmConfig, device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
self.device = device
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
|
|
||||||
self.block_size = kv_cache_spec.block_size
|
self.block_size = kv_cache_spec.block_size
|
||||||
self.kv_cache_spec = kv_cache_spec
|
|
||||||
|
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
self.num_heads_q = model_config.get_num_attention_heads(
|
self.num_heads_q = model_config.get_num_attention_heads(
|
||||||
|
|||||||
@ -72,6 +72,9 @@ class CommonAttentionMetadata:
|
|||||||
logits_indices_padded: Optional[torch.Tensor] = None
|
logits_indices_padded: Optional[torch.Tensor] = None
|
||||||
num_logits_indices: Optional[int] = None
|
num_logits_indices: Optional[int] = None
|
||||||
|
|
||||||
|
# Needed by CrossAttentionBuilder
|
||||||
|
encoder_seq_lens: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UbatchSlice:
|
class UbatchSlice:
|
||||||
@ -193,6 +196,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
vllm_config: VllmConfig, device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
|
self.layer_names = layer_names
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.device = device
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def build(self,
|
def build(self,
|
||||||
|
|||||||
@ -206,8 +206,9 @@ class XFormersAttentionMetadataBuilder(
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
|
|
||||||
assert XFORMERS_AVAILABLE
|
assert XFORMERS_AVAILABLE
|
||||||
self.kv_cache_spec = kv_cache_spec
|
|
||||||
self.block_size = kv_cache_spec.block_size
|
self.block_size = kv_cache_spec.block_size
|
||||||
self._num_decodes = 0
|
self._num_decodes = 0
|
||||||
self._num_decode_tokens = 0
|
self._num_decode_tokens = 0
|
||||||
|
|||||||
@ -144,8 +144,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
|
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
|
||||||
# projector if needed). Currently, we assume that the encoder also
|
# projector if needed) for MM models as well as encoder-decoder
|
||||||
# has the Transformer architecture (e.g., ViT).
|
# transformers.
|
||||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
self.max_num_encoder_input_tokens = encoder_compute_budget
|
||||||
# NOTE: For the models without encoder (e.g., text-only models),
|
# NOTE: For the models without encoder (e.g., text-only models),
|
||||||
# the encoder cache will not be initialized because cache size is 0
|
# the encoder cache will not be initialized because cache size is 0
|
||||||
@ -775,15 +775,19 @@ class Scheduler(SchedulerInterface):
|
|||||||
# in the decoder's KV cache.
|
# in the decoder's KV cache.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# The same encoder input has already been scheduled in the current
|
if not self.is_encoder_decoder:
|
||||||
# step.
|
# We are not using the encoder cache for encoder-decoder models,
|
||||||
if request.mm_hashes[i] in mm_hashes_to_schedule:
|
# yet.
|
||||||
continue
|
if request.mm_hashes[i] in mm_hashes_to_schedule:
|
||||||
|
# The same encoder input has already been scheduled in the
|
||||||
|
# current step.
|
||||||
|
continue
|
||||||
|
|
||||||
if self.encoder_cache_manager.check_and_update_cache(request, i):
|
if self.encoder_cache_manager.check_and_update_cache(
|
||||||
# The encoder input is already computed and cached from a
|
request, i):
|
||||||
# previous step.
|
# The encoder input is already computed and cached from a
|
||||||
continue
|
# previous step.
|
||||||
|
continue
|
||||||
|
|
||||||
# If no encoder input chunking is allowed, we do not want to
|
# If no encoder input chunking is allowed, we do not want to
|
||||||
# partially schedule a multimodal item. If the scheduled range would
|
# partially schedule a multimodal item. If the scheduled range would
|
||||||
@ -1047,7 +1051,13 @@ class Scheduler(SchedulerInterface):
|
|||||||
mm_positions = request.mm_positions[input_id]
|
mm_positions = request.mm_positions[input_id]
|
||||||
start_pos = mm_positions.offset
|
start_pos = mm_positions.offset
|
||||||
num_tokens = mm_positions.length
|
num_tokens = mm_positions.length
|
||||||
if start_pos + num_tokens <= request.num_computed_tokens:
|
if self.is_encoder_decoder and request.num_computed_tokens > 0:
|
||||||
|
# With Whisper, as soon as we've generated a single token,
|
||||||
|
# we know we're done with the encoder input. Cross Attention
|
||||||
|
# KVs have been calculated and cached already.
|
||||||
|
self.encoder_cache_manager.free_encoder_input(
|
||||||
|
request, input_id)
|
||||||
|
elif start_pos + num_tokens <= request.num_computed_tokens:
|
||||||
# The encoder output is already processed and stored
|
# The encoder output is already processed and stored
|
||||||
# in the decoder's KV cache.
|
# in the decoder's KV cache.
|
||||||
self.encoder_cache_manager.free_encoder_input(
|
self.encoder_cache_manager.free_encoder_input(
|
||||||
|
|||||||
@ -325,7 +325,6 @@ class Processor:
|
|||||||
) -> tuple[Optional[str], EngineCoreRequest]:
|
) -> tuple[Optional[str], EngineCoreRequest]:
|
||||||
|
|
||||||
# TODO(woosuk): Support pooling models.
|
# TODO(woosuk): Support pooling models.
|
||||||
# TODO(woosuk): Support encoder-decoder models.
|
|
||||||
self._validate_lora(lora_request)
|
self._validate_lora(lora_request)
|
||||||
self._validate_params(params, lora_request)
|
self._validate_params(params, lora_request)
|
||||||
if trace_headers is not None:
|
if trace_headers is not None:
|
||||||
@ -384,10 +383,6 @@ class Processor:
|
|||||||
|
|
||||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||||
|
|
||||||
# TODO: Impl encoder-decoder
|
|
||||||
if encoder_inputs is not None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
sampling_params = None
|
sampling_params = None
|
||||||
pooling_params = None
|
pooling_params = None
|
||||||
if isinstance(params, SamplingParams):
|
if isinstance(params, SamplingParams):
|
||||||
|
|||||||
@ -61,12 +61,16 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
create_fast_prefill_custom_backend,
|
create_fast_prefill_custom_backend,
|
||||||
reorder_batch_to_split_decodes_and_prefills)
|
reorder_batch_to_split_decodes_and_prefills)
|
||||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||||
ChunkedLocalAttentionSpec,
|
ChunkedLocalAttentionSpec,
|
||||||
|
CrossAttentionSpec,
|
||||||
EncoderOnlyAttentionSpec,
|
EncoderOnlyAttentionSpec,
|
||||||
FullAttentionSpec, KVCacheConfig,
|
FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheGroupSpec, KVCacheSpec,
|
KVCacheGroupSpec, KVCacheSpec,
|
||||||
MambaSpec, SlidingWindowSpec)
|
MambaSpec, SlidingWindowSpec)
|
||||||
|
# yapf: enable
|
||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||||
DraftTokenIds, LogprobsLists, LogprobsTensors,
|
DraftTokenIds, LogprobsLists, LogprobsTensors,
|
||||||
ModelRunnerOutput, SamplerOutput)
|
ModelRunnerOutput, SamplerOutput)
|
||||||
@ -208,6 +212,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||||
model_config)
|
model_config)
|
||||||
|
|
||||||
|
if self.model_config.is_encoder_decoder:
|
||||||
|
# Maximum length of the encoder input, only for encoder-decoder
|
||||||
|
# models.
|
||||||
|
self.max_encoder_len = self.mm_registry.\
|
||||||
|
get_encdec_max_encoder_len(model_config)
|
||||||
|
else:
|
||||||
|
self.max_encoder_len = 0
|
||||||
|
|
||||||
# Sampler
|
# Sampler
|
||||||
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
||||||
|
|
||||||
@ -265,7 +277,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# the block_sizes in the kv cache config.
|
# the block_sizes in the kv cache config.
|
||||||
self.input_batch = InputBatch(
|
self.input_batch = InputBatch(
|
||||||
max_num_reqs=self.max_num_reqs,
|
max_num_reqs=self.max_num_reqs,
|
||||||
max_model_len=self.max_model_len,
|
# We need to use the encoder length for encoder-decoer
|
||||||
|
# because of KV cache for cross-attention.
|
||||||
|
max_model_len=max(self.max_model_len, self.max_encoder_len),
|
||||||
max_num_batched_tokens=self.max_num_tokens,
|
max_num_batched_tokens=self.max_num_tokens,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
@ -798,6 +812,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
src=self.input_batch.prev_sampled_token_ids[
|
src=self.input_batch.prev_sampled_token_ids[
|
||||||
prev_common_req_indices_tensor, 0])
|
prev_common_req_indices_tensor, 0])
|
||||||
|
|
||||||
|
def _get_encoder_seq_lens(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
kv_cache_spec: KVCacheSpec,
|
||||||
|
num_reqs: int,
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
|
if not isinstance(kv_cache_spec, CrossAttentionSpec):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build encoder_seq_lens array mapping request indices to
|
||||||
|
# encoder lengths for inputs scheduled in this batch
|
||||||
|
encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
|
||||||
|
for req_id in scheduler_output.scheduled_encoder_inputs:
|
||||||
|
req_index = self.input_batch.req_id_to_index[req_id]
|
||||||
|
encoder_seq_lens[req_index] = self.max_encoder_len
|
||||||
|
|
||||||
|
return encoder_seq_lens
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
@ -937,6 +969,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# in the same group share the same metadata.
|
# in the same group share the same metadata.
|
||||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||||
self.kv_cache_config.kv_cache_groups):
|
self.kv_cache_config.kv_cache_groups):
|
||||||
|
encoder_seq_lens = self._get_encoder_seq_lens(
|
||||||
|
scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs)
|
||||||
|
|
||||||
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
||||||
EncoderOnlyAttentionSpec):
|
EncoderOnlyAttentionSpec):
|
||||||
@ -981,6 +1015,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
logits_indices_padded=logits_indices_padded,
|
logits_indices_padded=logits_indices_padded,
|
||||||
num_logits_indices=logits_indices.size(0),
|
num_logits_indices=logits_indices.size(0),
|
||||||
causal=True,
|
causal=True,
|
||||||
|
encoder_seq_lens=encoder_seq_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.speculative_config and \
|
if self.speculative_config and \
|
||||||
@ -1253,10 +1288,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded])
|
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded])
|
||||||
return logits_indices_padded
|
return logits_indices_padded
|
||||||
|
|
||||||
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
|
def _batch_mm_kwargs_from_scheduler(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]:
|
||||||
|
"""Batch multimodal kwargs from scheduled encoder inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler_output: The scheduler output containing scheduled encoder
|
||||||
|
inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (mm_kwargs, req_ids_pos) where:
|
||||||
|
- mm_kwargs: List of multimodal kwargs items to be batched
|
||||||
|
- mm_hashes_pos: List of (mm_hash, position_info) tuples
|
||||||
|
"""
|
||||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||||
if not scheduled_encoder_inputs:
|
if not scheduled_encoder_inputs:
|
||||||
return
|
return [], []
|
||||||
# Batch the multi-modal inputs.
|
# Batch the multi-modal inputs.
|
||||||
mm_kwargs = list[MultiModalKwargsItem]()
|
mm_kwargs = list[MultiModalKwargsItem]()
|
||||||
# list of tuple (mm_hash, position_info)
|
# list of tuple (mm_hash, position_info)
|
||||||
@ -1270,6 +1319,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
mm_hashes_pos.append(
|
mm_hashes_pos.append(
|
||||||
(mm_hash, req_state.mm_positions[mm_input_id]))
|
(mm_hash, req_state.mm_positions[mm_input_id]))
|
||||||
|
|
||||||
|
return mm_kwargs, mm_hashes_pos
|
||||||
|
|
||||||
|
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||||
|
# Batch the multi-modal inputs using the helper method.
|
||||||
|
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
|
||||||
|
scheduler_output)
|
||||||
|
|
||||||
|
if not mm_kwargs:
|
||||||
|
return
|
||||||
|
|
||||||
# Batch mm inputs as much as we can: if a request in the batch has
|
# Batch mm inputs as much as we can: if a request in the batch has
|
||||||
# multiple modalities or a different modality than the previous one,
|
# multiple modalities or a different modality than the previous one,
|
||||||
# we process it separately to preserve item order.
|
# we process it separately to preserve item order.
|
||||||
@ -1360,6 +1419,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
mm_embeds.append(mm_embeds_item)
|
mm_embeds.append(mm_embeds_item)
|
||||||
return mm_embeds
|
return mm_embeds
|
||||||
|
|
||||||
|
def _extract_encoder_inputs(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
"""Extract encoder inputs for encoder-decoder models.
|
||||||
|
|
||||||
|
This method extracts multimodal input features from scheduled encoder
|
||||||
|
inputs and formats them for the encoder-decoder model forward pass.
|
||||||
|
"""
|
||||||
|
# Batch the multi-modal inputs using the helper method.
|
||||||
|
mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output)
|
||||||
|
|
||||||
|
if not mm_kwargs:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Group MM kwargs by modality and extract features
|
||||||
|
encoder_features = {}
|
||||||
|
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||||
|
mm_kwargs,
|
||||||
|
device=self.device,
|
||||||
|
pin_memory=self.pin_memory,
|
||||||
|
):
|
||||||
|
# Add the grouped features to encoder_features dict
|
||||||
|
# This allows the model to receive them as kwargs (e.g.,
|
||||||
|
# input_features=...)
|
||||||
|
encoder_features.update(mm_kwargs_group)
|
||||||
|
|
||||||
|
return encoder_features
|
||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
# get raw model out of the cudagraph wrapper.
|
# get raw model out of the cudagraph wrapper.
|
||||||
if isinstance(self.model, CUDAGraphWrapper):
|
if isinstance(self.model, CUDAGraphWrapper):
|
||||||
@ -1631,7 +1719,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||||
# modal outputs after that to ensure the correct order
|
# modal outputs after that to ensure the correct order
|
||||||
if self.supports_mm_inputs and get_pp_group().is_first_rank:
|
if (self.supports_mm_inputs and get_pp_group().is_first_rank
|
||||||
|
and not self.model_config.is_encoder_decoder):
|
||||||
# Run the multimodal encoder if any.
|
# Run the multimodal encoder if any.
|
||||||
self._execute_mm_encoder(scheduler_output)
|
self._execute_mm_encoder(scheduler_output)
|
||||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||||
@ -1673,6 +1762,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||||
num_input_tokens, intermediate_tensors, True)
|
num_input_tokens, intermediate_tensors, True)
|
||||||
|
|
||||||
|
if (self.model_config.is_encoder_decoder
|
||||||
|
and scheduler_output.scheduled_encoder_inputs):
|
||||||
|
encoder_inputs = self._extract_encoder_inputs(scheduler_output)
|
||||||
|
model_kwargs.update(encoder_inputs)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
num_scheduled_tokens,
|
num_scheduled_tokens,
|
||||||
num_input_tokens,
|
num_input_tokens,
|
||||||
@ -2591,17 +2685,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||||
num_scheduled_tokens, remove_lora):
|
num_scheduled_tokens, remove_lora):
|
||||||
if self.supports_mm_inputs:
|
model_kwargs = self._init_model_kwargs(num_tokens)
|
||||||
|
if (self.supports_mm_inputs
|
||||||
|
and not self.model_config.is_encoder_decoder):
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
|
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
**self._init_model_kwargs(num_tokens),
|
**model_kwargs,
|
||||||
**self._dummy_mm_kwargs(num_reqs),
|
**self._dummy_mm_kwargs(num_reqs),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
input_ids = self.input_ids.gpu[:num_tokens]
|
input_ids = self.input_ids.gpu[:num_tokens]
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
model_kwargs = self._init_model_kwargs(num_tokens)
|
|
||||||
|
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
positions = self.mrope_positions.gpu[:, :num_tokens]
|
positions = self.mrope_positions.gpu[:, :num_tokens]
|
||||||
@ -2823,7 +2918,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
mm_budget = self.mm_budget
|
mm_budget = self.mm_budget
|
||||||
assert mm_budget is not None
|
assert mm_budget is not None
|
||||||
|
|
||||||
# TODO: handle encoder-decoder models once we support them.
|
|
||||||
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
|
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
|
||||||
# NOTE: Currently model is profiled with a single non-text
|
# NOTE: Currently model is profiled with a single non-text
|
||||||
# modality with the max possible input tokens even when
|
# modality with the max possible input tokens even when
|
||||||
@ -3170,7 +3264,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
"for more details.")
|
"for more details.")
|
||||||
self.input_batch = InputBatch(
|
self.input_batch = InputBatch(
|
||||||
max_num_reqs=self.max_num_reqs,
|
max_num_reqs=self.max_num_reqs,
|
||||||
max_model_len=self.max_model_len,
|
max_model_len=max(self.max_model_len, self.max_encoder_len),
|
||||||
max_num_batched_tokens=self.max_num_tokens,
|
max_num_batched_tokens=self.max_num_tokens,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
@ -3443,7 +3537,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||||
for layer_name, attn_module in attn_layers.items():
|
for layer_name, attn_module in attn_layers.items():
|
||||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||||
attn_spec = EncoderOnlyAttentionSpec(
|
attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
head_size=attn_module.head_size,
|
head_size=attn_module.head_size,
|
||||||
@ -3485,7 +3579,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# TODO: Support other attention modules, e.g., cross-attention
|
|
||||||
# TODO(lucas): move the attention specs into the model layers like
|
# TODO(lucas): move the attention specs into the model layers like
|
||||||
# the attention backends
|
# the attention backends
|
||||||
if attn_module.attn_type == AttentionType.DECODER:
|
if attn_module.attn_type == AttentionType.DECODER:
|
||||||
@ -3513,12 +3606,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
head_size=attn_module.head_size,
|
head_size=attn_module.head_size,
|
||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
|
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
kv_cache_spec[layer_name] = CrossAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
|
head_size=attn_module.head_size,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
use_mla=use_mla)
|
||||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||||
AttentionType.ENCODER_ONLY):
|
AttentionType.ENCODER_ONLY):
|
||||||
# encoder-only attention does not need KV cache.
|
# encoder-only attention does not need KV cache.
|
||||||
continue
|
continue
|
||||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown attention type: {attn_module.attn_type}")
|
f"Unknown attention type: {attn_module.attn_type}")
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
|||||||
from vllm.model_executor.models.utils import extract_layer_index
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
from vllm.multimodal.cache import processor_only_cache_from_config
|
||||||
from vllm.multimodal.registry import MultiModalRegistry
|
from vllm.multimodal.registry import MultiModalRegistry
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
|
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
|
||||||
@ -269,7 +270,17 @@ def bind_kv_cache(
|
|||||||
# One typical case is encoder-decoder model, e.g., bart.
|
# One typical case is encoder-decoder model, e.g., bart.
|
||||||
# The cross attention and self attention in the same decoder layer
|
# The cross attention and self attention in the same decoder layer
|
||||||
# has different layer_name but the same layer_index.
|
# has different layer_name but the same layer_index.
|
||||||
raise NotImplementedError
|
|
||||||
|
# TODO - analyze where runner_kv_caches is used and the right
|
||||||
|
# way to ensure it properly reflects multiple attention layers
|
||||||
|
# in the same decoder block.
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
# We know that the GPU runner is not impacted by this
|
||||||
|
# case. Some test code depends on runner_kv_caches, but
|
||||||
|
# not in a way that's impacted by ignoring this.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
layer_name = layer_names[0]
|
layer_name = layer_names[0]
|
||||||
runner_kv_caches.append(kv_caches[layer_name])
|
runner_kv_caches.append(kv_caches[layer_name])
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user