diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 9cdf644c3cc52..2ce1cd9a943bf 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -427,6 +427,7 @@ th { | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | | `OuroForCausalLM` | ouro | `ByteDance/Ouro-1.4B`, `ByteDance/Ouro-2.6B`, etc. | ✅︎ | | | `PanguEmbeddedForCausalLM` |openPangu-Embedded-7B | `FreedomIntelligence/openPangu-Embedded-7B-V1.1` | ✅︎ | ✅︎ | +| `PanguProMoEV2ForCausalLM` |openpangu-pro-moe-v2 | | ✅︎ | ✅︎ | | `PanguUltraMoEForCausalLM` |openpangu-ultra-moe-718b-model | `FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1` | ✅︎ | ✅︎ | | `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | | `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 644d0619215fb..b8b9dc9c43799 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -383,6 +383,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "PanguEmbeddedForCausalLM": _HfExamplesInfo( "FreedomIntelligence/openPangu-Embedded-7B-V1.1", trust_remote_code=True ), + "PanguProMoEV2ForCausalLM": _HfExamplesInfo( + "", + trust_remote_code=True, + is_available_online=False, + ), "PanguUltraMoEForCausalLM": _HfExamplesInfo( "FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1", trust_remote_code=True, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 37f9a4b383ce9..deabf7a0b0770 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -939,21 +939,42 @@ def unified_attention_with_output( value: torch.Tensor, output: torch.Tensor, layer_name: str, + sink_key: torch.Tensor | None = None, + sink_value: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) - self.impl.forward( - self, - query, - key, - value, - kv_cache, - attn_metadata, - output=output, - output_scale=output_scale, - output_block_scale=output_block_scale, - ) + if sink_key is None and sink_value is None: + self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + ) + else: + assert sink_key is not None and sink_value is not None, ( + "Currently, it is only supported when " + "sink_key and sink_value are both not None" + ) + self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + sink_key=sink_key, + sink_value=sink_value, + ) def unified_attention_with_output_fake( @@ -962,6 +983,8 @@ def unified_attention_with_output_fake( value: torch.Tensor, output: torch.Tensor, layer_name: str, + sink_key: torch.Tensor | None = None, + sink_value: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py index 5d2ba154ae018..d79e209e303b0 100644 --- a/vllm/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -182,3 +182,136 @@ def triton_reshape_and_cache_flash( num_warps=num_warps, num_stages=num_stages, ) + + +@triton.jit +def reshape_and_cache_kernel_flash_diffkv( + kv_ptr, # [num_tokens, num_heads, head_size + head_size_v] + kv_cache_ptr, # [num_blocks, block_size, num_heads, head_size + head_size_v] + slot_mapping_ptr, # [num_tokens] + k_scale, # float32 + v_scale, # float32 + # strides + kv_stride: tl.int64, + block_stride: tl.int64, + page_stride: tl.int64, + num_heads: tl.constexpr, + head_size_kv: tl.constexpr, + block_size: tl.constexpr, + # FP8 flags + FP8_KV_CACHE: tl.constexpr, + # tune parameters + TILE_SIZE: tl.constexpr, +): + token_idx = tl.program_id(axis=0) + slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64) + if slot_idx < 0: + # Padding token that should be ignored. + return + + tile_i = tl.program_id(axis=1) + tile_offs = tl.arange(0, TILE_SIZE) + tile_pos = tile_i * TILE_SIZE + tile_offs + + block_idx = slot_idx // block_size + block_offset = slot_idx % block_size + + src_kv_idx = token_idx * kv_stride + + tgt_idx = block_idx * block_stride + block_offset * page_stride + + # [TILE_SIZE] + kv_tile = tl.load( + kv_ptr + src_kv_idx + tile_pos, mask=tile_pos < (num_heads * head_size_kv) + ) + + tl.store( + kv_cache_ptr + tgt_idx + tile_pos, + kv_tile, + mask=tile_pos < (num_heads * head_size_kv), + ) + return + + +def triton_reshape_and_cache_flash_diffkv( + key: torch.Tensor, # [num_tokens, num_heads, head_size] + value: torch.Tensor, # [num_tokens, num_heads, head_size_v] + # [num_blocks, block_size, num_heads, head_size + head_size_v] + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, # [num_tokens] + kv_cache_dtype: str, # "auto", "fp8" + k_scale: torch.Tensor, # float32 + v_scale: torch.Tensor, # float32 +): + kv = torch.cat([key, value], dim=-1).contiguous() + num_heads = kv.shape[1] + head_size_kv = kv.shape[2] + block_size = kv_cache.shape[1] + n = num_heads * head_size_kv + + kv_stride = kv.stride()[0] + block_stride = kv_cache.stride()[0] + page_stride = kv_cache.stride()[1] + + assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), ( + f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." + ) + kv_cache_torch_dtype = ( + current_platform.fp8_dtype() + if kv_cache_dtype.startswith("fp8") + else kv_cache.dtype + ) + + if kv_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"): + # to avoid erounous implicit cast in triton kernel (tl.store to uint8) + # (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) + kv_cache = kv_cache.view(kv_cache_torch_dtype) + assert kv_cache_dtype != torch.uint8, ( + "explicit fp8 cast and store to " + "uint8 is not supported by triton reshape_and_cache_flash_diffkv" + ) + + FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") + assert not FP8_KV_CACHE, ( + "unsupported dtype of KV cache tensor, got " + "{kv_cache_torch_dtype}. Supported kv cache dtypes: bfloat16, float16, float32." + ) + + # heuristics instead of autotuning + TILE_SIZE = min(2048, triton.next_power_of_2(n)) + if current_platform.is_rocm() or current_platform.is_xpu(): + num_stages = 4 + num_warps = 8 + else: # cuda + num_stages = 10 + num_warps = 16 + if torch.cuda.get_device_capability(key.device)[0] < 9: + TILE_SIZE = min(512, TILE_SIZE) + + # TODO(ngl): maybe replace with static launch grid to avoid overhead if + # using cudagraphs + grid = lambda meta: ( + slot_mapping.shape[0], + triton.cdiv(n, meta["TILE_SIZE"]), + ) + + reshape_and_cache_kernel_flash_diffkv[grid]( + kv_ptr=kv, + kv_cache_ptr=kv_cache, + slot_mapping_ptr=slot_mapping, + k_scale=k_scale, + v_scale=v_scale, + # strides + kv_stride=kv_stride, + block_stride=block_stride, + page_stride=page_stride, + num_heads=num_heads, + head_size_kv=head_size_kv, + block_size=block_size, + # FP8 flags + FP8_KV_CACHE=FP8_KV_CACHE, + # autotune parameters + TILE_SIZE=TILE_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index d13a745beffeb..f7249475b49de 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -30,15 +30,18 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionType +from vllm.attention.backends.abstract import AttentionBackend from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.distributed import ( get_ep_group, get_pp_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_gather, ) +from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -76,7 +79,13 @@ from vllm.model_executor.models.utils import ( maybe_prefix, sequence_parallel_chunk, ) +from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.flash_sink_attn import FlashSinkAttentionBackend +from vllm.v1.kv_cache_interface import ( + FullSinkAttentionSpec, + KVCacheSpec, +) def check_ffn_act_fn(act_fn: str): @@ -86,6 +95,140 @@ def check_ffn_act_fn(act_fn: str): ) +class AttentionWithSink(Attention): + def __init__( + self, + num_heads: int, + head_size: int, + head_size_v: int, + scale: float, + num_kv_heads: int | None = None, + alibi_slopes: list[float] | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + logits_soft_cap: float | None = None, + per_layer_sliding_window: int | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: str | None = None, + attn_backend: type[AttentionBackend] | None = None, + **extra_impl_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + cache_config, + quant_config, + logits_soft_cap, + per_layer_sliding_window, + prefix, + attn_type, + kv_sharing_target_layer_name, + attn_backend, + **extra_impl_args, + ) + self.head_size_v = head_size_v + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + # For attention with sink, we have sink k, v + sink_key: torch.Tensor | None = None, + sink_value: torch.Tensor | None = None, + output_shape: torch.Size | None = None, + ) -> torch.Tensor: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + + Attention metadata (`attn_metadata`) is set using a context manager in + the model runner's `execute_model` method. It is accessed via forward + context using + `vllm.forward_context.get_forward_context().attn_metadata`. + """ + if self.calculate_kv_scales: + torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) + output_dtype = query.dtype + if self.query_quant is not None: + # quantizing with a simple torch operation enables + # torch.compile to fuse this into previous ops + # which reduces overheads during decoding. + # Otherwise queries are quantized using custom ops + # which causes decoding overheads + assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} + + # check if query quantization is supported + if self.impl.supports_quant_query_input(): + query, _ = self.query_quant(query, self._q_scale) + + if self.use_output: + output_shape = output_shape if output_shape is not None else query.shape + output = torch.empty(output_shape, dtype=output_dtype, device=query.device) + hidden_size = output_shape[-1] + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size_v) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size_v) + if self.use_direct_call: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward( + self, + query, + key, + value, + self_kv_cache, + attn_metadata, + output=output, + sink_key=sink_key, + sink_value=sink_value, + ) + else: + torch.ops.vllm.unified_attention_with_output( + query, + key, + value, + output, + self.layer_name, + sink_key=sink_key, + sink_value=sink_value, + ) + return output.view(-1, hidden_size) + else: + raise ValueError( + "Unsupport Error, currently only flash_sink_attn " + "backend with output buffer is supported" + ) + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # Block size may get updated after model loading, refresh it + block_size = vllm_config.cache_config.block_size + # Should not be called for enc-dec or encoder-only attention. + assert self.attn_type == AttentionType.DECODER + # Only support for full attention now. + assert self.sliding_window is None + return FullSinkAttentionSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + head_size_v=self.head_size_v, + dtype=self.kv_cache_torch_dtype, + ) + + class OpenPanguMLP(nn.Module): def __init__( self, @@ -153,7 +296,15 @@ class OpenPanguMoE(nn.Module): quant_config=None, prefix=f"{prefix}.gate", ) - self.gate.e_score_correction_bias = None + if ( + hasattr(config, "router_enable_expert_bias") + and config.router_enable_expert_bias + ): + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(self.n_routed_experts, dtype=torch.float32) + ) + else: + self.gate.e_score_correction_bias = None # Load balancing settings. eplb_config = parallel_config.eplb_config @@ -539,6 +690,276 @@ class OpenPanguEmbeddedAttention(nn.Module): ) +class OpenPanguSinkAttention(nn.Module): + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: dict[str, Any] | None = None, + max_position_embeddings: int = 8192, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: CacheConfig | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.total_num_heads = num_heads + if self.total_num_heads % self.tp_size != 0: + raise ValueError( + f"total_num_heads {self.total_num_heads} " + f"is not divisible by tp_size {self.tp_size}." + ) + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = num_kv_heads + if ( + self.total_num_kv_heads > self.tp_size + and self.total_num_kv_heads % self.tp_size != 0 + ): + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel ranks. + raise ValueError( + "Number of KV heads is greater than TP size, " + f"but total_num_kv_heads {self.total_num_kv_heads} " + f"is not divisible by tp_size {self.tp_size}." + ) + elif self.total_num_kv_heads < self.tp_size: + # TODO: Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel ranks. + raise ValueError( + f"Number of KV heads {self.total_num_kv_heads} is less than " + f"TP size {self.tp_size}, KV heads replication is not support yet." + ) + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.qk_nope_dim = getattr(config, "qk_nope_dim", None) + self.qk_rope_dim = getattr(config, "qk_rope_dim", None) + self.v_channels = getattr(config, "v_channels", None) + self.head_dim = self.qk_rope_dim + self.qk_nope_dim + self.q_size = self.num_heads * self.head_dim + self.k_size = self.num_kv_heads * self.head_dim + self.v_size = self.num_kv_heads * self.v_channels + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.param_sink_number = getattr(config, "param_sink_number", 0) + self.param_sink_with_value = getattr(config, "param_sink_with_value", False) + self.param_sink_scalar = getattr(config, "param_sink_scalar", None) + self.param_sink_of_head_num = getattr(config, "param_sink_of_head_dim", False) + + self.qkv_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[ + self.q_size * self.tp_size, + self.k_size * self.tp_size, + self.v_size * self.tp_size, + ], + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.v_channels, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.k_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + self._init_rotary_emb( + config, rope_scaling=rope_scaling, quant_config=quant_config + ) + + if hasattr(config, "interleaved_sliding_window"): + interleaved_sliding_window = config.interleaved_sliding_window + if isinstance(interleaved_sliding_window, int): + sliding_window = interleaved_sliding_window + elif isinstance(interleaved_sliding_window, list): + sw_idx = layer_idx % len(interleaved_sliding_window) + sliding_window = interleaved_sliding_window[sw_idx] + else: + raise ValueError( + f"{type(interleaved_sliding_window)} " + "for interleaved_sliding_window is not supported." + ) + else: + sliding_window = None + + FlashSinkAttentionBackend.set_cache_head_size_ratio( + (self.head_dim + self.v_channels) / self.head_dim + ) + self.attn = AttentionWithSink( + self.num_heads, + self.head_dim, + self.v_channels, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=sliding_window, + attn_type=attn_type, + prefix=f"{prefix}.attn", + attn_backend=FlashSinkAttentionBackend, + ) + + if self.param_sink_number > 0: + self.param_sink_key = torch.nn.Parameter( + torch.empty( + ( + self.param_sink_number, + self.num_kv_heads, + self.head_dim, + ), + device=torch.cuda.current_device(), + dtype=config.torch_dtype, + ) + ) + set_weight_attrs( + self.param_sink_key, + { + "output_dim": 1, + "weight_loader": self.weight_loader, + }, + ) + + if self.param_sink_with_value: + self.param_sink_value = torch.nn.Parameter( + torch.empty( + ( + self.param_sink_number, + self.num_kv_heads, + self.v_channels, + ), + device=torch.cuda.current_device(), + dtype=config.torch_dtype, + ) + ) + set_weight_attrs( + self.param_sink_value, + { + "output_dim": 1, + "weight_loader": self.weight_loader, + }, + ) + else: + self.param_sink_value = torch.zeros( + torch.empty( + ( + self.param_sink_number, + self.num_kv_heads, + self.v_channels, + ), + device=torch.cuda.current_device(), + dtype=config.torch_dtype, + ) + ) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + output_dim = getattr(param, "output_dim", None) + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + + # Special case for GGUF + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, nn.UninitializedParameter): + final_shape = list(loaded_weight.shape) + if output_dim is not None: + assert final_shape[output_dim] % self.tp_size == 0 + final_shape[output_dim] = final_shape[output_dim] // self.tp_size + param.materialize(final_shape, dtype=loaded_weight.dtype) + + param_data = param.data + if output_dim is not None and not is_sharded_weight: + shard_size = param_data.shape[output_dim] + start_idx = self.tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + k = self.k_layernorm(k.view(-1, self.num_kv_heads, self.head_dim)) + q, k = self.rotary_emb(positions, q, k) + + q = q.view(-1, self.q_size) + k = k.view(-1, self.k_size) + param_sink_key = self.param_sink_key + if ( + self.param_sink_number > 0 + and hasattr(self, "k_layernorm") + and self.k_layernorm is not None + ): + param_sink_key = self.k_layernorm(param_sink_key) + + attn_output = self.attn( + q, + k, + v, + output_shape=torch.Size( + [q.shape[0], q.shape[1] // self.head_dim * self.v_channels] + ), + **( + dict( + sink_key=param_sink_key, + sink_value=self.param_sink_value, + ) + if self.param_sink_number > 0 + else {} + ), + ) + attn_output = attn_output.reshape(-1, self.num_heads * self.v_channels) + output, _ = self.o_proj(attn_output) + return output + + def _init_rotary_emb( + self, + config: PretrainedConfig, + rope_scaling: dict[str, Any] | None, + quant_config: QuantizationConfig | None, + ) -> None: + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.qk_rope_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + + class OpenPanguDecoderLayer(nn.Module): def __init__( self, @@ -567,6 +988,9 @@ class OpenPanguDecoderLayer(nn.Module): and hasattr(config, "v_head_dim") and hasattr(config, "kv_lora_rank") ) + self.use_sink_attention = ( + hasattr(config, "param_sink_number") and config.param_sink_number > 0 + ) if self.use_mla: self.self_attn = OpenPanguMLAAttention( config=config, @@ -585,6 +1009,37 @@ class OpenPanguDecoderLayer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.self_attn", ) + elif self.use_sink_attention: + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False + ) + bias_o_proj = attention_bias + if hasattr(config, "qkv_bias"): + attention_bias = config.qkv_bias + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + raise ValueError( + f"is_causal={config.is_causal} is not support " + "for attention with sink" + ) + self.self_attn = OpenPanguSinkAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), + rope_theta=rope_theta, + rope_scaling=getattr(config, "rope_scaling", None), + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) else: attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False @@ -916,6 +1371,10 @@ class OpenPanguModel(nn.Module): if name.endswith(".bias") and name not in params_dict: continue name = maybe_remap_kv_scale_name(name, params_dict) + if name.endswith("e_score_correction_bias"): + name = name.replace( + "e_score_correction_bias", "gate.e_score_correction_bias" + ) if name is None: continue if is_pp_missing_parameter(name, self): @@ -1060,3 +1519,7 @@ class PanguEmbeddedForCausalLM(OpenPanguEmbeddedModel): class PanguUltraMoEForCausalLM(OpenPanguMoEModel): pass + + +class PanguProMoEV2ForCausalLM(OpenPanguMoEModel): + pass diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 4af8fa01f562b..c50be12883897 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -150,6 +150,7 @@ _TEXT_GENERATION_MODELS = { "OrionForCausalLM": ("orion", "OrionForCausalLM"), "OuroForCausalLM": ("ouro", "OuroForCausalLM"), "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"), + "PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"), "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), diff --git a/vllm/v1/attention/backends/flash_sink_attn.py b/vllm/v1/attention/backends/flash_sink_attn.py new file mode 100644 index 0000000000000..580f0f544102c --- /dev/null +++ b/vllm/v1/attention/backends/flash_sink_attn.py @@ -0,0 +1,1005 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with FlashAttention.""" + +from typing import ClassVar + +import numpy as np +import torch + +from vllm import envs +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionType, + MultipleOf, + is_quantized_kv_cache, +) +from vllm.attention.layer import Attention +from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.ops.triton_reshape_and_cache_flash import ( + triton_reshape_and_cache_flash_diffkv, +) +from vllm.attention.utils.fa_utils import ( + flash_attn_supports_fp8, + get_flash_attn_version, + is_flash_attn_varlen_func_available, +) + +if is_flash_attn_varlen_func_available(): + from vllm.attention.utils.fa_utils import ( + flash_attn_supports_sinks, + flash_attn_varlen_func, + get_scheduler_metadata, + ) +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config.cache import CacheDType +from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.platforms.interface import DeviceCapability +from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_dcp_local_seq_lens, + get_kv_cache_layout, +) +from vllm.v1.kv_cache_interface import AttentionSpec + +from .flash_attn import FlashAttentionMetadata + +logger = init_logger(__name__) + + +class FlashSinkAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + # NOTE(tdoublep): while in principle, FA supports + # MultipleOf(16), these are the block sizes that do not + # suffer from the NaN propagation problem described here: + # https://github.com/Dao-AILab/flash-attention/issues/1974 + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64] + # TODO: Remove hard code + cache_head_size_ratio: float = 2.0 + + @staticmethod + def get_name() -> str: + return "FLASH_SINK_ATTN" + + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + """FlashSinkAttention supports all attention types.""" + from vllm.attention import AttentionType + + return attn_type in ( + AttentionType.DECODER, + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + AttentionType.ENCODER_DECODER, + ) + + @staticmethod + def get_impl_cls() -> type["FlashSinkAttentionImpl"]: + return FlashSinkAttentionImpl + + @staticmethod + def get_builder_cls() -> type["FlashSinkAttentionMetadataBuilder"]: + return FlashSinkAttentionMetadataBuilder + + @classmethod + def set_cache_head_size_ratio(cls, ratio: float) -> None: + cls.cache_head_size_ratio = ratio + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return ( + num_blocks, + block_size, + num_kv_heads, + int(head_size * FlashSinkAttentionBackend.cache_head_size_ratio), + ) + + @staticmethod + def get_kv_cache_stride_order() -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + cache_layout = get_kv_cache_layout() + if cache_layout == "NHD": + stride_order = (0, 1, 2, 3) + elif cache_layout == "HND": + stride_order = (0, 2, 1, 3) + else: + raise ValueError(f"Unknown cache layout format {cache_layout}.") + return stride_order + + @staticmethod + def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: + if kv_cache_dtype is None: + return True + if kv_cache_dtype.startswith("fp8"): + return flash_attn_supports_fp8() + return kv_cache_dtype in ["auto"] + + @classmethod + def supports_sink(cls) -> bool: + if not is_flash_attn_varlen_func_available(): + return False + return flash_attn_supports_sinks() + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + return capability >= DeviceCapability(8, 0) + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + if has_sink and device_capability < DeviceCapability(9, 0): + return "sink not supported on compute capability < 9.0" + return None + + +def _get_sliding_window_configs( + vllm_config: VllmConfig, +) -> set[tuple[int, int] | None]: + """Get the set of all sliding window configs used in the model.""" + sliding_window_configs: set[tuple[int, int] | None] = set() + layers = get_layers_from_vllm_config(vllm_config, Attention) + for layer in layers.values(): + assert isinstance(layer.impl, FlashSinkAttentionImpl) + sliding_window_configs.add(layer.impl.sliding_window) + return sliding_window_configs + + +class FlashSinkAttentionMetadataBuilder( + AttentionMetadataBuilder[FlashAttentionMetadata] +): + # FA3: + # Supports full cudagraphs for all cases. + # + # FA2: + # For FA2, a graph is captured with max_query_len=1, (which is what we + # capture by default for num_tokens <= max_num_seqs when there is no + # spec-decode) then these graphs will not work for mixed prefill-decode + # (unlike FA3). This is due to special max_query_len=1 packed-GQA handling + # in FA2. + # In summary if we are running with spec decodes the graphs would + # work for mixed prefill-decode and uniform-decode. But for non-spec decodes + # the graphs would not work for mixed prefill-decode; sorta the inverse + # of UNIFORM_SINGLE_TOKEN_DECODE. + # There's probably a better way to describe this using `AttentionCGSupport` + # but for now just set it to `UNIFORM_BATCH` to get use to drop down + # to FULL_AND_PIECEWISE. + # TODO(luka, lucas): audit FA2 as part of: + # https://github.com/vllm-project/vllm/issues/22945 + _cudagraph_support = ( + AttentionCGSupport.ALWAYS + if get_flash_attn_version() == 3 + else AttentionCGSupport.UNIFORM_BATCH + ) + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config + + self.num_heads_q = self.model_config.get_num_attention_heads( + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) + self.kv_cache_dtype = kv_cache_spec.dtype + self.headdim = self.model_config.get_head_size() + self.block_size = kv_cache_spec.block_size + + self.max_num_splits = 0 # No upper bound on the number of splits. + self.aot_schedule = get_flash_attn_version() == 3 + + try: + from vllm.distributed.parallel_state import get_dcp_group + + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + + self.dcp_kv_cache_interleave_size = ( + self.parallel_config.dcp_kv_cache_interleave_size + ) + + self.use_full_cuda_graph = ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) + self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size + + if self.use_full_cuda_graph and self.aot_schedule: + self.scheduler_metadata = torch.zeros( + vllm_config.scheduler_config.max_num_seqs + 1, + dtype=torch.int32, + device=self.device, + ) + # When using cuda graph, we need to set the upper bound of the + # number of splits so that large enough intermediate buffers are + # pre-allocated during capture. + self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + + # Sliding window size to be used with the AOT scheduler will be + # populated on first build() call. + self.aot_sliding_window: tuple[int, int] | None = None + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashAttentionMetadata: + """ + fast_build disables AOT scheduling, used when there will be few + iterations i.e. spec-decode + """ + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + causal = common_attn_metadata.causal + + # the overhead of the aot schedule is not worth it for spec-decode + aot_schedule = self.aot_schedule and not fast_build + + if self.aot_sliding_window is None: + self.aot_sliding_window = (-1, -1) + # For the AOT scheduler we need the sliding window value to be + # constant for all layers to. We have to populate this on the first + # build() call so the layers are constructed (cannot populate) + # in __init__. + if aot_schedule: + sliding_window_configs = _get_sliding_window_configs(self.vllm_config) + if len(sliding_window_configs) == 1: + sliding_window_config = sliding_window_configs.pop() + if sliding_window_config is not None: + self.aot_sliding_window = sliding_window_config + elif len(sliding_window_configs) > 1: + self.aot_schedule = False + aot_schedule = False + + max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible + if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + + if vllm_is_batch_invariant(): + max_num_splits = 1 + + def schedule( + batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + ): + cache_dtype = self.cache_config.cache_dtype + if cache_dtype.startswith("fp8"): + qkv_dtype = FlashSinkAttentionBackend.get_fp8_dtype_for_flashattn( + cache_dtype + ) + else: + qkv_dtype = self.kv_cache_dtype + if aot_schedule: + return get_scheduler_metadata( + batch_size=batch_size, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + num_heads_q=self.num_heads_q * self.dcp_world_size, + num_heads_kv=self.num_heads_kv, + headdim=self.headdim, + cache_seqlens=seqlens, + qkv_dtype=qkv_dtype, + cu_seqlens_q=cu_query_lens, + page_size=self.block_size, + causal=causal, + window_size=self.aot_sliding_window, + num_splits=max_num_splits, + ) + return None + + use_cascade = common_prefix_len > 0 + max_dcp_context_kv_len = 0 + dcp_context_kv_lens = None + + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + if self.dcp_world_size > 1: + query_kv_lens_cpu = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) + dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu + + dcp_context_kv_lens_cpu = get_dcp_local_seq_lens( + dcp_context_kv_lens_cpu, + self.dcp_world_size, + self.dcp_rank, + self.dcp_kv_cache_interleave_size, + ) + dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) + max_dcp_context_kv_len = dcp_context_kv_lens.max().item() + + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=dcp_context_kv_lens, + max_seq_len=max_dcp_context_kv_len, + causal=False, + ) + elif use_cascade: + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) + suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to( + self.device, non_blocking=True + ) + prefix_scheduler_metadata = schedule( + batch_size=1, + cu_query_lens=cu_prefix_query_lens, + max_query_len=num_actual_tokens, + seqlens=prefix_kv_lens, + max_seq_len=common_prefix_len, + causal=False, + ) + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=suffix_kv_lens, + max_seq_len=max_seq_len - common_prefix_len, + causal=True, + ) + else: + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=seq_lens, + max_seq_len=max_seq_len, + causal=causal, + ) + # For FA3 + full cudagraph + if self.use_full_cuda_graph and scheduler_metadata is not None: + n = scheduler_metadata.shape[0] + self.scheduler_metadata[:n] = scheduler_metadata + # NOTE(woosuk): We should zero out the rest of the scheduler + # metadata to guarantee the correctness. Otherwise, some thread + # blocks may use the invalid scheduler metadata and overwrite the + # output buffer. + self.scheduler_metadata[n:] = 0 + scheduler_metadata = self.scheduler_metadata[:n] + + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + max_dcp_context_kv_len=max_dcp_context_kv_len, + dcp_context_kv_lens=dcp_context_kv_lens, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + scheduler_metadata=scheduler_metadata, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + prefix_scheduler_metadata=prefix_scheduler_metadata, + max_num_splits=max_num_splits, + causal=causal, + ) + return attn_metadata + + def use_cascade_attention(self, *args, **kwargs) -> bool: + return use_cascade_attention(*args, **kwargs) + + +class FlashSinkAttentionImpl(AttentionImpl): + can_return_lse_for_decode: bool = True + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: str | None = None, + sinks: torch.Tensor | None = None, + head_size_v: int | None = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.head_size_v = head_size_v + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + elif attn_type == AttentionType.ENCODER_ONLY: + self.sliding_window = (sliding_window - 1, sliding_window - 1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.attn_type = attn_type + self.vllm_flash_attn_version = get_flash_attn_version() + # Cache the batch invariant result for use in forward passes + self.batch_invariant_enabled = vllm_is_batch_invariant() + + if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): + raise NotImplementedError( + "FlashAttention does not support fp8 kv-cache on this device." + ) + + self.sinks = sinks + if self.sinks is not None: + assert flash_attn_supports_sinks(), ( + "Sinks are only supported in FlashAttention 3" + ) + assert self.sinks.shape[0] == num_heads, ( + "Sinks must have the same number of heads as the number of " + "heads in the layer" + ) + + def supports_quant_query_input(self) -> bool: + return False + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + sink_key: torch.Tensor | None = None, + sink_value: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass with FlashSinkAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size_v] + kv_cache: shape = + [num_blocks, block_size, num_kv_heads, head_size + head_size_v] + attn_metadata: Metadata for attention. + sink_key: shape = [sink_len, num_kv_heads, head_size] + sink_value: shape = [sink_len, num_kv_heads, head_size_v] + Returns: + shape = [num_tokens, num_heads * head_size_v] + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + assert output is not None, "Output tensor must be provided." + assert sink_key is not None and sink_value is not None, ( + "sink_key and sink_value must be provided" + ) + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not supported yet " + "for FlashSinkAttentionImpl" + ) + + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + + attn_type = self.attn_type + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + sink_len = sink_key.shape[0] + + # Handle encoder attention differently - no KV cache needed + if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): + # For encoder attention, + # we use direct Q, K, V tensors without caching + return self._forward_encoder_attention( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, + layer, + sink_key, + sink_value, + ) + + # For decoder and cross-attention, use KV cache as before + + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + + # store sink_key and sink_value in head blocks + key_cache = kv_cache[..., : self.head_size] + value_cache = kv_cache[..., self.head_size :] + block_size = key_cache.shape[1] + assert sink_len % block_size == 0 + num_sink_blocks = sink_len // block_size + sink_kv_slot_mapping = torch.arange( + sink_len, + device=attn_metadata.slot_mapping.device, + dtype=attn_metadata.slot_mapping.dtype, + ) + triton_reshape_and_cache_flash_diffkv( + sink_key, + sink_value, + kv_cache, + sink_kv_slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + triton_reshape_and_cache_flash_diffkv( + key, + value, + kv_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.kv_cache_dtype.startswith("fp8"): + # queries are quantized in the attention layer + dtype = FlashSinkAttentionBackend.get_fp8_dtype_for_flashattn( + self.kv_cache_dtype + ) + key_cache = key_cache.view(dtype) + value_cache = value_cache.view(dtype) + + if not attn_metadata.use_cascade: + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + sink_len + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + sink_len + block_table = attn_metadata.block_table + scheduler_metadata = attn_metadata.scheduler_metadata + sink_block_table = torch.arange( + num_sink_blocks, device=block_table.device, dtype=block_table.dtype + ) + sink_block_table = sink_block_table[None, :].expand( + block_table.shape[0], -1 + ) + block_table = torch.cat((sink_block_table, block_table), dim=1) + + descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) + + if self.dcp_world_size > 1: + raise ValueError( + "Decode context parallel is not supported yet " + f"for dcp_world_size = {self.dcp_world_size}" + ) + else: + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + scheduler_metadata=scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + num_splits=attn_metadata.max_num_splits, + s_aux=self.sinks, + ) + return output + + # Cascade attention (rare case). + cascade_attention( + output[:num_actual_tokens], + query[:num_actual_tokens], + key_cache, + value_cache, + cu_query_lens=attn_metadata.query_start_loc, + max_query_len=attn_metadata.max_query_len, + cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, + prefix_kv_lens=attn_metadata.prefix_kv_lens, + suffix_kv_lens=attn_metadata.suffix_kv_lens, + max_kv_len=attn_metadata.max_seq_len, + softmax_scale=self.scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window, + logits_soft_cap=self.logits_soft_cap, + block_table=attn_metadata.block_table, + common_prefix_len=attn_metadata.common_prefix_len, + max_num_splits=attn_metadata.max_num_splits, + fa_version=self.vllm_flash_attn_version, + prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, + suffix_scheduler_metadata=attn_metadata.scheduler_metadata, + q_descale=layer._q_scale, + k_descale=layer._k_scale, + v_descale=layer._v_scale, + s_aux=self.sinks, + sink_len=sink_len, + ) + return output + + def _forward_encoder_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + layer: torch.nn.Module, + sink_key: torch.Tensor, + sink_value: torch.Tensor, + ) -> torch.Tensor: + """Forward pass for encoder attention without KV cache. + + Args: + query: shape = [num_encoder_tokens, num_heads, head_size] + key: shape = [num_encoder_tokens, num_kv_heads, head_size] + value: shape = [num_encoder_tokens, num_kv_heads, head_size_v] + output: shape = [num_encoder_tokens, num_heads, head_size] + attn_metadata: Encoder attention metadata + layer: The attention layer + sink_key: shape = [sink_len, num_kv_heads, head_size] + sink_value: shape = [sink_len, num_kv_heads, head_size_v] + """ + # For encoder attention, process FP8 quantization if needed + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError( + "quantization is not supported for encoder attention" + ) + + # Use encoder-specific metadata for sequence information + sink_len = sink_key.shape[0] + key_list = [] + value_list = [] + for seq_id in range(attn_metadata.block_table.shape[0]): + seq_start = attn_metadata.query_start_loc[seq_id] + seq_end = attn_metadata.query_start_loc[seq_id + 1] + key_list.append( + torch.cat( + [ + sink_key, + key[seq_start:seq_end], + ], + dim=0, + ) + ) + value_list.append( + torch.cat( + [ + sink_value, + value[seq_start:seq_end], + ], + dim=0, + ) + ) + key = torch.cat(key_list, dim=0).contiguous() + value = torch.cat(value_list, dim=0).contiguous() + + cu_seqlens_q = attn_metadata.query_start_loc + cu_seqlens_k = attn_metadata.seq_lens + sink_len + cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(cu_seqlens_k, dim=-1), [1, 0], value=0 + ).to(torch.int32) + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + sink_len + + descale_shape = ( + cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] + self.num_kv_heads, + ) + + # Call flash attention directly on Q, K, V tensors + flash_attn_varlen_func( + q=query, + k=key, + v=value, + out=output, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=False, # Encoder attention is bidirectional + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + num_splits=1 if self.batch_invariant_enabled else 0, + ) + + return output + + +def use_cascade_attention( + common_prefix_len: int, + query_lens: np.ndarray, + num_query_heads: int, + num_kv_heads: int, + use_alibi: bool, + use_sliding_window: bool, + use_local_attention: bool, + num_sms: int, + dcp_world_size: int, +) -> bool: + """Decide whether to use cascade attention. + + This function 1) checks whether cascade attention is supported with the + given configuration, and 2) heuristically decides whether using cascade + attention can improve performance. + """ + # Too short common prefix. Probably not worth using cascade attention. + # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold. + # NOTE(woosuk): This is the common case. We should return False as soon as + # possible to avoid any unnecessary computation. + if common_prefix_len < 256: + return False + # Cascade attention is currently not supported with these variants. + if use_alibi or use_sliding_window or use_local_attention: + return False + # Too few queries. Probably not worth using cascade attention. + # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold. + num_reqs = len(query_lens) + if num_reqs < 8: + return False + # disable cascade attention for DCP + if dcp_world_size > 1: + return False + + # Heuristics to decide whether using cascade attention is beneficial. + # 1. When FlashDecoding is not used for normal attention, cascade attention + # is likely to be faster since it saves memory bandwidth. + num_queries_per_kv = num_query_heads // num_kv_heads + # The criteria for using FlashDecoding can be found in the following link: + # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 + use_flash_decoding = ( + num_queries_per_kv > 1 + and not use_sliding_window + and not use_alibi + and np.all(query_lens == 1) + ) + if not use_flash_decoding: + # Use cascade attention. + return True + + # 2. When FlashDecoding is used for normal attention, it is not clear + # whether cascade attention is beneficial, because FlashDecoding can + # launch more CTAs than cascade attention. + # We use a simple performance model to compare the two methods. + # NOTE(woosuk): The performance model is very rough and may not be + # accurate. + num_tokens = num_reqs + # NOTE(woosuk): These are default tile sizes. flash-attn might use + # different tile sizes (e.g., 64 or 256) depending on the configuration. + q_tile_size = 128 + kv_tile_size = 128 + num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size) + + cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size) + cascade_waves = cdiv(cascade_ctas, num_sms) + cascade_time = cascade_waves * num_prefix_tiles + + flash_decoding_ctas = ( + num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size) + ) + flash_decoding_ctas *= num_prefix_tiles + flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) + + # Use cascade attention if it is faster than FlashDecoding. + return cascade_time < flash_decoding_time + + +def cascade_attention( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + cu_query_lens: torch.Tensor, + max_query_len: int, + cu_prefix_query_lens: torch.Tensor, + prefix_kv_lens: torch.Tensor, + suffix_kv_lens: torch.Tensor, + max_kv_len: int, + softmax_scale: float, + alibi_slopes: torch.Tensor | None, + sliding_window: tuple[int, int], + logits_soft_cap: float, + block_table: torch.Tensor, + common_prefix_len: int, + max_num_splits: int, + fa_version: int, + prefix_scheduler_metadata: torch.Tensor | None = None, + suffix_scheduler_metadata: torch.Tensor | None = None, + q_descale: torch.Tensor | None = None, + k_descale: torch.Tensor | None = None, + v_descale: torch.Tensor | None = None, + s_aux: torch.Tensor | None = None, + sink_len: int | None = None, +) -> torch.Tensor: + assert alibi_slopes is None, "Cascade attention does not support ALiBi." + # TODO: Support sliding window. + assert sliding_window == (-1, -1), ( + "Cascade attention does not support sliding window." + ) + assert sink_len is not None, "sink_len must be provided." + + num_tokens = query.shape[0] + block_size = key_cache.shape[-3] + assert common_prefix_len % block_size == 0 + num_common_kv_blocks = common_prefix_len // block_size + assert num_common_kv_blocks > 0 + descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2]) + + num_sink_blocks = sink_len // block_size + block_table = block_table + num_sink_blocks + block_table[block_table == num_sink_blocks] = 0 + sink_block_table = ( + torch.arange( + num_sink_blocks, device=block_table.device, dtype=block_table.dtype + ) + + 1 + ) + sink_block_table = sink_block_table[None, :].expand(block_table.shape[0], -1) + block_table = torch.cat((sink_block_table, block_table), dim=1) + + # Process shared prefix. + prefix_output, prefix_lse = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_prefix_query_lens, + seqused_k=prefix_kv_lens + sink_len, + max_seqlen_q=num_tokens, + max_seqlen_k=common_prefix_len + sink_len, + softmax_scale=softmax_scale, + causal=False, + window_size=sliding_window, + block_table=block_table[:1], + softcap=logits_soft_cap, + return_softmax_lse=True, + scheduler_metadata=prefix_scheduler_metadata, + fa_version=fa_version, + q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, + k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, + v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, + # s_aux is incorporated into prefix_lse inside the GPU kernel, + # enabling its effect during the final attention merge. + s_aux=s_aux, + num_splits=1 if vllm_is_batch_invariant() else max_num_splits, + ) + + descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) + + # Process suffix per query. + suffix_output, suffix_lse = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + seqused_k=suffix_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len - common_prefix_len, + softmax_scale=softmax_scale, + causal=True, + window_size=sliding_window, + block_table=block_table[:, num_sink_blocks + num_common_kv_blocks :], + softcap=logits_soft_cap, + return_softmax_lse=True, + scheduler_metadata=suffix_scheduler_metadata, + fa_version=fa_version, + q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, + k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, + v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, + num_splits=1 if vllm_is_batch_invariant() else max_num_splits, + ) + + # Merge prefix and suffix outputs, and store the result in output. + merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c640c40a455d0..041393d097b9a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -184,6 +184,11 @@ class Scheduler(SchedulerInterface): enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, ) + sink_len = getattr(vllm_config.model_config.hf_config, "param_sink_number", 0) + if sink_len > 0: + assert sink_len % self.block_size == 0 + num_sink_block = sink_len // self.block_size + self.kv_cache_manager.block_pool.free_block_queue.popleft_n(num_sink_block) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 def schedule(self) -> SchedulerOutput: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 14ac83028ee44..5aba227fa0176 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -12,6 +12,7 @@ from vllm.v1.kv_cache_interface import ( ChunkedLocalAttentionSpec, CrossAttentionSpec, FullAttentionSpec, + FullSinkAttentionSpec, KVCacheSpec, MambaSpec, MLAAttentionSpec, @@ -305,7 +306,8 @@ class FullAttentionManager(SingleTypeKVCacheManager): dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( - kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) + kv_cache_spec, + (FullAttentionSpec, FullSinkAttentionSpec, ChunkedLocalAttentionSpec), ), ( "FullAttentionManager can only be used for full attention " "and chunked local attention groups" @@ -720,6 +722,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager): spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, + FullSinkAttentionSpec: FullAttentionManager, MLAAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 7f33eb7e699c7..5ff99acac52a2 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -157,6 +157,98 @@ class FullAttentionSpec(AttentionSpec): return merged_spec +@dataclass(frozen=True) +class FullSinkAttentionSpec(AttentionSpec): + head_size_v: int + sliding_window: int | None = None + attention_chunk_size: int | None = None + + """ + When hybrid allocator is disabled and the model contains both full + attention layers and sliding window attention layers, sliding + window attention are regarded as full attention in KV cache manager + (blocks are allocated for all tokens), while computed as sliding window + attention in model runner. + In this case, we use FullSinkAttentionSpec and record the sliding window size. + Default to None for not using sliding window attention. + """ + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_model_len = vllm_config.model_config.max_model_len + dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size + # Note(hc): each dcp rank only need save + # (max_model_len//dcp_world_size) tokens locally. + if dcp_world_size > 1: + max_model_len = cdiv(max_model_len, dcp_world_size) + return cdiv(max_model_len, self.block_size) * self.page_size_bytes + + @classmethod + def merge_window_sizes(cls, window_sizes: set[int]) -> int | None: + if len(window_sizes) == 0: + return None + elif len(window_sizes) == 1: + return window_sizes.pop() + else: + raise ValueError( + "All attention layers in the same KV cache group must have the " + "same window size." + ) + + @classmethod + def merge(cls, specs: list[Self]) -> Self: + """ + Merge a list of FullSinkAttentionSpec objects into a single + FullSinkAttentionSpec object. + """ + assert all(isinstance(spec, FullSinkAttentionSpec) for spec in specs), ( + "All attention layers in the same KV cache group must be " + "FullSinkAttentionSpec." + ) + + sliding_window = set( + spec.sliding_window for spec in specs if spec.sliding_window is not None + ) + attention_chunk_size = set( + spec.attention_chunk_size + for spec in specs + if spec.attention_chunk_size is not None + ) + assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), ( + "MLAAttentionSpec should be merged in MLAAttentionSpec.merge" + ) + merged_spec = cls( + block_size=specs[0].block_size, + num_kv_heads=specs[0].num_kv_heads, + head_size=specs[0].head_size, + head_size_v=specs[0].head_size_v, + dtype=specs[0].dtype, + sliding_window=cls.merge_window_sizes(sliding_window), + attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), + ) + for spec in specs: + for f in fields(AttentionSpec): + assert getattr(spec, f.name) == getattr(merged_spec, f.name), ( + "All attention layers in the same KV cache group must have " + "the same attention spec." + ) + assert (merged_spec.sliding_window is not None) + ( + merged_spec.attention_chunk_size is not None + ) <= 1, ( + "Model with both sliding window layers and chunked local attention " + "layers is not supported." + ) + return merged_spec + + @property + def page_size_bytes(self) -> int: + return ( + self.block_size + * self.num_kv_heads + * (self.head_size + self.head_size_v) + * get_dtype_size(self.dtype) + ) + + @dataclass(frozen=True) class MLAAttentionSpec(FullAttentionSpec): # TODO(Lucas/Chen): less hacky way to do this