diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 6838fc227f355..a89af35cc9098 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -433,6 +433,7 @@ th { | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | | `OuroForCausalLM` | ouro | `ByteDance/Ouro-1.4B`, `ByteDance/Ouro-2.6B`, etc. | ✅︎ | | | `PanguEmbeddedForCausalLM` |openPangu-Embedded-7B | `FreedomIntelligence/openPangu-Embedded-7B-V1.1` | ✅︎ | ✅︎ | +| `PanguProMoEV2ForCausalLM` |openpangu-pro-moe-v2 | | ✅︎ | ✅︎ | | `PanguUltraMoEForCausalLM` |openpangu-ultra-moe-718b-model | `FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1` | ✅︎ | ✅︎ | | `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | | `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 2922414cdaa6a..887a3a24ae44c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -394,6 +394,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "PanguEmbeddedForCausalLM": _HfExamplesInfo( "FreedomIntelligence/openPangu-Embedded-7B-V1.1", trust_remote_code=True ), + "PanguProMoEV2ForCausalLM": _HfExamplesInfo( + "", + trust_remote_code=True, + is_available_online=False, + ), "PanguUltraMoEForCausalLM": _HfExamplesInfo( "FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1", trust_remote_code=True, diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 416b996df9f22..3529775170db3 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -42,6 +42,9 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): """ FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + FLASH_ATTN_DIFFKV = ( + "vllm.v1.attention.backends.flash_attn_diffkv.FlashAttentionDiffKVBackend" + ) TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 1d882eb87bfde..2d7a084bd1245 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -136,6 +136,7 @@ class Attention(nn.Module, AttentionLayerBase): attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, attn_backend: type[AttentionBackend] | None = None, + head_size_v: int | None = None, **extra_impl_args, ) -> None: """ @@ -177,6 +178,7 @@ class Attention(nn.Module, AttentionLayerBase): self.num_heads = num_heads self.head_size = head_size + self.head_size_v = self.head_size if head_size_v is None else head_size_v self.num_kv_heads = num_kv_heads self.sliding_window = sliding_window self.has_sink = extra_impl_args.get("sinks") is not None @@ -234,8 +236,7 @@ class Attention(nn.Module, AttentionLayerBase): kv_sharing_target_layer_name, **extra_impl_args, ) - backend_name = self.attn_backend.get_name() - self.backend = AttentionBackendEnum.__members__.get(backend_name) + self.backend = AttentionBackendEnum[self.attn_backend.get_name()] self.dtype = dtype # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how @@ -316,6 +317,10 @@ class Attention(nn.Module, AttentionLayerBase): query, _ = self.query_quant(query, self._q_scale) if self.use_output: + if output_shape is None: + output_shape = torch.Size( + (*query.shape[:-1], self.num_heads * self.head_size_v) + ) output_shape = output_shape if output_shape is not None else query.shape output = torch.empty(output_shape, dtype=output_dtype, device=query.device) hidden_size = output_shape[-1] @@ -323,11 +328,11 @@ class Attention(nn.Module, AttentionLayerBase): # NOTE(woosuk): We do this outside the custom op to minimize the # CPU overheads from the non-CUDA-graph regions. query = query.view(-1, self.num_heads, self.head_size) - output = output.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size_v) if key is not None: key = key.view(-1, self.num_kv_heads, self.head_size) if value is not None: - value = value.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size_v) if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -402,6 +407,7 @@ class Attention(nn.Module, AttentionLayerBase): block_size=block_size, num_kv_heads=self.num_kv_heads, head_size=self.head_size, + head_size_v=self.head_size_v, dtype=self.kv_cache_torch_dtype, ) @@ -728,6 +734,7 @@ def unified_attention_with_output( output_block_scale: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) + self.impl.forward( self, query, diff --git a/vllm/attention/layers/static_sink_attention.py b/vllm/attention/layers/static_sink_attention.py new file mode 100644 index 0000000000000..e5ed16ec14932 --- /dev/null +++ b/vllm/attention/layers/static_sink_attention.py @@ -0,0 +1,254 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools + +import torch + +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionMetadata, + AttentionType, +) +from vllm.attention.layer import Attention +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash_diffkv, +) +from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig, VllmConfig +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import direct_register_custom_op +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + subclass_attention_backend, +) +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + KVCacheSpec, + SinkFullAttentionSpec, +) + +logger = init_logger(__name__) + + +@functools.lru_cache +def create_static_sink_attention_backend( + underlying_attn_backend: type[AttentionBackend], + sink_len: int = 0, +) -> type[AttentionBackend]: + prefix = "StaticSink_" + underlying_builder = underlying_attn_backend.get_builder_cls() + + class StaticSinkAttentionBuilder(underlying_builder): # type: ignore + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + model_config = vllm_config.model_config + scheduler_config = vllm_config.scheduler_config + self.sink_len = sink_len + self.block_size = vllm_config.cache_config.block_size + self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size + self.max_num_blocks = cdiv( + model_config.max_model_len, vllm_config.cache_config.block_size + ) + self.block_table_with_sink = torch.zeros( + ( + scheduler_config.max_num_seqs, + self.max_num_blocks + self.num_sink_blocks, + ), + device=device, + dtype=torch.int32, + ) + self.block_table_with_sink[:, : self.num_sink_blocks] = torch.arange( + 1, + self.num_sink_blocks + 1, + device=device, + dtype=torch.int32, + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AttentionMetadata: + common_attn_metadata.seq_lens[:] = ( + common_attn_metadata.seq_lens + self.sink_len + ) + common_attn_metadata.seq_lens[ + common_attn_metadata.seq_lens == self.sink_len + ] = 0 + common_attn_metadata.max_seq_len = ( + common_attn_metadata.max_seq_len + self.sink_len + ) + max_num_blocks = cdiv(common_attn_metadata.max_seq_len, self.block_size) + num_reqs = common_attn_metadata.num_reqs + self.block_table_with_sink[ + :num_reqs, self.num_sink_blocks : self.num_sink_blocks + max_num_blocks + ] = common_attn_metadata.block_table_tensor[:, :max_num_blocks] + common_attn_metadata.block_table_tensor = self.block_table_with_sink[ + :num_reqs + ] + + return super().build(common_prefix_len, common_attn_metadata, fast_build) + + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=StaticSinkAttentionBuilder, + ) + + return attn_backend + + +@CustomOp.register("static_sink_attention") +class StaticSinkAttention(Attention, CustomOp): + """ + Attention with static sink tokens + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + sink_len: int, + attn_backend: type[AttentionBackend] | None = None, + cache_config: CacheConfig | None = None, + **kwargs, + ): + dtype = torch.get_default_dtype() + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + if attn_backend is not None: + underlying_attn_backend = attn_backend + else: + underlying_attn_backend = get_attn_backend( + head_size, dtype, kv_cache_dtype, block_size + ) + attn_backend = create_static_sink_attention_backend( + underlying_attn_backend, + sink_len=sink_len, + ) + Attention.__init__( + self=self, + num_heads=num_heads, + head_size=head_size, + scale=scale, + cache_config=cache_config, + attn_backend=attn_backend, + **kwargs, + ) + CustomOp.__init__(self) + + self.sink_len = sink_len + self.block_size = block_size + self.sink_populated = False + self.sink_key = None + self.sink_value = None + + def update_sink_kv(self, sink_key, sink_value) -> None: + self.sink_key = sink_key + self.sink_value = sink_value + + def forward_native( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output_shape: torch.Size | None = None, + ) -> torch.Tensor: + assert self.sink_key is not None and self.sink_value is not None, ( + "sink_key and sink_value have not been prepared" + ) + if not self.sink_populated: + forward_context: ForwardContext = get_forward_context() + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name) + + return super().forward(query, key, value, output_shape) + + def forward_cuda( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output_shape: torch.Size | None = None, + ) -> torch.Tensor: + return self.forward_native(query, key, value, output_shape) + + def forward(self, *args, **kwargs): + return self._forward_method(*args, **kwargs) + + def populate_sink_kv(self, self_kv_cache): + sink_kv_slot_mapping = torch.arange( + self.block_size, + self.sink_len + self.block_size, + device=torch.cuda.current_device(), + dtype=torch.long, + ) + triton_reshape_and_cache_flash_diffkv( + self.sink_key, + self.sink_value, + self_kv_cache, + sink_kv_slot_mapping, + self.kv_cache_dtype, + self._k_scale, + self._v_scale, + ) + # We only populate the sink_key and sink_value once + self.sink_populated = True + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # Block size may get updated after model loading, refresh it + block_size = vllm_config.cache_config.block_size + # Should not be called for enc-dec or encoder-only attention. + assert self.attn_type == AttentionType.DECODER + + return SinkFullAttentionSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + head_size_v=self.head_size_v, + sink_len=self.sink_len, + dtype=self.kv_cache_torch_dtype, + ) + + +def maybe_populate_sink( + self_kv_cache: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + if self.sink_populated or self_kv_cache.numel() == 0: + return + self.populate_sink_kv(self_kv_cache) + + +def maybe_populate_sink_fake( + self_kv_cache: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="maybe_populate_sink", + op_func=maybe_populate_sink, + mutates_args=["self_kv_cache"], + fake_impl=maybe_populate_sink_fake, +) diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py index 5d2ba154ae018..a383de0ac76cc 100644 --- a/vllm/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -182,3 +182,174 @@ def triton_reshape_and_cache_flash( num_warps=num_warps, num_stages=num_stages, ) + + +@triton.jit +def reshape_and_cache_kernel_flash_diffkv( + key_ptr, # [num_tokens, num_heads, head_size] + value_ptr, # [num_tokens, num_heads, head_size_v] + kv_cache_ptr, # [num_blocks, block_size, num_heads, head_size + head_size_v] + slot_mapping_ptr, # [num_tokens] + k_scale, # float32 + v_scale, # float32 + # strides + key_stride: tl.int64, + value_stride: tl.int64, + block_stride: tl.int64, + page_stride: tl.int64, + num_heads: tl.constexpr, + head_size_k: tl.constexpr, + head_size_v: tl.constexpr, + block_size: tl.constexpr, + # FP8 flags + FP8_KV_CACHE: tl.constexpr, + # tune parameters + TILE_SIZE: tl.constexpr, +): + token_idx = tl.program_id(axis=0) + slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64) + if slot_idx < 0: + # Padding token that should be ignored. + return + + tile_i = tl.program_id(axis=1) + tile_offs = tl.arange(0, TILE_SIZE) + + block_idx = slot_idx // block_size + block_offset = slot_idx % block_size + + src_key_idx = token_idx * key_stride + tile_i * head_size_k + src_value_idx = token_idx * value_stride + tile_i * head_size_v + + tgt_idx = ( + block_idx * block_stride + + block_offset * page_stride + + tile_i * (head_size_k + head_size_v) + ) + + # [TILE_SIZE] + key_load = tl.load(key_ptr + src_key_idx + tile_offs, mask=tile_offs < head_size_k) + if FP8_KV_CACHE: + # tl.store will do the correct implicit cast to fp8, + # based on the key_cache_ptr.dtype.element_ty + key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale) + else: + key_tile = key_load + + # [TILE_SIZE] + value_load = tl.load( + value_ptr + src_value_idx + tile_offs, mask=tile_offs < head_size_v + ) + if FP8_KV_CACHE: + if value_load.dtype.is_fp8(): + value_tile = value_load + else: + # tl.store will do the correct implicit cast to fp8, + # based on the value_cache_ptr.dtype.element_ty + value_tile = value_load / tl.load(v_scale) + else: + value_tile = value_load + + tl.store( + kv_cache_ptr + tgt_idx + tile_offs, + key_tile, + mask=tile_offs < head_size_k, + ) + tl.store( + kv_cache_ptr + tgt_idx + head_size_k + tile_offs, + value_tile, + mask=tile_offs < head_size_v, + ) + return + + +def triton_reshape_and_cache_flash_diffkv( + key: torch.Tensor, # [num_tokens, num_heads, head_size] + value: torch.Tensor, # [num_tokens, num_heads, head_size_v] + # [num_blocks, block_size, num_heads, head_size + head_size_v] + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, # [num_tokens] + kv_cache_dtype: str, # "auto", "fp8" + k_scale: torch.Tensor, # float32 + v_scale: torch.Tensor, # float32 +): + num_heads = key.shape[1] + head_size_k = key.shape[2] + head_size_v = value.shape[2] + block_size = kv_cache.shape[1] + + k_stride = key.stride()[0] + v_stride = value.stride()[0] + block_stride = kv_cache.stride()[0] + page_stride = kv_cache.stride()[1] + + assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), ( + f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." + ) + kv_cache_torch_dtype = ( + current_platform.fp8_dtype() + if kv_cache_dtype.startswith("fp8") + else kv_cache.dtype + ) + + if kv_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"): + # to avoid erounous implicit cast in triton kernel (tl.store to uint8) + # (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) + kv_cache = kv_cache.view(kv_cache_torch_dtype) + assert kv_cache_dtype != torch.uint8, ( + "explicit fp8 cast and store to " + "uint8 is not supported by triton reshape_and_cache_flash_diffkv" + ) + + FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") + assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.uint8, + torch.float8_e4m3fnuz, + ], ( + "unsupported dtype of KV cache tensor, got " + "{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " + "fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz." + ) + + # heuristics instead of autotuning + TILE_SIZE = max(head_size_k, head_size_v) + TILE_SIZE = triton.next_power_of_2(TILE_SIZE) + if current_platform.is_rocm() or current_platform.is_xpu(): + num_stages = 4 + num_warps = 8 + else: # cuda + num_stages = 10 + num_warps = 16 + + # TODO(ngl): maybe replace with static launch grid to avoid overhead if + # using cudagraphs + grid = lambda meta: ( + slot_mapping.shape[0], + num_heads, + ) + + reshape_and_cache_kernel_flash_diffkv[grid]( + key_ptr=key, + value_ptr=value, + kv_cache_ptr=kv_cache, + slot_mapping_ptr=slot_mapping, + k_scale=k_scale, + v_scale=v_scale, + # strides + key_stride=k_stride, + value_stride=v_stride, + block_stride=block_stride, + page_stride=page_stride, + num_heads=num_heads, + head_size_k=head_size_k, + head_size_v=head_size_v, + block_size=block_size, + # FP8 flags + FP8_KV_CACHE=FP8_KV_CACHE, + # autotune parameters + TILE_SIZE=TILE_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index 47abd7bf0b68a..662ecef3ac8f6 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -29,13 +29,14 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.attention.backends.abstract import AttentionType -from vllm.attention.layer import Attention +from vllm.attention.layer import Attention, AttentionType +from vllm.attention.layers.static_sink_attention import StaticSinkAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.distributed import ( get_ep_group, get_pp_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_gather, @@ -77,8 +78,11 @@ from vllm.model_executor.models.utils import ( maybe_prefix, sequence_parallel_chunk, ) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import set_default_rope_theta +from vllm.v1.attention.backends.flash_attn_diffkv import FlashAttentionDiffKVBackend def check_ffn_act_fn(act_fn: str): @@ -155,7 +159,15 @@ class OpenPanguMoE(nn.Module): quant_config=None, prefix=f"{prefix}.gate", ) - self.gate.e_score_correction_bias = None + if ( + hasattr(config, "router_enable_expert_bias") + and config.router_enable_expert_bias + ): + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(self.n_routed_experts, dtype=torch.float32) + ) + else: + self.gate.e_score_correction_bias = None # Load balancing settings. eplb_config = parallel_config.eplb_config @@ -530,6 +542,264 @@ class OpenPanguEmbeddedAttention(nn.Module): ) +class OpenPanguSinkAttention(nn.Module): + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_parameters: dict[str, Any] | None = None, + max_position_embeddings: int = 8192, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: CacheConfig | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.total_num_heads = num_heads + if self.total_num_heads % self.tp_size != 0: + raise ValueError( + f"total_num_heads {self.total_num_heads} " + f"is not divisible by tp_size {self.tp_size}." + ) + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = num_kv_heads + if ( + self.total_num_kv_heads > self.tp_size + and self.total_num_kv_heads % self.tp_size != 0 + ): + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel ranks. + raise ValueError( + "Number of KV heads is greater than TP size, " + f"but total_num_kv_heads {self.total_num_kv_heads} " + f"is not divisible by tp_size {self.tp_size}." + ) + elif self.total_num_kv_heads < self.tp_size: + # TODO: Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel ranks. + raise ValueError( + f"Number of KV heads {self.total_num_kv_heads} is less than " + f"TP size {self.tp_size}, KV heads replication is not support yet." + ) + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.qk_nope_dim = getattr(config, "qk_nope_dim", None) + self.qk_rope_dim = getattr(config, "qk_rope_dim", None) + self.v_channels = getattr(config, "v_channels", None) + self.head_dim = self.qk_rope_dim + self.qk_nope_dim + self.q_size = self.num_heads * self.head_dim + self.k_size = self.num_kv_heads * self.head_dim + self.v_size = self.num_kv_heads * self.v_channels + self.scaling = self.head_dim**-0.5 + self.max_position_embeddings = max_position_embeddings + + self.param_sink_number = getattr(config, "param_sink_number", 0) + self.param_sink_with_value = getattr(config, "param_sink_with_value", False) + self.param_sink_scalar = getattr(config, "param_sink_scalar", None) + self.param_sink_of_head_num = getattr(config, "param_sink_of_head_dim", False) + + self.qkv_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[ + self.q_size * self.tp_size, + self.k_size * self.tp_size, + self.v_size * self.tp_size, + ], + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.v_channels, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.k_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + self._init_rotary_emb( + config, rope_parameters=rope_parameters, quant_config=quant_config + ) + + if hasattr(config, "interleaved_sliding_window"): + interleaved_sliding_window = config.interleaved_sliding_window + if isinstance(interleaved_sliding_window, int): + sliding_window = interleaved_sliding_window + elif isinstance(interleaved_sliding_window, list): + sw_idx = layer_idx % len(interleaved_sliding_window) + sliding_window = interleaved_sliding_window[sw_idx] + else: + raise ValueError( + f"{type(interleaved_sliding_window)} " + "for interleaved_sliding_window is not supported." + ) + else: + sliding_window = None + + FlashAttentionDiffKVBackend.set_head_size_v(self.v_channels) + self.attn = StaticSinkAttention( + self.num_heads, + self.head_dim, + self.scaling, + sink_len=self.param_sink_number, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=attn_type, + prefix=f"{prefix}.attn", + attn_backend=FlashAttentionDiffKVBackend, + head_size_v=self.v_channels, + ) + + if self.param_sink_number > 0: + self.param_sink_key = torch.nn.Parameter( + torch.empty( + ( + self.param_sink_number, + self.num_kv_heads, + self.head_dim, + ), + device=current_platform.current_device(), + dtype=config.torch_dtype, + ) + ) + set_weight_attrs( + self.param_sink_key, + { + "output_dim": 1, + "weight_loader": self.weight_loader, + }, + ) + + if self.param_sink_with_value: + self.param_sink_value = torch.nn.Parameter( + torch.empty( + ( + self.param_sink_number, + self.num_kv_heads, + self.v_channels, + ), + device=current_platform.current_device(), + dtype=config.torch_dtype, + ) + ) + set_weight_attrs( + self.param_sink_value, + { + "output_dim": 1, + "weight_loader": self.weight_loader, + }, + ) + else: + self.param_sink_value = torch.zeros( + ( + self.param_sink_number, + self.num_kv_heads, + self.v_channels, + ), + device=current_platform.current_device(), + dtype=config.torch_dtype, + ) + # To enable dummy run with out weight + self.post_weight_load() + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + output_dim = getattr(param, "output_dim", None) + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + + # Special case for GGUF + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, nn.UninitializedParameter): + final_shape = list(loaded_weight.shape) + if output_dim is not None: + assert final_shape[output_dim] % self.tp_size == 0 + final_shape[output_dim] = final_shape[output_dim] // self.tp_size + param.materialize(final_shape, dtype=loaded_weight.dtype) + + param_data = param.data + if output_dim is not None and not is_sharded_weight: + shard_size = param_data.shape[output_dim] + start_idx = self.tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + k = self.k_layernorm(k.view(-1, self.num_kv_heads, self.head_dim)) + q, k = self.rotary_emb(positions, q, k) + + q = q.view(-1, self.q_size) + k = k.view(-1, self.k_size) + + attn_output = self.attn( + q, + k, + v, + output_shape=torch.Size( + [q.shape[0], q.shape[1] // self.head_dim * self.v_channels] + ), + ) + output, _ = self.o_proj(attn_output) + return output + + def _init_rotary_emb( + self, + config: PretrainedConfig, + rope_parameters: dict[str, Any] | None, + quant_config: QuantizationConfig | None, + ) -> None: + is_neox_style = False + rope_parameters = {"partial_rotary_factor": self.qk_rope_dim / self.head_dim} + + self.rotary_emb = get_rope( + self.head_dim, + max_position=self.max_position_embeddings, + rope_parameters=rope_parameters, + is_neox_style=is_neox_style, + ) + + def post_weight_load(self) -> None: + if hasattr(self, "k_layernorm") and self.k_layernorm is not None: + param_sink_key = self.k_layernorm(self.param_sink_key) + else: + param_sink_key = self.param_sink_key + + self.attn.update_sink_kv(param_sink_key, self.param_sink_value) + + class OpenPanguDecoderLayer(nn.Module): def __init__( self, @@ -557,6 +827,9 @@ class OpenPanguDecoderLayer(nn.Module): and hasattr(config, "v_head_dim") and hasattr(config, "kv_lora_rank") ) + self.use_sink_attention = ( + hasattr(config, "param_sink_number") and config.param_sink_number > 0 + ) if self.use_mla: self.self_attn = OpenPanguMLAAttention( config=config, @@ -574,6 +847,42 @@ class OpenPanguDecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.self_attn", ) + elif self.use_sink_attention: + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False + ) + bias_o_proj = attention_bias + if hasattr(config, "qkv_bias"): + attention_bias = config.qkv_bias + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + raise ValueError( + f"is_causal={config.is_causal} is not support " + "for attention with sink" + ) + rope_parameters = getattr(config, "rope_scaling", None) + if rope_parameters is None: + rope_parameters = { + "rope_type": "default", + "rope_theta": config.rope_theta, + } + self.self_attn = OpenPanguSinkAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), + rope_parameters=rope_parameters, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) else: attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False @@ -903,6 +1212,10 @@ class OpenPanguModel(nn.Module): if name.endswith(".bias") and name not in params_dict: continue name = maybe_remap_kv_scale_name(name, params_dict) + if name.endswith("e_score_correction_bias"): + name = name.replace( + "e_score_correction_bias", "gate.e_score_correction_bias" + ) if name is None: continue if is_pp_missing_parameter(name, self): @@ -912,8 +1225,17 @@ class OpenPanguModel(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) + + self.post_weight_load() return loaded_params + def post_weight_load(self) -> None: + for name, module in self.named_modules(): + if module is self: + continue + if hasattr(module, "post_weight_load"): + module.post_weight_load() + class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA): packed_modules_mapping = { @@ -1047,3 +1369,7 @@ class PanguEmbeddedForCausalLM(OpenPanguEmbeddedModel): class PanguUltraMoEForCausalLM(OpenPanguMoEModel): pass + + +class PanguProMoEV2ForCausalLM(OpenPanguMoEModel): + pass diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index fd39afe259ae3..f183b3f554f1e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -164,6 +164,7 @@ _TEXT_GENERATION_MODELS = { "OrionForCausalLM": ("orion", "OrionForCausalLM"), "OuroForCausalLM": ("ouro", "OuroForCausalLM"), "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"), + "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"), "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), diff --git a/vllm/v1/attention/backends/flash_attn_diffkv.py b/vllm/v1/attention/backends/flash_attn_diffkv.py new file mode 100644 index 0000000000000..2e36740bd9e52 --- /dev/null +++ b/vllm/v1/attention/backends/flash_attn_diffkv.py @@ -0,0 +1,269 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with FlashAttention.""" + +import torch + +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash_diffkv, +) +from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available + +if is_flash_attn_varlen_func_available(): + from vllm.attention.utils.fa_utils import flash_attn_varlen_func +from vllm.logger import init_logger +from vllm.v1.attention.backends.utils import get_kv_cache_layout + +from .flash_attn import ( + FlashAttentionBackend, + FlashAttentionImpl, + FlashAttentionMetadata, + cascade_attention, +) + +logger = init_logger(__name__) + + +class FlashAttentionDiffKVBackend(FlashAttentionBackend): + # Default to 128 for this backend + head_size_v: int = 128 + + @classmethod + def set_head_size_v(cls, head_size_v: int) -> None: + cls.head_size_v = head_size_v + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN_DIFFKV" + + @staticmethod + def get_impl_cls() -> type["FlashAttentionImpl"]: + return FlashAttentionDiffKVImpl + + # Do not modify the interface of get_kv_cache_shape, + # but consider head_size_v when returning result. + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return ( + num_blocks, + block_size, + num_kv_heads, + head_size + FlashAttentionDiffKVBackend.head_size_v, + ) + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + cache_layout = get_kv_cache_layout() + if cache_layout == "NHD" and include_num_layers_dimension: + # (num_blocks, num_layers, block_size, + # num_kv_heads, head_size + head_size_v) + return (1, 0, 2, 3, 4) + elif cache_layout == "NHD": + stride_order = (0, 1, 2, 3) + elif cache_layout == "HND" and include_num_layers_dimension: + # (num_blocks, num_kv_heads, num_layers, + # block_size, head_size + head_size_v) + return (1, 3, 0, 2, 4) + elif cache_layout == "HND": + stride_order = (0, 2, 1, 3) + else: + raise ValueError(f"Unknown cache layout format {cache_layout}.") + return stride_order + + +class FlashAttentionDiffKVImpl(FlashAttentionImpl): + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size_v] + kv_cache: shape = + [num_blocks, block_size, num_kv_heads, head_size + head_size_v] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size_v] + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for FlashAttentionImpl" + ) + + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + + attn_type = self.attn_type + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + + # Handle encoder attention differently - no KV cache needed + if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + # For encoder attention, + # we use direct Q, K, V tensors without caching + return self._forward_encoder_attention( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, + layer, + ) + + # For decoder and cross-attention, use KV cache as before + # Different head_size for K and V + key_cache = kv_cache[..., : self.head_size] + value_cache = kv_cache[..., self.head_size :] + + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + + # kv_cache update for different head_size K and V + triton_reshape_and_cache_flash_diffkv( + key, + value, + kv_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.kv_cache_dtype.startswith("fp8"): + # queries are quantized in the attention layer + dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( + self.kv_cache_dtype + ) + key_cache = key_cache.view(dtype) + value_cache = value_cache.view(dtype) + + if not attn_metadata.use_cascade: + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + scheduler_metadata = attn_metadata.scheduler_metadata + + descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) + + if self.dcp_world_size > 1: + self._forward_with_dcp( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + output[:num_actual_tokens], + attn_metadata, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + return output + else: + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + scheduler_metadata=scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + num_splits=attn_metadata.max_num_splits, + s_aux=self.sinks, + ) + return output + + # Cascade attention (rare case). + cascade_attention( + output[:num_actual_tokens], + query[:num_actual_tokens], + key_cache, + value_cache, + cu_query_lens=attn_metadata.query_start_loc, + max_query_len=attn_metadata.max_query_len, + cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, + prefix_kv_lens=attn_metadata.prefix_kv_lens, + suffix_kv_lens=attn_metadata.suffix_kv_lens, + max_kv_len=attn_metadata.max_seq_len, + softmax_scale=self.scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window, + logits_soft_cap=self.logits_soft_cap, + block_table=attn_metadata.block_table, + common_prefix_len=attn_metadata.common_prefix_len, + max_num_splits=attn_metadata.max_num_splits, + fa_version=self.vllm_flash_attn_version, + prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, + suffix_scheduler_metadata=attn_metadata.scheduler_metadata, + q_descale=layer._q_scale, + k_descale=layer._k_scale, + v_descale=layer._v_scale, + s_aux=self.sinks, + ) + return output diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 4aeb17a156bb3..14905b36754b4 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -15,6 +15,7 @@ from vllm.v1.kv_cache_interface import ( KVCacheSpec, MambaSpec, MLAAttentionSpec, + SinkFullAttentionSpec, SlidingWindowSpec, ) from vllm.v1.request import Request @@ -783,6 +784,24 @@ class CrossAttentionManager(SingleTypeKVCacheManager): raise NotImplementedError("CrossAttentionManager does not support caching") +class SinkFullAttentionManager(FullAttentionManager): + def __init__( + self, + kv_cache_spec: SinkFullAttentionSpec, + block_pool: BlockPool, + kv_cache_group_id: int, + dcp_world_size: int = 1, + pcp_world_size: int = 1, + ): + super().__init__( + kv_cache_spec, block_pool, kv_cache_group_id, dcp_world_size, pcp_world_size + ) + sink_len = kv_cache_spec.sink_len + assert sink_len is not None and sink_len > 0 and sink_len % self.block_size == 0 + num_sink_block = sink_len // self.block_size + self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block) + + spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, MLAAttentionSpec: FullAttentionManager, @@ -790,6 +809,7 @@ spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, CrossAttentionSpec: CrossAttentionManager, + SinkFullAttentionSpec: SinkFullAttentionManager, } diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 7370f0aefafb4..7a7bb9036211b 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -89,12 +89,18 @@ class FullAttentionSpec(AttentionSpec): In this case, we use FullAttentionSpec and record the sliding window size. """ + head_size_v: int | None = None + sliding_window: int | None = None """ Default to None for not using sliding window attention. """ attention_chunk_size: int | None = None + def __post_init__(self): + if self.head_size_v is None: + object.__setattr__(self, "head_size_v", self.head_size) + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size @@ -142,6 +148,7 @@ class FullAttentionSpec(AttentionSpec): block_size=specs[0].block_size, num_kv_heads=specs[0].num_kv_heads, head_size=specs[0].head_size, + head_size_v=specs[0].head_size_v, dtype=specs[0].dtype, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), @@ -160,6 +167,15 @@ class FullAttentionSpec(AttentionSpec): ) return merged_spec + @property + def page_size_bytes(self) -> int: + return ( + self.block_size + * self.num_kv_heads + * (self.head_size + self.head_size_v) + * get_dtype_size(self.dtype) + ) + @dataclass(frozen=True) class MLAAttentionSpec(FullAttentionSpec): @@ -287,6 +303,56 @@ class CrossAttentionSpec(AttentionSpec): return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes +@dataclass(frozen=True) +class SinkFullAttentionSpec(FullAttentionSpec): + sink_len: int | None = None + + @classmethod + def merge(cls, specs: list[Self]) -> Self: + """ + Merge a list of FullAttentionSpec objects into a single + FullAttentionSpec object. + """ + assert all(isinstance(spec, FullAttentionSpec) for spec in specs), ( + "All attention layers in the same KV cache group must be FullAttentionSpec." + ) + + sliding_window = set( + spec.sliding_window for spec in specs if spec.sliding_window is not None + ) + attention_chunk_size = set( + spec.attention_chunk_size + for spec in specs + if spec.attention_chunk_size is not None + ) + assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), ( + "MLAAttentionSpec should be merged in MLAAttentionSpec.merge" + ) + merged_spec = cls( + block_size=specs[0].block_size, + num_kv_heads=specs[0].num_kv_heads, + head_size=specs[0].head_size, + head_size_v=specs[0].head_size_v, + sink_len=specs[0].sink_len, + dtype=specs[0].dtype, + sliding_window=cls.merge_window_sizes(sliding_window), + attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), + ) + for spec in specs: + for f in fields(AttentionSpec): + assert getattr(spec, f.name) == getattr(merged_spec, f.name), ( + "All attention layers in the same KV cache group must have " + "the same attention spec." + ) + assert (merged_spec.sliding_window is not None) + ( + merged_spec.attention_chunk_size is not None + ) <= 1, ( + "Model with both sliding window layers and chunked local attention " + "layers is not supported." + ) + return merged_spec + + @dataclass(frozen=True) class UniformTypeKVCacheSpecs(KVCacheSpec): """