mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 15:37:13 +08:00
Support encoder-only models without KV-Cache (#21270)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
f27fdfc3ed
commit
1cd6eaba54
@ -3,12 +3,12 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
import albumentations
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import regex as re
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
|
||||
|
||||
@ -1062,8 +1062,17 @@ class VllmRunner:
|
||||
return [req_output.outputs.score for req_output in req_outputs]
|
||||
|
||||
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||
executor = self.llm.llm_engine.model_executor
|
||||
return executor.apply_model(func)
|
||||
if hasattr(self.llm.llm_engine, "model_executor"):
|
||||
# This works either in V0 or in V1 with
|
||||
# VLLM_ENABLE_V1_MULTIPROCESSING=0
|
||||
executor = self.llm.llm_engine.model_executor
|
||||
return executor.apply_model(func)
|
||||
|
||||
# This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1
|
||||
def _apply_model(self):
|
||||
return func(self.get_model())
|
||||
|
||||
return self.llm.llm_engine.collective_rpc(_apply_model)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
@ -22,10 +22,12 @@ REVISION_ROBERTA = os.environ.get("REVISION", "main")
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_model_loading_with_params(vllm_runner):
|
||||
def test_model_loading_with_params(vllm_runner, monkeypatch):
|
||||
"""
|
||||
Test parameter weight loading with tp>1.
|
||||
"""
|
||||
# to use apply_model
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
with vllm_runner(model_name=MODEL_NAME,
|
||||
revision=REVISION,
|
||||
dtype="float16",
|
||||
@ -61,10 +63,12 @@ def test_model_loading_with_params(vllm_runner):
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_roberta_model_loading_with_params(vllm_runner):
|
||||
def test_roberta_model_loading_with_params(vllm_runner, monkeypatch):
|
||||
"""
|
||||
Test parameter weight loading with tp>1.
|
||||
"""
|
||||
# to use apply_model
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
with vllm_runner(model_name=MODEL_NAME_ROBERTA,
|
||||
revision=REVISION_ROBERTA,
|
||||
dtype="float16",
|
||||
@ -101,10 +105,12 @@ def test_roberta_model_loading_with_params(vllm_runner):
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_facebook_roberta_model_loading_with_params(vllm_runner):
|
||||
def test_facebook_roberta_model_loading_with_params(vllm_runner, monkeypatch):
|
||||
"""
|
||||
Test loading roberta-base model with no lm_head.
|
||||
"""
|
||||
# to use apply_model
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
model_name = "FacebookAI/roberta-base"
|
||||
with vllm_runner(model_name=model_name,
|
||||
dtype="float16",
|
||||
|
||||
@ -39,17 +39,9 @@ def v1(run_with_both_engines):
|
||||
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base",
|
||||
marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]),
|
||||
# [Encoder-only]
|
||||
pytest.param(
|
||||
"BAAI/bge-base-en-v1.5",
|
||||
marks=[
|
||||
# CPU only supports V1
|
||||
pytest.mark.core_model,
|
||||
pytest.mark.skip_v1
|
||||
]),
|
||||
pytest.param("sentence-transformers/all-MiniLM-L12-v2",
|
||||
marks=[pytest.mark.skip_v1]),
|
||||
pytest.param("intfloat/multilingual-e5-small",
|
||||
marks=[pytest.mark.skip_v1]),
|
||||
pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]),
|
||||
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
|
||||
pytest.param("intfloat/multilingual-e5-small"),
|
||||
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
marks=[pytest.mark.skip_v1]),
|
||||
# [Cross-Encoder]
|
||||
|
||||
@ -23,6 +23,14 @@ RERANK_MODELS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
|
||||
def test_embed_models_mteb(hf_runner, vllm_runner,
|
||||
model_info: EmbedModelInfo) -> None:
|
||||
|
||||
@ -93,6 +93,7 @@ def create_common_attn_metadata(
|
||||
max_query_len=max_query_len,
|
||||
block_table_tensor=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@ UNSUPPORTED_MODELS_V1 = [
|
||||
"openai/whisper-large-v3", # transcription
|
||||
"facebook/bart-large-cnn", # encoder decoder
|
||||
"state-spaces/mamba-130m-hf", # mamba1
|
||||
"BAAI/bge-m3", # embedding
|
||||
]
|
||||
|
||||
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import re
|
||||
|
||||
import pytest
|
||||
import regex as re
|
||||
import requests
|
||||
import torch
|
||||
|
||||
|
||||
@ -1649,7 +1649,8 @@ class EngineArgs:
|
||||
|
||||
if (self.max_num_seqs is None
|
||||
and usage_context in default_max_num_seqs):
|
||||
self.max_num_seqs = default_max_num_seqs[usage_context]
|
||||
self.max_num_seqs = min(default_max_num_seqs[usage_context],
|
||||
self.max_num_batched_tokens or sys.maxsize)
|
||||
|
||||
logger.debug("Setting max_num_seqs to %d for %s usage context.",
|
||||
self.max_num_seqs, use_context_value)
|
||||
|
||||
@ -12,7 +12,6 @@ from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -60,7 +59,6 @@ class BertEmbedding(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@ -119,7 +117,6 @@ class BertPooler(Pooler):
|
||||
return pooled_output
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@ -337,6 +334,7 @@ class BertOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class BertModel(nn.Module, SupportsQuant):
|
||||
|
||||
is_pooling_model = True
|
||||
@ -368,13 +366,9 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
assert hasattr(attn_metadata, "seq_lens_tensor")
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
seq_lens=attn_metadata.seq_lens_tensor,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
hidden_states = self.embeddings(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
return self.encoder(hidden_states)
|
||||
|
||||
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
@ -447,7 +441,7 @@ class BertPoolingModel(BertModel):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
This class encapsulates the BertModel and provides an interface for
|
||||
@ -474,11 +468,13 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ from torch import nn
|
||||
from transformers import RobertaConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool,
|
||||
DispatchPooler, Pooler)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -51,33 +52,12 @@ class RobertaEmbedding(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
input_shape = input_ids.size()
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
# Replace position ids because in RoBERTa models
|
||||
# they have to start at padding_idx + 1 and ignore
|
||||
# existing padding tokens
|
||||
# References:
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
new_pos_list = []
|
||||
for positions, tokens in zip(position_ids.split(seq_lens_list),
|
||||
input_ids.split(seq_lens_list)):
|
||||
# Verify assumption that incoming position are
|
||||
# always a sequence from 0 to N.
|
||||
expected_pos = torch.arange(positions.size()[0],
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device)
|
||||
assert torch.equal(positions, expected_pos)
|
||||
new_pos_list.append(
|
||||
create_position_ids_from_input_ids(tokens, self.padding_idx))
|
||||
position_ids = torch.cat(new_pos_list)
|
||||
|
||||
# Position embeddings.
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
if token_type_ids is None:
|
||||
@ -119,6 +99,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Fix Roberta positions here outside of the CUDA graph.
|
||||
# Because we need the to extract the sequences from
|
||||
# input_ids the control flow is data dependent.
|
||||
replace_roberta_positions(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
padding_idx=self.padding_idx)
|
||||
|
||||
return self.model(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
def _build_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "") -> Union[BertModel, BertWithRope]:
|
||||
@ -175,6 +181,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.padding_idx = vllm_config.model_config.hf_config.pad_token_id
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.roberta = BertModel(vllm_config=vllm_config,
|
||||
@ -216,6 +223,9 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
replace_roberta_positions(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
padding_idx=self.padding_idx)
|
||||
return self.roberta(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@ -245,3 +255,36 @@ def create_position_ids_from_input_ids(input_ids,
|
||||
past_key_values_length) * mask
|
||||
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
def replace_roberta_positions(input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
padding_idx: int) -> None:
|
||||
|
||||
seq_lens: Optional[torch.Tensor] = None
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata is not None: # can be None during warmup
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = next(iter(attn_metadata.values()))
|
||||
# TODO: remove "seq_lens_tensor" after V0 is removed
|
||||
seq_lens = getattr(attn_metadata, "seq_lens_tensor",
|
||||
getattr(attn_metadata, "seq_lens", None))
|
||||
|
||||
if seq_lens is not None:
|
||||
assert isinstance(seq_lens, torch.Tensor)
|
||||
|
||||
# Replace position ids because in RoBERTa models
|
||||
# they have to start at padding_idx + 1 and ignore
|
||||
# existing padding tokens
|
||||
# References:
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
||||
token_list = torch.split(input_ids[:torch.sum(seq_lens)],
|
||||
seq_lens.tolist())
|
||||
|
||||
offset = 0
|
||||
for tokens in token_list:
|
||||
length = tokens.shape[0]
|
||||
position_ids[offset:offset+length] = \
|
||||
create_position_ids_from_input_ids(tokens, padding_idx)
|
||||
offset = offset + length
|
||||
|
||||
@ -130,6 +130,8 @@ class FlashAttentionMetadata:
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||
max_num_splits: int = 0
|
||||
|
||||
causal: bool = True
|
||||
|
||||
|
||||
def _get_sliding_window_configs(
|
||||
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
|
||||
@ -213,6 +215,7 @@ class FlashAttentionMetadataBuilder(
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
causal = common_attn_metadata.causal
|
||||
|
||||
# the overhead of the aot schedule is not worth it for spec-decode
|
||||
aot_schedule = self.aot_schedule and not fast_build
|
||||
@ -288,7 +291,7 @@ class FlashAttentionMetadataBuilder(
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
causal=True)
|
||||
causal=causal)
|
||||
|
||||
if self.use_full_cuda_graph:
|
||||
assert scheduler_metadata is not None
|
||||
@ -326,7 +329,7 @@ class FlashAttentionMetadataBuilder(
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
)
|
||||
causal=causal)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
@ -375,11 +378,14 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
FlashAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
if attn_type not in [
|
||||
AttentionType.DECODER, AttentionType.ENCODER_ONLY
|
||||
]:
|
||||
raise NotImplementedError("Encoder/decoder cross-attention "
|
||||
"is not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
|
||||
self.attn_type = attn_type
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
||||
and not flash_attn_supports_fp8():
|
||||
@ -422,6 +428,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
attn_type = self.attn_type
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
@ -432,6 +440,18 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
# Handle encoder attention differently - no KV cache needed
|
||||
if attn_type in (AttentionType.ENCODER_ONLY, ):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return self._forward_encoder_attention(query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata, layer)
|
||||
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
@ -483,7 +503,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
causal=attn_metadata.causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
@ -524,6 +544,63 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
return output
|
||||
|
||||
def _forward_encoder_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
layer: torch.nn.Module,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for encoder attention without KV cache.
|
||||
|
||||
Args:
|
||||
query: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
output: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
attn_metadata: Encoder attention metadata
|
||||
layer: The attention layer
|
||||
"""
|
||||
# For encoder attention, process FP8 quantization if needed
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError(
|
||||
"quantization is not supported for encoder attention")
|
||||
|
||||
# Use encoder-specific metadata for sequence information
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
cu_seqlens_k = attn_metadata.query_start_loc
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_query_len
|
||||
|
||||
descale_shape = (
|
||||
cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr]
|
||||
self.num_kv_heads)
|
||||
|
||||
# Call flash attention directly on Q, K, V tensors
|
||||
flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
out=output,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # Encoder attention is bidirectional
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
softcap=self.logits_soft_cap,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def use_cascade_attention(
|
||||
common_prefix_len: int,
|
||||
|
||||
@ -59,6 +59,8 @@ class CommonAttentionMetadata:
|
||||
block_table_tensor: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
causal: bool = True
|
||||
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
@ -395,6 +397,7 @@ def make_local_attention_virtual_batches(
|
||||
max_query_len=seqlens_q_local.max(),
|
||||
block_table_tensor=block_table_local,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -111,6 +111,12 @@ class EngineCore:
|
||||
"compatibility may not be maintained.",
|
||||
vllm_config.scheduler_config.scheduler_cls)
|
||||
|
||||
if len(kv_cache_config.kv_cache_groups) == 0:
|
||||
# Encoder models without KV cache don't support
|
||||
# chunked prefill. But do SSM models?
|
||||
logger.info("Disabling chunked prefill for model without KVCache")
|
||||
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
|
||||
self.scheduler: SchedulerInterface = Scheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
|
||||
@ -330,6 +330,7 @@ class EagleProposer:
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
|
||||
causal=True,
|
||||
)
|
||||
|
||||
return spec_common_attn_metadata, token_indices
|
||||
|
||||
@ -4,6 +4,7 @@ from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -59,6 +60,9 @@ class CPUModelRunner(GPUModelRunner):
|
||||
self.scheduler_config,
|
||||
self.lora_config, self.device)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def warming_up_model(self) -> None:
|
||||
logger.info("Warming up model for the compilation...")
|
||||
# Only generate graph for the generic shape
|
||||
|
||||
@ -126,6 +126,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
self.is_multimodal_model = model_config.is_multimodal_model
|
||||
self.is_pooling_model = model_config.pooler_config is not None
|
||||
self.is_encoder_only_model = False
|
||||
self.model_supports_multimodal_raw_input = (
|
||||
model_config.model_supports_multimodal_raw_input)
|
||||
self.max_model_len = model_config.max_model_len
|
||||
@ -735,6 +736,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
spec_decode_common_attn_metadata = None
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
|
||||
# Prepare encoder attention metadata separately
|
||||
# (encoder layers are not in KV cache groups)
|
||||
if self.is_encoder_only_model:
|
||||
common_attn_metadata, encoder_attn_metadata = \
|
||||
self._build_encoder_only_attn_metadata(
|
||||
scheduler_output)
|
||||
|
||||
# Add encoder attention metadata for all encoder layers
|
||||
attention_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attention_layers.items():
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
attn_metadata[layer_name] = encoder_attn_metadata
|
||||
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
@ -760,6 +776,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
block_table_tensor=blk_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if self.speculative_config and \
|
||||
@ -2102,7 +2119,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
block_table_tensor=self.input_batch.block_table[
|
||||
kv_cache_group_id].get_device_tensor()[:num_reqs],
|
||||
slot_mapping=self.input_batch.
|
||||
block_table[kv_cache_group_id].slot_mapping[:num_tokens])
|
||||
block_table[kv_cache_group_id].slot_mapping[:num_tokens],
|
||||
causal=True)
|
||||
|
||||
attn_metadata_i = self.attn_metadata_builders[
|
||||
kv_cache_group_id].build_for_cudagraph_capture(
|
||||
@ -2466,6 +2484,49 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time, cuda_graph_size / (1 << 30))
|
||||
|
||||
def _initialize_single_attn_backend(
|
||||
self, kv_cache_spec: KVCacheSpec
|
||||
) -> tuple[AttentionBackend, AttentionMetadataBuilder]:
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
attn_backend_i = get_attn_backend(
|
||||
kv_cache_spec.head_size,
|
||||
self.dtype,
|
||||
kv_cache_spec.dtype,
|
||||
kv_cache_spec.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=kv_cache_spec.use_mla,
|
||||
)
|
||||
if attn_backend_i is None:
|
||||
error_msg = (f"Error with get_attn_backend: "
|
||||
f"{kv_cache_spec.head_size=}, "
|
||||
f"{self.dtype=}, {kv_cache_spec.dtype=}, "
|
||||
f"{kv_cache_spec.block_size=}, "
|
||||
f"{self.model_config.is_attention_free=}, "
|
||||
f"{kv_cache_spec.use_mla=}")
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 "
|
||||
"GPUModelRunner.")
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
attn_backend_i = Mamba2AttentionBackend
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
|
||||
|
||||
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
|
||||
kv_cache_spec,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
|
||||
if (self.full_cuda_graph
|
||||
and not attn_metadata_builder_i.full_cudagraph_supported):
|
||||
raise ValueError(
|
||||
f"Full CUDAGraph not supported for "
|
||||
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
|
||||
f"full_cuda_graph or use a different attention backend.")
|
||||
return attn_backend_i, attn_metadata_builder_i
|
||||
|
||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize the attention backends and attention metadata builders.
|
||||
@ -2476,48 +2537,45 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for i, kv_cache_group_spec in enumerate(
|
||||
kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
attn_backend_i = get_attn_backend(
|
||||
kv_cache_spec.head_size,
|
||||
self.dtype,
|
||||
kv_cache_spec.dtype,
|
||||
kv_cache_spec.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=kv_cache_spec.use_mla,
|
||||
)
|
||||
if attn_backend_i is None:
|
||||
error_msg = (f"Error with get_attn_backend: "
|
||||
f"{kv_cache_spec.head_size=}, "
|
||||
f"{self.dtype=}, {kv_cache_spec.dtype=}, "
|
||||
f"{kv_cache_spec.block_size=}, "
|
||||
f"{self.model_config.is_attention_free=}, "
|
||||
f"{kv_cache_spec.use_mla=}")
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 "
|
||||
"GPUModelRunner.")
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
attn_backend_i = Mamba2AttentionBackend
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
|
||||
|
||||
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
|
||||
kv_cache_spec,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
|
||||
if (self.full_cuda_graph
|
||||
and not attn_metadata_builder_i.full_cudagraph_supported):
|
||||
raise ValueError(
|
||||
f"Full CUDAGraph not supported for "
|
||||
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
|
||||
f"full_cuda_graph or use a different attention backend.")
|
||||
|
||||
attn_backend_i, attn_metadata_builder_i = \
|
||||
self._initialize_single_attn_backend(kv_cache_spec)
|
||||
self.attn_backends.append(attn_backend_i)
|
||||
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
||||
|
||||
if len(self.attn_backends) > 0:
|
||||
return
|
||||
|
||||
# Check if model is encoder-only
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
attn_specs = list[AttentionSpec]()
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for attn_module in attn_layers.values():
|
||||
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
assert attn_module.sliding_window is None, "Sliding "
|
||||
"window attention is not supported for encoder-only models"
|
||||
|
||||
attn_specs.append(
|
||||
FullAttentionSpec(block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla))
|
||||
else:
|
||||
raise ValueError("Expected only encoder-only layers")
|
||||
|
||||
if len(attn_specs) > 0:
|
||||
assert len(attn_specs) == len(attn_layers), \
|
||||
"All or none of the layers are expected to be encoder-only"
|
||||
|
||||
attn_backend, attn_metadata_builder = \
|
||||
self._initialize_single_attn_backend(attn_specs[0])
|
||||
self.attn_backends.append(attn_backend)
|
||||
self.attn_metadata_builders.append(attn_metadata_builder)
|
||||
self.is_encoder_only_model = True
|
||||
|
||||
def may_reinitialize_input_batch(self,
|
||||
kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
@ -2833,3 +2891,53 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
page_size_padded=page_size_padded)
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def _build_encoder_only_attn_metadata(
|
||||
self, scheduler_output: "SchedulerOutput") -> \
|
||||
tuple[CommonAttentionMetadata, Any]:
|
||||
"""Prepare encoder attention metadata for encoder-only models.
|
||||
|
||||
Args:
|
||||
scheduler_output: Scheduler output
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Encoder attention metadata
|
||||
"""
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
req_ids = self.input_batch.req_ids
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
|
||||
# Use the first attention metadata builder
|
||||
# to create encoder attention metadata
|
||||
builder = self.attn_metadata_builders[0]
|
||||
|
||||
dummy_block_table = torch.zeros((num_reqs, 1),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
common_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||
seq_lens=self.seq_lens[:num_reqs],
|
||||
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
||||
num_computed_tokens_cpu=self.input_batch.
|
||||
num_computed_tokens_cpu_tensor[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
block_table_tensor=dummy_block_table,
|
||||
slot_mapping=dummy_slot_mapping,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
return common_metadata, builder.build(
|
||||
common_prefix_len=0, # No cascade for encoder
|
||||
common_attn_metadata=common_metadata,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user