mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-29 19:57:17 +08:00
Merge 6304606fad6969cbb7654266336182883246c59a into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
6f7c3f2b40
@ -433,6 +433,7 @@ th {
|
||||
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ |
|
||||
| `OuroForCausalLM` | ouro | `ByteDance/Ouro-1.4B`, `ByteDance/Ouro-2.6B`, etc. | ✅︎ | |
|
||||
| `PanguEmbeddedForCausalLM` |openPangu-Embedded-7B | `FreedomIntelligence/openPangu-Embedded-7B-V1.1` | ✅︎ | ✅︎ |
|
||||
| `PanguProMoEV2ForCausalLM` |openpangu-pro-moe-v2 | | ✅︎ | ✅︎ |
|
||||
| `PanguUltraMoEForCausalLM` |openpangu-ultra-moe-718b-model | `FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1` | ✅︎ | ✅︎ |
|
||||
| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ |
|
||||
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
|
||||
@ -394,6 +394,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"PanguEmbeddedForCausalLM": _HfExamplesInfo(
|
||||
"FreedomIntelligence/openPangu-Embedded-7B-V1.1", trust_remote_code=True
|
||||
),
|
||||
"PanguProMoEV2ForCausalLM": _HfExamplesInfo(
|
||||
"",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False,
|
||||
),
|
||||
"PanguUltraMoEForCausalLM": _HfExamplesInfo(
|
||||
"FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1",
|
||||
trust_remote_code=True,
|
||||
|
||||
@ -42,6 +42,9 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
|
||||
"""
|
||||
|
||||
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
FLASH_ATTN_DIFFKV = (
|
||||
"vllm.v1.attention.backends.flash_attn_diffkv.FlashAttentionDiffKVBackend"
|
||||
)
|
||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
||||
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
|
||||
|
||||
@ -136,6 +136,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
head_size_v: int | None = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
@ -177,6 +178,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.head_size_v = self.head_size if head_size_v is None else head_size_v
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||
@ -234,8 +236,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
kv_sharing_target_layer_name,
|
||||
**extra_impl_args,
|
||||
)
|
||||
backend_name = self.attn_backend.get_name()
|
||||
self.backend = AttentionBackendEnum.__members__.get(backend_name)
|
||||
self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
@ -316,6 +317,10 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
if output_shape is None:
|
||||
output_shape = torch.Size(
|
||||
(*query.shape[:-1], self.num_heads * self.head_size_v)
|
||||
)
|
||||
output_shape = output_shape if output_shape is not None else query.shape
|
||||
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
@ -323,11 +328,11 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size_v)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
@ -402,6 +407,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
head_size_v=self.head_size_v,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
@ -728,6 +734,7 @@ def unified_attention_with_output(
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
|
||||
254
vllm/attention/layers/static_sink_attention.py
Normal file
254
vllm/attention/layers/static_sink_attention.py
Normal file
@ -0,0 +1,254 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_diffkv,
|
||||
)
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
KVCacheSpec,
|
||||
SinkFullAttentionSpec,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_static_sink_attention_backend(
|
||||
underlying_attn_backend: type[AttentionBackend],
|
||||
sink_len: int = 0,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "StaticSink_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class StaticSinkAttentionBuilder(underlying_builder): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
model_config = vllm_config.model_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.sink_len = sink_len
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size
|
||||
self.max_num_blocks = cdiv(
|
||||
model_config.max_model_len, vllm_config.cache_config.block_size
|
||||
)
|
||||
self.block_table_with_sink = torch.zeros(
|
||||
(
|
||||
scheduler_config.max_num_seqs,
|
||||
self.max_num_blocks + self.num_sink_blocks,
|
||||
),
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.block_table_with_sink[:, : self.num_sink_blocks] = torch.arange(
|
||||
1,
|
||||
self.num_sink_blocks + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
common_attn_metadata.seq_lens[:] = (
|
||||
common_attn_metadata.seq_lens + self.sink_len
|
||||
)
|
||||
common_attn_metadata.seq_lens[
|
||||
common_attn_metadata.seq_lens == self.sink_len
|
||||
] = 0
|
||||
common_attn_metadata.max_seq_len = (
|
||||
common_attn_metadata.max_seq_len + self.sink_len
|
||||
)
|
||||
max_num_blocks = cdiv(common_attn_metadata.max_seq_len, self.block_size)
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
self.block_table_with_sink[
|
||||
:num_reqs, self.num_sink_blocks : self.num_sink_blocks + max_num_blocks
|
||||
] = common_attn_metadata.block_table_tensor[:, :max_num_blocks]
|
||||
common_attn_metadata.block_table_tensor = self.block_table_with_sink[
|
||||
:num_reqs
|
||||
]
|
||||
|
||||
return super().build(common_prefix_len, common_attn_metadata, fast_build)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=StaticSinkAttentionBuilder,
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
@CustomOp.register("static_sink_attention")
|
||||
class StaticSinkAttention(Attention, CustomOp):
|
||||
"""
|
||||
Attention with static sink tokens
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
sink_len: int,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if attn_backend is not None:
|
||||
underlying_attn_backend = attn_backend
|
||||
else:
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size, dtype, kv_cache_dtype, block_size
|
||||
)
|
||||
attn_backend = create_static_sink_attention_backend(
|
||||
underlying_attn_backend,
|
||||
sink_len=sink_len,
|
||||
)
|
||||
Attention.__init__(
|
||||
self=self,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
**kwargs,
|
||||
)
|
||||
CustomOp.__init__(self)
|
||||
|
||||
self.sink_len = sink_len
|
||||
self.block_size = block_size
|
||||
self.sink_populated = False
|
||||
self.sink_key = None
|
||||
self.sink_value = None
|
||||
|
||||
def update_sink_kv(self, sink_key, sink_value) -> None:
|
||||
self.sink_key = sink_key
|
||||
self.sink_value = sink_value
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.sink_key is not None and self.sink_value is not None, (
|
||||
"sink_key and sink_value have not been prepared"
|
||||
)
|
||||
if not self.sink_populated:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)
|
||||
|
||||
return super().forward(query, key, value, output_shape)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.forward_native(query, key, value, output_shape)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self._forward_method(*args, **kwargs)
|
||||
|
||||
def populate_sink_kv(self, self_kv_cache):
|
||||
sink_kv_slot_mapping = torch.arange(
|
||||
self.block_size,
|
||||
self.sink_len + self.block_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.long,
|
||||
)
|
||||
triton_reshape_and_cache_flash_diffkv(
|
||||
self.sink_key,
|
||||
self.sink_value,
|
||||
self_kv_cache,
|
||||
sink_kv_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
self._k_scale,
|
||||
self._v_scale,
|
||||
)
|
||||
# We only populate the sink_key and sink_value once
|
||||
self.sink_populated = True
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Block size may get updated after model loading, refresh it
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
|
||||
return SinkFullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
head_size_v=self.head_size_v,
|
||||
sink_len=self.sink_len,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
def maybe_populate_sink(
|
||||
self_kv_cache: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
if self.sink_populated or self_kv_cache.numel() == 0:
|
||||
return
|
||||
self.populate_sink_kv(self_kv_cache)
|
||||
|
||||
|
||||
def maybe_populate_sink_fake(
|
||||
self_kv_cache: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="maybe_populate_sink",
|
||||
op_func=maybe_populate_sink,
|
||||
mutates_args=["self_kv_cache"],
|
||||
fake_impl=maybe_populate_sink_fake,
|
||||
)
|
||||
@ -182,3 +182,174 @@ def triton_reshape_and_cache_flash(
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reshape_and_cache_kernel_flash_diffkv(
|
||||
key_ptr, # [num_tokens, num_heads, head_size]
|
||||
value_ptr, # [num_tokens, num_heads, head_size_v]
|
||||
kv_cache_ptr, # [num_blocks, block_size, num_heads, head_size + head_size_v]
|
||||
slot_mapping_ptr, # [num_tokens]
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
# strides
|
||||
key_stride: tl.int64,
|
||||
value_stride: tl.int64,
|
||||
block_stride: tl.int64,
|
||||
page_stride: tl.int64,
|
||||
num_heads: tl.constexpr,
|
||||
head_size_k: tl.constexpr,
|
||||
head_size_v: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE: tl.constexpr,
|
||||
# tune parameters
|
||||
TILE_SIZE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(axis=0)
|
||||
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
|
||||
if slot_idx < 0:
|
||||
# Padding token that should be ignored.
|
||||
return
|
||||
|
||||
tile_i = tl.program_id(axis=1)
|
||||
tile_offs = tl.arange(0, TILE_SIZE)
|
||||
|
||||
block_idx = slot_idx // block_size
|
||||
block_offset = slot_idx % block_size
|
||||
|
||||
src_key_idx = token_idx * key_stride + tile_i * head_size_k
|
||||
src_value_idx = token_idx * value_stride + tile_i * head_size_v
|
||||
|
||||
tgt_idx = (
|
||||
block_idx * block_stride
|
||||
+ block_offset * page_stride
|
||||
+ tile_i * (head_size_k + head_size_v)
|
||||
)
|
||||
|
||||
# [TILE_SIZE]
|
||||
key_load = tl.load(key_ptr + src_key_idx + tile_offs, mask=tile_offs < head_size_k)
|
||||
if FP8_KV_CACHE:
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the key_cache_ptr.dtype.element_ty
|
||||
key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
|
||||
else:
|
||||
key_tile = key_load
|
||||
|
||||
# [TILE_SIZE]
|
||||
value_load = tl.load(
|
||||
value_ptr + src_value_idx + tile_offs, mask=tile_offs < head_size_v
|
||||
)
|
||||
if FP8_KV_CACHE:
|
||||
if value_load.dtype.is_fp8():
|
||||
value_tile = value_load
|
||||
else:
|
||||
# tl.store will do the correct implicit cast to fp8,
|
||||
# based on the value_cache_ptr.dtype.element_ty
|
||||
value_tile = value_load / tl.load(v_scale)
|
||||
else:
|
||||
value_tile = value_load
|
||||
|
||||
tl.store(
|
||||
kv_cache_ptr + tgt_idx + tile_offs,
|
||||
key_tile,
|
||||
mask=tile_offs < head_size_k,
|
||||
)
|
||||
tl.store(
|
||||
kv_cache_ptr + tgt_idx + head_size_k + tile_offs,
|
||||
value_tile,
|
||||
mask=tile_offs < head_size_v,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def triton_reshape_and_cache_flash_diffkv(
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads, head_size_v]
|
||||
# [num_blocks, block_size, num_heads, head_size + head_size_v]
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor, # [num_tokens]
|
||||
kv_cache_dtype: str, # "auto", "fp8"
|
||||
k_scale: torch.Tensor, # float32
|
||||
v_scale: torch.Tensor, # float32
|
||||
):
|
||||
num_heads = key.shape[1]
|
||||
head_size_k = key.shape[2]
|
||||
head_size_v = value.shape[2]
|
||||
block_size = kv_cache.shape[1]
|
||||
|
||||
k_stride = key.stride()[0]
|
||||
v_stride = value.stride()[0]
|
||||
block_stride = kv_cache.stride()[0]
|
||||
page_stride = kv_cache.stride()[1]
|
||||
|
||||
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
|
||||
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
||||
)
|
||||
kv_cache_torch_dtype = (
|
||||
current_platform.fp8_dtype()
|
||||
if kv_cache_dtype.startswith("fp8")
|
||||
else kv_cache.dtype
|
||||
)
|
||||
|
||||
if kv_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"):
|
||||
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
|
||||
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
|
||||
kv_cache = kv_cache.view(kv_cache_torch_dtype)
|
||||
assert kv_cache_dtype != torch.uint8, (
|
||||
"explicit fp8 cast and store to "
|
||||
"uint8 is not supported by triton reshape_and_cache_flash_diffkv"
|
||||
)
|
||||
|
||||
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
||||
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
torch.uint8,
|
||||
torch.float8_e4m3fnuz,
|
||||
], (
|
||||
"unsupported dtype of KV cache tensor, got "
|
||||
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, "
|
||||
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
|
||||
)
|
||||
|
||||
# heuristics instead of autotuning
|
||||
TILE_SIZE = max(head_size_k, head_size_v)
|
||||
TILE_SIZE = triton.next_power_of_2(TILE_SIZE)
|
||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||
num_stages = 4
|
||||
num_warps = 8
|
||||
else: # cuda
|
||||
num_stages = 10
|
||||
num_warps = 16
|
||||
|
||||
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
|
||||
# using cudagraphs
|
||||
grid = lambda meta: (
|
||||
slot_mapping.shape[0],
|
||||
num_heads,
|
||||
)
|
||||
|
||||
reshape_and_cache_kernel_flash_diffkv[grid](
|
||||
key_ptr=key,
|
||||
value_ptr=value,
|
||||
kv_cache_ptr=kv_cache,
|
||||
slot_mapping_ptr=slot_mapping,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
# strides
|
||||
key_stride=k_stride,
|
||||
value_stride=v_stride,
|
||||
block_stride=block_stride,
|
||||
page_stride=page_stride,
|
||||
num_heads=num_heads,
|
||||
head_size_k=head_size_k,
|
||||
head_size_v=head_size_v,
|
||||
block_size=block_size,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE=FP8_KV_CACHE,
|
||||
# autotune parameters
|
||||
TILE_SIZE=TILE_SIZE,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
|
||||
@ -29,13 +29,14 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layer import Attention, AttentionType
|
||||
from vllm.attention.layers.static_sink_attention import StaticSinkAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_ep_group,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
tensor_model_parallel_all_gather,
|
||||
@ -77,8 +78,11 @@ from vllm.model_executor.models.utils import (
|
||||
maybe_prefix,
|
||||
sequence_parallel_chunk,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import set_default_rope_theta
|
||||
from vllm.v1.attention.backends.flash_attn_diffkv import FlashAttentionDiffKVBackend
|
||||
|
||||
|
||||
def check_ffn_act_fn(act_fn: str):
|
||||
@ -155,7 +159,15 @@ class OpenPanguMoE(nn.Module):
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
self.gate.e_score_correction_bias = None
|
||||
if (
|
||||
hasattr(config, "router_enable_expert_bias")
|
||||
and config.router_enable_expert_bias
|
||||
):
|
||||
self.gate.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(self.n_routed_experts, dtype=torch.float32)
|
||||
)
|
||||
else:
|
||||
self.gate.e_score_correction_bias = None
|
||||
|
||||
# Load balancing settings.
|
||||
eplb_config = parallel_config.eplb_config
|
||||
@ -530,6 +542,264 @@ class OpenPanguEmbeddedAttention(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class OpenPanguSinkAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_parameters: dict[str, Any] | None = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
bias: bool = False,
|
||||
bias_o_proj: bool = False,
|
||||
cache_config: CacheConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
self.hidden_size = hidden_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.total_num_heads = num_heads
|
||||
if self.total_num_heads % self.tp_size != 0:
|
||||
raise ValueError(
|
||||
f"total_num_heads {self.total_num_heads} "
|
||||
f"is not divisible by tp_size {self.tp_size}."
|
||||
)
|
||||
self.num_heads = self.total_num_heads // self.tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if (
|
||||
self.total_num_kv_heads > self.tp_size
|
||||
and self.total_num_kv_heads % self.tp_size != 0
|
||||
):
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel ranks.
|
||||
raise ValueError(
|
||||
"Number of KV heads is greater than TP size, "
|
||||
f"but total_num_kv_heads {self.total_num_kv_heads} "
|
||||
f"is not divisible by tp_size {self.tp_size}."
|
||||
)
|
||||
elif self.total_num_kv_heads < self.tp_size:
|
||||
# TODO: Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel ranks.
|
||||
raise ValueError(
|
||||
f"Number of KV heads {self.total_num_kv_heads} is less than "
|
||||
f"TP size {self.tp_size}, KV heads replication is not support yet."
|
||||
)
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
|
||||
self.qk_nope_dim = getattr(config, "qk_nope_dim", None)
|
||||
self.qk_rope_dim = getattr(config, "qk_rope_dim", None)
|
||||
self.v_channels = getattr(config, "v_channels", None)
|
||||
self.head_dim = self.qk_rope_dim + self.qk_nope_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.k_size = self.num_kv_heads * self.head_dim
|
||||
self.v_size = self.num_kv_heads * self.v_channels
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.param_sink_number = getattr(config, "param_sink_number", 0)
|
||||
self.param_sink_with_value = getattr(config, "param_sink_with_value", False)
|
||||
self.param_sink_scalar = getattr(config, "param_sink_scalar", None)
|
||||
self.param_sink_of_head_num = getattr(config, "param_sink_of_head_dim", False)
|
||||
|
||||
self.qkv_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[
|
||||
self.q_size * self.tp_size,
|
||||
self.k_size * self.tp_size,
|
||||
self.v_size * self.tp_size,
|
||||
],
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.v_channels,
|
||||
output_size=hidden_size,
|
||||
bias=bias_o_proj,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.k_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
|
||||
self._init_rotary_emb(
|
||||
config, rope_parameters=rope_parameters, quant_config=quant_config
|
||||
)
|
||||
|
||||
if hasattr(config, "interleaved_sliding_window"):
|
||||
interleaved_sliding_window = config.interleaved_sliding_window
|
||||
if isinstance(interleaved_sliding_window, int):
|
||||
sliding_window = interleaved_sliding_window
|
||||
elif isinstance(interleaved_sliding_window, list):
|
||||
sw_idx = layer_idx % len(interleaved_sliding_window)
|
||||
sliding_window = interleaved_sliding_window[sw_idx]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{type(interleaved_sliding_window)} "
|
||||
"for interleaved_sliding_window is not supported."
|
||||
)
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
FlashAttentionDiffKVBackend.set_head_size_v(self.v_channels)
|
||||
self.attn = StaticSinkAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
sink_len=self.param_sink_number,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
attn_type=attn_type,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_backend=FlashAttentionDiffKVBackend,
|
||||
head_size_v=self.v_channels,
|
||||
)
|
||||
|
||||
if self.param_sink_number > 0:
|
||||
self.param_sink_key = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
(
|
||||
self.param_sink_number,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
),
|
||||
device=current_platform.current_device(),
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.param_sink_key,
|
||||
{
|
||||
"output_dim": 1,
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
|
||||
if self.param_sink_with_value:
|
||||
self.param_sink_value = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
(
|
||||
self.param_sink_number,
|
||||
self.num_kv_heads,
|
||||
self.v_channels,
|
||||
),
|
||||
device=current_platform.current_device(),
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.param_sink_value,
|
||||
{
|
||||
"output_dim": 1,
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.param_sink_value = torch.zeros(
|
||||
(
|
||||
self.param_sink_number,
|
||||
self.num_kv_heads,
|
||||
self.v_channels,
|
||||
),
|
||||
device=current_platform.current_device(),
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
# To enable dummy run with out weight
|
||||
self.post_weight_load()
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
|
||||
# Special case for GGUF
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.weight_type = loaded_weight.item()
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, nn.UninitializedParameter):
|
||||
final_shape = list(loaded_weight.shape)
|
||||
if output_dim is not None:
|
||||
assert final_shape[output_dim] % self.tp_size == 0
|
||||
final_shape[output_dim] = final_shape[output_dim] // self.tp_size
|
||||
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
if output_dim is not None and not is_sharded_weight:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
k = self.k_layernorm(k.view(-1, self.num_kv_heads, self.head_dim))
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
q = q.view(-1, self.q_size)
|
||||
k = k.view(-1, self.k_size)
|
||||
|
||||
attn_output = self.attn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output_shape=torch.Size(
|
||||
[q.shape[0], q.shape[1] // self.head_dim * self.v_channels]
|
||||
),
|
||||
)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def _init_rotary_emb(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
rope_parameters: dict[str, Any] | None,
|
||||
quant_config: QuantizationConfig | None,
|
||||
) -> None:
|
||||
is_neox_style = False
|
||||
rope_parameters = {"partial_rotary_factor": self.qk_rope_dim / self.head_dim}
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
rope_parameters=rope_parameters,
|
||||
is_neox_style=is_neox_style,
|
||||
)
|
||||
|
||||
def post_weight_load(self) -> None:
|
||||
if hasattr(self, "k_layernorm") and self.k_layernorm is not None:
|
||||
param_sink_key = self.k_layernorm(self.param_sink_key)
|
||||
else:
|
||||
param_sink_key = self.param_sink_key
|
||||
|
||||
self.attn.update_sink_kv(param_sink_key, self.param_sink_value)
|
||||
|
||||
|
||||
class OpenPanguDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -557,6 +827,9 @@ class OpenPanguDecoderLayer(nn.Module):
|
||||
and hasattr(config, "v_head_dim")
|
||||
and hasattr(config, "kv_lora_rank")
|
||||
)
|
||||
self.use_sink_attention = (
|
||||
hasattr(config, "param_sink_number") and config.param_sink_number > 0
|
||||
)
|
||||
if self.use_mla:
|
||||
self.self_attn = OpenPanguMLAAttention(
|
||||
config=config,
|
||||
@ -574,6 +847,42 @@ class OpenPanguDecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
elif self.use_sink_attention:
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "bias", False
|
||||
)
|
||||
bias_o_proj = attention_bias
|
||||
if hasattr(config, "qkv_bias"):
|
||||
attention_bias = config.qkv_bias
|
||||
if getattr(config, "is_causal", True):
|
||||
attn_type = AttentionType.DECODER
|
||||
else:
|
||||
raise ValueError(
|
||||
f"is_causal={config.is_causal} is not support "
|
||||
"for attention with sink"
|
||||
)
|
||||
rope_parameters = getattr(config, "rope_scaling", None)
|
||||
if rope_parameters is None:
|
||||
rope_parameters = {
|
||||
"rope_type": "default",
|
||||
"rope_theta": config.rope_theta,
|
||||
}
|
||||
self.self_attn = OpenPanguSinkAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(
|
||||
config, "num_key_value_heads", config.num_attention_heads
|
||||
),
|
||||
rope_parameters=rope_parameters,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
bias_o_proj=bias_o_proj,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
)
|
||||
else:
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "bias", False
|
||||
@ -903,6 +1212,10 @@ class OpenPanguModel(nn.Module):
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name.endswith("e_score_correction_bias"):
|
||||
name = name.replace(
|
||||
"e_score_correction_bias", "gate.e_score_correction_bias"
|
||||
)
|
||||
if name is None:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
@ -912,8 +1225,17 @@ class OpenPanguModel(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
self.post_weight_load()
|
||||
return loaded_params
|
||||
|
||||
def post_weight_load(self) -> None:
|
||||
for name, module in self.named_modules():
|
||||
if module is self:
|
||||
continue
|
||||
if hasattr(module, "post_weight_load"):
|
||||
module.post_weight_load()
|
||||
|
||||
|
||||
class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
@ -1047,3 +1369,7 @@ class PanguEmbeddedForCausalLM(OpenPanguEmbeddedModel):
|
||||
|
||||
class PanguUltraMoEForCausalLM(OpenPanguMoEModel):
|
||||
pass
|
||||
|
||||
|
||||
class PanguProMoEV2ForCausalLM(OpenPanguMoEModel):
|
||||
pass
|
||||
|
||||
@ -164,6 +164,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||
"OuroForCausalLM": ("ouro", "OuroForCausalLM"),
|
||||
"PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
|
||||
"PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
|
||||
"PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
|
||||
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
|
||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||
|
||||
269
vllm/v1/attention/backends/flash_attn_diffkv.py
Normal file
269
vllm/v1/attention/backends/flash_attn_diffkv.py
Normal file
@ -0,0 +1,269 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashAttention."""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_diffkv,
|
||||
)
|
||||
from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
|
||||
|
||||
if is_flash_attn_varlen_func_available():
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
|
||||
from .flash_attn import (
|
||||
FlashAttentionBackend,
|
||||
FlashAttentionImpl,
|
||||
FlashAttentionMetadata,
|
||||
cascade_attention,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashAttentionDiffKVBackend(FlashAttentionBackend):
|
||||
# Default to 128 for this backend
|
||||
head_size_v: int = 128
|
||||
|
||||
@classmethod
|
||||
def set_head_size_v(cls, head_size_v: int) -> None:
|
||||
cls.head_size_v = head_size_v
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_DIFFKV"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashAttentionImpl"]:
|
||||
return FlashAttentionDiffKVImpl
|
||||
|
||||
# Do not modify the interface of get_kv_cache_shape,
|
||||
# but consider head_size_v when returning result.
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (
|
||||
num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size + FlashAttentionDiffKVBackend.head_size_v,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
# `stride_order` indicates the permutation that gets
|
||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "NHD" and include_num_layers_dimension:
|
||||
# (num_blocks, num_layers, block_size,
|
||||
# num_kv_heads, head_size + head_size_v)
|
||||
return (1, 0, 2, 3, 4)
|
||||
elif cache_layout == "NHD":
|
||||
stride_order = (0, 1, 2, 3)
|
||||
elif cache_layout == "HND" and include_num_layers_dimension:
|
||||
# (num_blocks, num_kv_heads, num_layers,
|
||||
# block_size, head_size + head_size_v)
|
||||
return (1, 3, 0, 2, 4)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 2, 1, 3)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
|
||||
|
||||
class FlashAttentionDiffKVImpl(FlashAttentionImpl):
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size_v]
|
||||
kv_cache: shape =
|
||||
[num_blocks, block_size, num_kv_heads, head_size + head_size_v]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size_v]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for FlashAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
attn_type = self.attn_type
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
# Handle encoder attention differently - no KV cache needed
|
||||
if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return self._forward_encoder_attention(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
layer,
|
||||
)
|
||||
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
# Different head_size for K and V
|
||||
key_cache = kv_cache[..., : self.head_size]
|
||||
value_cache = kv_cache[..., self.head_size :]
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
|
||||
# kv_cache update for different head_size K and V
|
||||
triton_reshape_and_cache_flash_diffkv(
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
# queries are quantized in the attention layer
|
||||
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||
self.kv_cache_dtype
|
||||
)
|
||||
key_cache = key_cache.view(dtype)
|
||||
value_cache = value_cache.view(dtype)
|
||||
|
||||
if not attn_metadata.use_cascade:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
self._forward_with_dcp(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
return output
|
||||
else:
|
||||
flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=attn_metadata.causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
num_splits=attn_metadata.max_num_splits,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
return output
|
||||
|
||||
# Cascade attention (rare case).
|
||||
cascade_attention(
|
||||
output[:num_actual_tokens],
|
||||
query[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
cu_query_lens=attn_metadata.query_start_loc,
|
||||
max_query_len=attn_metadata.max_query_len,
|
||||
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
|
||||
prefix_kv_lens=attn_metadata.prefix_kv_lens,
|
||||
suffix_kv_lens=attn_metadata.suffix_kv_lens,
|
||||
max_kv_len=attn_metadata.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
block_table=attn_metadata.block_table,
|
||||
common_prefix_len=attn_metadata.common_prefix_len,
|
||||
max_num_splits=attn_metadata.max_num_splits,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
|
||||
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
q_descale=layer._q_scale,
|
||||
k_descale=layer._k_scale,
|
||||
v_descale=layer._v_scale,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
return output
|
||||
@ -15,6 +15,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
KVCacheSpec,
|
||||
MambaSpec,
|
||||
MLAAttentionSpec,
|
||||
SinkFullAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
from vllm.v1.request import Request
|
||||
@ -783,6 +784,24 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
raise NotImplementedError("CrossAttentionManager does not support caching")
|
||||
|
||||
|
||||
class SinkFullAttentionManager(FullAttentionManager):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: SinkFullAttentionSpec,
|
||||
block_pool: BlockPool,
|
||||
kv_cache_group_id: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec, block_pool, kv_cache_group_id, dcp_world_size, pcp_world_size
|
||||
)
|
||||
sink_len = kv_cache_spec.sink_len
|
||||
assert sink_len is not None and sink_len > 0 and sink_len % self.block_size == 0
|
||||
num_sink_block = sink_len // self.block_size
|
||||
self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block)
|
||||
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
MLAAttentionSpec: FullAttentionManager,
|
||||
@ -790,6 +809,7 @@ spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
||||
MambaSpec: MambaManager,
|
||||
CrossAttentionSpec: CrossAttentionManager,
|
||||
SinkFullAttentionSpec: SinkFullAttentionManager,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -89,12 +89,18 @@ class FullAttentionSpec(AttentionSpec):
|
||||
In this case, we use FullAttentionSpec and record the sliding window size.
|
||||
"""
|
||||
|
||||
head_size_v: int | None = None
|
||||
|
||||
sliding_window: int | None = None
|
||||
"""
|
||||
Default to None for not using sliding window attention.
|
||||
"""
|
||||
attention_chunk_size: int | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.head_size_v is None:
|
||||
object.__setattr__(self, "head_size_v", self.head_size)
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
@ -142,6 +148,7 @@ class FullAttentionSpec(AttentionSpec):
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
head_size_v=specs[0].head_size_v,
|
||||
dtype=specs[0].dtype,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
@ -160,6 +167,15 @@ class FullAttentionSpec(AttentionSpec):
|
||||
)
|
||||
return merged_spec
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
return (
|
||||
self.block_size
|
||||
* self.num_kv_heads
|
||||
* (self.head_size + self.head_size_v)
|
||||
* get_dtype_size(self.dtype)
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MLAAttentionSpec(FullAttentionSpec):
|
||||
@ -287,6 +303,56 @@ class CrossAttentionSpec(AttentionSpec):
|
||||
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SinkFullAttentionSpec(FullAttentionSpec):
|
||||
sink_len: int | None = None
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
Merge a list of FullAttentionSpec objects into a single
|
||||
FullAttentionSpec object.
|
||||
"""
|
||||
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be FullAttentionSpec."
|
||||
)
|
||||
|
||||
sliding_window = set(
|
||||
spec.sliding_window for spec in specs if spec.sliding_window is not None
|
||||
)
|
||||
attention_chunk_size = set(
|
||||
spec.attention_chunk_size
|
||||
for spec in specs
|
||||
if spec.attention_chunk_size is not None
|
||||
)
|
||||
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
||||
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
|
||||
)
|
||||
merged_spec = cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
head_size_v=specs[0].head_size_v,
|
||||
sink_len=specs[0].sink_len,
|
||||
dtype=specs[0].dtype,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
)
|
||||
for spec in specs:
|
||||
for f in fields(AttentionSpec):
|
||||
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
|
||||
"All attention layers in the same KV cache group must have "
|
||||
"the same attention spec."
|
||||
)
|
||||
assert (merged_spec.sliding_window is not None) + (
|
||||
merged_spec.attention_chunk_size is not None
|
||||
) <= 1, (
|
||||
"Model with both sliding window layers and chunked local attention "
|
||||
"layers is not supported."
|
||||
)
|
||||
return merged_spec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UniformTypeKVCacheSpecs(KVCacheSpec):
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user