mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 15:42:32 +08:00
[Model] New model support for microsoft/Phi-4-mini-flash-reasoning (#20702)
Signed-off-by: Congcong Chen <congcongchen@microsoft.com>
This commit is contained in:
parent
b639327ad9
commit
2c11a738b3
@ -312,19 +312,20 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
|
||||
constexpr bool kIsVariableB = true;
|
||||
constexpr bool kIsVariableC = true;
|
||||
constexpr bool kHasZ = true;
|
||||
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
||||
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
|
||||
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -612,19 +613,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
|
||||
at::Tensor z, out_z;
|
||||
const bool has_z = z_.has_value();
|
||||
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(z, dim, seqlen);
|
||||
} else {
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
if (has_z) {
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(z, dim, seqlen);
|
||||
} else {
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
}
|
||||
|
||||
out_z = z;
|
||||
}
|
||||
|
||||
out_z = z;
|
||||
|
||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||
at::Tensor out = delta;
|
||||
TORCH_CHECK(ssm_states.scalar_type() == input_type);
|
||||
@ -653,4 +655,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -374,6 +374,7 @@ Specified using `--task generate`.
|
||||
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
|
||||
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | |
|
||||
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
|
||||
@ -248,6 +248,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
|
||||
trust_remote_code=True,
|
||||
v0_only=True),
|
||||
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
|
||||
trust_remote_code=True,
|
||||
v0_only=True,
|
||||
max_model_len=10240),
|
||||
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
|
||||
trust_remote_code=True),
|
||||
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
|
||||
|
||||
@ -103,6 +103,9 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
||||
_initialize_kv_caches_v1), monkeypatch.context() as m):
|
||||
if model_info.v0_only:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
if model_arch == "Phi4FlashForCausalLM":
|
||||
# Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN")
|
||||
LLM(
|
||||
model_info.default,
|
||||
tokenizer=model_info.tokenizer,
|
||||
|
||||
@ -458,6 +458,31 @@ def test_bind_kv_cache():
|
||||
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
|
||||
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]
|
||||
|
||||
def test_bind_kv_cache_kv_sharing():
|
||||
from vllm.attention import Attention
|
||||
|
||||
ctx = {
|
||||
'layers.0.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.1.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.2.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.3.self_attn': Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = [
|
||||
torch.zeros((1, )),
|
||||
torch.zeros((1, )),
|
||||
torch.zeros((1, )),
|
||||
torch.zeros((1, )),
|
||||
]
|
||||
shared_kv_cache_layers = {
|
||||
'layers.2.self_attn': 'layers.1.self_attn',
|
||||
'layers.3.self_attn': 'layers.0.self_attn'
|
||||
}
|
||||
bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers)
|
||||
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
|
||||
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
|
||||
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[1]
|
||||
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[0]
|
||||
|
||||
def test_bind_kv_cache_non_attention():
|
||||
from vllm.attention import Attention
|
||||
|
||||
|
||||
@ -308,7 +308,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"BLOCK_SPARSE_FLASH_ATTN Backend.")
|
||||
assert blocksparse_params is not None
|
||||
assert alibi_slopes is None, ValueError(
|
||||
"Alibi not support for blocksparse flash attention.")
|
||||
|
||||
1000
vllm/attention/backends/differential_flash_attn.py
Normal file
1000
vllm/attention/backends/differential_flash_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -295,7 +295,8 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
||||
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"DUAL_CHUNK_FLASH_ATTN backend.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
|
||||
@ -622,7 +622,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"FLASH_ATTN backend.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
|
||||
@ -1006,7 +1006,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"FLASHINFER backend.")
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in FlashInfer is not supported yet, it will fall"
|
||||
|
||||
@ -115,7 +115,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
) -> None:
|
||||
super(AttentionImpl, self).__init__()
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"HPU_ATTN backend.")
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in HPU is not supported yet, it will fall back "
|
||||
|
||||
@ -501,7 +501,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"ROCM_FLASH backend.")
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in ROCm Flash Attention is not supported yet, it "
|
||||
|
||||
@ -394,7 +394,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"XFORMERS backend.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"XFormers does not support block-sparse attention.")
|
||||
|
||||
@ -160,10 +160,6 @@ class Attention(nn.Module):
|
||||
self.attn_type = attn_type
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"Cross-layer KV sharing is not supported in V0.")
|
||||
|
||||
validate_kv_sharing_target(
|
||||
prefix,
|
||||
kv_sharing_target_layer_name,
|
||||
|
||||
@ -59,11 +59,12 @@ class LogitsProcessor(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: Optional[SamplingMetadata] = None,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
prune_hidden_states: bool = True,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if self.logits_as_input:
|
||||
logits = hidden_states
|
||||
else:
|
||||
if sampling_metadata is not None:
|
||||
if sampling_metadata is not None and prune_hidden_states:
|
||||
hidden_states = _prune_hidden_states(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
|
||||
746
vllm/model_executor/models/phi4flash.py
Normal file
746
vllm/model_executor/models/phi4flash.py
Normal file
@ -0,0 +1,746 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
||||
SupportsV0Only)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import make_layers, maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SwiGLUActivation(nn.Module):
|
||||
|
||||
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
||||
return x1 * nn.functional.silu(x2)
|
||||
|
||||
|
||||
class SambaYMLP(nn.Module):
|
||||
"""Gated Linear Unit.
|
||||
|
||||
Reference:
|
||||
Language Modeling with Gated Convolutional Networks.
|
||||
https://arxiv.org/pdf/1612.08083v3.pdf.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.fc1 = nn.Linear(config.hidden_size,
|
||||
2 * config.intermediate_size,
|
||||
bias=False)
|
||||
self.fc2 = nn.Linear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
y = self.fc1(hidden_states)
|
||||
gate, y = y.chunk(2, dim=-1)
|
||||
y = y * self.activation_fn(gate)
|
||||
return self.fc2(y)
|
||||
|
||||
|
||||
def get_virtual_engine():
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
return forward_context.virtual_engine
|
||||
|
||||
|
||||
class SambaYAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
layer_idx: Optional[int] = None,
|
||||
yoco_cross: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
f"Instantiating {self.__class__.__name__} without passing "
|
||||
"a `layer_idx` is not recommended and will lead to errors "
|
||||
"during the forward call if caching is used. Please make "
|
||||
"sure to provide a `layer_idx` when creating this class.")
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.yoco_cross = yoco_cross
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError("hidden_size must be divisible by num_heads "
|
||||
f"(got `hidden_size`: {self.hidden_size} and "
|
||||
f"`num_heads`: {self.num_heads}).")
|
||||
|
||||
op_size = self.num_heads * self.head_dim + 2 * (
|
||||
self.num_key_value_heads * self.head_dim)
|
||||
self.out_proj = nn.Linear(self.num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=True)
|
||||
if yoco_cross:
|
||||
self.Wqkv = nn.Linear(self.hidden_size,
|
||||
self.num_heads * self.head_dim,
|
||||
bias=True)
|
||||
else:
|
||||
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
|
||||
|
||||
# disable sliding window for the second half of the model
|
||||
sliding_window = config.interleaved_sliding_window[layer_idx]
|
||||
if layer_idx >= config.num_hidden_layers // 2:
|
||||
assert sliding_window is None, \
|
||||
"sliding_window must be none for the second decoder"
|
||||
else:
|
||||
assert sliding_window is not None, \
|
||||
"sliding_window must be set for the first decoder"
|
||||
|
||||
assert self.num_heads % 2 == 0, 'num_heads should be even'
|
||||
assert self.num_key_value_heads % 2 == 0, 'num_heads should be even'
|
||||
|
||||
self.lambda_init = self.lambda_init_fn(layer_idx)
|
||||
self.lambda_q1 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
||||
std=0.1))
|
||||
self.lambda_k1 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
||||
std=0.1))
|
||||
self.lambda_q2 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
||||
std=0.1))
|
||||
self.lambda_k2 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
||||
std=0.1))
|
||||
self.subln = nn.RMSNorm(2 * self.head_dim,
|
||||
eps=1e-5,
|
||||
elementwise_affine=True)
|
||||
|
||||
params = {
|
||||
'differential_flash_attention_config': {
|
||||
'lambda_init': self.lambda_init,
|
||||
'lambda_q1': self.lambda_q1,
|
||||
'lambda_k1': self.lambda_k1,
|
||||
'lambda_q2': self.lambda_q2,
|
||||
'lambda_k2': self.lambda_k2,
|
||||
"subln": self.subln,
|
||||
}
|
||||
}
|
||||
|
||||
if yoco_cross:
|
||||
kv_shared_layer_index = config.num_hidden_layers // 2 + 1
|
||||
kv_sharing_target_layer_name = \
|
||||
f"model.layers.{kv_shared_layer_index}.self_attn.attn"
|
||||
else:
|
||||
kv_sharing_target_layer_name = None
|
||||
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.head_dim**-0.5,
|
||||
num_kv_heads=self.num_key_value_heads,
|
||||
cache_config=cache_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||
**params)
|
||||
assert self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN,\
|
||||
"DIFFERENTIAL_FLASH_ATTN required"
|
||||
|
||||
def lambda_init_fn(self, depth):
|
||||
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
|
||||
if not self.yoco_cross: # need to generate kv-cache
|
||||
qkv = self.Wqkv(hidden_states)
|
||||
q, k, v = qkv.split([
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim,
|
||||
self.num_key_value_heads * self.head_dim
|
||||
],
|
||||
dim=-1)
|
||||
attn_output = self.attn(q, k, v)
|
||||
else: # re-use the kv cache, full attention
|
||||
q = self.Wqkv(hidden_states)
|
||||
attn_output = self.attn(q, None, None)
|
||||
attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
|
||||
return self.out_proj(attn_output)
|
||||
|
||||
|
||||
class Phi4Mamba(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
d_state=16,
|
||||
d_conv=4,
|
||||
expand=2,
|
||||
dt_rank="auto",
|
||||
dt_min=0.001,
|
||||
dt_max=0.1,
|
||||
dt_init="random", # difference
|
||||
dt_scale=1.0, # difference
|
||||
dt_init_floor=1e-4,
|
||||
conv_bias=True,
|
||||
bias=False,
|
||||
use_fast_path=True, # Fused kernel options
|
||||
layer_idx=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
yoco_cross=False,
|
||||
yoco_kv=False,
|
||||
):
|
||||
factory_kwargs = {"params_dtype": dtype} # difference
|
||||
super().__init__()
|
||||
self.yoco_cross = yoco_cross
|
||||
self.yoco_kv = yoco_kv
|
||||
self.d_model = d_model
|
||||
self.d_state = d_state
|
||||
self.d_conv = d_conv
|
||||
self.expand = expand
|
||||
self.d_inner = int(self.expand * self.d_model)
|
||||
self.dt_rank = math.ceil(self.d_model /
|
||||
16) if dt_rank == "auto" else dt_rank
|
||||
self.use_fast_path = use_fast_path
|
||||
self.layer_idx = layer_idx
|
||||
self.swiGluActivation = SwiGLUActivation()
|
||||
if self.yoco_cross:
|
||||
self.in_proj = MergedColumnParallelLinear(self.d_model,
|
||||
[self.d_inner],
|
||||
bias=bias,
|
||||
**factory_kwargs)
|
||||
self.out_proj = RowParallelLinear(self.d_inner,
|
||||
self.d_model,
|
||||
bias=bias,
|
||||
**factory_kwargs)
|
||||
return
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=d_conv,
|
||||
output_size=self.d_inner,
|
||||
bias=conv_bias,
|
||||
params_dtype=dtype,
|
||||
)
|
||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||
# Can't do this in `weight_loader` since it already exists in
|
||||
# `ColumnParallelLinear` and `set_weight_attrs`
|
||||
# doesn't allow to override it
|
||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
self.d_model,
|
||||
[self.d_inner] * 2,
|
||||
bias=bias,
|
||||
params_dtype=dtype,
|
||||
)
|
||||
|
||||
# selective projection used to make dt, B and C input dependent
|
||||
self.x_proj = RowParallelLinear(
|
||||
self.d_inner,
|
||||
self.dt_rank + self.d_state * 2,
|
||||
bias=False,
|
||||
params_dtype=dtype,
|
||||
)
|
||||
|
||||
# time step projection (discretization) -
|
||||
# In the forward we need to apply dt_proj without the bias,
|
||||
# as the bias is added in the selective scan kernel.
|
||||
self.dt_proj = ColumnParallelLinear(
|
||||
self.dt_rank,
|
||||
self.d_inner,
|
||||
bias=True,
|
||||
skip_bias_add=True,
|
||||
params_dtype=dtype,
|
||||
)
|
||||
|
||||
# # D "skip" parameter
|
||||
# self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32
|
||||
self.A = nn.Parameter(
|
||||
torch.empty(
|
||||
self.d_inner,
|
||||
self.d_state,
|
||||
dtype=torch.float32,
|
||||
))
|
||||
self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32))
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.d_inner,
|
||||
self.d_model,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
params_dtype=dtype,
|
||||
)
|
||||
self.activation = "silu"
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
yoco_key_values=None) -> torch.Tensor:
|
||||
|
||||
if self.yoco_cross:
|
||||
out = self.in_proj(hidden_states)[0]
|
||||
out = self.swiGluActivation(yoco_key_values, out)
|
||||
out = self.out_proj(out)
|
||||
return out[0], yoco_key_values
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
# projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
projected_states = self.in_proj(
|
||||
hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1)
|
||||
hidden_states, gate = projected_states.chunk(2, dim=-2)
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if attn_metadata.query_start_loc is not None \
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
hidden_states = causal_conv1d_fn(
|
||||
hidden_states,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=mamba_cache_params.conv_state,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
hidden_states = causal_conv1d_update(
|
||||
hidden_states.transpose(0, 1),
|
||||
mamba_cache_params.conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=mamba_cache_params.state_indices_tensor)
|
||||
hidden_states = hidden_states.transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
# 3.a. input varying initialization of time_step, B and C
|
||||
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
|
||||
|
||||
time_step, B, C = torch.split(
|
||||
ssm_parameters,
|
||||
[self.dt_rank, self.d_state, self.d_state],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
|
||||
|
||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
||||
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
|
||||
self.dt_proj, "bias") else None)
|
||||
|
||||
if attn_metadata.query_start_loc is not None \
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
scan_outputs = selective_scan_fn(
|
||||
hidden_states,
|
||||
mamba_cache_params.ssm_state,
|
||||
discrete_time_step,
|
||||
self.A,
|
||||
B.transpose(-2, -1),
|
||||
C.transpose(-2, -1),
|
||||
self.D.float(),
|
||||
# z,
|
||||
None if self.yoco_kv else gate,
|
||||
time_proj_bias,
|
||||
delta_softplus=True,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
scan_outputs = selective_state_update(
|
||||
mamba_cache_params.ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
self.A,
|
||||
B,
|
||||
C,
|
||||
self.D,
|
||||
# z
|
||||
# gate.transpose(0, 1),
|
||||
None if self.yoco_kv else gate.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor)
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
|
||||
# 4. Final linear projection
|
||||
if self.yoco_kv:
|
||||
# gate = gate.transpose(-1,-2).contiguous()
|
||||
yoco_key_values = scan_outputs.transpose(-2, -1)
|
||||
scan_outputs = self.swiGluActivation(scan_outputs, gate)
|
||||
|
||||
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
|
||||
-1))[0]
|
||||
|
||||
return contextualized_states, yoco_key_values
|
||||
|
||||
|
||||
class SambaYDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
layer_idx,
|
||||
cache_config,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.mlp = SambaYMLP(config)
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
self.yoco_mb = False
|
||||
self.yoco_cross = False
|
||||
if layer_idx >= config.num_hidden_layers // 2:
|
||||
self.yoco_mb = True
|
||||
self.yoco_cross = (layer_idx
|
||||
>= (config.num_hidden_layers // 2 + 2))
|
||||
self.use_mamba = config.mb_per_layer > 0 and \
|
||||
layer_idx % config.mb_per_layer == 0
|
||||
if self.use_mamba:
|
||||
factory_kwargs = {"dtype": None}
|
||||
self.attn = Phi4Mamba(config.hidden_size,
|
||||
layer_idx=layer_idx,
|
||||
yoco_cross=self.yoco_cross,
|
||||
yoco_kv=self.yoco_mb,
|
||||
**factory_kwargs)
|
||||
else:
|
||||
self.attn = SambaYAttention(config,
|
||||
layer_idx=layer_idx,
|
||||
yoco_cross=self.yoco_cross,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
ssm_output: Optional[torch.LongTensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if self.use_mamba:
|
||||
assert mamba_cache_params is not None
|
||||
else:
|
||||
assert mamba_cache_params is None
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(
|
||||
hidden_states.to(dtype=self.input_layernorm.weight.dtype))
|
||||
|
||||
if self.use_mamba:
|
||||
attn_outputs, ssm_output = self.attn(hidden_states,
|
||||
attn_metadata,
|
||||
mamba_cache_params,
|
||||
yoco_key_values=ssm_output)
|
||||
residual = residual.to(torch.float32)
|
||||
else:
|
||||
attn_outputs = self.attn(hidden_states, )
|
||||
hidden_states = residual + attn_outputs
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(
|
||||
hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype))
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states, ssm_output
|
||||
|
||||
|
||||
class SambaYModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
cache_config=None,
|
||||
quant_config=None,
|
||||
lora_config=None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
# Pipeline parallel is not supported since the second half of
|
||||
# the layers share the kv cache.
|
||||
if get_pp_group().world_size != 1:
|
||||
raise ValueError("Pipeline Parallel not supported")
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: SambaYDecoderLayer(config,
|
||||
int(prefix.split('.')[-1]),
|
||||
cache_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
mamba_state_idx = 0
|
||||
ssm_output = None
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
if i == self.config.num_hidden_layers // 2 + 2:
|
||||
# profile run
|
||||
kv_cache_idx = self.config.num_hidden_layers // 2 + 1
|
||||
cache_layer = self.layers[kv_cache_idx]
|
||||
kv_cache = cache_layer.attn.attn.kv_cache
|
||||
if kv_cache[0].numel() == 0:
|
||||
break
|
||||
|
||||
# Starting from this layer, we do not need to calculate
|
||||
# the kv cache since we reuse the kv cache from last layer.
|
||||
# If in prefill phase, we can <s>prune></s> truncate
|
||||
# the hidden state to save computation cost.
|
||||
if attn_metadata.prefill_metadata and not envs.VLLM_USE_V1:
|
||||
selected_token_indices = torch.cumsum(
|
||||
attn_metadata.seq_lens_tensor, dim=0) - 1
|
||||
hidden_states = hidden_states.index_select(
|
||||
0, selected_token_indices)
|
||||
ssm_output = ssm_output.index_select(
|
||||
0, selected_token_indices)
|
||||
|
||||
if layer.use_mamba:
|
||||
if i < self.config.num_hidden_layers // 2 or \
|
||||
not layer.yoco_cross:
|
||||
mamba_cache = mamba_cache_params.at_layer_idx(
|
||||
mamba_state_idx)
|
||||
mamba_state_idx += 1
|
||||
else:
|
||||
mamba_cache = mamba_cache_params.at_layer_idx(
|
||||
mamba_state_idx - 1)
|
||||
|
||||
hidden_states, ssm_output = layer(hidden_states,
|
||||
positions,
|
||||
attn_metadata,
|
||||
mamba_cache,
|
||||
ssm_output=ssm_output)
|
||||
else:
|
||||
hidden_states, ssm_output = layer(
|
||||
hidden_states,
|
||||
positions,
|
||||
attn_metadata,
|
||||
None, # mamba_cache_params
|
||||
ssm_output=ssm_output)
|
||||
|
||||
hidden_states = self.final_layernorm(
|
||||
hidden_states.to(dtype=self.final_layernorm.weight.dtype))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
quant_config = vllm_config.quant_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.vllm_config = vllm_config
|
||||
# Prefix caching and chunked prefill is not supported for this model.
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Phi4flash currently does not support prefix caching"
|
||||
assert not scheduler_config.chunked_prefill_enabled, \
|
||||
"Phi4Flash currently does not support prefix caching"
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model = SambaYModel(config,
|
||||
cache_config=cache_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=(
|
||||
DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size),
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.embedding_bias = None
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
logits_as_input=False)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = self.config.num_hidden_layers \
|
||||
// 2 // self.config.mb_per_layer + 1
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# input_ids and hidden_states isn't a one-to-one mapping in prefill
|
||||
# stage due to YOCO optimization.
|
||||
hidden_states = self.model(input_ids, positions, attn_metadata,
|
||||
mamba_cache_params, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def _get_mamba_cache_shape(
|
||||
self
|
||||
) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
hidden_size = self.config.hidden_size
|
||||
mamba_expand = self.config.mamba_expand # 2
|
||||
mamba_d_conv = self.config.mamba_d_conv # 4
|
||||
mamba_d_state = self.config.mamba_d_state # 16
|
||||
conv_state_shape = (
|
||||
mamba_expand * hidden_size // world_size,
|
||||
mamba_d_conv - 1,
|
||||
)
|
||||
temporal_state_shape = (
|
||||
mamba_expand * hidden_size // world_size,
|
||||
mamba_d_state,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# If the shape is the same, it means that we have already
|
||||
# prune hidden states manually.
|
||||
prune_hidden_states = hidden_states.size(
|
||||
0) != sampling_metadata.selected_token_indices.size(0)
|
||||
processed_logits = self.logits_processor(
|
||||
self.lm_head,
|
||||
hidden_states,
|
||||
sampling_metadata,
|
||||
self.embedding_bias,
|
||||
prune_hidden_states=prune_hidden_states)
|
||||
return processed_logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
):
|
||||
weights = {name: weight for name, weight in weights}
|
||||
adjusted_weights = {}
|
||||
for name, weight in weights.items():
|
||||
if "A_log" in name:
|
||||
name = name.replace("A_log", "A")
|
||||
weight = -torch.exp(weight.float())
|
||||
if "inner_cross_attn." in name:
|
||||
name = name.replace("inner_cross_attn.", "")
|
||||
adjusted_weights[name] = weight
|
||||
adjusted_weights["lm_head.weight"] = weights[
|
||||
"model.embed_tokens.weight"]
|
||||
loaded_params: set[str] = set()
|
||||
for name, param in self.named_parameters():
|
||||
weight = adjusted_weights.get(name)
|
||||
if weight is not None and weight.shape != param.shape:
|
||||
logger.warning("Shape mismatch: %s %s %s", name, weight.shape,
|
||||
param.shape)
|
||||
loaded_params.add(name)
|
||||
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
|
||||
strict=False)
|
||||
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
|
||||
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
|
||||
return loaded_params
|
||||
@ -110,6 +110,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
||||
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
||||
"Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
|
||||
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
|
||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
|
||||
@ -316,6 +316,10 @@ class CudaPlatformBase(Platform):
|
||||
logger.info("Using DualChunkFlashAttention backend.")
|
||||
return ("vllm.attention.backends.dual_chunk_flash_attn."
|
||||
"DualChunkFlashAttentionBackend")
|
||||
elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN:
|
||||
logger.info("Using DifferentialFlashAttention backend.")
|
||||
return ("vllm.attention.backends.differential_flash_attn."
|
||||
"DifferentialFlashAttentionBackend")
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
pass
|
||||
elif selected_backend:
|
||||
|
||||
@ -60,6 +60,7 @@ class _Backend(enum.Enum):
|
||||
IPEX = enum.auto()
|
||||
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
|
||||
DUAL_CHUNK_FLASH_ATTN = enum.auto()
|
||||
DIFFERENTIAL_FLASH_ATTN = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
FLEX_ATTENTION = enum.auto()
|
||||
|
||||
|
||||
@ -2888,8 +2888,9 @@ def get_mp_context():
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
ctx: dict[str, Any],
|
||||
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
|
||||
ctx: dict[str, Any],
|
||||
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
|
||||
shared_kv_cache_layers: Optional[dict[str, str]] = None
|
||||
) -> None:
|
||||
# Bind the kv_cache tensor to Attention modules, similar to
|
||||
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
|
||||
@ -2901,12 +2902,17 @@ def bind_kv_cache(
|
||||
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
|
||||
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
|
||||
# tensor
|
||||
# 5. Some models have attention layers that share kv cache with previous
|
||||
# layers, this is specified through shared_kv_cache_layers
|
||||
if shared_kv_cache_layers is None:
|
||||
shared_kv_cache_layers = {}
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
layer_need_kv_cache = [
|
||||
layer_name for layer_name in ctx
|
||||
if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type
|
||||
in (AttentionType.DECODER, AttentionType.ENCODER_DECODER))
|
||||
in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) \
|
||||
and ctx[layer_name].kv_sharing_target_layer_name is None
|
||||
]
|
||||
layer_index_sorted = sorted(
|
||||
set(
|
||||
@ -2919,6 +2925,12 @@ def bind_kv_cache(
|
||||
assert len(forward_ctx.kv_cache) == len(kv_cache)
|
||||
for ve, ve_kv_cache in enumerate(kv_cache):
|
||||
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
|
||||
if shared_kv_cache_layers is not None:
|
||||
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
||||
assert extract_layer_index(target_layer_name) < \
|
||||
extract_layer_index(layer_name), \
|
||||
"v0 doesn't support interleaving kv sharing"
|
||||
ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache
|
||||
|
||||
|
||||
def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any],
|
||||
|
||||
@ -1112,6 +1112,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
|
||||
dtype=np.int32)
|
||||
|
||||
self.cross_layer_shared_graph_block_tables = np.zeros(
|
||||
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
|
||||
dtype=np.int32)
|
||||
|
||||
# Attention-free but stateful models like Mamba need a placeholder attn
|
||||
# backend, as the attention metadata is needed to manage internal state.
|
||||
# However we must bypass attention selection altogether for some models
|
||||
|
||||
@ -9,7 +9,8 @@ import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
@ -345,8 +346,29 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
self.cache_engine[ve].gpu_cache
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Layer pairings for cross-layer KV sharing.
|
||||
# If an Attention layer `layer_name` is in the keys of this dict, it
|
||||
# means this layer will perform attention using the keys and values
|
||||
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||
shared_kv_cache_layers: dict[str, str] = {}
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if (kv_tgt_layer :=
|
||||
attn_module.kv_sharing_target_layer_name) is not None:
|
||||
# The layer doesn't need its own KV cache and will use that of
|
||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||
# that KV cache management logic will act as this layer does
|
||||
# not exist, and doesn't allocate KV cache for the layer. This
|
||||
# enables the memory saving of cross-layer kv sharing, allowing
|
||||
# a given amount of memory to accommodate longer context lengths
|
||||
# or enable more requests to be processed simultaneously.
|
||||
shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
self.gpu_cache)
|
||||
self.gpu_cache, shared_kv_cache_layers)
|
||||
|
||||
def _warm_up_model(self) -> None:
|
||||
# warm up sizes that are not in cudagraph capture sizes,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user