Support encoder-only models without KV-Cache (#21270)

Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Maximilien de Bayser 2025-07-26 10:09:52 -03:00 committed by GitHub
parent f27fdfc3ed
commit 1cd6eaba54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 352 additions and 99 deletions

View File

@ -3,12 +3,12 @@
import argparse
import datetime
import os
import re
from typing import Union
import albumentations
import numpy as np
import rasterio
import regex as re
import torch
from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule

View File

@ -1062,8 +1062,17 @@ class VllmRunner:
return [req_output.outputs.score for req_output in req_outputs]
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
executor = self.llm.llm_engine.model_executor
return executor.apply_model(func)
if hasattr(self.llm.llm_engine, "model_executor"):
# This works either in V0 or in V1 with
# VLLM_ENABLE_V1_MULTIPROCESSING=0
executor = self.llm.llm_engine.model_executor
return executor.apply_model(func)
# This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1
def _apply_model(self):
return func(self.get_model())
return self.llm.llm_engine.collective_rpc(_apply_model)
def __enter__(self):
return self

View File

@ -22,10 +22,12 @@ REVISION_ROBERTA = os.environ.get("REVISION", "main")
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_model_loading_with_params(vllm_runner):
def test_model_loading_with_params(vllm_runner, monkeypatch):
"""
Test parameter weight loading with tp>1.
"""
# to use apply_model
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_name=MODEL_NAME,
revision=REVISION,
dtype="float16",
@ -61,10 +63,12 @@ def test_model_loading_with_params(vllm_runner):
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_roberta_model_loading_with_params(vllm_runner):
def test_roberta_model_loading_with_params(vllm_runner, monkeypatch):
"""
Test parameter weight loading with tp>1.
"""
# to use apply_model
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_name=MODEL_NAME_ROBERTA,
revision=REVISION_ROBERTA,
dtype="float16",
@ -101,10 +105,12 @@ def test_roberta_model_loading_with_params(vllm_runner):
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_facebook_roberta_model_loading_with_params(vllm_runner):
def test_facebook_roberta_model_loading_with_params(vllm_runner, monkeypatch):
"""
Test loading roberta-base model with no lm_head.
"""
# to use apply_model
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
model_name = "FacebookAI/roberta-base"
with vllm_runner(model_name=model_name,
dtype="float16",

View File

@ -39,17 +39,9 @@ def v1(run_with_both_engines):
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]),
# [Encoder-only]
pytest.param(
"BAAI/bge-base-en-v1.5",
marks=[
# CPU only supports V1
pytest.mark.core_model,
pytest.mark.skip_v1
]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2",
marks=[pytest.mark.skip_v1]),
pytest.param("intfloat/multilingual-e5-small",
marks=[pytest.mark.skip_v1]),
pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
pytest.param("intfloat/multilingual-e5-small"),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
marks=[pytest.mark.skip_v1]),
# [Cross-Encoder]

View File

@ -23,6 +23,14 @@ RERANK_MODELS = [
]
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:

View File

@ -93,6 +93,7 @@ def create_common_attn_metadata(
max_query_len=max_query_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
causal=True,
)

View File

@ -13,7 +13,6 @@ UNSUPPORTED_MODELS_V1 = [
"openai/whisper-large-v3", # transcription
"facebook/bart-large-cnn", # encoder decoder
"state-spaces/mamba-130m-hf", # mamba1
"BAAI/bge-m3", # embedding
]
MODEL = "meta-llama/Llama-3.2-1B-Instruct"

View File

@ -1,9 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import re
import pytest
import regex as re
import requests
import torch

View File

@ -1649,7 +1649,8 @@ class EngineArgs:
if (self.max_num_seqs is None
and usage_context in default_max_num_seqs):
self.max_num_seqs = default_max_num_seqs[usage_context]
self.max_num_seqs = min(default_max_num_seqs[usage_context],
self.max_num_batched_tokens or sys.maxsize)
logger.debug("Setting max_num_seqs to %d for %s usage context.",
self.max_num_seqs, use_context_value)

View File

@ -12,7 +12,6 @@ from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
@ -60,7 +59,6 @@ class BertEmbedding(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@ -119,7 +117,6 @@ class BertPooler(Pooler):
return pooled_output
@support_torch_compile
class BertEncoder(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
@ -337,6 +334,7 @@ class BertOutput(nn.Module):
return hidden_states
@support_torch_compile
class BertModel(nn.Module, SupportsQuant):
is_pooling_model = True
@ -368,13 +366,9 @@ class BertModel(nn.Module, SupportsQuant):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
attn_metadata = get_forward_context().attn_metadata
assert hasattr(attn_metadata, "seq_lens_tensor")
hidden_states = self.embeddings(
input_ids=input_ids,
seq_lens=attn_metadata.seq_lens_tensor,
position_ids=position_ids,
token_type_ids=token_type_ids)
hidden_states = self.embeddings(input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
return self.encoder(hidden_states)
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
@ -447,7 +441,7 @@ class BertPoolingModel(BertModel):
return loaded_params
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
class BertEmbeddingModel(nn.Module, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
@ -474,11 +468,13 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
position_ids=positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)

View File

@ -9,6 +9,7 @@ from torch import nn
from transformers import RobertaConfig
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
DispatchPooler, Pooler)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -51,33 +52,12 @@ class RobertaEmbedding(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)
# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
seq_lens_list = seq_lens.tolist()
new_pos_list = []
for positions, tokens in zip(position_ids.split(seq_lens_list),
input_ids.split(seq_lens_list)):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
dtype=torch.long,
device=inputs_embeds.device)
assert torch.equal(positions, expected_pos)
new_pos_list.append(
create_position_ids_from_input_ids(tokens, self.padding_idx))
position_ids = torch.cat(new_pos_list)
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
@ -119,6 +99,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Fix Roberta positions here outside of the CUDA graph.
# Because we need the to extract the sequences from
# input_ids the control flow is data dependent.
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)
return self.model(input_ids=input_ids,
position_ids=positions,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> Union[BertModel, BertWithRope]:
@ -175,6 +181,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
self.num_labels = config.num_labels
self.roberta = BertModel(vllm_config=vllm_config,
@ -216,6 +223,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)
return self.roberta(input_ids=input_ids,
position_ids=positions,
inputs_embeds=inputs_embeds,
@ -245,3 +255,36 @@ def create_position_ids_from_input_ids(input_ids,
past_key_values_length) * mask
return incremental_indices.long() + padding_idx
def replace_roberta_positions(input_ids: torch.Tensor,
position_ids: torch.Tensor,
padding_idx: int) -> None:
seq_lens: Optional[torch.Tensor] = None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None: # can be None during warmup
if isinstance(attn_metadata, dict):
attn_metadata = next(iter(attn_metadata.values()))
# TODO: remove "seq_lens_tensor" after V0 is removed
seq_lens = getattr(attn_metadata, "seq_lens_tensor",
getattr(attn_metadata, "seq_lens", None))
if seq_lens is not None:
assert isinstance(seq_lens, torch.Tensor)
# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
token_list = torch.split(input_ids[:torch.sum(seq_lens)],
seq_lens.tolist())
offset = 0
for tokens in token_list:
length = tokens.shape[0]
position_ids[offset:offset+length] = \
create_position_ids_from_input_ids(tokens, padding_idx)
offset = offset + length

View File

@ -130,6 +130,8 @@ class FlashAttentionMetadata:
prefix_scheduler_metadata: Optional[torch.Tensor] = None
max_num_splits: int = 0
causal: bool = True
def _get_sliding_window_configs(
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
@ -213,6 +215,7 @@ class FlashAttentionMetadataBuilder(
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
causal = common_attn_metadata.causal
# the overhead of the aot schedule is not worth it for spec-decode
aot_schedule = self.aot_schedule and not fast_build
@ -288,7 +291,7 @@ class FlashAttentionMetadataBuilder(
max_query_len=max_query_len,
seqlens=seq_lens,
max_seq_len=max_seq_len,
causal=True)
causal=causal)
if self.use_full_cuda_graph:
assert scheduler_metadata is not None
@ -326,7 +329,7 @@ class FlashAttentionMetadataBuilder(
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
)
causal=causal)
return attn_metadata
def can_run_in_cudagraph(
@ -375,11 +378,14 @@ class FlashAttentionImpl(AttentionImpl):
FlashAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
if attn_type not in [
AttentionType.DECODER, AttentionType.ENCODER_ONLY
]:
raise NotImplementedError("Encoder/decoder cross-attention "
"is not implemented for "
"FlashAttentionImpl")
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
@ -422,6 +428,8 @@ class FlashAttentionImpl(AttentionImpl):
# Profiling run.
return output
attn_type = self.attn_type
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
@ -432,6 +440,18 @@ class FlashAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if attn_type in (AttentionType.ENCODER_ONLY, ):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata, layer)
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
@ -483,7 +503,7 @@ class FlashAttentionImpl(AttentionImpl):
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
@ -524,6 +544,63 @@ class FlashAttentionImpl(AttentionImpl):
)
return output
def _forward_encoder_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
layer: torch.nn.Module,
) -> torch.Tensor:
"""Forward pass for encoder attention without KV cache.
Args:
query: shape = [num_encoder_tokens, num_heads, head_size]
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
output: shape = [num_encoder_tokens, num_heads, head_size]
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"quantization is not supported for encoder attention")
# Use encoder-specific metadata for sequence information
cu_seqlens_q = attn_metadata.query_start_loc
cu_seqlens_k = attn_metadata.query_start_loc
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_query_len
descale_shape = (
cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr]
self.num_kv_heads)
# Call flash attention directly on Q, K, V tensors
flash_attn_varlen_func(
q=query,
k=key,
v=value,
out=output,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=False, # Encoder attention is bidirectional
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output
def use_cascade_attention(
common_prefix_len: int,

View File

@ -59,6 +59,8 @@ class CommonAttentionMetadata:
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
causal: bool = True
M = TypeVar("M")
@ -395,6 +397,7 @@ def make_local_attention_virtual_batches(
max_query_len=seqlens_q_local.max(),
block_table_tensor=block_table_local,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
)

View File

@ -111,6 +111,12 @@ class EngineCore:
"compatibility may not be maintained.",
vllm_config.scheduler_config.scheduler_cls)
if len(kv_cache_config.kv_cache_groups) == 0:
# Encoder models without KV cache don't support
# chunked prefill. But do SSM models?
logger.info("Disabling chunked prefill for model without KVCache")
vllm_config.scheduler_config.chunked_prefill_enabled = False
self.scheduler: SchedulerInterface = Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,

View File

@ -330,6 +330,7 @@ class EagleProposer:
max_query_len=new_query_len_per_req.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
causal=True,
)
return spec_common_attn_metadata, token_indices

View File

@ -4,6 +4,7 @@ from contextlib import contextmanager
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
@ -59,6 +60,9 @@ class CPUModelRunner(GPUModelRunner):
self.scheduler_config,
self.lora_config, self.device)
def get_model(self) -> nn.Module:
return self.model
def warming_up_model(self) -> None:
logger.info("Warming up model for the compilation...")
# Only generate graph for the generic shape

View File

@ -126,6 +126,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.is_multimodal_model = model_config.is_multimodal_model
self.is_pooling_model = model_config.pooler_config is not None
self.is_encoder_only_model = False
self.model_supports_multimodal_raw_input = (
model_config.model_supports_multimodal_raw_input)
self.max_model_len = model_config.max_model_len
@ -735,6 +736,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
spec_decode_common_attn_metadata = None
attn_metadata: dict[str, Any] = {}
# Prepare encoder attention metadata separately
# (encoder layers are not in KV cache groups)
if self.is_encoder_only_model:
common_attn_metadata, encoder_attn_metadata = \
self._build_encoder_only_attn_metadata(
scheduler_output)
# Add encoder attention metadata for all encoder layers
attention_layers = get_layers_from_vllm_config(
self.vllm_config, Attention)
for layer_name, attn_module in attention_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
attn_metadata[layer_name] = encoder_attn_metadata
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
@ -760,6 +776,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_query_len=max_num_scheduled_tokens,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping,
causal=True,
)
if self.speculative_config and \
@ -2102,7 +2119,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
block_table_tensor=self.input_batch.block_table[
kv_cache_group_id].get_device_tensor()[:num_reqs],
slot_mapping=self.input_batch.
block_table[kv_cache_group_id].slot_mapping[:num_tokens])
block_table[kv_cache_group_id].slot_mapping[:num_tokens],
causal=True)
attn_metadata_i = self.attn_metadata_builders[
kv_cache_group_id].build_for_cudagraph_capture(
@ -2466,6 +2484,49 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))
def _initialize_single_attn_backend(
self, kv_cache_spec: KVCacheSpec
) -> tuple[AttentionBackend, AttentionMetadataBuilder]:
if isinstance(kv_cache_spec, AttentionSpec):
attn_backend_i = get_attn_backend(
kv_cache_spec.head_size,
self.dtype,
kv_cache_spec.dtype,
kv_cache_spec.block_size,
self.model_config.is_attention_free,
use_mla=kv_cache_spec.use_mla,
)
if attn_backend_i is None:
error_msg = (f"Error with get_attn_backend: "
f"{kv_cache_spec.head_size=}, "
f"{self.dtype=}, {kv_cache_spec.dtype=}, "
f"{kv_cache_spec.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{kv_cache_spec.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 "
"GPUModelRunner.")
elif isinstance(kv_cache_spec, MambaSpec):
attn_backend_i = Mamba2AttentionBackend
else:
raise ValueError(
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
kv_cache_spec,
self.vllm_config,
self.device,
)
if (self.full_cuda_graph
and not attn_metadata_builder_i.full_cudagraph_supported):
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
f"full_cuda_graph or use a different attention backend.")
return attn_backend_i, attn_metadata_builder_i
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the attention backends and attention metadata builders.
@ -2476,48 +2537,45 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(kv_cache_spec, AttentionSpec):
attn_backend_i = get_attn_backend(
kv_cache_spec.head_size,
self.dtype,
kv_cache_spec.dtype,
kv_cache_spec.block_size,
self.model_config.is_attention_free,
use_mla=kv_cache_spec.use_mla,
)
if attn_backend_i is None:
error_msg = (f"Error with get_attn_backend: "
f"{kv_cache_spec.head_size=}, "
f"{self.dtype=}, {kv_cache_spec.dtype=}, "
f"{kv_cache_spec.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{kv_cache_spec.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 "
"GPUModelRunner.")
elif isinstance(kv_cache_spec, MambaSpec):
attn_backend_i = Mamba2AttentionBackend
else:
raise ValueError(
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
kv_cache_spec,
self.vllm_config,
self.device,
)
if (self.full_cuda_graph
and not attn_metadata_builder_i.full_cudagraph_supported):
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
f"full_cuda_graph or use a different attention backend.")
attn_backend_i, attn_metadata_builder_i = \
self._initialize_single_attn_backend(kv_cache_spec)
self.attn_backends.append(attn_backend_i)
self.attn_metadata_builders.append(attn_metadata_builder_i)
if len(self.attn_backends) > 0:
return
# Check if model is encoder-only
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
attn_specs = list[AttentionSpec]()
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for attn_module in attn_layers.values():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
assert attn_module.sliding_window is None, "Sliding "
"window attention is not supported for encoder-only models"
attn_specs.append(
FullAttentionSpec(block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla))
else:
raise ValueError("Expected only encoder-only layers")
if len(attn_specs) > 0:
assert len(attn_specs) == len(attn_layers), \
"All or none of the layers are expected to be encoder-only"
attn_backend, attn_metadata_builder = \
self._initialize_single_attn_backend(attn_specs[0])
self.attn_backends.append(attn_backend)
self.attn_metadata_builders.append(attn_metadata_builder)
self.is_encoder_only_model = True
def may_reinitialize_input_batch(self,
kv_cache_config: KVCacheConfig) -> None:
"""
@ -2833,3 +2891,53 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
page_size_padded=page_size_padded)
return kv_cache_spec
def _build_encoder_only_attn_metadata(
self, scheduler_output: "SchedulerOutput") -> \
tuple[CommonAttentionMetadata, Any]:
"""Prepare encoder attention metadata for encoder-only models.
Args:
scheduler_output: Scheduler output
Returns:
dict[str, Any]: Encoder attention metadata
"""
num_reqs = self.input_batch.num_reqs
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# Get the number of scheduled tokens for each request.
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
max_num_scheduled_tokens = max(tokens)
# Use the first attention metadata builder
# to create encoder attention metadata
builder = self.attn_metadata_builders[0]
dummy_block_table = torch.zeros((num_reqs, 1),
dtype=torch.int32,
device=self.device)
dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
dtype=torch.int32,
device=self.device)
common_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
seq_lens=self.seq_lens[:num_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
num_computed_tokens_cpu=self.input_batch.
num_computed_tokens_cpu_tensor[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
block_table_tensor=dummy_block_table,
slot_mapping=dummy_slot_mapping,
causal=False,
)
return common_metadata, builder.build(
common_prefix_len=0, # No cascade for encoder
common_attn_metadata=common_metadata,
)