Add support for openpangu_pro_moe_v2, which characterized by its different kv head size and sink kv in attention.

Signed-off-by: yuantao <2422264527@qq.com>
This commit is contained in:
yuantao 2025-11-15 12:00:40 +08:00
parent 9fc81ec765
commit 8de4315229
10 changed files with 1744 additions and 13 deletions

View File

@ -427,6 +427,7 @@ th {
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ |
| `OuroForCausalLM` | ouro | `ByteDance/Ouro-1.4B`, `ByteDance/Ouro-2.6B`, etc. | ✅︎ | |
| `PanguEmbeddedForCausalLM` |openPangu-Embedded-7B | `FreedomIntelligence/openPangu-Embedded-7B-V1.1` | ✅︎ | ✅︎ |
| `PanguProMoEV2ForCausalLM` |openpangu-pro-moe-v2 | | ✅︎ | ✅︎ |
| `PanguUltraMoEForCausalLM` |openpangu-ultra-moe-718b-model | `FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1` | ✅︎ | ✅︎ |
| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ |
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ |

View File

@ -383,6 +383,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"PanguEmbeddedForCausalLM": _HfExamplesInfo(
"FreedomIntelligence/openPangu-Embedded-7B-V1.1", trust_remote_code=True
),
"PanguProMoEV2ForCausalLM": _HfExamplesInfo(
"",
trust_remote_code=True,
is_available_online=False,
),
"PanguUltraMoEForCausalLM": _HfExamplesInfo(
"FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1",
trust_remote_code=True,

View File

@ -939,21 +939,42 @@ def unified_attention_with_output(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
sink_key: torch.Tensor | None = None,
sink_value: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
if sink_key is None and sink_value is None:
self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
else:
assert sink_key is not None and sink_value is not None, (
"Currently, it is only supported when "
"sink_key and sink_value are both not None"
)
self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
sink_key=sink_key,
sink_value=sink_value,
)
def unified_attention_with_output_fake(
@ -962,6 +983,8 @@ def unified_attention_with_output_fake(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
sink_key: torch.Tensor | None = None,
sink_value: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:

View File

@ -182,3 +182,136 @@ def triton_reshape_and_cache_flash(
num_warps=num_warps,
num_stages=num_stages,
)
@triton.jit
def reshape_and_cache_kernel_flash_diffkv(
kv_ptr, # [num_tokens, num_heads, head_size + head_size_v]
kv_cache_ptr, # [num_blocks, block_size, num_heads, head_size + head_size_v]
slot_mapping_ptr, # [num_tokens]
k_scale, # float32
v_scale, # float32
# strides
kv_stride: tl.int64,
block_stride: tl.int64,
page_stride: tl.int64,
num_heads: tl.constexpr,
head_size_kv: tl.constexpr,
block_size: tl.constexpr,
# FP8 flags
FP8_KV_CACHE: tl.constexpr,
# tune parameters
TILE_SIZE: tl.constexpr,
):
token_idx = tl.program_id(axis=0)
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
if slot_idx < 0:
# Padding token that should be ignored.
return
tile_i = tl.program_id(axis=1)
tile_offs = tl.arange(0, TILE_SIZE)
tile_pos = tile_i * TILE_SIZE + tile_offs
block_idx = slot_idx // block_size
block_offset = slot_idx % block_size
src_kv_idx = token_idx * kv_stride
tgt_idx = block_idx * block_stride + block_offset * page_stride
# [TILE_SIZE]
kv_tile = tl.load(
kv_ptr + src_kv_idx + tile_pos, mask=tile_pos < (num_heads * head_size_kv)
)
tl.store(
kv_cache_ptr + tgt_idx + tile_pos,
kv_tile,
mask=tile_pos < (num_heads * head_size_kv),
)
return
def triton_reshape_and_cache_flash_diffkv(
key: torch.Tensor, # [num_tokens, num_heads, head_size]
value: torch.Tensor, # [num_tokens, num_heads, head_size_v]
# [num_blocks, block_size, num_heads, head_size + head_size_v]
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor, # [num_tokens]
kv_cache_dtype: str, # "auto", "fp8"
k_scale: torch.Tensor, # float32
v_scale: torch.Tensor, # float32
):
kv = torch.cat([key, value], dim=-1).contiguous()
num_heads = kv.shape[1]
head_size_kv = kv.shape[2]
block_size = kv_cache.shape[1]
n = num_heads * head_size_kv
kv_stride = kv.stride()[0]
block_stride = kv_cache.stride()[0]
page_stride = kv_cache.stride()[1]
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
)
kv_cache_torch_dtype = (
current_platform.fp8_dtype()
if kv_cache_dtype.startswith("fp8")
else kv_cache.dtype
)
if kv_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
kv_cache = kv_cache.view(kv_cache_torch_dtype)
assert kv_cache_dtype != torch.uint8, (
"explicit fp8 cast and store to "
"uint8 is not supported by triton reshape_and_cache_flash_diffkv"
)
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
assert not FP8_KV_CACHE, (
"unsupported dtype of KV cache tensor, got "
"{kv_cache_torch_dtype}. Supported kv cache dtypes: bfloat16, float16, float32."
)
# heuristics instead of autotuning
TILE_SIZE = min(2048, triton.next_power_of_2(n))
if current_platform.is_rocm() or current_platform.is_xpu():
num_stages = 4
num_warps = 8
else: # cuda
num_stages = 10
num_warps = 16
if torch.cuda.get_device_capability(key.device)[0] < 9:
TILE_SIZE = min(512, TILE_SIZE)
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
# using cudagraphs
grid = lambda meta: (
slot_mapping.shape[0],
triton.cdiv(n, meta["TILE_SIZE"]),
)
reshape_and_cache_kernel_flash_diffkv[grid](
kv_ptr=kv,
kv_cache_ptr=kv_cache,
slot_mapping_ptr=slot_mapping,
k_scale=k_scale,
v_scale=v_scale,
# strides
kv_stride=kv_stride,
block_stride=block_stride,
page_stride=page_stride,
num_heads=num_heads,
head_size_kv=head_size_kv,
block_size=block_size,
# FP8 flags
FP8_KV_CACHE=FP8_KV_CACHE,
# autotune parameters
TILE_SIZE=TILE_SIZE,
num_warps=num_warps,
num_stages=num_stages,
)

View File

@ -30,15 +30,18 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.distributed import (
get_ep_group,
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_gather,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
@ -76,7 +79,13 @@ from vllm.model_executor.models.utils import (
maybe_prefix,
sequence_parallel_chunk,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.flash_sink_attn import FlashSinkAttentionBackend
from vllm.v1.kv_cache_interface import (
FullSinkAttentionSpec,
KVCacheSpec,
)
def check_ffn_act_fn(act_fn: str):
@ -86,6 +95,140 @@ def check_ffn_act_fn(act_fn: str):
)
class AttentionWithSink(Attention):
def __init__(
self,
num_heads: int,
head_size: int,
head_size_v: int,
scale: float,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
logits_soft_cap: float | None = None,
per_layer_sliding_window: int | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
attn_backend: type[AttentionBackend] | None = None,
**extra_impl_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
cache_config,
quant_config,
logits_soft_cap,
per_layer_sliding_window,
prefix,
attn_type,
kv_sharing_target_layer_name,
attn_backend,
**extra_impl_args,
)
self.head_size_v = head_size_v
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
# For attention with sink, we have sink k, v
sink_key: torch.Tensor | None = None,
sink_value: torch.Tensor | None = None,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
Attention metadata (`attn_metadata`) is set using a context manager in
the model runner's `execute_model` method. It is accessed via forward
context using
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
output_dtype = query.dtype
if self.query_quant is not None:
# quantizing with a simple torch operation enables
# torch.compile to fuse this into previous ops
# which reduces overheads during decoding.
# Otherwise queries are quantized using custom ops
# which causes decoding overheads
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
# check if query quantization is supported
if self.impl.supports_quant_query_input():
query, _ = self.query_quant(query, self._q_scale)
if self.use_output:
output_shape = output_shape if output_shape is not None else query.shape
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size_v)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size_v)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(
self,
query,
key,
value,
self_kv_cache,
attn_metadata,
output=output,
sink_key=sink_key,
sink_value=sink_value,
)
else:
torch.ops.vllm.unified_attention_with_output(
query,
key,
value,
output,
self.layer_name,
sink_key=sink_key,
sink_value=sink_value,
)
return output.view(-1, hidden_size)
else:
raise ValueError(
"Unsupport Error, currently only flash_sink_attn "
"backend with output buffer is supported"
)
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
# Only support for full attention now.
assert self.sliding_window is None
return FullSinkAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
head_size_v=self.head_size_v,
dtype=self.kv_cache_torch_dtype,
)
class OpenPanguMLP(nn.Module):
def __init__(
self,
@ -153,7 +296,15 @@ class OpenPanguMoE(nn.Module):
quant_config=None,
prefix=f"{prefix}.gate",
)
self.gate.e_score_correction_bias = None
if (
hasattr(config, "router_enable_expert_bias")
and config.router_enable_expert_bias
):
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(self.n_routed_experts, dtype=torch.float32)
)
else:
self.gate.e_score_correction_bias = None
# Load balancing settings.
eplb_config = parallel_config.eplb_config
@ -539,6 +690,276 @@ class OpenPanguEmbeddedAttention(nn.Module):
)
class OpenPanguSinkAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: CacheConfig | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__()
layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.total_num_heads = num_heads
if self.total_num_heads % self.tp_size != 0:
raise ValueError(
f"total_num_heads {self.total_num_heads} "
f"is not divisible by tp_size {self.tp_size}."
)
self.num_heads = self.total_num_heads // self.tp_size
self.total_num_kv_heads = num_kv_heads
if (
self.total_num_kv_heads > self.tp_size
and self.total_num_kv_heads % self.tp_size != 0
):
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel ranks.
raise ValueError(
"Number of KV heads is greater than TP size, "
f"but total_num_kv_heads {self.total_num_kv_heads} "
f"is not divisible by tp_size {self.tp_size}."
)
elif self.total_num_kv_heads < self.tp_size:
# TODO: Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel ranks.
raise ValueError(
f"Number of KV heads {self.total_num_kv_heads} is less than "
f"TP size {self.tp_size}, KV heads replication is not support yet."
)
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
self.qk_nope_dim = getattr(config, "qk_nope_dim", None)
self.qk_rope_dim = getattr(config, "qk_rope_dim", None)
self.v_channels = getattr(config, "v_channels", None)
self.head_dim = self.qk_rope_dim + self.qk_nope_dim
self.q_size = self.num_heads * self.head_dim
self.k_size = self.num_kv_heads * self.head_dim
self.v_size = self.num_kv_heads * self.v_channels
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.param_sink_number = getattr(config, "param_sink_number", 0)
self.param_sink_with_value = getattr(config, "param_sink_with_value", False)
self.param_sink_scalar = getattr(config, "param_sink_scalar", None)
self.param_sink_of_head_num = getattr(config, "param_sink_of_head_dim", False)
self.qkv_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[
self.q_size * self.tp_size,
self.k_size * self.tp_size,
self.v_size * self.tp_size,
],
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.v_channels,
output_size=hidden_size,
bias=bias_o_proj,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.k_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self._init_rotary_emb(
config, rope_scaling=rope_scaling, quant_config=quant_config
)
if hasattr(config, "interleaved_sliding_window"):
interleaved_sliding_window = config.interleaved_sliding_window
if isinstance(interleaved_sliding_window, int):
sliding_window = interleaved_sliding_window
elif isinstance(interleaved_sliding_window, list):
sw_idx = layer_idx % len(interleaved_sliding_window)
sliding_window = interleaved_sliding_window[sw_idx]
else:
raise ValueError(
f"{type(interleaved_sliding_window)} "
"for interleaved_sliding_window is not supported."
)
else:
sliding_window = None
FlashSinkAttentionBackend.set_cache_head_size_ratio(
(self.head_dim + self.v_channels) / self.head_dim
)
self.attn = AttentionWithSink(
self.num_heads,
self.head_dim,
self.v_channels,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
attn_type=attn_type,
prefix=f"{prefix}.attn",
attn_backend=FlashSinkAttentionBackend,
)
if self.param_sink_number > 0:
self.param_sink_key = torch.nn.Parameter(
torch.empty(
(
self.param_sink_number,
self.num_kv_heads,
self.head_dim,
),
device=torch.cuda.current_device(),
dtype=config.torch_dtype,
)
)
set_weight_attrs(
self.param_sink_key,
{
"output_dim": 1,
"weight_loader": self.weight_loader,
},
)
if self.param_sink_with_value:
self.param_sink_value = torch.nn.Parameter(
torch.empty(
(
self.param_sink_number,
self.num_kv_heads,
self.v_channels,
),
device=torch.cuda.current_device(),
dtype=config.torch_dtype,
)
)
set_weight_attrs(
self.param_sink_value,
{
"output_dim": 1,
"weight_loader": self.weight_loader,
},
)
else:
self.param_sink_value = torch.zeros(
torch.empty(
(
self.param_sink_number,
self.num_kv_heads,
self.v_channels,
),
device=torch.cuda.current_device(),
dtype=config.torch_dtype,
)
)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
output_dim = getattr(param, "output_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()
# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, nn.UninitializedParameter):
final_shape = list(loaded_weight.shape)
if output_dim is not None:
assert final_shape[output_dim] % self.tp_size == 0
final_shape[output_dim] = final_shape[output_dim] // self.tp_size
param.materialize(final_shape, dtype=loaded_weight.dtype)
param_data = param.data
if output_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[output_dim]
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
k = self.k_layernorm(k.view(-1, self.num_kv_heads, self.head_dim))
q, k = self.rotary_emb(positions, q, k)
q = q.view(-1, self.q_size)
k = k.view(-1, self.k_size)
param_sink_key = self.param_sink_key
if (
self.param_sink_number > 0
and hasattr(self, "k_layernorm")
and self.k_layernorm is not None
):
param_sink_key = self.k_layernorm(param_sink_key)
attn_output = self.attn(
q,
k,
v,
output_shape=torch.Size(
[q.shape[0], q.shape[1] // self.head_dim * self.v_channels]
),
**(
dict(
sink_key=param_sink_key,
sink_value=self.param_sink_value,
)
if self.param_sink_number > 0
else {}
),
)
attn_output = attn_output.reshape(-1, self.num_heads * self.v_channels)
output, _ = self.o_proj(attn_output)
return output
def _init_rotary_emb(
self,
config: PretrainedConfig,
rope_scaling: dict[str, Any] | None,
quant_config: QuantizationConfig | None,
) -> None:
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.qk_rope_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
)
class OpenPanguDecoderLayer(nn.Module):
def __init__(
self,
@ -567,6 +988,9 @@ class OpenPanguDecoderLayer(nn.Module):
and hasattr(config, "v_head_dim")
and hasattr(config, "kv_lora_rank")
)
self.use_sink_attention = (
hasattr(config, "param_sink_number") and config.param_sink_number > 0
)
if self.use_mla:
self.self_attn = OpenPanguMLAAttention(
config=config,
@ -585,6 +1009,37 @@ class OpenPanguDecoderLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
elif self.use_sink_attention:
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False
)
bias_o_proj = attention_bias
if hasattr(config, "qkv_bias"):
attention_bias = config.qkv_bias
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
raise ValueError(
f"is_causal={config.is_causal} is not support "
"for attention with sink"
)
self.self_attn = OpenPanguSinkAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
rope_theta=rope_theta,
rope_scaling=getattr(config, "rope_scaling", None),
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
bias_o_proj=bias_o_proj,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
else:
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False
@ -916,6 +1371,10 @@ class OpenPanguModel(nn.Module):
if name.endswith(".bias") and name not in params_dict:
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name.endswith("e_score_correction_bias"):
name = name.replace(
"e_score_correction_bias", "gate.e_score_correction_bias"
)
if name is None:
continue
if is_pp_missing_parameter(name, self):
@ -1060,3 +1519,7 @@ class PanguEmbeddedForCausalLM(OpenPanguEmbeddedModel):
class PanguUltraMoEForCausalLM(OpenPanguMoEModel):
pass
class PanguProMoEV2ForCausalLM(OpenPanguMoEModel):
pass

View File

@ -150,6 +150,7 @@ _TEXT_GENERATION_MODELS = {
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"OuroForCausalLM": ("ouro", "OuroForCausalLM"),
"PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
"PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
"PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),

File diff suppressed because it is too large Load Diff

View File

@ -184,6 +184,11 @@ class Scheduler(SchedulerInterface):
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size,
)
sink_len = getattr(vllm_config.model_config.hf_config, "param_sink_number", 0)
if sink_len > 0:
assert sink_len % self.block_size == 0
num_sink_block = sink_len // self.block_size
self.kv_cache_manager.block_pool.free_block_queue.popleft_n(num_sink_block)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
def schedule(self) -> SchedulerOutput:

View File

@ -12,6 +12,7 @@ from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec,
CrossAttentionSpec,
FullAttentionSpec,
FullSinkAttentionSpec,
KVCacheSpec,
MambaSpec,
MLAAttentionSpec,
@ -305,7 +306,8 @@ class FullAttentionManager(SingleTypeKVCacheManager):
dcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
kv_cache_spec,
(FullAttentionSpec, FullSinkAttentionSpec, ChunkedLocalAttentionSpec),
), (
"FullAttentionManager can only be used for full attention "
"and chunked local attention groups"
@ -720,6 +722,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
FullSinkAttentionSpec: FullAttentionManager,
MLAAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,

View File

@ -157,6 +157,98 @@ class FullAttentionSpec(AttentionSpec):
return merged_spec
@dataclass(frozen=True)
class FullSinkAttentionSpec(AttentionSpec):
head_size_v: int
sliding_window: int | None = None
attention_chunk_size: int | None = None
"""
When hybrid allocator is disabled and the model contains both full
attention layers and sliding window attention layers, sliding
window attention are regarded as full attention in KV cache manager
(blocks are allocated for all tokens), while computed as sliding window
attention in model runner.
In this case, we use FullSinkAttentionSpec and record the sliding window size.
Default to None for not using sliding window attention.
"""
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
# Note(hc): each dcp rank only need save
# (max_model_len//dcp_world_size) tokens locally.
if dcp_world_size > 1:
max_model_len = cdiv(max_model_len, dcp_world_size)
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod
def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
if len(window_sizes) == 0:
return None
elif len(window_sizes) == 1:
return window_sizes.pop()
else:
raise ValueError(
"All attention layers in the same KV cache group must have the "
"same window size."
)
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of FullSinkAttentionSpec objects into a single
FullSinkAttentionSpec object.
"""
assert all(isinstance(spec, FullSinkAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be "
"FullSinkAttentionSpec."
)
sliding_window = set(
spec.sliding_window for spec in specs if spec.sliding_window is not None
)
attention_chunk_size = set(
spec.attention_chunk_size
for spec in specs
if spec.attention_chunk_size is not None
)
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
)
merged_spec = cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
head_size_v=specs[0].head_size_v,
dtype=specs[0].dtype,
sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
)
for spec in specs:
for f in fields(AttentionSpec):
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
"All attention layers in the same KV cache group must have "
"the same attention spec."
)
assert (merged_spec.sliding_window is not None) + (
merged_spec.attention_chunk_size is not None
) <= 1, (
"Model with both sliding window layers and chunked local attention "
"layers is not supported."
)
return merged_spec
@property
def page_size_bytes(self) -> int:
return (
self.block_size
* self.num_kv_heads
* (self.head_size + self.head_size_v)
* get_dtype_size(self.dtype)
)
@dataclass(frozen=True)
class MLAAttentionSpec(FullAttentionSpec):
# TODO(Lucas/Chen): less hacky way to do this