[Model] New model support for microsoft/Phi-4-mini-flash-reasoning (#20702)

Signed-off-by: Congcong Chen <congcongchen@microsoft.com>
This commit is contained in:
Congcong Chen 2025-07-12 06:02:10 -07:00 committed by GitHub
parent b639327ad9
commit 2c11a738b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1869 additions and 41 deletions

View File

@ -312,19 +312,20 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
constexpr bool kIsVariableB = true;
constexpr bool kIsVariableC = true;
constexpr bool kHasZ = true;
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
}
@ -612,19 +613,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
at::Tensor z, out_z;
const bool has_z = z_.has_value();
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
z = z_.value();
TORCH_CHECK(z.scalar_type() == input_type);
TORCH_CHECK(z.is_cuda());
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
if (varlen){
CHECK_SHAPE(z, dim, seqlen);
} else {
CHECK_SHAPE(z, batch_size, dim, seqlen);
if (has_z) {
z = z_.value();
TORCH_CHECK(z.scalar_type() == input_type);
TORCH_CHECK(z.is_cuda());
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
if (varlen){
CHECK_SHAPE(z, dim, seqlen);
} else {
CHECK_SHAPE(z, batch_size, dim, seqlen);
}
out_z = z;
}
out_z = z;
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at::Tensor out = delta;
TORCH_CHECK(ssm_states.scalar_type() == input_type);
@ -653,4 +655,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
});
}

View File

@ -374,6 +374,7 @@ Specified using `--task generate`.
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ |
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | |
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |

View File

@ -248,6 +248,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
trust_remote_code=True,
v0_only=True),
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
trust_remote_code=True,
v0_only=True,
max_model_len=10240),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",

View File

@ -103,6 +103,9 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
_initialize_kv_caches_v1), monkeypatch.context() as m):
if model_info.v0_only:
m.setenv("VLLM_USE_V1", "0")
if model_arch == "Phi4FlashForCausalLM":
# Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend
m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN")
LLM(
model_info.default,
tokenizer=model_info.tokenizer,

View File

@ -458,6 +458,31 @@ def test_bind_kv_cache():
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]
def test_bind_kv_cache_kv_sharing():
from vllm.attention import Attention
ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
'layers.1.self_attn': Attention(32, 128, 0.1),
'layers.2.self_attn': Attention(32, 128, 0.1),
'layers.3.self_attn': Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
]
shared_kv_cache_layers = {
'layers.2.self_attn': 'layers.1.self_attn',
'layers.3.self_attn': 'layers.0.self_attn'
}
bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers)
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[1]
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[0]
def test_bind_kv_cache_non_attention():
from vllm.attention import Attention

View File

@ -308,7 +308,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
raise NotImplementedError("KV sharing is not supported in V0 "
"BLOCK_SPARSE_FLASH_ATTN Backend.")
assert blocksparse_params is not None
assert alibi_slopes is None, ValueError(
"Alibi not support for blocksparse flash attention.")

File diff suppressed because it is too large Load Diff

View File

@ -295,7 +295,8 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
raise NotImplementedError("KV sharing is not supported in V0 "
"DUAL_CHUNK_FLASH_ATTN backend.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)

View File

@ -622,7 +622,8 @@ class FlashAttentionImpl(AttentionImpl):
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
raise NotImplementedError("KV sharing is not supported in V0 "
"FLASH_ATTN backend.")
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")

View File

@ -1006,7 +1006,8 @@ class FlashInferImpl(AttentionImpl):
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
raise NotImplementedError("KV sharing is not supported in V0 "
"FLASHINFER backend.")
if use_irope:
logger.warning_once(
"Using irope in FlashInfer is not supported yet, it will fall"

View File

@ -115,7 +115,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
) -> None:
super(AttentionImpl, self).__init__()
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
raise NotImplementedError("KV sharing is not supported in V0 "
"HPU_ATTN backend.")
if use_irope:
logger.warning_once(
"Using irope in HPU is not supported yet, it will fall back "

View File

@ -501,7 +501,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
raise NotImplementedError("KV sharing is not supported in V0 "
"ROCM_FLASH backend.")
if use_irope:
logger.warning_once(
"Using irope in ROCm Flash Attention is not supported yet, it "

View File

@ -394,7 +394,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
raise NotImplementedError("KV sharing is not supported in V0 "
"XFORMERS backend.")
if blocksparse_params is not None:
raise ValueError(
"XFormers does not support block-sparse attention.")

View File

@ -160,10 +160,6 @@ class Attention(nn.Module):
self.attn_type = attn_type
if kv_sharing_target_layer_name is not None:
if not envs.VLLM_USE_V1:
raise NotImplementedError(
"Cross-layer KV sharing is not supported in V0.")
validate_kv_sharing_target(
prefix,
kv_sharing_target_layer_name,

View File

@ -59,11 +59,12 @@ class LogitsProcessor(nn.Module):
hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata] = None,
embedding_bias: Optional[torch.Tensor] = None,
prune_hidden_states: bool = True,
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
if sampling_metadata is not None:
if sampling_metadata is not None and prune_hidden_states:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)

View File

@ -0,0 +1,746 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.nn as nn
from transformers.activations import ACT2FN
import vllm.envs as envs
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsV0Only)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import make_layers, maybe_prefix
logger = init_logger(__name__)
class SwiGLUActivation(nn.Module):
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return x1 * nn.functional.silu(x2)
class SambaYMLP(nn.Module):
"""Gated Linear Unit.
Reference:
Language Modeling with Gated Convolutional Networks.
https://arxiv.org/pdf/1612.08083v3.pdf.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.fc1 = nn.Linear(config.hidden_size,
2 * config.intermediate_size,
bias=False)
self.fc2 = nn.Linear(config.intermediate_size,
config.hidden_size,
bias=False)
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
y = self.fc1(hidden_states)
gate, y = y.chunk(2, dim=-1)
y = y * self.activation_fn(gate)
return self.fc2(y)
def get_virtual_engine():
forward_context: ForwardContext = get_forward_context()
return forward_context.virtual_engine
class SambaYAttention(nn.Module):
def __init__(self,
config,
layer_idx: Optional[int] = None,
yoco_cross: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = ""):
super().__init__()
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing "
"a `layer_idx` is not recommended and will lead to errors "
"during the forward call if caching is used. Please make "
"sure to provide a `layer_idx` when creating this class.")
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.yoco_cross = yoco_cross
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError("hidden_size must be divisible by num_heads "
f"(got `hidden_size`: {self.hidden_size} and "
f"`num_heads`: {self.num_heads}).")
op_size = self.num_heads * self.head_dim + 2 * (
self.num_key_value_heads * self.head_dim)
self.out_proj = nn.Linear(self.num_heads * self.head_dim,
self.hidden_size,
bias=True)
if yoco_cross:
self.Wqkv = nn.Linear(self.hidden_size,
self.num_heads * self.head_dim,
bias=True)
else:
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
# disable sliding window for the second half of the model
sliding_window = config.interleaved_sliding_window[layer_idx]
if layer_idx >= config.num_hidden_layers // 2:
assert sliding_window is None, \
"sliding_window must be none for the second decoder"
else:
assert sliding_window is not None, \
"sliding_window must be set for the first decoder"
assert self.num_heads % 2 == 0, 'num_heads should be even'
assert self.num_key_value_heads % 2 == 0, 'num_heads should be even'
self.lambda_init = self.lambda_init_fn(layer_idx)
self.lambda_q1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_k1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_q2 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.subln = nn.RMSNorm(2 * self.head_dim,
eps=1e-5,
elementwise_affine=True)
params = {
'differential_flash_attention_config': {
'lambda_init': self.lambda_init,
'lambda_q1': self.lambda_q1,
'lambda_k1': self.lambda_k1,
'lambda_q2': self.lambda_q2,
'lambda_k2': self.lambda_k2,
"subln": self.subln,
}
}
if yoco_cross:
kv_shared_layer_index = config.num_hidden_layers // 2 + 1
kv_sharing_target_layer_name = \
f"model.layers.{kv_shared_layer_index}.self_attn.attn"
else:
kv_sharing_target_layer_name = None
self.attn = Attention(
self.num_heads,
self.head_dim,
self.head_dim**-0.5,
num_kv_heads=self.num_key_value_heads,
cache_config=cache_config,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn",
attn_type=AttentionType.DECODER,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
**params)
assert self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN,\
"DIFFERENTIAL_FLASH_ATTN required"
def lambda_init_fn(self, depth):
return 0.8 - 0.6 * math.exp(-0.3 * depth)
def forward(
self,
hidden_states: torch.Tensor,
):
if not self.yoco_cross: # need to generate kv-cache
qkv = self.Wqkv(hidden_states)
q, k, v = qkv.split([
self.hidden_size, self.num_key_value_heads * self.head_dim,
self.num_key_value_heads * self.head_dim
],
dim=-1)
attn_output = self.attn(q, k, v)
else: # re-use the kv cache, full attention
q = self.Wqkv(hidden_states)
attn_output = self.attn(q, None, None)
attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
return self.out_proj(attn_output)
class Phi4Mamba(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random", # difference
dt_scale=1.0, # difference
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True, # Fused kernel options
layer_idx=None,
device=None,
dtype=None,
yoco_cross=False,
yoco_kv=False,
):
factory_kwargs = {"params_dtype": dtype} # difference
super().__init__()
self.yoco_cross = yoco_cross
self.yoco_kv = yoco_kv
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model /
16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.swiGluActivation = SwiGLUActivation()
if self.yoco_cross:
self.in_proj = MergedColumnParallelLinear(self.d_model,
[self.d_inner],
bias=bias,
**factory_kwargs)
self.out_proj = RowParallelLinear(self.d_inner,
self.d_model,
bias=bias,
**factory_kwargs)
return
self.conv1d = ColumnParallelLinear(
input_size=d_conv,
output_size=self.d_inner,
bias=conv_bias,
params_dtype=dtype,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(
self.d_model,
[self.d_inner] * 2,
bias=bias,
params_dtype=dtype,
)
# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
self.d_inner,
self.dt_rank + self.d_state * 2,
bias=False,
params_dtype=dtype,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(
self.dt_rank,
self.d_inner,
bias=True,
skip_bias_add=True,
params_dtype=dtype,
)
# # D "skip" parameter
# self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32
self.A = nn.Parameter(
torch.empty(
self.d_inner,
self.d_state,
dtype=torch.float32,
))
self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32))
self.out_proj = RowParallelLinear(
self.d_inner,
self.d_model,
bias=bias,
input_is_parallel=True,
params_dtype=dtype,
)
self.activation = "silu"
def forward(self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
yoco_key_values=None) -> torch.Tensor:
if self.yoco_cross:
out = self.in_proj(hidden_states)[0]
out = self.swiGluActivation(yoco_key_values, out)
out = self.out_proj(out)
return out[0], yoco_key_values
# 1. Gated MLP's linear projection
# projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
projected_states = self.in_proj(
hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split(
ssm_parameters,
[self.dt_rank, self.d_state, self.d_state],
dim=-1,
)
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
mamba_cache_params.ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
# z,
None if self.yoco_kv else gate,
time_proj_bias,
delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
# z
# gate.transpose(0, 1),
None if self.yoco_kv else gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
if self.yoco_kv:
# gate = gate.transpose(-1,-2).contiguous()
yoco_key_values = scan_outputs.transpose(-2, -1)
scan_outputs = self.swiGluActivation(scan_outputs, gate)
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
return contextualized_states, yoco_key_values
class SambaYDecoderLayer(nn.Module):
def __init__(
self,
config,
layer_idx,
cache_config,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.mlp = SambaYMLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.yoco_mb = False
self.yoco_cross = False
if layer_idx >= config.num_hidden_layers // 2:
self.yoco_mb = True
self.yoco_cross = (layer_idx
>= (config.num_hidden_layers // 2 + 2))
self.use_mamba = config.mb_per_layer > 0 and \
layer_idx % config.mb_per_layer == 0
if self.use_mamba:
factory_kwargs = {"dtype": None}
self.attn = Phi4Mamba(config.hidden_size,
layer_idx=layer_idx,
yoco_cross=self.yoco_cross,
yoco_kv=self.yoco_mb,
**factory_kwargs)
else:
self.attn = SambaYAttention(config,
layer_idx=layer_idx,
yoco_cross=self.yoco_cross,
cache_config=cache_config,
prefix=f"{prefix}.self_attn")
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
ssm_output: Optional[torch.LongTensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if self.use_mamba:
assert mamba_cache_params is not None
else:
assert mamba_cache_params is None
residual = hidden_states
hidden_states = self.input_layernorm(
hidden_states.to(dtype=self.input_layernorm.weight.dtype))
if self.use_mamba:
attn_outputs, ssm_output = self.attn(hidden_states,
attn_metadata,
mamba_cache_params,
yoco_key_values=ssm_output)
residual = residual.to(torch.float32)
else:
attn_outputs = self.attn(hidden_states, )
hidden_states = residual + attn_outputs
residual = hidden_states
hidden_states = self.post_attention_layernorm(
hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype))
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, ssm_output
class SambaYModel(nn.Module):
def __init__(self,
config,
cache_config=None,
quant_config=None,
lora_config=None,
prefix: str = "") -> None:
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
# Pipeline parallel is not supported since the second half of
# the layers share the kv cache.
if get_pp_group().world_size != 1:
raise ValueError("Pipeline Parallel not supported")
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: SambaYDecoderLayer(config,
int(prefix.split('.')[-1]),
cache_config,
prefix=prefix),
prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
mamba_state_idx = 0
ssm_output = None
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if i == self.config.num_hidden_layers // 2 + 2:
# profile run
kv_cache_idx = self.config.num_hidden_layers // 2 + 1
cache_layer = self.layers[kv_cache_idx]
kv_cache = cache_layer.attn.attn.kv_cache
if kv_cache[0].numel() == 0:
break
# Starting from this layer, we do not need to calculate
# the kv cache since we reuse the kv cache from last layer.
# If in prefill phase, we can <s>prune></s> truncate
# the hidden state to save computation cost.
if attn_metadata.prefill_metadata and not envs.VLLM_USE_V1:
selected_token_indices = torch.cumsum(
attn_metadata.seq_lens_tensor, dim=0) - 1
hidden_states = hidden_states.index_select(
0, selected_token_indices)
ssm_output = ssm_output.index_select(
0, selected_token_indices)
if layer.use_mamba:
if i < self.config.num_hidden_layers // 2 or \
not layer.yoco_cross:
mamba_cache = mamba_cache_params.at_layer_idx(
mamba_state_idx)
mamba_state_idx += 1
else:
mamba_cache = mamba_cache_params.at_layer_idx(
mamba_state_idx - 1)
hidden_states, ssm_output = layer(hidden_states,
positions,
attn_metadata,
mamba_cache,
ssm_output=ssm_output)
else:
hidden_states, ssm_output = layer(
hidden_states,
positions,
attn_metadata,
None, # mamba_cache_params
ssm_output=ssm_output)
hidden_states = self.final_layernorm(
hidden_states.to(dtype=self.final_layernorm.weight.dtype))
return hidden_states
class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
quant_config = vllm_config.quant_config
scheduler_config = vllm_config.scheduler_config
self.compilation_config = vllm_config.compilation_config
self.vllm_config = vllm_config
# Prefix caching and chunked prefill is not supported for this model.
assert not cache_config.enable_prefix_caching, \
"Phi4flash currently does not support prefix caching"
assert not scheduler_config.chunked_prefill_enabled, \
"Phi4Flash currently does not support prefix caching"
super().__init__()
self.config = config
self.model_config = vllm_config.model_config
self.scheduler_config = scheduler_config
self.model = SambaYModel(config,
cache_config=cache_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size),
quant_config=quant_config,
)
self.embedding_bias = None
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logits_as_input=False)
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers \
// 2 // self.config.mb_per_layer + 1
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
attn_metadata = get_forward_context().attn_metadata
# input_ids and hidden_states isn't a one-to-one mapping in prefill
# stage due to YOCO optimization.
hidden_states = self.model(input_ids, positions, attn_metadata,
mamba_cache_params, intermediate_tensors,
inputs_embeds)
return hidden_states
def _get_mamba_cache_shape(
self
) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
mamba_expand = self.config.mamba_expand # 2
mamba_d_conv = self.config.mamba_d_conv # 4
mamba_d_state = self.config.mamba_d_state # 16
conv_state_shape = (
mamba_expand * hidden_size // world_size,
mamba_d_conv - 1,
)
temporal_state_shape = (
mamba_expand * hidden_size // world_size,
mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# If the shape is the same, it means that we have already
# prune hidden states manually.
prune_hidden_states = hidden_states.size(
0) != sampling_metadata.selected_token_indices.size(0)
processed_logits = self.logits_processor(
self.lm_head,
hidden_states,
sampling_metadata,
self.embedding_bias,
prune_hidden_states=prune_hidden_states)
return processed_logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
):
weights = {name: weight for name, weight in weights}
adjusted_weights = {}
for name, weight in weights.items():
if "A_log" in name:
name = name.replace("A_log", "A")
weight = -torch.exp(weight.float())
if "inner_cross_attn." in name:
name = name.replace("inner_cross_attn.", "")
adjusted_weights[name] = weight
adjusted_weights["lm_head.weight"] = weights[
"model.embed_tokens.weight"]
loaded_params: set[str] = set()
for name, param in self.named_parameters():
weight = adjusted_weights.get(name)
if weight is not None and weight.shape != param.shape:
logger.warning("Shape mismatch: %s %s %s", name, weight.shape,
param.shape)
loaded_params.add(name)
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
strict=False)
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
return loaded_params

View File

@ -110,6 +110,7 @@ _TEXT_GENERATION_MODELS = {
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),

View File

@ -316,6 +316,10 @@ class CudaPlatformBase(Platform):
logger.info("Using DualChunkFlashAttention backend.")
return ("vllm.attention.backends.dual_chunk_flash_attn."
"DualChunkFlashAttentionBackend")
elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN:
logger.info("Using DifferentialFlashAttention backend.")
return ("vllm.attention.backends.differential_flash_attn."
"DifferentialFlashAttentionBackend")
elif selected_backend == _Backend.FLASH_ATTN:
pass
elif selected_backend:

View File

@ -60,6 +60,7 @@ class _Backend(enum.Enum):
IPEX = enum.auto()
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
DUAL_CHUNK_FLASH_ATTN = enum.auto()
DIFFERENTIAL_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()

View File

@ -2888,8 +2888,9 @@ def get_mp_context():
def bind_kv_cache(
ctx: dict[str, Any],
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
ctx: dict[str, Any],
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
shared_kv_cache_layers: Optional[dict[str, str]] = None
) -> None:
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
@ -2901,12 +2902,17 @@ def bind_kv_cache(
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
# tensor
# 5. Some models have attention layers that share kv cache with previous
# layers, this is specified through shared_kv_cache_layers
if shared_kv_cache_layers is None:
shared_kv_cache_layers = {}
from vllm.attention import AttentionType
from vllm.model_executor.models.utils import extract_layer_index
layer_need_kv_cache = [
layer_name for layer_name in ctx
if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type
in (AttentionType.DECODER, AttentionType.ENCODER_DECODER))
in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) \
and ctx[layer_name].kv_sharing_target_layer_name is None
]
layer_index_sorted = sorted(
set(
@ -2919,6 +2925,12 @@ def bind_kv_cache(
assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache):
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
if shared_kv_cache_layers is not None:
for layer_name, target_layer_name in shared_kv_cache_layers.items():
assert extract_layer_index(target_layer_name) < \
extract_layer_index(layer_name), \
"v0 doesn't support interleaving kv sharing"
ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache
def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any],

View File

@ -1112,6 +1112,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
dtype=np.int32)
self.cross_layer_shared_graph_block_tables = np.zeros(
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
dtype=np.int32)
# Attention-free but stateful models like Mamba need a placeholder attn
# backend, as the attention metadata is needed to manage internal state.
# However we must bypass attention selection altogether for some models

View File

@ -9,7 +9,8 @@ import torch
import torch.distributed
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
@ -345,8 +346,29 @@ class Worker(LocalOrDistributedWorkerBase):
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
# means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
shared_kv_cache_layers: dict[str, str] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if (kv_tgt_layer :=
attn_module.kv_sharing_target_layer_name) is not None:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
shared_kv_cache_layers[layer_name] = kv_tgt_layer
bind_kv_cache(self.compilation_config.static_forward_context,
self.gpu_cache)
self.gpu_cache, shared_kv_cache_layers)
def _warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,