mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 08:45:01 +08:00
[V1] Remove V0 code paths for Hybrid models (#25400)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
02134245a9
commit
f97da2c732
@ -20,7 +20,9 @@ pytestmark = pytest.mark.hybrid_model
|
||||
SSM_MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"tiiuae/falcon-mamba-tiny-dev",
|
||||
"yujiepan/mamba2-codestral-v0.1-tiny-random",
|
||||
# mamba2-codestral in transformers is broken pending:
|
||||
# https://github.com/huggingface/transformers/pull/40861
|
||||
#"yujiepan/mamba2-codestral-v0.1-tiny-random",
|
||||
]
|
||||
|
||||
HYBRID_MODELS = [
|
||||
@ -31,18 +33,7 @@ HYBRID_MODELS = [
|
||||
"ibm-granite/granite-4.0-tiny-preview",
|
||||
"tiiuae/Falcon-H1-0.5B-Base",
|
||||
"LiquidAI/LFM2-1.2B",
|
||||
]
|
||||
|
||||
V1_SUPPORTED_MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
"pfnet/plamo-2-1b",
|
||||
"yujiepan/mamba2-codestral-v0.1-tiny-random",
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
"hmellor/tiny-random-BambaForCausalLM",
|
||||
"ibm-granite/granite-4.0-tiny-preview",
|
||||
"tiiuae/Falcon-H1-0.5B-Base",
|
||||
"LiquidAI/LFM2-1.2B",
|
||||
"tiny-random/qwen3-next-moe",
|
||||
]
|
||||
|
||||
FULL_CUDA_GRAPH_MODELS = [
|
||||
@ -51,10 +42,6 @@ FULL_CUDA_GRAPH_MODELS = [
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
]
|
||||
|
||||
V0_UNSUPPORTED_MODELS = [
|
||||
"LiquidAI/LFM2-1.2B",
|
||||
]
|
||||
|
||||
FP32_STATE_MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
@ -88,20 +75,16 @@ def test_models(
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
vllm_v1_outputs = None
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_v1_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm-v1",
|
||||
)
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@ -299,14 +282,14 @@ def test_full_cuda_graph(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_v1_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm-v1",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@ -340,12 +323,12 @@ def test_fp32_cache_state(
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
**{cache_dtype_param: "float32"}) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_v1_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm-v1",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
@ -312,14 +312,12 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
|
||||
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
|
||||
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
|
||||
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
|
||||
trust_remote_code=True,
|
||||
v0_only=True,
|
||||
max_model_len=10240),
|
||||
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
|
||||
trust_remote_code=True),
|
||||
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
|
||||
trust_remote_code=True),
|
||||
max_transformers_version="4.55.4",
|
||||
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
|
||||
max_transformers_version="4.53",
|
||||
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
|
||||
@ -330,7 +328,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
||||
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
||||
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
min_transformers_version="4.56.2"),
|
||||
extras={"tiny-random": "tiny-random/qwen3-next-moe"}, # noqa: E501
|
||||
min_transformers_version="4.56.3"),
|
||||
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
||||
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
|
||||
trust_remote_code=True,
|
||||
@ -644,7 +643,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
|
||||
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
min_transformers_version="4.56.2"),
|
||||
min_transformers_version="4.56.3"),
|
||||
}
|
||||
|
||||
_TRANSFORMERS_BACKEND_MODELS = {
|
||||
|
||||
@ -20,10 +20,7 @@ class MambaBase(AttentionLayerBase):
|
||||
|
||||
# Contains the KV cache (mamba state) for the layer
|
||||
# in the shape specified by `self.get_state_shape`.
|
||||
# The outer list is for v0 PP virtual engine. Though this code path
|
||||
# only runs for v1, we have to do this to unify with the interface
|
||||
# of Attention + v0 PP.
|
||||
kv_cache: list[Iterable[torch.Tensor]]
|
||||
kv_cache: tuple[torch.Tensor, ...]
|
||||
|
||||
@abstractmethod
|
||||
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
|
||||
|
||||
@ -15,7 +15,6 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
@ -42,8 +41,6 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.model_executor.models.minimax_cache import MinimaxCacheParams
|
||||
|
||||
|
||||
class MiniMaxText01RMSNormTP(CustomOp):
|
||||
name = "MiniMaxText01RMSNormTP"
|
||||
@ -225,11 +222,10 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
self.tp_heads:(self.tp_rank + 1) *
|
||||
self.tp_heads].contiguous()
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
@staticmethod
|
||||
def weight_direct_load(param: torch.Tensor,
|
||||
@ -268,8 +264,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
break
|
||||
if _prefill_idx >= len(state_indices_tensor):
|
||||
break
|
||||
# prefills are packed at end of batch in V1
|
||||
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
|
||||
offset = attn_metadata.num_decode_tokens
|
||||
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
||||
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
||||
slot_id = state_indices_tensor[offset + _prefill_idx]
|
||||
@ -291,10 +286,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
hidden_decode = self._decode_infer(q, k, v, kv_cache,
|
||||
state_indices_tensor,
|
||||
attn_metadata)
|
||||
if envs.VLLM_USE_V1:
|
||||
hidden.insert(0, hidden_decode)
|
||||
else:
|
||||
hidden.append(hidden_decode)
|
||||
hidden.insert(0, hidden_decode)
|
||||
|
||||
if not hidden:
|
||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||
@ -304,40 +296,28 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
|
||||
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
|
||||
attn_metadata):
|
||||
if not envs.VLLM_USE_V1:
|
||||
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
slot_id = state_indices_tensor[num_prefills:]
|
||||
else:
|
||||
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
|
||||
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
|
||||
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
|
||||
slot_id, 32)
|
||||
return hidden
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: MinimaxCacheParams) -> None:
|
||||
if not envs.VLLM_USE_V1:
|
||||
self._forward(hidden_states, output, positions, kv_caches)
|
||||
else:
|
||||
torch.ops.vllm.linear_attention(
|
||||
hidden_states,
|
||||
output,
|
||||
positions,
|
||||
self.prefix,
|
||||
)
|
||||
positions: torch.Tensor) -> None:
|
||||
torch.ops.vllm.linear_attention(
|
||||
hidden_states,
|
||||
output,
|
||||
positions,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[MinimaxCacheParams]) -> None:
|
||||
positions: torch.Tensor) -> None:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
if envs.VLLM_USE_V1 and attn_metadata is not None:
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
||||
@ -351,32 +331,26 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
qkvact = torch.nn.functional.silu(qkv32)
|
||||
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
|
||||
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
||||
if envs.VLLM_USE_V1:
|
||||
if attn_metadata is not None:
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
if attn_metadata is not None:
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
if num_prefills > 0:
|
||||
num_decode_tokens = getattr(attn_metadata,
|
||||
"num_decode_tokens", 0)
|
||||
for prefill_idx in range(num_prefills):
|
||||
q_start = attn_metadata.query_start_loc[
|
||||
num_decode_tokens + prefill_idx]
|
||||
q_end = attn_metadata.query_start_loc[num_decode_tokens
|
||||
+ prefill_idx +
|
||||
1]
|
||||
query_len = q_end - q_start
|
||||
context_len = attn_metadata.seq_lens[
|
||||
num_decode_tokens + prefill_idx] - query_len
|
||||
if context_len == 0:
|
||||
block_to_clear = state_indices_tensor[
|
||||
num_decode_tokens + prefill_idx]
|
||||
kv_cache[block_to_clear, ...] = 0
|
||||
else:
|
||||
assert kv_caches is not None
|
||||
kv_cache = kv_caches.minimax_cache
|
||||
state_indices_tensor = kv_caches.state_indices_tensor
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
if num_prefills > 0:
|
||||
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens",
|
||||
0)
|
||||
for prefill_idx in range(num_prefills):
|
||||
q_start = attn_metadata.query_start_loc[num_decode_tokens +
|
||||
prefill_idx]
|
||||
q_end = attn_metadata.query_start_loc[num_decode_tokens +
|
||||
prefill_idx + 1]
|
||||
query_len = q_end - q_start
|
||||
context_len = attn_metadata.seq_lens[
|
||||
num_decode_tokens + prefill_idx] - query_len
|
||||
if context_len == 0:
|
||||
block_to_clear = state_indices_tensor[num_decode_tokens
|
||||
+ prefill_idx]
|
||||
kv_cache[block_to_clear, ...] = 0
|
||||
|
||||
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
||||
if attn_metadata is None:
|
||||
@ -410,8 +384,7 @@ def linear_attention(
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self._forward(hidden_states=hidden_states,
|
||||
output=output,
|
||||
positions=positions,
|
||||
kv_caches=None)
|
||||
positions=positions)
|
||||
|
||||
|
||||
def linear_attention_fake(
|
||||
|
||||
@ -1,177 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionMetadata)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||
from vllm.v1.attention.backends.mamba2_attn import (
|
||||
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba2Metadata:
|
||||
prep_initial_states: bool
|
||||
chunk_size: int
|
||||
|
||||
has_initial_states_p: torch.Tensor
|
||||
seq_idx_p: torch.Tensor
|
||||
chunk_indices_p: torch.Tensor
|
||||
chunk_offsets_p: torch.Tensor
|
||||
"""
|
||||
With continuous batching layout of `x` in vLLM, to enable a Triton program
|
||||
to handle a request in parallel, two supporting tensors are used
|
||||
(batch_ptr, token_chunk_offset_ptr)
|
||||
BLOCK_M = the # tokens to be handled by a Triton program
|
||||
(can be customized for different hardware)
|
||||
|
||||
nums_dict:
|
||||
tracks the data associated with a given value of BLOCK_M
|
||||
BLOCK_M = #tokens handled by a Triton program
|
||||
cu_seqlen: total tokens per batch
|
||||
(used as flag to update other data at each new input)
|
||||
batch_ptr: tracks batch-id handled by the Triton program
|
||||
token_chunk_offset_ptr: tracks token group_idx handled by the Triton program
|
||||
(Triton implementation of causal_conv1d handles parallelism in 3-axes
|
||||
- feature-axis
|
||||
- batch-axis
|
||||
- sequence-axis)
|
||||
"""
|
||||
nums_dict: Optional[dict] = None
|
||||
cu_seqlen: Optional[int] = None
|
||||
batch_ptr: Optional[torch.Tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
||||
"""Returns the appropriate metadata classes for the current platform."""
|
||||
if current_platform.is_rocm():
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
||||
AiterFlashAttentionMetadata)
|
||||
from vllm.v1.attention.backends.triton_attn import (
|
||||
TritonAttentionMetadata)
|
||||
return (AiterFlashAttentionMetadata, TritonAttentionMetadata,
|
||||
PlaceholderAttentionMetadata)
|
||||
if current_platform.is_cuda():
|
||||
from vllm.v1.attention.backends.flash_attn import (
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.attention.backends.xformers import (
|
||||
XFormersAttentionMetadata)
|
||||
return (FlashAttentionMetadata, XFormersAttentionMetadata,
|
||||
PlaceholderAttentionMetadata)
|
||||
raise ValueError(
|
||||
f"Unsupported platform for Mamba2: {current_platform.device_type}")
|
||||
|
||||
|
||||
def prepare_mamba2_metadata(
|
||||
chunk_size: int,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Mamba2Metadata:
|
||||
|
||||
# compute number of prefill and decode requests
|
||||
# NOTE: in V0 we assume prefills are before decodes
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
|
||||
seq_idx_p = None
|
||||
chunk_indices_p, chunk_offsets_p = None, None
|
||||
# Need flags to indicate if there are initial states
|
||||
# currently we really only support the FlashAttention backend
|
||||
has_initial_states_p = None
|
||||
prep_initial_states = False
|
||||
|
||||
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
||||
if num_prefills > 0:
|
||||
attn_metadata_instances = get_platform_metadata_classes()
|
||||
if (isinstance(attn_metadata, attn_metadata_instances)
|
||||
and attn_metadata.context_lens_tensor is not None):
|
||||
# precompute flag to avoid device syncs later in mamba2 layer
|
||||
# forwards
|
||||
# prep is only needed for mamba2 ssd prefill processing
|
||||
has_initial_states_p = (
|
||||
attn_metadata.context_lens_tensor[:num_prefills] > 0)
|
||||
prep_initial_states = torch.any(has_initial_states_p).item()
|
||||
query_start_loc_p = attn_metadata.query_start_loc[:num_prefills + 1]
|
||||
seq_idx_p = torch.repeat_interleave(torch.arange(
|
||||
num_prefills, dtype=torch.int32, device=query_start_loc_p.device),
|
||||
query_start_loc_p.diff(),
|
||||
output_size=num_prefill_tokens)
|
||||
seq_idx_p.unsqueeze_(0)
|
||||
|
||||
# We compute metadata for chunked prefill once at the top level model
|
||||
# forward and reuse them in mamba layers. If not needed, they will be
|
||||
# ignored inside mamba kernels.
|
||||
if prep_initial_states:
|
||||
chunk_indices_p, chunk_offsets_p = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc_p, chunk_size, num_prefill_tokens)
|
||||
|
||||
return Mamba2Metadata(has_initial_states_p=has_initial_states_p,
|
||||
prep_initial_states=prep_initial_states,
|
||||
chunk_size=chunk_size,
|
||||
seq_idx_p=seq_idx_p,
|
||||
chunk_indices_p=chunk_indices_p,
|
||||
chunk_offsets_p=chunk_offsets_p)
|
||||
|
||||
|
||||
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
|
||||
mamba2_metadata: Union[Mamba2Metadata,
|
||||
Mamba2AttentionMetadata,
|
||||
GDNAttentionMetadata]):
|
||||
"""
|
||||
this is triggered upon handling a new input at the first layer
|
||||
"""
|
||||
dim, cu_seqlen = x.shape
|
||||
mamba2_metadata.cu_seqlen = cu_seqlen
|
||||
seqlens = np.diff(query_start_loc.to('cpu'))
|
||||
nums_dict = {} # type: ignore
|
||||
for BLOCK_M in [8]: # cover all BLOCK_M values
|
||||
nums = -(-seqlens // BLOCK_M)
|
||||
nums_dict[BLOCK_M] = {}
|
||||
nums_dict[BLOCK_M]['nums'] = nums
|
||||
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
|
||||
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
|
||||
nums_dict[BLOCK_M]['mlist'] = mlist
|
||||
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
|
||||
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
|
||||
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
|
||||
offsetlist = [] # type: ignore
|
||||
for idx, num in enumerate(nums):
|
||||
offsetlist.extend(range(num))
|
||||
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
|
||||
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
|
||||
|
||||
if mamba2_metadata.batch_ptr is None:
|
||||
# Update default value after class definition
|
||||
#mamba2_metadata.MAX_NUM_PROGRAMS *= 2
|
||||
mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
mamba2_metadata.token_chunk_offset_ptr = torch.full(
|
||||
(MAX_NUM_PROGRAMS, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
else:
|
||||
if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS:
|
||||
mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(
|
||||
PAD_SLOT_ID)
|
||||
mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore
|
||||
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
|
||||
|
||||
mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist)
|
||||
mamba2_metadata.token_chunk_offset_ptr[ # type: ignore
|
||||
0:mlist_len].copy_(offsetlist)
|
||||
nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr
|
||||
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (
|
||||
mamba2_metadata.token_chunk_offset_ptr) # type: ignore
|
||||
mamba2_metadata.nums_dict = nums_dict
|
||||
return mamba2_metadata
|
||||
@ -10,8 +10,6 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
@ -28,7 +26,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
@ -149,16 +146,12 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
has_weight=rms_norm_has_weight,
|
||||
) if use_rms_norm else None
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The outer list is for v0 PP virtual engine. Though this code path
|
||||
# only runs for v1, we have to do this to unify with the interface
|
||||
# of Attention + v0 PP.
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@ -186,29 +179,18 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||
return discrete_time_step, B, C
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
if not envs.VLLM_USE_V1:
|
||||
CustomOp.forward(self, hidden_states, output, mamba_cache_params)
|
||||
else:
|
||||
torch.ops.vllm.mamba_mixer(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
)
|
||||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
torch.ops.vllm.mamba_mixer(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def forward_native(self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
def forward_native(self, hidden_states: torch.Tensor,
|
||||
output: torch.Tensor):
|
||||
pass
|
||||
|
||||
def forward_cuda(self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
"""
|
||||
Run the Mamba-1 SSM pipeline.
|
||||
|
||||
@ -234,31 +216,18 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
mamba1_metadata = attn_metadata
|
||||
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
|
||||
query_start_loc = mamba1_metadata.query_start_loc
|
||||
state_indices_tensor = mamba1_metadata.state_indices_tensor
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states = mamba1_metadata.has_initial_states
|
||||
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
||||
else:
|
||||
assert isinstance(attn_metadata, AttentionMetadata)
|
||||
assert mamba_cache_params is not None
|
||||
conv_state = mamba_cache_params.conv_state
|
||||
ssm_state = mamba_cache_params.ssm_state
|
||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||
query_start_loc = attn_metadata.query_start_loc
|
||||
context_lens_tensor = attn_metadata.context_lens_tensor
|
||||
has_initial_states = None
|
||||
if context_lens_tensor is not None:
|
||||
has_initial_states = context_lens_tensor > 0
|
||||
num_padded_decodes = attn_metadata.num_decode_tokens
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
mamba1_metadata = attn_metadata
|
||||
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
|
||||
query_start_loc = mamba1_metadata.query_start_loc
|
||||
state_indices_tensor = mamba1_metadata.state_indices_tensor
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states = mamba1_metadata.has_initial_states
|
||||
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@ -267,7 +236,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if envs.VLLM_USE_V1 and attn_metadata is None:
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
hidden_states_BC = hidden_states_BC.contiguous()
|
||||
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
|
||||
@ -368,10 +337,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
out=scan_outputs_d)
|
||||
scan_outputs_d = scan_outputs_d.transpose(0, 1)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
ssm_outputs.insert(0, scan_outputs_d)
|
||||
else:
|
||||
ssm_outputs.append(scan_outputs_d)
|
||||
ssm_outputs.insert(0, scan_outputs_d)
|
||||
|
||||
scan_outputs_combined = ssm_outputs[0] if len(
|
||||
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
|
||||
@ -441,40 +407,27 @@ def split_batch_to_prefill_and_decode(
|
||||
num_decodes: int,
|
||||
num_padded_decodes: int,
|
||||
) -> PrefillDecodeSplit:
|
||||
|
||||
num_actual_tokens = num_prefill_tokens + num_padded_decodes
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
# In v1, decode tokens come first, then prefill tokens.
|
||||
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||
hidden_states_BC[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
gate_d, gate_p = torch.split(gate[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
# In v1, decode tokens come first, then prefill tokens.
|
||||
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||
hidden_states_BC[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
gate_d, gate_p = torch.split(gate[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
|
||||
# num_padded_decodes accounts for CUDA graph padding when applicable
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[:num_padded_decodes + num_prefills],
|
||||
[num_padded_decodes, num_prefills],
|
||||
dim=0)
|
||||
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
|
||||
num_padded_decodes if num_prefills > 0 else None)
|
||||
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
||||
has_initial_states is not None and num_prefills > 0) else None
|
||||
else:
|
||||
# In v0, prefill tokens come first, then decode tokens.
|
||||
hidden_states_BC_p, hidden_states_BC_d = torch.split(
|
||||
hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1)
|
||||
gate_p, gate_d = torch.split(gate,
|
||||
[num_prefill_tokens, num_decode_tokens],
|
||||
dim=-1)
|
||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||
state_indices_tensor, [num_prefills, num_decodes], dim=0)
|
||||
query_start_loc_p = (query_start_loc[:num_prefills +
|
||||
1] if num_prefills > 0 else None)
|
||||
has_initial_states_p = has_initial_states[:num_prefills] if (
|
||||
has_initial_states is not None and num_prefills > 0) else None
|
||||
# num_padded_decodes accounts for CUDA graph padding when applicable
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[:num_padded_decodes + num_prefills],
|
||||
[num_padded_decodes, num_prefills],
|
||||
dim=0)
|
||||
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
|
||||
num_padded_decodes if num_prefills > 0 else None)
|
||||
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
||||
has_initial_states is not None and num_prefills > 0) else None
|
||||
|
||||
return PrefillDecodeSplit(
|
||||
hidden_states_BC_p=hidden_states_BC_p,
|
||||
@ -495,9 +448,7 @@ def mamba_mixer(
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states,
|
||||
output=output,
|
||||
mamba_cache_params=None)
|
||||
self.forward_cuda(hidden_states=hidden_states, output=output)
|
||||
|
||||
|
||||
def mamba_mixer_fake(
|
||||
|
||||
@ -9,7 +9,6 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
@ -22,8 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||
update_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
@ -36,7 +33,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
@ -449,16 +445,12 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self.use_rms_norm,
|
||||
eps=rms_norm_eps)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The outer list is for v0 PP virtual engine. Though this code path
|
||||
# only runs for v1, we have to do this to unify with the interface
|
||||
# of Attention + v0 PP.
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@ -468,8 +460,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
@ -478,59 +468,43 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if not envs.VLLM_USE_V1:
|
||||
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
|
||||
mamba2_metadata, mup_vector)
|
||||
else:
|
||||
torch.ops.vllm.mamba_mixer2(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
mup_vector,
|
||||
)
|
||||
torch.ops.vllm.mamba_mixer2(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
mup_vector,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
||||
# attn_metadata contains metadata necessary for the mamba2 triton
|
||||
# kernels to operate in continuous batching and in chunked prefill
|
||||
# modes; they are computed at top-level model forward since they
|
||||
# stay the same and reused for all mamba layers in the same iteration
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
if envs.VLLM_USE_V1:
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
mamba2_metadata = attn_metadata
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
else:
|
||||
conv_state = mamba_cache_params.conv_state
|
||||
ssm_state = mamba_cache_params.ssm_state
|
||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||
|
||||
# Common members between V1 metadata and V0 metadata
|
||||
if mamba2_metadata is not None:
|
||||
has_initial_states_p = mamba2_metadata.has_initial_states_p
|
||||
prep_initial_states = mamba2_metadata.prep_initial_states
|
||||
chunk_size = mamba2_metadata.chunk_size
|
||||
seq_idx_p = mamba2_metadata.seq_idx_p
|
||||
chunk_indices_p = mamba2_metadata.chunk_indices_p
|
||||
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
prep_initial_states = attn_metadata.prep_initial_states
|
||||
chunk_size = attn_metadata.chunk_size
|
||||
seq_idx_p = attn_metadata.seq_idx_p
|
||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states, _ = self.in_proj(hidden_states)
|
||||
@ -562,8 +536,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_V1 and attn_metadata is None:
|
||||
# V1 profile run
|
||||
if attn_metadata is None:
|
||||
# profile run
|
||||
hidden_states_B_C = (hidden_states_B_C.transpose(
|
||||
0, 1).clone().transpose(0, 1)).contiguous()
|
||||
hidden_states, _B, _C = split_hidden_states_B_C_fn(
|
||||
@ -579,49 +553,27 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
has_decode = num_decodes > 0
|
||||
num_actual_tokens = num_prefill_tokens + num_decodes
|
||||
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
if envs.VLLM_USE_V1:
|
||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||
hidden_states_B_C[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
dt_d, dt_p = torch.split(
|
||||
dt[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[:num_actual_tokens],
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if has_prefill else None)
|
||||
else:
|
||||
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
|
||||
hidden_states_B_C,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
dt_p, dt_d = torch.split(
|
||||
dt,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_prefills, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
|
||||
1]
|
||||
if has_prefill else None)
|
||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||
hidden_states_B_C[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
dt_d, dt_p = torch.split(
|
||||
dt[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[:num_actual_tokens],
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if has_prefill else None)
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
@ -633,18 +585,11 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
if envs.VLLM_USE_V1:
|
||||
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Process prefill requests
|
||||
if has_prefill:
|
||||
@ -653,9 +598,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# pointed to by "state_indices_tensor"
|
||||
x = hidden_states_B_C_p.transpose(
|
||||
0, 1) # this is the form that causal-conv see
|
||||
if mamba2_metadata.cu_seqlen is None:
|
||||
mamba2_metadata = update_metadata(x, query_start_loc_p,
|
||||
mamba2_metadata)
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
x,
|
||||
conv_weights,
|
||||
@ -664,7 +606,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
metadata=mamba2_metadata,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
|
||||
@ -806,8 +748,6 @@ def mamba_mixer2(
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states,
|
||||
output=output,
|
||||
mamba_cache_params=None,
|
||||
mamba2_metadata=None,
|
||||
mup_vector=mup_vector)
|
||||
|
||||
|
||||
|
||||
@ -100,7 +100,6 @@ class MambaStateShapeCalculator:
|
||||
intermediate_size: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
conv_state_shape = (divide(intermediate_size,
|
||||
tp_world_size), conv_kernel - 1)
|
||||
@ -108,11 +107,7 @@ class MambaStateShapeCalculator:
|
||||
temporal_state_shape = (divide(intermediate_size,
|
||||
tp_world_size), state_size)
|
||||
|
||||
# In V0, the conv_state shape was swapped during allocation in
|
||||
# MambaCacheManager, but in V1 it needs to be determined here at the
|
||||
# calculation level
|
||||
if use_v1:
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@ -126,7 +121,6 @@ class MambaStateShapeCalculator:
|
||||
head_dim: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
@ -137,8 +131,6 @@ class MambaStateShapeCalculator:
|
||||
|
||||
# contiguous along 'dim' axis
|
||||
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
|
||||
if not use_v1:
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
@ -153,12 +145,9 @@ class MambaStateShapeCalculator:
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
conv_kernel: int,
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int]]:
|
||||
conv_dim = divide(intermediate_size, tp_world_size)
|
||||
conv_state_shape = (conv_kernel - 1, conv_dim)
|
||||
if not use_v1:
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
return (conv_state_shape, )
|
||||
|
||||
@classmethod
|
||||
@ -183,7 +172,6 @@ class MambaStateShapeCalculator:
|
||||
head_v_dim: int,
|
||||
conv_kernel_size: int,
|
||||
num_spec: int = 0,
|
||||
use_v1: bool = True,
|
||||
):
|
||||
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
|
||||
conv_state_shape = (
|
||||
@ -191,11 +179,7 @@ class MambaStateShapeCalculator:
|
||||
conv_kernel_size - 1 + num_spec,
|
||||
)
|
||||
|
||||
# In V0, the conv_state shape was swapped during allocation in
|
||||
# MambaCacheManager, but in V1 it needs to be determined here at the
|
||||
# calculation level
|
||||
if use_v1:
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
temporal_state_shape = (divide(num_v_heads,
|
||||
tp_world_size), head_k_dim, head_v_dim)
|
||||
|
||||
@ -420,9 +420,7 @@ def causal_conv1d_fn(
|
||||
x = x.to(conv_states.dtype)
|
||||
out = torch.empty_like(x)
|
||||
if metadata is not None:
|
||||
cu_seqlen = metadata.cu_seqlen
|
||||
nums_dict = metadata.nums_dict
|
||||
#x = metadata.x
|
||||
args = nums_dict
|
||||
batch_ptr = metadata.batch_ptr
|
||||
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
|
||||
@ -926,7 +924,6 @@ def causal_conv1d_update(
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
max_query_len: int = -1,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
metadata=None,
|
||||
validate_data=False,
|
||||
):
|
||||
"""
|
||||
|
||||
@ -8,7 +8,6 @@ if TYPE_CHECKING:
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -18,7 +17,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
@ -71,15 +69,11 @@ class ShortConv(MambaBase, CustomOp):
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1")
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The outer list is for v0 PP virtual engine. Though this code path
|
||||
# only runs for v1, we have to do this to unify with the interface
|
||||
# of Attention + v0 PP.
|
||||
self.kv_cache = [(torch.tensor([]), )]
|
||||
self.kv_cache = (torch.tensor([]), )
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@ -89,7 +83,6 @@ class ShortConv(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
conv_metadata: ShortConvAttentionMetadata,
|
||||
):
|
||||
return
|
||||
|
||||
@ -97,7 +90,6 @@ class ShortConv(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
conv_metadata: ShortConvAttentionMetadata,
|
||||
):
|
||||
torch.ops.vllm.short_conv(
|
||||
hidden_states,
|
||||
@ -109,7 +101,6 @@ class ShortConv(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
conv_metadata: ShortConvAttentionMetadata,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
# ShortConvAttentionMetadata contains metadata necessary for the
|
||||
@ -121,7 +112,6 @@ class ShortConv(MambaBase, CustomOp):
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
conv_metadata = attn_metadata
|
||||
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
@ -181,9 +171,6 @@ class ShortConv(MambaBase, CustomOp):
|
||||
|
||||
if has_prefill:
|
||||
Bx_p = (B_p * x_p).transpose(0, 1)
|
||||
if conv_metadata.cu_seqlen is None:
|
||||
conv_metadata = update_metadata(Bx_p, query_start_loc_p,
|
||||
conv_metadata)
|
||||
Bx = causal_conv1d_fn(Bx_p,
|
||||
conv_weights,
|
||||
self.conv.bias,
|
||||
@ -191,7 +178,7 @@ class ShortConv(MambaBase, CustomOp):
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
metadata=conv_metadata,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
|
||||
@ -248,9 +235,7 @@ def short_conv(
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states,
|
||||
output=output,
|
||||
conv_metadata=None)
|
||||
self.forward_cuda(hidden_states=hidden_states, output=output)
|
||||
|
||||
|
||||
def short_conv_fake(
|
||||
|
||||
@ -9,21 +9,17 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import BambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
@ -32,10 +28,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
||||
SupportsQuant)
|
||||
@ -115,8 +108,6 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
@ -127,7 +118,7 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
hidden_states, residual)
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
||||
self.mamba(hidden_states, output)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(output, residual)
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
@ -315,22 +306,10 @@ class BambaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# v1 get mamba2_metadata from forward_context
|
||||
mamba2_metadata = None
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
@ -343,23 +322,11 @@ class BambaModel(nn.Module):
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
residual = None
|
||||
num_attn = 0
|
||||
for i, layer in enumerate(self.layers):
|
||||
if isinstance(layer, BambaAttentionDecoderLayer):
|
||||
num_attn += 1
|
||||
|
||||
layer_mamba_cache_params = None
|
||||
if isinstance(layer,
|
||||
BambaMixerDecoderLayer) and mamba_cache_params:
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
i - num_attn)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=layer_mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
@ -457,13 +424,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM config
|
||||
use_v1: Get shapes for V1 (or V0)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
@ -482,7 +447,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
head_dim=hf_config.mamba_d_head,
|
||||
state_size=hf_config.mamba_d_state,
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
use_v1=use_v1,
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@ -515,8 +479,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
@ -534,39 +496,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = \
|
||||
self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config,
|
||||
LayerBlockType.mamba
|
||||
)
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
mamba_state_dtype = \
|
||||
self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
num_mamba_layers,
|
||||
*mamba_state_shape,
|
||||
*mamba_state_dtype)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -1,137 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
|
||||
|
||||
class ConstantSizeCache(ABC):
|
||||
"""
|
||||
Abstract base class for managing constant size caches
|
||||
like Mamba and Minimax.
|
||||
"""
|
||||
|
||||
def __init__(self, max_batch_size: int):
|
||||
# Maps between the request id and a dict that maps between the seq_id
|
||||
# and its index inside the cache
|
||||
self.cache_indices_mapping: dict[str, dict[int, int]] = {}
|
||||
self.free_cache_indices = list(range(max_batch_size))
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def cache(self) -> Any:
|
||||
"""Return the underlying cache tensor(s)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _copy_cache(self, from_index: int, to_index: int):
|
||||
"""Copy cache data from one index to another"""
|
||||
pass
|
||||
|
||||
def current_run_tensors(self, **kwargs) -> tuple:
|
||||
"""
|
||||
Return the tensors for the current run's conv and ssm state.
|
||||
"""
|
||||
if "seqlen_agnostic_capture_inputs" not in kwargs:
|
||||
# We get here only on Prefill/Eager mode runs
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
state_indices = self._prepare_current_run_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
|
||||
state_indices_tensor = torch.as_tensor(state_indices,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
cache_tensors = self.cache
|
||||
else:
|
||||
# CUDA graph capturing runs
|
||||
cache_tensors, state_indices_tensor = kwargs[
|
||||
"seqlen_agnostic_capture_inputs"]
|
||||
|
||||
return (cache_tensors, state_indices_tensor)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
"""
|
||||
Copy the relevant state_indices into the CUDA graph input buffer
|
||||
"""
|
||||
assert all(
|
||||
key in kwargs
|
||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
assert "seqlen_agnostic_capture_inputs" in input_buffers
|
||||
_, input_state_indices_buffer = input_buffers[
|
||||
"seqlen_agnostic_capture_inputs"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
state_indices = self._prepare_current_run_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
|
||||
state_indices)
|
||||
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
|
||||
|
||||
input_state_indices_buffer.copy_(
|
||||
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
"""
|
||||
Provide the CUDA graph capture runs with a buffer in adjusted size.
|
||||
The buffer is used to maintain the Cache during the CUDA graph replay
|
||||
runs.
|
||||
"""
|
||||
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
return (self.cache, state_indices_tensor)
|
||||
|
||||
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
|
||||
finished_requests_ids) -> int:
|
||||
"""
|
||||
Assign (req_id,seq_id) pair to a `destination_index` index, if
|
||||
already occupied, move the occupying index to a free index.
|
||||
"""
|
||||
if cur_rid in finished_requests_ids:
|
||||
# set as pad, do not allocate destination index
|
||||
return PAD_SLOT_ID
|
||||
elif cur_rid not in self.cache_indices_mapping:
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self.cache_indices_mapping[cur_rid] = {seq_id: destination_index}
|
||||
return destination_index
|
||||
elif seq_id not in (seq_ids2indices :=
|
||||
self.cache_indices_mapping[cur_rid]):
|
||||
# parallel sampling , where n > 1, assume prefill have
|
||||
# already happened, so we copy the
|
||||
# existing cache into the siblings seq_ids caches
|
||||
index_exists = next(iter(seq_ids2indices.values()))
|
||||
# case of decoding n>1, copy prefill cache to decoding indices
|
||||
destination_index = self.free_cache_indices.pop()
|
||||
self._copy_cache(from_index=index_exists,
|
||||
to_index=destination_index)
|
||||
self.cache_indices_mapping[cur_rid][seq_id] = destination_index
|
||||
return destination_index
|
||||
else:
|
||||
return self.cache_indices_mapping[cur_rid][seq_id]
|
||||
|
||||
def _prepare_current_run_cache(
|
||||
self, request_ids_to_seq_ids: dict[str, list[int]],
|
||||
finished_requests_ids: list[str]) -> list[int]:
|
||||
return [
|
||||
self._assign_seq_id_to_cache_index(req_id, seq_id,
|
||||
finished_requests_ids)
|
||||
for req_id, seq_ids in request_ids_to_seq_ids.items()
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
|
||||
def _release_finished_requests(self,
|
||||
finished_seq_groups_req_ids: list[str]):
|
||||
for req_id in finished_seq_groups_req_ids:
|
||||
if req_id in self.cache_indices_mapping:
|
||||
for seq_id in self.cache_indices_mapping[req_id]:
|
||||
self.free_cache_indices.append(
|
||||
self.cache_indices_mapping[req_id][seq_id])
|
||||
self.cache_indices_mapping.pop(req_id)
|
||||
@ -8,21 +8,17 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import FalconH1Config
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
@ -31,8 +27,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
||||
@ -179,16 +173,12 @@ class FalconH1SSMDecoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
):
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(
|
||||
hidden_states,
|
||||
output,
|
||||
mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
mup_vector=self.mup_vector,
|
||||
)
|
||||
return output, residual
|
||||
@ -364,8 +354,6 @@ class FalconH1ParallelHybrid(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
):
|
||||
residual = hidden_states
|
||||
@ -382,12 +370,10 @@ class FalconH1ParallelHybrid(nn.Module):
|
||||
|
||||
# Process input through the SSM branch.
|
||||
# FalconH1SSMDecoderLayer expects hidden_states, attn_metadata,
|
||||
# residual, mamba_cache_params, and sequence_idx.
|
||||
# residual, and sequence_idx.
|
||||
ssm_hidden, _ = self.mamba(
|
||||
hidden_states=hidden_states * self.ssm_in_multiplier,
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
# Sum the outputs from both branches.
|
||||
@ -464,25 +450,10 @@ class FalconH1Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# pass a sequence index tensor, that is required for
|
||||
# proper continuous batching computation including
|
||||
# chunked prefill
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# v1 get mamba2_metadata from forward_context
|
||||
mamba2_metadata = None
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds * self.embedding_multiplier
|
||||
@ -495,14 +466,9 @@ class FalconH1Model(nn.Module):
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
layer_mamba_cache_params = None
|
||||
if mamba_cache_params:
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
|
||||
hidden_states = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
mamba_cache_params=layer_mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
@ -541,13 +507,11 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM config
|
||||
use_v1: Get shapes for V1 (or V0)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
@ -570,7 +534,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
head_dim=hf_config.mamba_d_head,
|
||||
state_size=hf_config.mamba_d_state,
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
use_v1=use_v1,
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@ -592,7 +555,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.tie_word_embeddings = config.tie_word_embeddings
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
if get_pp_group().is_last_rank:
|
||||
@ -637,40 +599,15 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
mamba_state_dtype = \
|
||||
self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config,
|
||||
self.config.num_hidden_layers,
|
||||
*mamba_state_shape,
|
||||
*mamba_state_dtype,
|
||||
)
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
mamba_cache_params,
|
||||
intermediate_tensors,
|
||||
inputs_embeds,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -9,19 +9,15 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import GraniteMoeHybridConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
@ -30,10 +26,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
from .granitemoe import GraniteMoeMoE
|
||||
from .granitemoeshared import GraniteMoeSharedMLP
|
||||
@ -102,14 +95,12 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
||||
self.mamba(hidden_states, output)
|
||||
hidden_states = residual + output * self.residual_multiplier
|
||||
|
||||
residual = hidden_states
|
||||
@ -182,8 +173,6 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -366,22 +355,10 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# v1 get mamba2_metadata from forward_context
|
||||
mamba2_metadata = None
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
@ -399,20 +376,9 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
for i, layer in enumerate(self.layers):
|
||||
if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer):
|
||||
num_attn += 1
|
||||
|
||||
layer_mamba_cache_params = None
|
||||
if isinstance(
|
||||
layer,
|
||||
GraniteMoeHybridMambaDecoderLayer) and mamba_cache_params:
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
i - num_attn)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=layer_mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata)
|
||||
hidden_states, residual = layer(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
@ -552,13 +518,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM config
|
||||
use_v1: Get shapes for V1 (or V0)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
@ -577,7 +541,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
||||
head_dim=hf_config.mamba_d_head,
|
||||
state_size=hf_config.mamba_d_state,
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
use_v1=use_v1,
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@ -620,9 +583,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
||||
scale=1 /
|
||||
self.config.logits_scaling)
|
||||
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -636,38 +596,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = (
|
||||
self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config,
|
||||
LayerBlockType.mamba))
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
mamba_state_dtype = \
|
||||
self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
num_mamba_layers,
|
||||
*mamba_state_shape,
|
||||
*mamba_state_dtype)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -9,7 +9,6 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import JambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
@ -30,10 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaMLP as JambaMLP
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||
@ -145,7 +141,6 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
@ -156,7 +151,7 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
hidden_states, residual)
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(hidden_states, output, mamba_cache_params)
|
||||
self.mamba(hidden_states, output)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(output, residual)
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
@ -333,7 +328,6 @@ class JambaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@ -348,24 +342,11 @@ class JambaModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
kv_cache_index = 0
|
||||
mamba_cache_index = 0
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
layer_mamba_cache_params = None
|
||||
if isinstance(layer, JambaAttentionDecoderLayer):
|
||||
kv_cache_index += 1
|
||||
if isinstance(layer,
|
||||
JambaMambaDecoderLayer) and mamba_cache_params:
|
||||
current_state_layer = mamba_cache_index
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
current_state_layer)
|
||||
mamba_cache_index += 1
|
||||
hidden_states, residual = layer(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=layer_mamba_cache_params)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
@ -503,8 +484,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
@ -521,24 +500,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
# NOTE: mamba_cache_params is not needed for v1
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
state_shape = self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config)
|
||||
state_dtype = self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
num_layers, *state_shape,
|
||||
*state_dtype)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
@ -574,7 +538,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
intermediate_size=hf_config.mamba_expand * hidden_size,
|
||||
state_size=hf_config.mamba_d_state,
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
|
||||
@ -8,7 +8,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import Lfm2Config
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
@ -297,7 +296,6 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
|
||||
self.conv(
|
||||
hidden_states,
|
||||
output,
|
||||
conv_metadata=None,
|
||||
)
|
||||
hidden_states, residual = self.ffn_norm(output, residual)
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
@ -459,13 +457,11 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int]]:
|
||||
""" Calculate shapes for LFM2's convolutional cache.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM config
|
||||
use_v1: Get shapes for V1 (or V0)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
@ -478,7 +474,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
tp_world_size=parallel_config.tensor_parallel_size,
|
||||
intermediate_size=hf_config.conv_dim,
|
||||
conv_kernel=hf_config.conv_L_cache,
|
||||
use_v1=use_v1,
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
@ -489,8 +484,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert (not cache_config.enable_prefix_caching
|
||||
), "Lfm2 currently does not support prefix caching"
|
||||
assert envs.VLLM_USE_V1, (
|
||||
"Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1")
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@ -8,7 +8,6 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import MambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
@ -24,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||
IsAttentionFree, SupportsPP)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
@ -72,7 +68,6 @@ class MambaDecoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
@ -82,7 +77,7 @@ class MambaDecoderLayer(nn.Module):
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mixer(hidden_states, output, mamba_cache_params)
|
||||
self.mixer(hidden_states, output)
|
||||
return output, residual
|
||||
|
||||
|
||||
@ -134,7 +129,6 @@ class MambaModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@ -151,17 +145,9 @@ class MambaModel(nn.Module):
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
|
||||
layer_cache_params = None
|
||||
if mamba_cache_params is not None:
|
||||
layer_cache_params = mamba_cache_params.at_layer_idx(
|
||||
i - self.start_layer)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=layer_cache_params)
|
||||
hidden_states, residual = layer(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
@ -225,9 +211,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
|
||||
@ -244,22 +227,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
state_shape = self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config)
|
||||
state_dtype = self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
num_layers, *state_shape,
|
||||
*state_dtype)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
|
||||
hidden_states = self.backbone(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
@ -288,8 +256,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
||||
tp_world_size=parallel_config.tensor_parallel_size,
|
||||
intermediate_size=hf_config.intermediate_size,
|
||||
state_size=hf_config.state_size,
|
||||
conv_kernel=hf_config.conv_kernel,
|
||||
use_v1=envs.VLLM_USE_V1)
|
||||
conv_kernel=hf_config.conv_kernel)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
|
||||
@ -8,16 +8,11 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import MambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
@ -28,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||
IsAttentionFree)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
@ -74,8 +66,6 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
@ -85,7 +75,7 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
||||
self.mixer(hidden_states, output)
|
||||
return output, residual
|
||||
|
||||
|
||||
@ -137,7 +127,6 @@ class Mamba2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@ -152,25 +141,10 @@ class Mamba2Model(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# v1 get mamba2_metadata from forward_context
|
||||
mamba2_metadata = None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||
i - self.start_layer) if mamba_cache_params else None,
|
||||
mamba2_metadata=mamba2_metadata)
|
||||
hidden_states, residual = layer(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
@ -222,13 +196,11 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM config
|
||||
use_v1: Get shapes for V1 (or V0)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
@ -247,7 +219,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
head_dim=hf_config.head_dim,
|
||||
state_size=hf_config.state_size,
|
||||
conv_kernel=hf_config.conv_kernel,
|
||||
use_v1=use_v1,
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@ -282,9 +253,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)
|
||||
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
|
||||
@ -300,29 +268,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = (
|
||||
self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config,
|
||||
LayerBlockType.mamba))
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
mamba_state_dtype = \
|
||||
self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
num_mamba_layers,
|
||||
*mamba_state_shape,
|
||||
*mamba_state_dtype)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
else:
|
||||
# NOTE: mamba_cache_params is not needed for v1
|
||||
mamba_cache_params = None
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
|
||||
hidden_states = self.backbone(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -1,83 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaCacheParams:
|
||||
conv_state: torch.Tensor = torch.Tensor()
|
||||
ssm_state: torch.Tensor = torch.Tensor()
|
||||
state_indices_tensor: torch.Tensor = torch.Tensor()
|
||||
|
||||
def at_layer_idx(self, layer_idx):
|
||||
return MambaCacheParams(self.conv_state[layer_idx],
|
||||
self.ssm_state[layer_idx],
|
||||
self.state_indices_tensor)
|
||||
|
||||
|
||||
class MambaCacheManager(ConstantSizeCache):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int,
|
||||
conv_state_shape: tuple[int, int],
|
||||
temporal_state_shape: tuple[int, int],
|
||||
conv_state_dtype: torch.dtype,
|
||||
temporal_state_dtype: torch.dtype):
|
||||
|
||||
self.conv_state_dtype = conv_state_dtype
|
||||
self.temporal_state_dtype = temporal_state_dtype
|
||||
|
||||
# Determine max batch size to set size of MambaCache
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
if not vllm_config.model_config.enforce_eager:
|
||||
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
|
||||
|
||||
# Initialize parent class
|
||||
super().__init__(max_batch_size)
|
||||
|
||||
# assume conv_state = (dim, state_len)
|
||||
assert conv_state_shape[0] > conv_state_shape[1]
|
||||
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||
(conv_state_shape[1], conv_state_shape[0]),
|
||||
dtype=self.conv_state_dtype,
|
||||
device="cuda").transpose(-1, -2)
|
||||
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||
temporal_state_shape,
|
||||
dtype=self.temporal_state_dtype,
|
||||
device="cuda")
|
||||
|
||||
self._mamba_cache = (conv_state, temporal_state)
|
||||
|
||||
@property
|
||||
def cache(self):
|
||||
return self._mamba_cache
|
||||
|
||||
def _copy_cache(self, from_index: int, to_index: int):
|
||||
for cache_t in self.cache:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
||||
|
||||
def current_run_tensors(self, **kwargs) -> MambaCacheParams:
|
||||
"""
|
||||
Return the tensors for the current run's conv and ssm state.
|
||||
"""
|
||||
cache_tensors, state_indices_tensor = super().current_run_tensors(
|
||||
**kwargs)
|
||||
return MambaCacheParams(cache_tensors[0], cache_tensors[1],
|
||||
state_indices_tensor)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
"""
|
||||
Provide the CUDA graph capture runs with a buffer in adjusted size.
|
||||
The buffer is used to maintain the Mamba Cache during the CUDA graph
|
||||
replay runs.
|
||||
"""
|
||||
return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
@ -1,36 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class MinimaxCacheParams:
|
||||
minimax_cache: torch.Tensor = torch.Tensor()
|
||||
state_indices_tensor: torch.Tensor = torch.Tensor()
|
||||
|
||||
def at_layer_idx(self, layer_idx):
|
||||
return MinimaxCacheParams(self.minimax_cache[layer_idx, ...],
|
||||
self.state_indices_tensor)
|
||||
|
||||
|
||||
class MinimaxCacheManager(ConstantSizeCache):
|
||||
|
||||
def __init__(self, dtype, cache_shape):
|
||||
super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1]
|
||||
self._minimax_cache = torch.empty(size=cache_shape,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
@property
|
||||
def cache(self):
|
||||
return self._minimax_cache
|
||||
|
||||
def _copy_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.cache) > 0
|
||||
for cache_t in self.cache:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
||||
@ -14,7 +14,6 @@ import torch.distributed
|
||||
from torch import nn
|
||||
from transformers import MiniMaxConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
@ -44,7 +43,6 @@ from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid
|
||||
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||
|
||||
|
||||
@ -404,7 +402,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Union[list[dict], Optional[torch.Tensor]],
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
is_warmup: bool = False,
|
||||
@ -418,7 +415,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
hidden_states=layernorm_output,
|
||||
output=self_attention_output,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
)
|
||||
|
||||
residual = residual * self.layernorm_attention_alpha
|
||||
@ -563,10 +559,6 @@ class MiniMaxText01Model(nn.Module):
|
||||
self._dtype = _dummy.dtype
|
||||
del _dummy
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
self.minimax_cache = MinimaxCacheManager(
|
||||
dtype=torch.float32, cache_shape=self.cache_shape)
|
||||
|
||||
norm_kwargs = {}
|
||||
if hasattr(config, "rms_norm_eps"):
|
||||
norm_kwargs["eps"] = config.rms_norm_eps
|
||||
@ -614,25 +606,6 @@ class MiniMaxText01Model(nn.Module):
|
||||
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if not envs.VLLM_USE_V1 and attn_metadata is None:
|
||||
return None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if "request_ids_to_seq_ids" not in kwargs:
|
||||
kwargs["request_ids_to_seq_ids"] = {}
|
||||
if "finished_requests_ids" not in kwargs:
|
||||
kwargs["finished_requests_ids"] = []
|
||||
(
|
||||
minimax_cache_tensors,
|
||||
state_indices_tensor,
|
||||
) = self.minimax_cache.current_run_tensors(**kwargs)
|
||||
if getattr(attn_metadata, "num_prefills", 0) > 0:
|
||||
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
|
||||
**kwargs)
|
||||
|
||||
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
|
||||
state_indices_tensor)
|
||||
else:
|
||||
minimax_cache_params = None
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is None:
|
||||
@ -645,20 +618,10 @@ class MiniMaxText01Model(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
minimax_cache_index = 0
|
||||
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
_caches = None
|
||||
if not envs.VLLM_USE_V1 and isinstance(
|
||||
layer.self_attn, MiniMaxText01LinearAttention):
|
||||
current_state_layer = minimax_cache_index
|
||||
_caches = minimax_cache_params.at_layer_idx(
|
||||
current_state_layer)
|
||||
minimax_cache_index += 1
|
||||
hidden_states, residual = layer(
|
||||
hidden_states=hidden_states,
|
||||
positions=positions,
|
||||
kv_caches=_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
)
|
||||
@ -1003,13 +966,11 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, ...], ...]:
|
||||
"""Calculate shape for MiniMaxText01LinearAttention cache.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM config
|
||||
use_v1: Get shapes for V1 (or V0)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
|
||||
@ -23,21 +23,17 @@ from typing import Optional
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
@ -49,14 +45,11 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
||||
SupportsLoRA, SupportsPP,
|
||||
SupportsQuant)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory,
|
||||
make_layers, maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import NemotronHConfig
|
||||
from vllm.utils import LayerBlockType
|
||||
|
||||
|
||||
class NemotronHMLP(nn.Module):
|
||||
@ -181,8 +174,6 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
@ -192,7 +183,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
||||
self.mixer(hidden_states, output)
|
||||
return output, residual
|
||||
|
||||
|
||||
@ -370,22 +361,10 @@ class NemotronHModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# v1 get mamba2_metadata from forward_context
|
||||
mamba2_metadata = None
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
@ -398,22 +377,11 @@ class NemotronHModel(nn.Module):
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
residual = None
|
||||
num_non_mamba_layers = 0
|
||||
for i, layer in enumerate(self.layers):
|
||||
layer_mamba_cache_params = None
|
||||
if isinstance(layer,
|
||||
NemotronHMambaDecoderLayer) and mamba_cache_params:
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
i - num_non_mamba_layers)
|
||||
else:
|
||||
num_non_mamba_layers += 1
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=layer_mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
@ -508,13 +476,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM config
|
||||
use_v1: Get shapes for V1 (or V0)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
@ -533,7 +499,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
head_dim=hf_config.mamba_head_dim,
|
||||
state_size=hf_config.ssm_state_size,
|
||||
conv_kernel=hf_config.conv_kernel,
|
||||
use_v1=use_v1,
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@ -566,8 +531,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
@ -584,40 +547,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
|
||||
num_mamba_layers = \
|
||||
self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config,
|
||||
LayerBlockType.mamba
|
||||
)
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
mamba_state_dtype = \
|
||||
self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
num_mamba_layers,
|
||||
*mamba_state_shape,
|
||||
*mamba_state_dtype)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -1,731 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
||||
SupportsV0Only)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import make_layers, maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SwiGLUActivation(nn.Module):
|
||||
|
||||
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
||||
return x1 * nn.functional.silu(x2)
|
||||
|
||||
|
||||
class SambaYMLP(nn.Module):
|
||||
"""Gated Linear Unit.
|
||||
|
||||
Reference:
|
||||
Language Modeling with Gated Convolutional Networks.
|
||||
https://arxiv.org/pdf/1612.08083v3.pdf.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.fc1 = nn.Linear(config.hidden_size,
|
||||
2 * config.intermediate_size,
|
||||
bias=False)
|
||||
self.fc2 = nn.Linear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
y = self.fc1(hidden_states)
|
||||
gate, y = y.chunk(2, dim=-1)
|
||||
y = y * self.activation_fn(gate)
|
||||
return self.fc2(y)
|
||||
|
||||
|
||||
def get_virtual_engine():
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
return forward_context.virtual_engine
|
||||
|
||||
|
||||
class SambaYAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
layer_idx: Optional[int] = None,
|
||||
yoco_cross: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
f"Instantiating {self.__class__.__name__} without passing "
|
||||
"a `layer_idx` is not recommended and will lead to errors "
|
||||
"during the forward call if caching is used. Please make "
|
||||
"sure to provide a `layer_idx` when creating this class.")
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.yoco_cross = yoco_cross
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError("hidden_size must be divisible by num_heads "
|
||||
f"(got `hidden_size`: {self.hidden_size} and "
|
||||
f"`num_heads`: {self.num_heads}).")
|
||||
|
||||
op_size = self.num_heads * self.head_dim + 2 * (
|
||||
self.num_key_value_heads * self.head_dim)
|
||||
self.out_proj = nn.Linear(self.num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=True)
|
||||
if yoco_cross:
|
||||
self.Wqkv = nn.Linear(self.hidden_size,
|
||||
self.num_heads * self.head_dim,
|
||||
bias=True)
|
||||
else:
|
||||
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
|
||||
|
||||
# disable sliding window for the second half of the model
|
||||
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
||||
sliding_window = config.sliding_window if is_sliding else None
|
||||
|
||||
assert self.num_heads % 2 == 0, 'num_heads should be even'
|
||||
assert self.num_key_value_heads % 2 == 0, 'num_heads should be even'
|
||||
|
||||
self.lambda_init = self.lambda_init_fn(layer_idx)
|
||||
self.lambda_q1 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
||||
std=0.1))
|
||||
self.lambda_k1 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
||||
std=0.1))
|
||||
self.lambda_q2 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
||||
std=0.1))
|
||||
self.lambda_k2 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
||||
std=0.1))
|
||||
self.subln = nn.RMSNorm(2 * self.head_dim,
|
||||
eps=1e-5,
|
||||
elementwise_affine=True)
|
||||
|
||||
params = {
|
||||
'differential_flash_attention_config': {
|
||||
'lambda_init': self.lambda_init,
|
||||
'lambda_q1': self.lambda_q1,
|
||||
'lambda_k1': self.lambda_k1,
|
||||
'lambda_q2': self.lambda_q2,
|
||||
'lambda_k2': self.lambda_k2,
|
||||
"subln": self.subln,
|
||||
}
|
||||
}
|
||||
|
||||
if yoco_cross:
|
||||
kv_shared_layer_index = config.num_hidden_layers // 2 + 1
|
||||
kv_sharing_target_layer_name = \
|
||||
f"model.layers.{kv_shared_layer_index}.self_attn.attn"
|
||||
else:
|
||||
kv_sharing_target_layer_name = None
|
||||
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.head_dim**-0.5,
|
||||
num_kv_heads=self.num_key_value_heads,
|
||||
cache_config=cache_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||
**params)
|
||||
assert self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN,\
|
||||
"DIFFERENTIAL_FLASH_ATTN required"
|
||||
|
||||
def lambda_init_fn(self, depth):
|
||||
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
|
||||
if not self.yoco_cross: # need to generate kv-cache
|
||||
qkv = self.Wqkv(hidden_states)
|
||||
q, k, v = qkv.split([
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim,
|
||||
self.num_key_value_heads * self.head_dim
|
||||
],
|
||||
dim=-1)
|
||||
attn_output = self.attn(q, k, v)
|
||||
else: # reuse the kv cache, full attention
|
||||
q = self.Wqkv(hidden_states)
|
||||
attn_output = self.attn(q, None, None)
|
||||
attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
|
||||
return self.out_proj(attn_output)
|
||||
|
||||
|
||||
class Phi4Mamba(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
d_state=16,
|
||||
d_conv=4,
|
||||
expand=2,
|
||||
dt_rank="auto",
|
||||
dt_min=0.001,
|
||||
dt_max=0.1,
|
||||
dt_init="random", # difference
|
||||
dt_scale=1.0, # difference
|
||||
dt_init_floor=1e-4,
|
||||
conv_bias=True,
|
||||
bias=False,
|
||||
use_fast_path=True, # Fused kernel options
|
||||
layer_idx=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
yoco_cross=False,
|
||||
yoco_kv=False,
|
||||
):
|
||||
factory_kwargs = {"params_dtype": dtype} # difference
|
||||
super().__init__()
|
||||
self.yoco_cross = yoco_cross
|
||||
self.yoco_kv = yoco_kv
|
||||
self.d_model = d_model
|
||||
self.d_state = d_state
|
||||
self.d_conv = d_conv
|
||||
self.expand = expand
|
||||
self.d_inner = int(self.expand * self.d_model)
|
||||
self.dt_rank = math.ceil(self.d_model /
|
||||
16) if dt_rank == "auto" else dt_rank
|
||||
self.use_fast_path = use_fast_path
|
||||
self.layer_idx = layer_idx
|
||||
self.swiGluActivation = SwiGLUActivation()
|
||||
if self.yoco_cross:
|
||||
self.in_proj = MergedColumnParallelLinear(self.d_model,
|
||||
[self.d_inner],
|
||||
bias=bias,
|
||||
**factory_kwargs)
|
||||
self.out_proj = RowParallelLinear(self.d_inner,
|
||||
self.d_model,
|
||||
bias=bias,
|
||||
**factory_kwargs)
|
||||
return
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=d_conv,
|
||||
output_size=self.d_inner,
|
||||
bias=conv_bias,
|
||||
params_dtype=dtype,
|
||||
)
|
||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||
# Can't do this in `weight_loader` since it already exists in
|
||||
# `ColumnParallelLinear` and `set_weight_attrs`
|
||||
# doesn't allow to override it
|
||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
self.d_model,
|
||||
[self.d_inner] * 2,
|
||||
bias=bias,
|
||||
params_dtype=dtype,
|
||||
)
|
||||
|
||||
# selective projection used to make dt, B and C input dependent
|
||||
self.x_proj = RowParallelLinear(
|
||||
self.d_inner,
|
||||
self.dt_rank + self.d_state * 2,
|
||||
bias=False,
|
||||
params_dtype=dtype,
|
||||
)
|
||||
|
||||
# time step projection (discretization) -
|
||||
# In the forward we need to apply dt_proj without the bias,
|
||||
# as the bias is added in the selective scan kernel.
|
||||
self.dt_proj = ColumnParallelLinear(
|
||||
self.dt_rank,
|
||||
self.d_inner,
|
||||
bias=True,
|
||||
skip_bias_add=True,
|
||||
params_dtype=dtype,
|
||||
)
|
||||
|
||||
# # D "skip" parameter
|
||||
# self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32
|
||||
self.A = nn.Parameter(
|
||||
torch.empty(
|
||||
self.d_inner,
|
||||
self.d_state,
|
||||
dtype=torch.float32,
|
||||
))
|
||||
self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32))
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.d_inner,
|
||||
self.d_model,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
params_dtype=dtype,
|
||||
)
|
||||
self.activation = "silu"
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
yoco_key_values=None) -> torch.Tensor:
|
||||
|
||||
if self.yoco_cross:
|
||||
out = self.in_proj(hidden_states)[0]
|
||||
out = self.swiGluActivation(yoco_key_values, out)
|
||||
out = self.out_proj(out)
|
||||
return out[0], yoco_key_values
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
# projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
projected_states = self.in_proj(
|
||||
hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1)
|
||||
hidden_states, gate = projected_states.chunk(2, dim=-2)
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if attn_metadata.query_start_loc is not None \
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
hidden_states = causal_conv1d_fn(
|
||||
hidden_states,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=mamba_cache_params.conv_state,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
hidden_states = causal_conv1d_update(
|
||||
hidden_states.transpose(0, 1),
|
||||
mamba_cache_params.conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=mamba_cache_params.state_indices_tensor)
|
||||
hidden_states = hidden_states.transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
# 3.a. input varying initialization of time_step, B and C
|
||||
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
|
||||
|
||||
time_step, B, C = torch.split(
|
||||
ssm_parameters,
|
||||
[self.dt_rank, self.d_state, self.d_state],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
|
||||
|
||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
||||
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
|
||||
self.dt_proj, "bias") else None)
|
||||
|
||||
if attn_metadata.query_start_loc is not None \
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
scan_outputs = selective_scan_fn(
|
||||
hidden_states,
|
||||
mamba_cache_params.ssm_state,
|
||||
discrete_time_step,
|
||||
self.A,
|
||||
B.transpose(-2, -1),
|
||||
C.transpose(-2, -1),
|
||||
self.D.float(),
|
||||
# z,
|
||||
None if self.yoco_kv else gate,
|
||||
time_proj_bias,
|
||||
delta_softplus=True,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
|
||||
selective_state_update(
|
||||
mamba_cache_params.ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
self.A,
|
||||
B,
|
||||
C,
|
||||
self.D,
|
||||
# z
|
||||
# gate.transpose(0, 1),
|
||||
None if self.yoco_kv else gate.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor,
|
||||
out=scan_outputs)
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
|
||||
# 4. Final linear projection
|
||||
if self.yoco_kv:
|
||||
# gate = gate.transpose(-1,-2).contiguous()
|
||||
yoco_key_values = scan_outputs.transpose(-2, -1)
|
||||
scan_outputs = self.swiGluActivation(scan_outputs, gate)
|
||||
|
||||
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
|
||||
-1))[0]
|
||||
|
||||
return contextualized_states, yoco_key_values
|
||||
|
||||
|
||||
class SambaYDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
layer_idx,
|
||||
cache_config,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.mlp = SambaYMLP(config)
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
self.yoco_mb = False
|
||||
self.yoco_cross = False
|
||||
if layer_idx >= config.num_hidden_layers // 2:
|
||||
self.yoco_mb = True
|
||||
self.yoco_cross = (layer_idx
|
||||
>= (config.num_hidden_layers // 2 + 2))
|
||||
self.use_mamba = config.mb_per_layer > 0 and \
|
||||
layer_idx % config.mb_per_layer == 0
|
||||
if self.use_mamba:
|
||||
factory_kwargs = {"dtype": None}
|
||||
self.attn = Phi4Mamba(config.hidden_size,
|
||||
layer_idx=layer_idx,
|
||||
yoco_cross=self.yoco_cross,
|
||||
yoco_kv=self.yoco_mb,
|
||||
**factory_kwargs)
|
||||
else:
|
||||
self.attn = SambaYAttention(config,
|
||||
layer_idx=layer_idx,
|
||||
yoco_cross=self.yoco_cross,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
ssm_output: Optional[torch.LongTensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if self.use_mamba:
|
||||
assert mamba_cache_params is not None
|
||||
else:
|
||||
assert mamba_cache_params is None
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(
|
||||
hidden_states.to(dtype=self.input_layernorm.weight.dtype))
|
||||
|
||||
if self.use_mamba:
|
||||
attn_outputs, ssm_output = self.attn(hidden_states,
|
||||
attn_metadata,
|
||||
mamba_cache_params,
|
||||
yoco_key_values=ssm_output)
|
||||
residual = residual.to(torch.float32)
|
||||
else:
|
||||
attn_outputs = self.attn(hidden_states, )
|
||||
hidden_states = residual + attn_outputs
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(
|
||||
hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype))
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states, ssm_output
|
||||
|
||||
|
||||
class SambaYModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
cache_config=None,
|
||||
quant_config=None,
|
||||
lora_config=None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
# Pipeline parallel is not supported since the second half of
|
||||
# the layers share the kv cache.
|
||||
if get_pp_group().world_size != 1:
|
||||
raise ValueError("Pipeline Parallel not supported")
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: SambaYDecoderLayer(config,
|
||||
int(prefix.split('.')[-1]),
|
||||
cache_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
mamba_state_idx = 0
|
||||
ssm_output = None
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
if i == self.config.num_hidden_layers // 2 + 2:
|
||||
# profile run
|
||||
kv_cache_idx = self.config.num_hidden_layers // 2 + 1
|
||||
cache_layer = self.layers[kv_cache_idx]
|
||||
kv_cache = cache_layer.attn.attn.kv_cache
|
||||
if kv_cache[0].numel() == 0:
|
||||
break
|
||||
|
||||
# Starting from this layer, we do not need to calculate
|
||||
# the kv cache since we reuse the kv cache from last layer.
|
||||
# If in prefill phase, we can <s>prune></s> truncate
|
||||
# the hidden state to save computation cost.
|
||||
if attn_metadata.prefill_metadata and not envs.VLLM_USE_V1:
|
||||
selected_token_indices = torch.cumsum(
|
||||
attn_metadata.seq_lens_tensor, dim=0) - 1
|
||||
hidden_states = hidden_states.index_select(
|
||||
0, selected_token_indices)
|
||||
ssm_output = ssm_output.index_select(
|
||||
0, selected_token_indices)
|
||||
|
||||
if layer.use_mamba:
|
||||
if i < self.config.num_hidden_layers // 2 or \
|
||||
not layer.yoco_cross:
|
||||
mamba_cache = mamba_cache_params.at_layer_idx(
|
||||
mamba_state_idx)
|
||||
mamba_state_idx += 1
|
||||
else:
|
||||
mamba_cache = mamba_cache_params.at_layer_idx(
|
||||
mamba_state_idx - 1)
|
||||
|
||||
hidden_states, ssm_output = layer(hidden_states,
|
||||
positions,
|
||||
attn_metadata,
|
||||
mamba_cache,
|
||||
ssm_output=ssm_output)
|
||||
else:
|
||||
hidden_states, ssm_output = layer(
|
||||
hidden_states,
|
||||
positions,
|
||||
attn_metadata,
|
||||
None, # mamba_cache_params
|
||||
ssm_output=ssm_output)
|
||||
|
||||
hidden_states = self.final_layernorm(
|
||||
hidden_states.to(dtype=self.final_layernorm.weight.dtype))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
quant_config = vllm_config.quant_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.vllm_config = vllm_config
|
||||
# Prefix caching and chunked prefill is not supported for this model.
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Phi4flash currently does not support prefix caching"
|
||||
assert not scheduler_config.chunked_prefill_enabled, \
|
||||
"Phi4Flash currently does not support prefix caching"
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model = SambaYModel(config,
|
||||
cache_config=cache_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=(
|
||||
DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.embedding_bias = None
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
logits_as_input=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = self.config.num_hidden_layers \
|
||||
// 2 // self.config.mb_per_layer + 1
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config,
|
||||
num_mamba_layers,
|
||||
*self._get_mamba_cache_shape(),
|
||||
self.lm_head.weight.dtype,
|
||||
self.lm_head.weight.dtype,
|
||||
)
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# input_ids and hidden_states isn't a one-to-one mapping in prefill
|
||||
# stage due to YOCO optimization.
|
||||
hidden_states = self.model(input_ids, positions, attn_metadata,
|
||||
mamba_cache_params, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def _get_mamba_cache_shape(
|
||||
self
|
||||
) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
hidden_size = self.config.hidden_size
|
||||
mamba_expand = self.config.mamba_expand # 2
|
||||
mamba_d_conv = self.config.mamba_d_conv # 4
|
||||
mamba_d_state = self.config.mamba_d_state # 16
|
||||
conv_state_shape = (
|
||||
mamba_expand * hidden_size // world_size,
|
||||
mamba_d_conv - 1,
|
||||
)
|
||||
temporal_state_shape = (
|
||||
mamba_expand * hidden_size // world_size,
|
||||
mamba_d_state,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> Optional[torch.Tensor]:
|
||||
processed_logits = self.logits_processor(
|
||||
self.lm_head,
|
||||
hidden_states,
|
||||
self.embedding_bias,
|
||||
)
|
||||
return processed_logits
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
):
|
||||
weights = {name: weight for name, weight in weights}
|
||||
adjusted_weights = {}
|
||||
for name, weight in weights.items():
|
||||
if "A_log" in name:
|
||||
name = name.replace("A_log", "A")
|
||||
weight = -torch.exp(weight.float())
|
||||
if "inner_cross_attn." in name:
|
||||
name = name.replace("inner_cross_attn.", "")
|
||||
adjusted_weights[name] = weight
|
||||
adjusted_weights["lm_head.weight"] = weights[
|
||||
"model.embed_tokens.weight"]
|
||||
loaded_params: set[str] = set()
|
||||
for name, param in self.named_parameters():
|
||||
weight = adjusted_weights.get(name)
|
||||
if weight is not None and weight.shape != param.shape:
|
||||
logger.warning("Shape mismatch: %s %s %s", name, weight.shape,
|
||||
param.shape)
|
||||
loaded_params.add(name)
|
||||
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
|
||||
strict=False)
|
||||
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
|
||||
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
|
||||
return loaded_params
|
||||
@ -12,7 +12,6 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
@ -29,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata, update_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
@ -47,15 +44,13 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
||||
SupportsPP)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.models.utils import (
|
||||
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
|
||||
make_layers, maybe_prefix)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType, direct_register_custom_op
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||
|
||||
|
||||
@ -194,17 +189,13 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
|
||||
self.chunk_size = self.config.mamba_chunk_size
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The outer list is for v0 PP virtual engine. Though this code path
|
||||
# only runs for v1, we have to do this to unify with the interface
|
||||
# of Attention + v0 PP.
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||
assert self.chunk_size != -1, "chunk_size must be set for v1"
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
assert self.chunk_size != -1, "chunk_size must be set for v1"
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
@ -227,8 +218,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
):
|
||||
pass
|
||||
@ -237,59 +226,43 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
):
|
||||
if not envs.VLLM_USE_V1:
|
||||
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
|
||||
mamba2_metadata)
|
||||
else:
|
||||
torch.ops.vllm.plamo2_mamba_mixer(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
)
|
||||
torch.ops.vllm.plamo2_mamba_mixer(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
forward_context = get_forward_context()
|
||||
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
||||
# attn_metadata contains metadata necessary for the mamba2 triton
|
||||
# kernels to operate in continuous batching and in chunked prefill
|
||||
# modes; they are computed at top-level model forward since they
|
||||
# stay the same and reused for all mamba layers in the same iteration
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
if envs.VLLM_USE_V1:
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
mamba2_metadata = attn_metadata
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
else:
|
||||
conv_state = mamba_cache_params.conv_state
|
||||
ssm_state = mamba_cache_params.ssm_state
|
||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||
|
||||
# Common members between V1 metadata and V0 metadata
|
||||
if mamba2_metadata is not None:
|
||||
has_initial_states_p = mamba2_metadata.has_initial_states_p
|
||||
prep_initial_states = mamba2_metadata.prep_initial_states
|
||||
chunk_size = mamba2_metadata.chunk_size
|
||||
seq_idx_p = mamba2_metadata.seq_idx_p
|
||||
chunk_indices_p = mamba2_metadata.chunk_indices_p
|
||||
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
prep_initial_states = attn_metadata.prep_initial_states
|
||||
chunk_size = attn_metadata.chunk_size
|
||||
seq_idx_p = attn_metadata.seq_idx_p
|
||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)
|
||||
@ -299,8 +272,8 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if envs.VLLM_USE_V1 and attn_metadata is None:
|
||||
# V1 profile run
|
||||
if attn_metadata is None:
|
||||
# profile run
|
||||
hidden_states = (hidden_states.transpose(0, 1).clone().transpose(
|
||||
0, 1)).contiguous()
|
||||
output[:] = self.out_proj(hidden_states)
|
||||
@ -316,42 +289,23 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
if envs.VLLM_USE_V1:
|
||||
hidden_states_d, hidden_states_p = torch.split(
|
||||
hidden_states[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
gate_d, gate_p = torch.split(gate[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if has_prefill else None)
|
||||
else:
|
||||
hidden_states_p, hidden_states_d = torch.split(
|
||||
hidden_states,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
gate_p, gate_d = torch.split(gate,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_prefills, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
|
||||
1]
|
||||
if has_prefill else None)
|
||||
hidden_states_d, hidden_states_p = torch.split(
|
||||
hidden_states[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
gate_d, gate_p = torch.split(gate[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if has_prefill else None)
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
@ -363,18 +317,11 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
if envs.VLLM_USE_V1:
|
||||
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Process prefill requests
|
||||
if has_prefill:
|
||||
@ -383,9 +330,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
# pointed to by "state_indices_tensor"
|
||||
x = hidden_states_p.transpose(
|
||||
0, 1) # this is the form that causal-conv see
|
||||
if mamba2_metadata.cu_seqlen is None:
|
||||
mamba2_metadata = update_metadata(x, query_start_loc_p,
|
||||
mamba2_metadata)
|
||||
hidden_states_p = causal_conv1d_fn(
|
||||
x,
|
||||
conv_weights,
|
||||
@ -394,7 +338,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
metadata=mamba2_metadata,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p)
|
||||
hidden_states_p = hidden_states_p.transpose(0, 1)
|
||||
hidden_states_p = hidden_states_p[:num_prefill_tokens]
|
||||
@ -470,7 +414,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
-1, self.num_heads // self.tp_size, self.head_dim)
|
||||
|
||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||
# - mamba_cache_params.ssm_state's slots will be selected
|
||||
# - ssm_state's slots will be selected
|
||||
# using state_indices_tensor_d
|
||||
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
@ -530,10 +474,7 @@ def plamo2_mamba_mixer(
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states,
|
||||
output=output,
|
||||
mamba_cache_params=None,
|
||||
mamba2_metadata=None)
|
||||
self.forward_cuda(hidden_states=hidden_states, output=output)
|
||||
|
||||
|
||||
def plamo2_mamba_mixer_fake(
|
||||
@ -731,8 +672,6 @@ class Plamo2DecoderLayer(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
@ -747,8 +686,6 @@ class Plamo2DecoderLayer(nn.Module):
|
||||
output = torch.empty_like(hidden_states)
|
||||
mixer_kwargs = {
|
||||
"output": output,
|
||||
"mamba_cache_params": mamba_cache_params,
|
||||
"mamba2_metadata": mamba2_metadata,
|
||||
}
|
||||
else:
|
||||
mixer_kwargs = {
|
||||
@ -790,23 +727,12 @@ class Plamo2Decoder(torch.nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
) -> torch.Tensor:
|
||||
mamba_cache_index = 0
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
layer_mamba_cache_params = None
|
||||
if layer.is_mamba and mamba_cache_params is not None:
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
mamba_cache_index)
|
||||
mamba_cache_index += 1
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=layer_mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
return hidden_states, residual
|
||||
|
||||
@ -844,7 +770,6 @@ class Plamo2Model(torch.nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@ -859,23 +784,10 @@ class Plamo2Model(torch.nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
attn_metadata: AttentionMetadata = get_forward_context(
|
||||
).attn_metadata
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# v1 get mamba2_metadata from forward_context
|
||||
mamba2_metadata = None
|
||||
|
||||
hidden_states, residual = self.layers(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
@ -925,9 +837,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
self.config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
@ -942,39 +851,11 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = (
|
||||
self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config,
|
||||
LayerBlockType.mamba))
|
||||
|
||||
mamba_state_shape = self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
mamba_state_dtype = \
|
||||
self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
num_mamba_layers,
|
||||
*mamba_state_shape,
|
||||
*mamba_state_dtype)
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
else:
|
||||
# NOTE: mamba_cache_params is not needed for v1
|
||||
mamba_cache_params = None
|
||||
|
||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
@ -991,12 +872,10 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||
Args:
|
||||
vllm_config: vLLM config
|
||||
use_v1: Get shapes for V1 (or V0)
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- conv_state_shape: Shape for convolutional state cache
|
||||
@ -1015,7 +894,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
|
||||
head_dim=hf_config.hidden_size_per_head,
|
||||
state_size=hf_config.mamba_d_state,
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
use_v1=use_v1,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
|
||||
@ -11,7 +11,6 @@ from einops import rearrange
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import Attention, AttentionBackend, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
|
||||
@ -35,7 +34,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||
mamba_v2_sharded_weight_loader)
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
@ -51,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
@ -198,14 +195,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.gated_delta_net_state_shape(
|
||||
self.tp_size,
|
||||
self.num_k_heads,
|
||||
self.num_v_heads,
|
||||
self.head_k_dim,
|
||||
self.head_v_dim,
|
||||
self.conv_kernel_size,
|
||||
self.num_spec,
|
||||
use_v1=True)
|
||||
self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim,
|
||||
self.head_v_dim, self.conv_kernel_size, self.num_spec)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -394,7 +385,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
cache_params: Optional[MambaCacheParams] = None,
|
||||
):
|
||||
return torch.ops.vllm.gdn_attention(
|
||||
hidden_states,
|
||||
@ -416,7 +406,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
conv_metadata = attn_metadata
|
||||
assert isinstance(attn_metadata, GDNAttentionMetadata)
|
||||
has_initial_state = attn_metadata.has_initial_state
|
||||
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
||||
@ -479,12 +468,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
# 2.2: process the remaining part
|
||||
if attn_metadata.num_prefills > 0:
|
||||
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
|
||||
if conv_metadata.cu_seqlen is None:
|
||||
conv_metadata = update_metadata(mixed_qkv_non_spec_T,
|
||||
non_spec_query_start_loc,
|
||||
conv_metadata)
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "mamba_cache_params.state_indices_tensor"
|
||||
# pointed to by "state_indices_tensor"
|
||||
mixed_qkv_non_spec = causal_conv1d_fn(
|
||||
mixed_qkv_non_spec_T,
|
||||
conv_weights,
|
||||
@ -494,7 +479,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
has_initial_state=has_initial_state,
|
||||
cache_indices=non_spec_state_indices_tensor,
|
||||
query_start_loc=non_spec_query_start_loc,
|
||||
metadata=conv_metadata,
|
||||
metadata=attn_metadata,
|
||||
).transpose(0, 1)
|
||||
elif attn_metadata.num_decodes > 0:
|
||||
mixed_qkv_non_spec = causal_conv1d_update(
|
||||
@ -1075,7 +1060,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Qwen3Next currently does not support prefix caching"
|
||||
assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1"
|
||||
self.quant_config = vllm_config.quant_config
|
||||
|
||||
super().__init__()
|
||||
@ -1195,14 +1179,10 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
num_spec = (vllm_config.speculative_config.num_speculative_tokens
|
||||
if vllm_config.speculative_config else 0)
|
||||
return MambaStateShapeCalculator.gated_delta_net_state_shape(
|
||||
tp_size,
|
||||
hf_config.linear_num_key_heads,
|
||||
hf_config.linear_num_value_heads,
|
||||
hf_config.linear_key_head_dim,
|
||||
hf_config.linear_value_head_dim,
|
||||
hf_config.linear_conv_kernel_dim,
|
||||
num_spec,
|
||||
use_v1=True)
|
||||
tp_size, hf_config.linear_num_key_heads,
|
||||
hf_config.linear_num_value_heads, hf_config.linear_key_head_dim,
|
||||
hf_config.linear_value_head_dim, hf_config.linear_conv_kernel_dim,
|
||||
num_spec)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
|
||||
@ -134,7 +134,6 @@ _TEXT_GENERATION_MODELS = {
|
||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
||||
"Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
|
||||
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
|
||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
|
||||
@ -15,12 +15,10 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import Zamba2Config
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -29,8 +27,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
Mamba2Metadata, prepare_mamba2_metadata)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
@ -39,8 +35,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid
|
||||
@ -515,8 +509,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
transformer_hidden_states: Optional[torch.Tensor] = None,
|
||||
positions: Optional[torch.Tensor] = None,
|
||||
original_hidden_states: Optional[torch.Tensor] = None,
|
||||
@ -525,8 +517,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
|
||||
Args:
|
||||
hidden_states: Input tensor [batch_size, seq_len, hidden_size]
|
||||
mamba_cache_params: Parameters for Mamba's state caches
|
||||
(one for conv, one for ssm)
|
||||
transformer_hidden_states: Optional output from transformer path
|
||||
Added to input if provided (used in hybrid architecture)
|
||||
positions: Optional position IDs (unused in Mamba)
|
||||
@ -555,8 +545,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
self.mamba(
|
||||
hidden_states,
|
||||
output,
|
||||
mamba_cache_params=mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
|
||||
# residual connection after mamba
|
||||
@ -607,8 +595,6 @@ class Zamba2HybridLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
original_hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through the hybrid layer.
|
||||
|
||||
@ -623,8 +609,6 @@ class Zamba2HybridLayer(nn.Module):
|
||||
original_hidden_states: Original input for transformer residual
|
||||
connection
|
||||
positions: Position IDs for positional embeddings
|
||||
mamba_cache_params: Parameters for Mamba's state caches
|
||||
(one for conv, one for ssm)
|
||||
|
||||
Returns:
|
||||
Output tensor combining transformer and Mamba representations
|
||||
@ -644,8 +628,6 @@ class Zamba2HybridLayer(nn.Module):
|
||||
layer_outputs = self.mamba_decoder(
|
||||
hidden_states,
|
||||
transformer_hidden_states=transformer_hidden_states,
|
||||
mamba_cache_params=mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
|
||||
return layer_outputs
|
||||
@ -752,7 +734,6 @@ class Zamba2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""Forward pass through the model.
|
||||
@ -760,8 +741,6 @@ class Zamba2Model(nn.Module):
|
||||
Args:
|
||||
input_ids: Input token IDs
|
||||
positions: Position IDs for embeddings
|
||||
mamba_cache_params: Parameters for Mamba's state caches
|
||||
(one for conv, one for ssm)
|
||||
inputs_embeds: Optional pre-computed input embeddings
|
||||
|
||||
Returns:
|
||||
@ -773,33 +752,13 @@ class Zamba2Model(nn.Module):
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# v1 get mamba2_metadata from forward_context
|
||||
mamba2_metadata = None
|
||||
|
||||
# Process through layers
|
||||
original_hidden_states = torch.clone(hidden_states)
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
|
||||
layer_mamba_cache_params = None
|
||||
if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer))
|
||||
and mamba_cache_params):
|
||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
||||
layer_idx)
|
||||
|
||||
layer_outputs = layer(
|
||||
hidden_states,
|
||||
original_hidden_states=original_hidden_states,
|
||||
positions=positions,
|
||||
mamba_cache_params=layer_mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
hidden_states = layer_outputs
|
||||
|
||||
@ -870,13 +829,11 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM config
|
||||
use_v1: Get shapes for V1 (or V0)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
@ -896,7 +853,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
head_dim=hf_config.mamba_headdim,
|
||||
state_size=hf_config.mamba_d_state,
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
use_v1=use_v1,
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
@ -945,9 +901,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
# Tie weights with input embeddings if using same dimensions
|
||||
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
# Initialize logits processing and sampling
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
@ -977,61 +930,15 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
Returns:
|
||||
Output hidden states
|
||||
"""
|
||||
# Initialize Mamba cache if needed
|
||||
mamba_cache_params = None
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = self.config.num_hidden_layers
|
||||
mamba_state_shape = \
|
||||
self.get_mamba_state_shape_from_config(
|
||||
self.vllm_config, use_v1=False)
|
||||
mamba_state_dtype = \
|
||||
self.get_mamba_state_dtype_from_config(
|
||||
self.vllm_config)
|
||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||
num_mamba_layers,
|
||||
*mamba_state_shape,
|
||||
*mamba_state_dtype)
|
||||
|
||||
# Get cache parameters for current run
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
|
||||
# Forward pass through model
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
mamba_cache_params,
|
||||
inputs_embeds,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def copy_inputs_before_cuda_graphs(
|
||||
self, input_buffers: dict[str, torch.Tensor],
|
||||
**kwargs: Any) -> dict[str, torch.Tensor]:
|
||||
"""Copy inputs before CUDA graph capture.
|
||||
|
||||
Args:
|
||||
input_buffers: Dictionary of input tensors
|
||||
**kwargs: Additional arguments passed to cache manager
|
||||
|
||||
Returns:
|
||||
Updated input buffers
|
||||
"""
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(
|
||||
self, batch_size: int) -> dict[str, torch.Tensor]:
|
||||
"""Get inputs for sequence-length-agnostic graph capture.
|
||||
|
||||
Args:
|
||||
batch_size: Size of batch to capture
|
||||
Returns:
|
||||
Dictionary of capture inputs
|
||||
"""
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -12,6 +12,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
@ -52,7 +53,6 @@ class GDNAttentionMetadata:
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: Optional[dict] = None
|
||||
cu_seqlen: Optional[int] = None
|
||||
batch_ptr: Optional[torch.Tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||
|
||||
@ -134,6 +134,7 @@ class GDNAttentionMetadataBuilder(
|
||||
context_lens = m.num_computed_tokens_cpu
|
||||
context_lens_tensor = context_lens.to(query_start_loc.device)
|
||||
seq_lens_tensor = m.seq_lens
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if (not self.use_spec_decode or num_draft_tokens is None
|
||||
or num_draft_tokens.sum().item() == 0):
|
||||
@ -210,6 +211,8 @@ class GDNAttentionMetadataBuilder(
|
||||
has_initial_state = context_lens_tensor > 0
|
||||
if spec_sequence_masks is not None:
|
||||
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = \
|
||||
compute_causal_conv1d_metadata(non_spec_query_start_loc)
|
||||
else:
|
||||
has_initial_state = None
|
||||
num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
|
||||
@ -297,6 +300,9 @@ class GDNAttentionMetadataBuilder(
|
||||
spec_sequence_masks=spec_sequence_masks,
|
||||
spec_token_masks=spec_token_masks,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
@ -7,11 +7,12 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadataBuilder)
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
|
||||
CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
@ -131,7 +132,6 @@ class Mamba2AttentionMetadata:
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: Optional[dict] = None
|
||||
cu_seqlen: Optional[int] = None
|
||||
batch_ptr: Optional[torch.Tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||
|
||||
@ -161,6 +161,9 @@ class Mamba2AttentionMetadataBuilder(
|
||||
has_initial_states_p = None
|
||||
prep_initial_states = False
|
||||
|
||||
# for causal_conv1d
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
@ -198,6 +201,9 @@ class Mamba2AttentionMetadataBuilder(
|
||||
query_start_loc_p, self.chunk_size,
|
||||
num_prefill_tokens))
|
||||
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = \
|
||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||
|
||||
elif num_decodes <= self.decode_cudagraph_max_bs:
|
||||
# Pad state tensor for CUDA graph
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
@ -220,5 +226,8 @@ class Mamba2AttentionMetadataBuilder(
|
||||
chunk_indices_p=chunk_indices_p,
|
||||
chunk_offsets_p=chunk_offsets_p,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
@ -33,7 +34,6 @@ class ShortConvAttentionMetadata:
|
||||
|
||||
# For causal_conv1d
|
||||
nums_dict: Optional[dict] = None
|
||||
cu_seqlen: Optional[int] = None
|
||||
batch_ptr: Optional[torch.Tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||
|
||||
@ -57,6 +57,9 @@ class ShortConvAttentionMetadataBuilder(
|
||||
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
# for causal_conv1d
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
@ -70,6 +73,12 @@ class ShortConvAttentionMetadataBuilder(
|
||||
has_initial_states = has_initial_states_cpu.to(
|
||||
query_start_loc.device)
|
||||
|
||||
query_start_loc_p = common_attn_metadata.query_start_loc[
|
||||
-num_prefills - 1:] - num_decode_tokens
|
||||
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = \
|
||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||
|
||||
attn_metadata = ShortConvAttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
@ -78,5 +87,8 @@ class ShortConvAttentionMetadataBuilder(
|
||||
query_start_loc=query_start_loc,
|
||||
has_initial_states=has_initial_states,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
@ -34,6 +34,8 @@ logger = init_logger(__name__)
|
||||
KVCacheLayoutType = Literal["NHD", "HND"]
|
||||
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
def is_valid_kv_cache_layout(value: str) -> bool:
|
||||
return value in get_args(KVCacheLayoutType)
|
||||
@ -838,3 +840,52 @@ def create_fast_prefill_custom_backend(
|
||||
builder_cls=FastPrefillAttentionBuilder)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
|
||||
|
||||
# Needed for causal_conv1d
|
||||
seqlens = query_start_loc_p.diff().to('cpu')
|
||||
nums_dict = {} # type: ignore
|
||||
batch_ptr = None
|
||||
token_chunk_offset_ptr = None
|
||||
for BLOCK_M in [8]: # cover all BLOCK_M values
|
||||
nums = -(-seqlens // BLOCK_M)
|
||||
nums_dict[BLOCK_M] = {}
|
||||
nums_dict[BLOCK_M]['nums'] = nums
|
||||
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
|
||||
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
|
||||
nums_dict[BLOCK_M]['mlist'] = mlist
|
||||
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
|
||||
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
|
||||
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
|
||||
offsetlist = [] # type: ignore
|
||||
for idx, num in enumerate(nums):
|
||||
offsetlist.extend(range(num))
|
||||
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
|
||||
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
|
||||
|
||||
if batch_ptr is None:
|
||||
# Update default value after class definition
|
||||
batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
else:
|
||||
if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
|
||||
batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
|
||||
token_chunk_offset_ptr.resize_( # type: ignore
|
||||
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
|
||||
|
||||
batch_ptr[0:mlist_len].copy_(mlist)
|
||||
token_chunk_offset_ptr[ # type: ignore
|
||||
0:mlist_len].copy_(offsetlist)
|
||||
nums_dict[BLOCK_M]['batch_ptr'] = batch_ptr
|
||||
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (token_chunk_offset_ptr
|
||||
) # type: ignore
|
||||
|
||||
return nums_dict, batch_ptr, token_chunk_offset_ptr
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user