mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-30 14:37:08 +08:00
Add support for openpangu_pro_moe_v2, which characterized by its different kv head size and sink kv in attention.
Signed-off-by: yuantao <2422264527@qq.com>
This commit is contained in:
parent
9fc81ec765
commit
8de4315229
@ -427,6 +427,7 @@ th {
|
||||
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ |
|
||||
| `OuroForCausalLM` | ouro | `ByteDance/Ouro-1.4B`, `ByteDance/Ouro-2.6B`, etc. | ✅︎ | |
|
||||
| `PanguEmbeddedForCausalLM` |openPangu-Embedded-7B | `FreedomIntelligence/openPangu-Embedded-7B-V1.1` | ✅︎ | ✅︎ |
|
||||
| `PanguProMoEV2ForCausalLM` |openpangu-pro-moe-v2 | | ✅︎ | ✅︎ |
|
||||
| `PanguUltraMoEForCausalLM` |openpangu-ultra-moe-718b-model | `FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1` | ✅︎ | ✅︎ |
|
||||
| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ |
|
||||
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
|
||||
@ -383,6 +383,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"PanguEmbeddedForCausalLM": _HfExamplesInfo(
|
||||
"FreedomIntelligence/openPangu-Embedded-7B-V1.1", trust_remote_code=True
|
||||
),
|
||||
"PanguProMoEV2ForCausalLM": _HfExamplesInfo(
|
||||
"",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False,
|
||||
),
|
||||
"PanguUltraMoEForCausalLM": _HfExamplesInfo(
|
||||
"FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1",
|
||||
trust_remote_code=True,
|
||||
|
||||
@ -939,21 +939,42 @@ def unified_attention_with_output(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
sink_key: torch.Tensor | None = None,
|
||||
sink_value: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale,
|
||||
)
|
||||
if sink_key is None and sink_value is None:
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale,
|
||||
)
|
||||
else:
|
||||
assert sink_key is not None and sink_value is not None, (
|
||||
"Currently, it is only supported when "
|
||||
"sink_key and sink_value are both not None"
|
||||
)
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale,
|
||||
sink_key=sink_key,
|
||||
sink_value=sink_value,
|
||||
)
|
||||
|
||||
|
||||
def unified_attention_with_output_fake(
|
||||
@ -962,6 +983,8 @@ def unified_attention_with_output_fake(
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
sink_key: torch.Tensor | None = None,
|
||||
sink_value: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
|
||||
@ -182,3 +182,136 @@ def triton_reshape_and_cache_flash(
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def reshape_and_cache_kernel_flash_diffkv(
|
||||
kv_ptr, # [num_tokens, num_heads, head_size + head_size_v]
|
||||
kv_cache_ptr, # [num_blocks, block_size, num_heads, head_size + head_size_v]
|
||||
slot_mapping_ptr, # [num_tokens]
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
# strides
|
||||
kv_stride: tl.int64,
|
||||
block_stride: tl.int64,
|
||||
page_stride: tl.int64,
|
||||
num_heads: tl.constexpr,
|
||||
head_size_kv: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE: tl.constexpr,
|
||||
# tune parameters
|
||||
TILE_SIZE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(axis=0)
|
||||
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
|
||||
if slot_idx < 0:
|
||||
# Padding token that should be ignored.
|
||||
return
|
||||
|
||||
tile_i = tl.program_id(axis=1)
|
||||
tile_offs = tl.arange(0, TILE_SIZE)
|
||||
tile_pos = tile_i * TILE_SIZE + tile_offs
|
||||
|
||||
block_idx = slot_idx // block_size
|
||||
block_offset = slot_idx % block_size
|
||||
|
||||
src_kv_idx = token_idx * kv_stride
|
||||
|
||||
tgt_idx = block_idx * block_stride + block_offset * page_stride
|
||||
|
||||
# [TILE_SIZE]
|
||||
kv_tile = tl.load(
|
||||
kv_ptr + src_kv_idx + tile_pos, mask=tile_pos < (num_heads * head_size_kv)
|
||||
)
|
||||
|
||||
tl.store(
|
||||
kv_cache_ptr + tgt_idx + tile_pos,
|
||||
kv_tile,
|
||||
mask=tile_pos < (num_heads * head_size_kv),
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def triton_reshape_and_cache_flash_diffkv(
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads, head_size_v]
|
||||
# [num_blocks, block_size, num_heads, head_size + head_size_v]
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor, # [num_tokens]
|
||||
kv_cache_dtype: str, # "auto", "fp8"
|
||||
k_scale: torch.Tensor, # float32
|
||||
v_scale: torch.Tensor, # float32
|
||||
):
|
||||
kv = torch.cat([key, value], dim=-1).contiguous()
|
||||
num_heads = kv.shape[1]
|
||||
head_size_kv = kv.shape[2]
|
||||
block_size = kv_cache.shape[1]
|
||||
n = num_heads * head_size_kv
|
||||
|
||||
kv_stride = kv.stride()[0]
|
||||
block_stride = kv_cache.stride()[0]
|
||||
page_stride = kv_cache.stride()[1]
|
||||
|
||||
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
|
||||
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
||||
)
|
||||
kv_cache_torch_dtype = (
|
||||
current_platform.fp8_dtype()
|
||||
if kv_cache_dtype.startswith("fp8")
|
||||
else kv_cache.dtype
|
||||
)
|
||||
|
||||
if kv_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"):
|
||||
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
|
||||
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
|
||||
kv_cache = kv_cache.view(kv_cache_torch_dtype)
|
||||
assert kv_cache_dtype != torch.uint8, (
|
||||
"explicit fp8 cast and store to "
|
||||
"uint8 is not supported by triton reshape_and_cache_flash_diffkv"
|
||||
)
|
||||
|
||||
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
||||
assert not FP8_KV_CACHE, (
|
||||
"unsupported dtype of KV cache tensor, got "
|
||||
"{kv_cache_torch_dtype}. Supported kv cache dtypes: bfloat16, float16, float32."
|
||||
)
|
||||
|
||||
# heuristics instead of autotuning
|
||||
TILE_SIZE = min(2048, triton.next_power_of_2(n))
|
||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||
num_stages = 4
|
||||
num_warps = 8
|
||||
else: # cuda
|
||||
num_stages = 10
|
||||
num_warps = 16
|
||||
if torch.cuda.get_device_capability(key.device)[0] < 9:
|
||||
TILE_SIZE = min(512, TILE_SIZE)
|
||||
|
||||
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
|
||||
# using cudagraphs
|
||||
grid = lambda meta: (
|
||||
slot_mapping.shape[0],
|
||||
triton.cdiv(n, meta["TILE_SIZE"]),
|
||||
)
|
||||
|
||||
reshape_and_cache_kernel_flash_diffkv[grid](
|
||||
kv_ptr=kv,
|
||||
kv_cache_ptr=kv_cache,
|
||||
slot_mapping_ptr=slot_mapping,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
# strides
|
||||
kv_stride=kv_stride,
|
||||
block_stride=block_stride,
|
||||
page_stride=page_stride,
|
||||
num_heads=num_heads,
|
||||
head_size_kv=head_size_kv,
|
||||
block_size=block_size,
|
||||
# FP8 flags
|
||||
FP8_KV_CACHE=FP8_KV_CACHE,
|
||||
# autotune parameters
|
||||
TILE_SIZE=TILE_SIZE,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
|
||||
@ -30,15 +30,18 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_ep_group,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -76,7 +79,13 @@ from vllm.model_executor.models.utils import (
|
||||
maybe_prefix,
|
||||
sequence_parallel_chunk,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backends.flash_sink_attn import FlashSinkAttentionBackend
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullSinkAttentionSpec,
|
||||
KVCacheSpec,
|
||||
)
|
||||
|
||||
|
||||
def check_ffn_act_fn(act_fn: str):
|
||||
@ -86,6 +95,140 @@ def check_ffn_act_fn(act_fn: str):
|
||||
)
|
||||
|
||||
|
||||
class AttentionWithSink(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
head_size_v: int,
|
||||
scale: float,
|
||||
num_kv_heads: int | None = None,
|
||||
alibi_slopes: list[float] | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
logits_soft_cap: float | None = None,
|
||||
per_layer_sliding_window: int | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
cache_config,
|
||||
quant_config,
|
||||
logits_soft_cap,
|
||||
per_layer_sliding_window,
|
||||
prefix,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
attn_backend,
|
||||
**extra_impl_args,
|
||||
)
|
||||
self.head_size_v = head_size_v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
# For attention with sink, we have sink k, v
|
||||
sink_key: torch.Tensor | None = None,
|
||||
sink_value: torch.Tensor | None = None,
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
|
||||
Attention metadata (`attn_metadata`) is set using a context manager in
|
||||
the model runner's `execute_model` method. It is accessed via forward
|
||||
context using
|
||||
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||
"""
|
||||
if self.calculate_kv_scales:
|
||||
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
|
||||
output_dtype = query.dtype
|
||||
if self.query_quant is not None:
|
||||
# quantizing with a simple torch operation enables
|
||||
# torch.compile to fuse this into previous ops
|
||||
# which reduces overheads during decoding.
|
||||
# Otherwise queries are quantized using custom ops
|
||||
# which causes decoding overheads
|
||||
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||
|
||||
# check if query quantization is supported
|
||||
if self.impl.supports_quant_query_input():
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
output_shape = output_shape if output_shape is not None else query.shape
|
||||
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size_v)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
if self.use_direct_call:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.layer_name]
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
sink_key=sink_key,
|
||||
sink_value=sink_value,
|
||||
)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
sink_key=sink_key,
|
||||
sink_value=sink_value,
|
||||
)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupport Error, currently only flash_sink_attn "
|
||||
"backend with output buffer is supported"
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Block size may get updated after model loading, refresh it
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
# Only support for full attention now.
|
||||
assert self.sliding_window is None
|
||||
return FullSinkAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
head_size_v=self.head_size_v,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
class OpenPanguMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -153,7 +296,15 @@ class OpenPanguMoE(nn.Module):
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
self.gate.e_score_correction_bias = None
|
||||
if (
|
||||
hasattr(config, "router_enable_expert_bias")
|
||||
and config.router_enable_expert_bias
|
||||
):
|
||||
self.gate.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(self.n_routed_experts, dtype=torch.float32)
|
||||
)
|
||||
else:
|
||||
self.gate.e_score_correction_bias = None
|
||||
|
||||
# Load balancing settings.
|
||||
eplb_config = parallel_config.eplb_config
|
||||
@ -539,6 +690,276 @@ class OpenPanguEmbeddedAttention(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class OpenPanguSinkAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: dict[str, Any] | None = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
bias: bool = False,
|
||||
bias_o_proj: bool = False,
|
||||
cache_config: CacheConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
self.hidden_size = hidden_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.total_num_heads = num_heads
|
||||
if self.total_num_heads % self.tp_size != 0:
|
||||
raise ValueError(
|
||||
f"total_num_heads {self.total_num_heads} "
|
||||
f"is not divisible by tp_size {self.tp_size}."
|
||||
)
|
||||
self.num_heads = self.total_num_heads // self.tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if (
|
||||
self.total_num_kv_heads > self.tp_size
|
||||
and self.total_num_kv_heads % self.tp_size != 0
|
||||
):
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel ranks.
|
||||
raise ValueError(
|
||||
"Number of KV heads is greater than TP size, "
|
||||
f"but total_num_kv_heads {self.total_num_kv_heads} "
|
||||
f"is not divisible by tp_size {self.tp_size}."
|
||||
)
|
||||
elif self.total_num_kv_heads < self.tp_size:
|
||||
# TODO: Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel ranks.
|
||||
raise ValueError(
|
||||
f"Number of KV heads {self.total_num_kv_heads} is less than "
|
||||
f"TP size {self.tp_size}, KV heads replication is not support yet."
|
||||
)
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
|
||||
self.qk_nope_dim = getattr(config, "qk_nope_dim", None)
|
||||
self.qk_rope_dim = getattr(config, "qk_rope_dim", None)
|
||||
self.v_channels = getattr(config, "v_channels", None)
|
||||
self.head_dim = self.qk_rope_dim + self.qk_nope_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.k_size = self.num_kv_heads * self.head_dim
|
||||
self.v_size = self.num_kv_heads * self.v_channels
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.param_sink_number = getattr(config, "param_sink_number", 0)
|
||||
self.param_sink_with_value = getattr(config, "param_sink_with_value", False)
|
||||
self.param_sink_scalar = getattr(config, "param_sink_scalar", None)
|
||||
self.param_sink_of_head_num = getattr(config, "param_sink_of_head_dim", False)
|
||||
|
||||
self.qkv_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[
|
||||
self.q_size * self.tp_size,
|
||||
self.k_size * self.tp_size,
|
||||
self.v_size * self.tp_size,
|
||||
],
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.v_channels,
|
||||
output_size=hidden_size,
|
||||
bias=bias_o_proj,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.k_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
|
||||
self._init_rotary_emb(
|
||||
config, rope_scaling=rope_scaling, quant_config=quant_config
|
||||
)
|
||||
|
||||
if hasattr(config, "interleaved_sliding_window"):
|
||||
interleaved_sliding_window = config.interleaved_sliding_window
|
||||
if isinstance(interleaved_sliding_window, int):
|
||||
sliding_window = interleaved_sliding_window
|
||||
elif isinstance(interleaved_sliding_window, list):
|
||||
sw_idx = layer_idx % len(interleaved_sliding_window)
|
||||
sliding_window = interleaved_sliding_window[sw_idx]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{type(interleaved_sliding_window)} "
|
||||
"for interleaved_sliding_window is not supported."
|
||||
)
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
FlashSinkAttentionBackend.set_cache_head_size_ratio(
|
||||
(self.head_dim + self.v_channels) / self.head_dim
|
||||
)
|
||||
self.attn = AttentionWithSink(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.v_channels,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
attn_type=attn_type,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_backend=FlashSinkAttentionBackend,
|
||||
)
|
||||
|
||||
if self.param_sink_number > 0:
|
||||
self.param_sink_key = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
(
|
||||
self.param_sink_number,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
),
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.param_sink_key,
|
||||
{
|
||||
"output_dim": 1,
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
|
||||
if self.param_sink_with_value:
|
||||
self.param_sink_value = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
(
|
||||
self.param_sink_number,
|
||||
self.num_kv_heads,
|
||||
self.v_channels,
|
||||
),
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.param_sink_value,
|
||||
{
|
||||
"output_dim": 1,
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.param_sink_value = torch.zeros(
|
||||
torch.empty(
|
||||
(
|
||||
self.param_sink_number,
|
||||
self.num_kv_heads,
|
||||
self.v_channels,
|
||||
),
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
|
||||
# Special case for GGUF
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.weight_type = loaded_weight.item()
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, nn.UninitializedParameter):
|
||||
final_shape = list(loaded_weight.shape)
|
||||
if output_dim is not None:
|
||||
assert final_shape[output_dim] % self.tp_size == 0
|
||||
final_shape[output_dim] = final_shape[output_dim] // self.tp_size
|
||||
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
if output_dim is not None and not is_sharded_weight:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
k = self.k_layernorm(k.view(-1, self.num_kv_heads, self.head_dim))
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
q = q.view(-1, self.q_size)
|
||||
k = k.view(-1, self.k_size)
|
||||
param_sink_key = self.param_sink_key
|
||||
if (
|
||||
self.param_sink_number > 0
|
||||
and hasattr(self, "k_layernorm")
|
||||
and self.k_layernorm is not None
|
||||
):
|
||||
param_sink_key = self.k_layernorm(param_sink_key)
|
||||
|
||||
attn_output = self.attn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output_shape=torch.Size(
|
||||
[q.shape[0], q.shape[1] // self.head_dim * self.v_channels]
|
||||
),
|
||||
**(
|
||||
dict(
|
||||
sink_key=param_sink_key,
|
||||
sink_value=self.param_sink_value,
|
||||
)
|
||||
if self.param_sink_number > 0
|
||||
else {}
|
||||
),
|
||||
)
|
||||
attn_output = attn_output.reshape(-1, self.num_heads * self.v_channels)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def _init_rotary_emb(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
rope_scaling: dict[str, Any] | None,
|
||||
quant_config: QuantizationConfig | None,
|
||||
) -> None:
|
||||
is_neox_style = False
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.qk_rope_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=is_neox_style,
|
||||
)
|
||||
|
||||
|
||||
class OpenPanguDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -567,6 +988,9 @@ class OpenPanguDecoderLayer(nn.Module):
|
||||
and hasattr(config, "v_head_dim")
|
||||
and hasattr(config, "kv_lora_rank")
|
||||
)
|
||||
self.use_sink_attention = (
|
||||
hasattr(config, "param_sink_number") and config.param_sink_number > 0
|
||||
)
|
||||
if self.use_mla:
|
||||
self.self_attn = OpenPanguMLAAttention(
|
||||
config=config,
|
||||
@ -585,6 +1009,37 @@ class OpenPanguDecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
elif self.use_sink_attention:
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "bias", False
|
||||
)
|
||||
bias_o_proj = attention_bias
|
||||
if hasattr(config, "qkv_bias"):
|
||||
attention_bias = config.qkv_bias
|
||||
if getattr(config, "is_causal", True):
|
||||
attn_type = AttentionType.DECODER
|
||||
else:
|
||||
raise ValueError(
|
||||
f"is_causal={config.is_causal} is not support "
|
||||
"for attention with sink"
|
||||
)
|
||||
self.self_attn = OpenPanguSinkAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(
|
||||
config, "num_key_value_heads", config.num_attention_heads
|
||||
),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=getattr(config, "rope_scaling", None),
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
bias_o_proj=bias_o_proj,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
)
|
||||
else:
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "bias", False
|
||||
@ -916,6 +1371,10 @@ class OpenPanguModel(nn.Module):
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name.endswith("e_score_correction_bias"):
|
||||
name = name.replace(
|
||||
"e_score_correction_bias", "gate.e_score_correction_bias"
|
||||
)
|
||||
if name is None:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
@ -1060,3 +1519,7 @@ class PanguEmbeddedForCausalLM(OpenPanguEmbeddedModel):
|
||||
|
||||
class PanguUltraMoEForCausalLM(OpenPanguMoEModel):
|
||||
pass
|
||||
|
||||
|
||||
class PanguProMoEV2ForCausalLM(OpenPanguMoEModel):
|
||||
pass
|
||||
|
||||
@ -150,6 +150,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||
"OuroForCausalLM": ("ouro", "OuroForCausalLM"),
|
||||
"PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
|
||||
"PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
|
||||
"PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
|
||||
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
|
||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||
|
||||
1005
vllm/v1/attention/backends/flash_sink_attn.py
Normal file
1005
vllm/v1/attention/backends/flash_sink_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -184,6 +184,11 @@ class Scheduler(SchedulerInterface):
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
)
|
||||
sink_len = getattr(vllm_config.model_config.hf_config, "param_sink_number", 0)
|
||||
if sink_len > 0:
|
||||
assert sink_len % self.block_size == 0
|
||||
num_sink_block = sink_len // self.block_size
|
||||
self.kv_cache_manager.block_pool.free_block_queue.popleft_n(num_sink_block)
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
|
||||
@ -12,6 +12,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
ChunkedLocalAttentionSpec,
|
||||
CrossAttentionSpec,
|
||||
FullAttentionSpec,
|
||||
FullSinkAttentionSpec,
|
||||
KVCacheSpec,
|
||||
MambaSpec,
|
||||
MLAAttentionSpec,
|
||||
@ -305,7 +306,8 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
dcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(
|
||||
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
|
||||
kv_cache_spec,
|
||||
(FullAttentionSpec, FullSinkAttentionSpec, ChunkedLocalAttentionSpec),
|
||||
), (
|
||||
"FullAttentionManager can only be used for full attention "
|
||||
"and chunked local attention groups"
|
||||
@ -720,6 +722,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
FullSinkAttentionSpec: FullAttentionManager,
|
||||
MLAAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
||||
|
||||
@ -157,6 +157,98 @@ class FullAttentionSpec(AttentionSpec):
|
||||
return merged_spec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FullSinkAttentionSpec(AttentionSpec):
|
||||
head_size_v: int
|
||||
sliding_window: int | None = None
|
||||
attention_chunk_size: int | None = None
|
||||
|
||||
"""
|
||||
When hybrid allocator is disabled and the model contains both full
|
||||
attention layers and sliding window attention layers, sliding
|
||||
window attention are regarded as full attention in KV cache manager
|
||||
(blocks are allocated for all tokens), while computed as sliding window
|
||||
attention in model runner.
|
||||
In this case, we use FullSinkAttentionSpec and record the sliding window size.
|
||||
Default to None for not using sliding window attention.
|
||||
"""
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
# Note(hc): each dcp rank only need save
|
||||
# (max_model_len//dcp_world_size) tokens locally.
|
||||
if dcp_world_size > 1:
|
||||
max_model_len = cdiv(max_model_len, dcp_world_size)
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
@classmethod
|
||||
def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
|
||||
if len(window_sizes) == 0:
|
||||
return None
|
||||
elif len(window_sizes) == 1:
|
||||
return window_sizes.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
"All attention layers in the same KV cache group must have the "
|
||||
"same window size."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
Merge a list of FullSinkAttentionSpec objects into a single
|
||||
FullSinkAttentionSpec object.
|
||||
"""
|
||||
assert all(isinstance(spec, FullSinkAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be "
|
||||
"FullSinkAttentionSpec."
|
||||
)
|
||||
|
||||
sliding_window = set(
|
||||
spec.sliding_window for spec in specs if spec.sliding_window is not None
|
||||
)
|
||||
attention_chunk_size = set(
|
||||
spec.attention_chunk_size
|
||||
for spec in specs
|
||||
if spec.attention_chunk_size is not None
|
||||
)
|
||||
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
||||
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
|
||||
)
|
||||
merged_spec = cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
head_size_v=specs[0].head_size_v,
|
||||
dtype=specs[0].dtype,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
)
|
||||
for spec in specs:
|
||||
for f in fields(AttentionSpec):
|
||||
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
|
||||
"All attention layers in the same KV cache group must have "
|
||||
"the same attention spec."
|
||||
)
|
||||
assert (merged_spec.sliding_window is not None) + (
|
||||
merged_spec.attention_chunk_size is not None
|
||||
) <= 1, (
|
||||
"Model with both sliding window layers and chunked local attention "
|
||||
"layers is not supported."
|
||||
)
|
||||
return merged_spec
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
return (
|
||||
self.block_size
|
||||
* self.num_kv_heads
|
||||
* (self.head_size + self.head_size_v)
|
||||
* get_dtype_size(self.dtype)
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MLAAttentionSpec(FullAttentionSpec):
|
||||
# TODO(Lucas/Chen): less hacky way to do this
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user