mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 14:55:01 +08:00
[V1] Remove V0 code paths for Hybrid models (#25400)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
02134245a9
commit
f97da2c732
@ -20,7 +20,9 @@ pytestmark = pytest.mark.hybrid_model
|
|||||||
SSM_MODELS = [
|
SSM_MODELS = [
|
||||||
"state-spaces/mamba-130m-hf",
|
"state-spaces/mamba-130m-hf",
|
||||||
"tiiuae/falcon-mamba-tiny-dev",
|
"tiiuae/falcon-mamba-tiny-dev",
|
||||||
"yujiepan/mamba2-codestral-v0.1-tiny-random",
|
# mamba2-codestral in transformers is broken pending:
|
||||||
|
# https://github.com/huggingface/transformers/pull/40861
|
||||||
|
#"yujiepan/mamba2-codestral-v0.1-tiny-random",
|
||||||
]
|
]
|
||||||
|
|
||||||
HYBRID_MODELS = [
|
HYBRID_MODELS = [
|
||||||
@ -31,18 +33,7 @@ HYBRID_MODELS = [
|
|||||||
"ibm-granite/granite-4.0-tiny-preview",
|
"ibm-granite/granite-4.0-tiny-preview",
|
||||||
"tiiuae/Falcon-H1-0.5B-Base",
|
"tiiuae/Falcon-H1-0.5B-Base",
|
||||||
"LiquidAI/LFM2-1.2B",
|
"LiquidAI/LFM2-1.2B",
|
||||||
]
|
"tiny-random/qwen3-next-moe",
|
||||||
|
|
||||||
V1_SUPPORTED_MODELS = [
|
|
||||||
"state-spaces/mamba-130m-hf",
|
|
||||||
"ai21labs/Jamba-tiny-dev",
|
|
||||||
"pfnet/plamo-2-1b",
|
|
||||||
"yujiepan/mamba2-codestral-v0.1-tiny-random",
|
|
||||||
"Zyphra/Zamba2-1.2B-instruct",
|
|
||||||
"hmellor/tiny-random-BambaForCausalLM",
|
|
||||||
"ibm-granite/granite-4.0-tiny-preview",
|
|
||||||
"tiiuae/Falcon-H1-0.5B-Base",
|
|
||||||
"LiquidAI/LFM2-1.2B",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
FULL_CUDA_GRAPH_MODELS = [
|
FULL_CUDA_GRAPH_MODELS = [
|
||||||
@ -51,10 +42,6 @@ FULL_CUDA_GRAPH_MODELS = [
|
|||||||
"Zyphra/Zamba2-1.2B-instruct",
|
"Zyphra/Zamba2-1.2B-instruct",
|
||||||
]
|
]
|
||||||
|
|
||||||
V0_UNSUPPORTED_MODELS = [
|
|
||||||
"LiquidAI/LFM2-1.2B",
|
|
||||||
]
|
|
||||||
|
|
||||||
FP32_STATE_MODELS = [
|
FP32_STATE_MODELS = [
|
||||||
"state-spaces/mamba-130m-hf",
|
"state-spaces/mamba-130m-hf",
|
||||||
"Zyphra/Zamba2-1.2B-instruct",
|
"Zyphra/Zamba2-1.2B-instruct",
|
||||||
@ -88,19 +75,15 @@ def test_models(
|
|||||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
if model in V1_SUPPORTED_MODELS:
|
|
||||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
else:
|
|
||||||
vllm_v1_outputs = None
|
|
||||||
|
|
||||||
if model in V1_SUPPORTED_MODELS:
|
|
||||||
check_logprobs_close(
|
check_logprobs_close(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
outputs_1_lst=vllm_v1_outputs,
|
outputs_1_lst=vllm_outputs,
|
||||||
name_0="hf",
|
name_0="hf",
|
||||||
name_1="vllm-v1",
|
name_1="vllm",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -299,14 +282,14 @@ def test_full_cuda_graph(
|
|||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
check_logprobs_close(
|
check_logprobs_close(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
outputs_1_lst=vllm_v1_outputs,
|
outputs_1_lst=vllm_outputs,
|
||||||
name_0="hf",
|
name_0="hf",
|
||||||
name_1="vllm-v1",
|
name_1="vllm",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -340,12 +323,12 @@ def test_fp32_cache_state(
|
|||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
max_num_seqs=MAX_NUM_SEQS,
|
max_num_seqs=MAX_NUM_SEQS,
|
||||||
**{cache_dtype_param: "float32"}) as vllm_model:
|
**{cache_dtype_param: "float32"}) as vllm_model:
|
||||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
check_logprobs_close(
|
check_logprobs_close(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
outputs_1_lst=vllm_v1_outputs,
|
outputs_1_lst=vllm_outputs,
|
||||||
name_0="hf",
|
name_0="hf",
|
||||||
name_1="vllm-v1",
|
name_1="vllm",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -312,13 +312,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
|
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
|
||||||
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
|
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
|
||||||
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
|
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
|
||||||
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
|
|
||||||
trust_remote_code=True,
|
|
||||||
v0_only=True,
|
|
||||||
max_model_len=10240),
|
|
||||||
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
|
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
|
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
|
||||||
|
max_transformers_version="4.55.4",
|
||||||
|
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
|
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
|
||||||
max_transformers_version="4.53",
|
max_transformers_version="4.53",
|
||||||
@ -330,7 +328,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
||||||
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
||||||
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||||
min_transformers_version="4.56.2"),
|
extras={"tiny-random": "tiny-random/qwen3-next-moe"}, # noqa: E501
|
||||||
|
min_transformers_version="4.56.3"),
|
||||||
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
||||||
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
|
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
@ -644,7 +643,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
|
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
|
||||||
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||||
min_transformers_version="4.56.2"),
|
min_transformers_version="4.56.3"),
|
||||||
}
|
}
|
||||||
|
|
||||||
_TRANSFORMERS_BACKEND_MODELS = {
|
_TRANSFORMERS_BACKEND_MODELS = {
|
||||||
|
|||||||
@ -20,10 +20,7 @@ class MambaBase(AttentionLayerBase):
|
|||||||
|
|
||||||
# Contains the KV cache (mamba state) for the layer
|
# Contains the KV cache (mamba state) for the layer
|
||||||
# in the shape specified by `self.get_state_shape`.
|
# in the shape specified by `self.get_state_shape`.
|
||||||
# The outer list is for v0 PP virtual engine. Though this code path
|
kv_cache: tuple[torch.Tensor, ...]
|
||||||
# only runs for v1, we have to do this to unify with the interface
|
|
||||||
# of Attention + v0 PP.
|
|
||||||
kv_cache: list[Iterable[torch.Tensor]]
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
|
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
|
||||||
|
|||||||
@ -15,7 +15,6 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||||
@ -42,8 +41,6 @@ if TYPE_CHECKING:
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from vllm.model_executor.models.minimax_cache import MinimaxCacheParams
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxText01RMSNormTP(CustomOp):
|
class MiniMaxText01RMSNormTP(CustomOp):
|
||||||
name = "MiniMaxText01RMSNormTP"
|
name = "MiniMaxText01RMSNormTP"
|
||||||
@ -225,7 +222,6 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
|||||||
self.tp_heads:(self.tp_rank + 1) *
|
self.tp_heads:(self.tp_rank + 1) *
|
||||||
self.tp_heads].contiguous()
|
self.tp_heads].contiguous()
|
||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
compilation_config = get_current_vllm_config().compilation_config
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
if prefix in compilation_config.static_forward_context:
|
if prefix in compilation_config.static_forward_context:
|
||||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
@ -268,8 +264,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
|||||||
break
|
break
|
||||||
if _prefill_idx >= len(state_indices_tensor):
|
if _prefill_idx >= len(state_indices_tensor):
|
||||||
break
|
break
|
||||||
# prefills are packed at end of batch in V1
|
offset = attn_metadata.num_decode_tokens
|
||||||
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
|
|
||||||
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
||||||
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
||||||
slot_id = state_indices_tensor[offset + _prefill_idx]
|
slot_id = state_indices_tensor[offset + _prefill_idx]
|
||||||
@ -291,10 +286,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
|||||||
hidden_decode = self._decode_infer(q, k, v, kv_cache,
|
hidden_decode = self._decode_infer(q, k, v, kv_cache,
|
||||||
state_indices_tensor,
|
state_indices_tensor,
|
||||||
attn_metadata)
|
attn_metadata)
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
hidden.insert(0, hidden_decode)
|
hidden.insert(0, hidden_decode)
|
||||||
else:
|
|
||||||
hidden.append(hidden_decode)
|
|
||||||
|
|
||||||
if not hidden:
|
if not hidden:
|
||||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||||
@ -304,13 +296,6 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
|||||||
|
|
||||||
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
|
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
|
||||||
attn_metadata):
|
attn_metadata):
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
|
||||||
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
|
||||||
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
|
||||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
|
||||||
slot_id = state_indices_tensor[num_prefills:]
|
|
||||||
else:
|
|
||||||
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||||
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||||
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||||
@ -320,11 +305,7 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
|||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor) -> None:
|
||||||
kv_caches: MinimaxCacheParams) -> None:
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
self._forward(hidden_states, output, positions, kv_caches)
|
|
||||||
else:
|
|
||||||
torch.ops.vllm.linear_attention(
|
torch.ops.vllm.linear_attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
output,
|
output,
|
||||||
@ -333,11 +314,10 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor) -> None:
|
||||||
kv_caches: Optional[MinimaxCacheParams]) -> None:
|
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||||
if envs.VLLM_USE_V1 and attn_metadata is not None:
|
if attn_metadata is not None:
|
||||||
assert isinstance(attn_metadata, dict)
|
assert isinstance(attn_metadata, dict)
|
||||||
attn_metadata = attn_metadata[self.prefix]
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
||||||
@ -351,32 +331,26 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
|||||||
qkvact = torch.nn.functional.silu(qkv32)
|
qkvact = torch.nn.functional.silu(qkv32)
|
||||||
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
|
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
|
||||||
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
if attn_metadata is not None:
|
if attn_metadata is not None:
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
|
|
||||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||||
if num_prefills > 0:
|
if num_prefills > 0:
|
||||||
num_decode_tokens = getattr(attn_metadata,
|
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens",
|
||||||
"num_decode_tokens", 0)
|
0)
|
||||||
for prefill_idx in range(num_prefills):
|
for prefill_idx in range(num_prefills):
|
||||||
q_start = attn_metadata.query_start_loc[
|
q_start = attn_metadata.query_start_loc[num_decode_tokens +
|
||||||
num_decode_tokens + prefill_idx]
|
prefill_idx]
|
||||||
q_end = attn_metadata.query_start_loc[num_decode_tokens
|
q_end = attn_metadata.query_start_loc[num_decode_tokens +
|
||||||
+ prefill_idx +
|
prefill_idx + 1]
|
||||||
1]
|
|
||||||
query_len = q_end - q_start
|
query_len = q_end - q_start
|
||||||
context_len = attn_metadata.seq_lens[
|
context_len = attn_metadata.seq_lens[
|
||||||
num_decode_tokens + prefill_idx] - query_len
|
num_decode_tokens + prefill_idx] - query_len
|
||||||
if context_len == 0:
|
if context_len == 0:
|
||||||
block_to_clear = state_indices_tensor[
|
block_to_clear = state_indices_tensor[num_decode_tokens
|
||||||
num_decode_tokens + prefill_idx]
|
+ prefill_idx]
|
||||||
kv_cache[block_to_clear, ...] = 0
|
kv_cache[block_to_clear, ...] = 0
|
||||||
else:
|
|
||||||
assert kv_caches is not None
|
|
||||||
kv_cache = kv_caches.minimax_cache
|
|
||||||
state_indices_tensor = kv_caches.state_indices_tensor
|
|
||||||
|
|
||||||
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
@ -410,8 +384,7 @@ def linear_attention(
|
|||||||
self = forward_context.no_compile_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
self._forward(hidden_states=hidden_states,
|
self._forward(hidden_states=hidden_states,
|
||||||
output=output,
|
output=output,
|
||||||
positions=positions,
|
positions=positions)
|
||||||
kv_caches=None)
|
|
||||||
|
|
||||||
|
|
||||||
def linear_attention_fake(
|
def linear_attention_fake(
|
||||||
|
|||||||
@ -1,177 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.attention.backends.placeholder_attn import (
|
|
||||||
PlaceholderAttentionMetadata)
|
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
|
||||||
from vllm.v1.attention.backends.mamba2_attn import (
|
|
||||||
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Mamba2Metadata:
|
|
||||||
prep_initial_states: bool
|
|
||||||
chunk_size: int
|
|
||||||
|
|
||||||
has_initial_states_p: torch.Tensor
|
|
||||||
seq_idx_p: torch.Tensor
|
|
||||||
chunk_indices_p: torch.Tensor
|
|
||||||
chunk_offsets_p: torch.Tensor
|
|
||||||
"""
|
|
||||||
With continuous batching layout of `x` in vLLM, to enable a Triton program
|
|
||||||
to handle a request in parallel, two supporting tensors are used
|
|
||||||
(batch_ptr, token_chunk_offset_ptr)
|
|
||||||
BLOCK_M = the # tokens to be handled by a Triton program
|
|
||||||
(can be customized for different hardware)
|
|
||||||
|
|
||||||
nums_dict:
|
|
||||||
tracks the data associated with a given value of BLOCK_M
|
|
||||||
BLOCK_M = #tokens handled by a Triton program
|
|
||||||
cu_seqlen: total tokens per batch
|
|
||||||
(used as flag to update other data at each new input)
|
|
||||||
batch_ptr: tracks batch-id handled by the Triton program
|
|
||||||
token_chunk_offset_ptr: tracks token group_idx handled by the Triton program
|
|
||||||
(Triton implementation of causal_conv1d handles parallelism in 3-axes
|
|
||||||
- feature-axis
|
|
||||||
- batch-axis
|
|
||||||
- sequence-axis)
|
|
||||||
"""
|
|
||||||
nums_dict: Optional[dict] = None
|
|
||||||
cu_seqlen: Optional[int] = None
|
|
||||||
batch_ptr: Optional[torch.Tensor] = None
|
|
||||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
|
||||||
"""Returns the appropriate metadata classes for the current platform."""
|
|
||||||
if current_platform.is_rocm():
|
|
||||||
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
|
||||||
AiterFlashAttentionMetadata)
|
|
||||||
from vllm.v1.attention.backends.triton_attn import (
|
|
||||||
TritonAttentionMetadata)
|
|
||||||
return (AiterFlashAttentionMetadata, TritonAttentionMetadata,
|
|
||||||
PlaceholderAttentionMetadata)
|
|
||||||
if current_platform.is_cuda():
|
|
||||||
from vllm.v1.attention.backends.flash_attn import (
|
|
||||||
FlashAttentionMetadata)
|
|
||||||
from vllm.v1.attention.backends.xformers import (
|
|
||||||
XFormersAttentionMetadata)
|
|
||||||
return (FlashAttentionMetadata, XFormersAttentionMetadata,
|
|
||||||
PlaceholderAttentionMetadata)
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported platform for Mamba2: {current_platform.device_type}")
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_mamba2_metadata(
|
|
||||||
chunk_size: int,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
) -> Mamba2Metadata:
|
|
||||||
|
|
||||||
# compute number of prefill and decode requests
|
|
||||||
# NOTE: in V0 we assume prefills are before decodes
|
|
||||||
num_prefills = attn_metadata.num_prefills
|
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
|
||||||
|
|
||||||
seq_idx_p = None
|
|
||||||
chunk_indices_p, chunk_offsets_p = None, None
|
|
||||||
# Need flags to indicate if there are initial states
|
|
||||||
# currently we really only support the FlashAttention backend
|
|
||||||
has_initial_states_p = None
|
|
||||||
prep_initial_states = False
|
|
||||||
|
|
||||||
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
|
||||||
if num_prefills > 0:
|
|
||||||
attn_metadata_instances = get_platform_metadata_classes()
|
|
||||||
if (isinstance(attn_metadata, attn_metadata_instances)
|
|
||||||
and attn_metadata.context_lens_tensor is not None):
|
|
||||||
# precompute flag to avoid device syncs later in mamba2 layer
|
|
||||||
# forwards
|
|
||||||
# prep is only needed for mamba2 ssd prefill processing
|
|
||||||
has_initial_states_p = (
|
|
||||||
attn_metadata.context_lens_tensor[:num_prefills] > 0)
|
|
||||||
prep_initial_states = torch.any(has_initial_states_p).item()
|
|
||||||
query_start_loc_p = attn_metadata.query_start_loc[:num_prefills + 1]
|
|
||||||
seq_idx_p = torch.repeat_interleave(torch.arange(
|
|
||||||
num_prefills, dtype=torch.int32, device=query_start_loc_p.device),
|
|
||||||
query_start_loc_p.diff(),
|
|
||||||
output_size=num_prefill_tokens)
|
|
||||||
seq_idx_p.unsqueeze_(0)
|
|
||||||
|
|
||||||
# We compute metadata for chunked prefill once at the top level model
|
|
||||||
# forward and reuse them in mamba layers. If not needed, they will be
|
|
||||||
# ignored inside mamba kernels.
|
|
||||||
if prep_initial_states:
|
|
||||||
chunk_indices_p, chunk_offsets_p = \
|
|
||||||
_query_start_loc_to_chunk_indices_offsets(
|
|
||||||
query_start_loc_p, chunk_size, num_prefill_tokens)
|
|
||||||
|
|
||||||
return Mamba2Metadata(has_initial_states_p=has_initial_states_p,
|
|
||||||
prep_initial_states=prep_initial_states,
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
seq_idx_p=seq_idx_p,
|
|
||||||
chunk_indices_p=chunk_indices_p,
|
|
||||||
chunk_offsets_p=chunk_offsets_p)
|
|
||||||
|
|
||||||
|
|
||||||
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
|
|
||||||
mamba2_metadata: Union[Mamba2Metadata,
|
|
||||||
Mamba2AttentionMetadata,
|
|
||||||
GDNAttentionMetadata]):
|
|
||||||
"""
|
|
||||||
this is triggered upon handling a new input at the first layer
|
|
||||||
"""
|
|
||||||
dim, cu_seqlen = x.shape
|
|
||||||
mamba2_metadata.cu_seqlen = cu_seqlen
|
|
||||||
seqlens = np.diff(query_start_loc.to('cpu'))
|
|
||||||
nums_dict = {} # type: ignore
|
|
||||||
for BLOCK_M in [8]: # cover all BLOCK_M values
|
|
||||||
nums = -(-seqlens // BLOCK_M)
|
|
||||||
nums_dict[BLOCK_M] = {}
|
|
||||||
nums_dict[BLOCK_M]['nums'] = nums
|
|
||||||
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
|
|
||||||
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
|
|
||||||
nums_dict[BLOCK_M]['mlist'] = mlist
|
|
||||||
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
|
|
||||||
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
|
|
||||||
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
|
|
||||||
offsetlist = [] # type: ignore
|
|
||||||
for idx, num in enumerate(nums):
|
|
||||||
offsetlist.extend(range(num))
|
|
||||||
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
|
|
||||||
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
|
|
||||||
|
|
||||||
if mamba2_metadata.batch_ptr is None:
|
|
||||||
# Update default value after class definition
|
|
||||||
#mamba2_metadata.MAX_NUM_PROGRAMS *= 2
|
|
||||||
mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
|
||||||
PAD_SLOT_ID,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device='cuda')
|
|
||||||
mamba2_metadata.token_chunk_offset_ptr = torch.full(
|
|
||||||
(MAX_NUM_PROGRAMS, ),
|
|
||||||
PAD_SLOT_ID,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device='cuda')
|
|
||||||
else:
|
|
||||||
if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS:
|
|
||||||
mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(
|
|
||||||
PAD_SLOT_ID)
|
|
||||||
mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore
|
|
||||||
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
|
|
||||||
|
|
||||||
mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist)
|
|
||||||
mamba2_metadata.token_chunk_offset_ptr[ # type: ignore
|
|
||||||
0:mlist_len].copy_(offsetlist)
|
|
||||||
nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr
|
|
||||||
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (
|
|
||||||
mamba2_metadata.token_chunk_offset_ptr) # type: ignore
|
|
||||||
mamba2_metadata.nums_dict = nums_dict
|
|
||||||
return mamba2_metadata
|
|
||||||
@ -10,8 +10,6 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
@ -28,7 +26,6 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
|||||||
causal_conv1d_fn, causal_conv1d_update)
|
causal_conv1d_fn, causal_conv1d_update)
|
||||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||||
selective_scan_fn, selective_state_update)
|
selective_scan_fn, selective_state_update)
|
||||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
@ -149,16 +146,12 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
has_weight=rms_norm_has_weight,
|
has_weight=rms_norm_has_weight,
|
||||||
) if use_rms_norm else None
|
) if use_rms_norm else None
|
||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
compilation_config = get_current_vllm_config().compilation_config
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
if prefix in compilation_config.static_forward_context:
|
if prefix in compilation_config.static_forward_context:
|
||||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
compilation_config.static_forward_context[prefix] = self
|
compilation_config.static_forward_context[prefix] = self
|
||||||
# The outer list is for v0 PP virtual engine. Though this code path
|
|
||||||
# only runs for v1, we have to do this to unify with the interface
|
|
||||||
# of Attention + v0 PP.
|
|
||||||
# The inner tuple is (conv_state, ssm_state)
|
# The inner tuple is (conv_state, ssm_state)
|
||||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
@ -186,29 +179,18 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||||
return discrete_time_step, B, C
|
return discrete_time_step, B, C
|
||||||
|
|
||||||
def forward(self,
|
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
output: torch.Tensor,
|
|
||||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
CustomOp.forward(self, hidden_states, output, mamba_cache_params)
|
|
||||||
else:
|
|
||||||
torch.ops.vllm.mamba_mixer(
|
torch.ops.vllm.mamba_mixer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
output,
|
output,
|
||||||
self.prefix,
|
self.prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_native(self,
|
def forward_native(self, hidden_states: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
output: torch.Tensor):
|
||||||
output: torch.Tensor,
|
|
||||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def forward_cuda(self,
|
def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
output: torch.Tensor,
|
|
||||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
|
||||||
"""
|
"""
|
||||||
Run the Mamba-1 SSM pipeline.
|
Run the Mamba-1 SSM pipeline.
|
||||||
|
|
||||||
@ -234,7 +216,6 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
if attn_metadata is not None:
|
if attn_metadata is not None:
|
||||||
assert isinstance(attn_metadata, dict)
|
assert isinstance(attn_metadata, dict)
|
||||||
attn_metadata = attn_metadata[self.prefix]
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
@ -247,18 +228,6 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
ssm_state = self_kv_cache[1]
|
ssm_state = self_kv_cache[1]
|
||||||
has_initial_states = mamba1_metadata.has_initial_states
|
has_initial_states = mamba1_metadata.has_initial_states
|
||||||
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
||||||
else:
|
|
||||||
assert isinstance(attn_metadata, AttentionMetadata)
|
|
||||||
assert mamba_cache_params is not None
|
|
||||||
conv_state = mamba_cache_params.conv_state
|
|
||||||
ssm_state = mamba_cache_params.ssm_state
|
|
||||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
|
||||||
query_start_loc = attn_metadata.query_start_loc
|
|
||||||
context_lens_tensor = attn_metadata.context_lens_tensor
|
|
||||||
has_initial_states = None
|
|
||||||
if context_lens_tensor is not None:
|
|
||||||
has_initial_states = context_lens_tensor > 0
|
|
||||||
num_padded_decodes = attn_metadata.num_decode_tokens
|
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||||
@ -267,7 +236,7 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||||
self.conv1d.weight.size(2))
|
self.conv1d.weight.size(2))
|
||||||
|
|
||||||
if envs.VLLM_USE_V1 and attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# V1 profile run
|
# V1 profile run
|
||||||
hidden_states_BC = hidden_states_BC.contiguous()
|
hidden_states_BC = hidden_states_BC.contiguous()
|
||||||
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
|
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
|
||||||
@ -368,10 +337,7 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
out=scan_outputs_d)
|
out=scan_outputs_d)
|
||||||
scan_outputs_d = scan_outputs_d.transpose(0, 1)
|
scan_outputs_d = scan_outputs_d.transpose(0, 1)
|
||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
ssm_outputs.insert(0, scan_outputs_d)
|
ssm_outputs.insert(0, scan_outputs_d)
|
||||||
else:
|
|
||||||
ssm_outputs.append(scan_outputs_d)
|
|
||||||
|
|
||||||
scan_outputs_combined = ssm_outputs[0] if len(
|
scan_outputs_combined = ssm_outputs[0] if len(
|
||||||
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
|
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
|
||||||
@ -441,9 +407,9 @@ def split_batch_to_prefill_and_decode(
|
|||||||
num_decodes: int,
|
num_decodes: int,
|
||||||
num_padded_decodes: int,
|
num_padded_decodes: int,
|
||||||
) -> PrefillDecodeSplit:
|
) -> PrefillDecodeSplit:
|
||||||
|
|
||||||
num_actual_tokens = num_prefill_tokens + num_padded_decodes
|
num_actual_tokens = num_prefill_tokens + num_padded_decodes
|
||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
# In v1, decode tokens come first, then prefill tokens.
|
# In v1, decode tokens come first, then prefill tokens.
|
||||||
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||||
hidden_states_BC[..., :num_actual_tokens],
|
hidden_states_BC[..., :num_actual_tokens],
|
||||||
@ -462,19 +428,6 @@ def split_batch_to_prefill_and_decode(
|
|||||||
num_padded_decodes if num_prefills > 0 else None)
|
num_padded_decodes if num_prefills > 0 else None)
|
||||||
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
||||||
has_initial_states is not None and num_prefills > 0) else None
|
has_initial_states is not None and num_prefills > 0) else None
|
||||||
else:
|
|
||||||
# In v0, prefill tokens come first, then decode tokens.
|
|
||||||
hidden_states_BC_p, hidden_states_BC_d = torch.split(
|
|
||||||
hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1)
|
|
||||||
gate_p, gate_d = torch.split(gate,
|
|
||||||
[num_prefill_tokens, num_decode_tokens],
|
|
||||||
dim=-1)
|
|
||||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
|
||||||
state_indices_tensor, [num_prefills, num_decodes], dim=0)
|
|
||||||
query_start_loc_p = (query_start_loc[:num_prefills +
|
|
||||||
1] if num_prefills > 0 else None)
|
|
||||||
has_initial_states_p = has_initial_states[:num_prefills] if (
|
|
||||||
has_initial_states is not None and num_prefills > 0) else None
|
|
||||||
|
|
||||||
return PrefillDecodeSplit(
|
return PrefillDecodeSplit(
|
||||||
hidden_states_BC_p=hidden_states_BC_p,
|
hidden_states_BC_p=hidden_states_BC_p,
|
||||||
@ -495,9 +448,7 @@ def mamba_mixer(
|
|||||||
) -> None:
|
) -> None:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
self = forward_context.no_compile_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
self.forward_cuda(hidden_states=hidden_states,
|
self.forward_cuda(hidden_states=hidden_states, output=output)
|
||||||
output=output,
|
|
||||||
mamba_cache_params=None)
|
|
||||||
|
|
||||||
|
|
||||||
def mamba_mixer_fake(
|
def mamba_mixer_fake(
|
||||||
|
|||||||
@ -9,7 +9,6 @@ if TYPE_CHECKING:
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||||
@ -22,8 +21,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
|
||||||
update_metadata)
|
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
@ -36,7 +33,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
||||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
@ -449,16 +445,12 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
self.use_rms_norm,
|
self.use_rms_norm,
|
||||||
eps=rms_norm_eps)
|
eps=rms_norm_eps)
|
||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
compilation_config = get_current_vllm_config().compilation_config
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
if prefix in compilation_config.static_forward_context:
|
if prefix in compilation_config.static_forward_context:
|
||||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
compilation_config.static_forward_context[prefix] = self
|
compilation_config.static_forward_context[prefix] = self
|
||||||
# The outer list is for v0 PP virtual engine. Though this code path
|
# The tuple is (conv_state, ssm_state)
|
||||||
# only runs for v1, we have to do this to unify with the interface
|
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||||
# of Attention + v0 PP.
|
|
||||||
# The inner tuple is (conv_state, ssm_state)
|
|
||||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
|
||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
@ -468,8 +460,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
mup_vector: Optional[torch.Tensor] = None,
|
mup_vector: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
@ -478,14 +468,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
mup_vector: Optional[torch.Tensor] = None,
|
mup_vector: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
|
|
||||||
mamba2_metadata, mup_vector)
|
|
||||||
else:
|
|
||||||
torch.ops.vllm.mamba_mixer2(
|
torch.ops.vllm.mamba_mixer2(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
output,
|
output,
|
||||||
@ -497,40 +481,30 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
mup_vector: Optional[torch.Tensor] = None,
|
mup_vector: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
# attn_metadata contains metadata necessary for the mamba2 triton
|
||||||
# kernels to operate in continuous batching and in chunked prefill
|
# kernels to operate in continuous batching and in chunked prefill
|
||||||
# modes; they are computed at top-level model forward since they
|
# modes; they are computed at top-level model forward since they
|
||||||
# stay the same and reused for all mamba layers in the same iteration
|
# stay the same and reused for all mamba layers in the same iteration
|
||||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
if attn_metadata is not None:
|
if attn_metadata is not None:
|
||||||
assert isinstance(attn_metadata, dict)
|
assert isinstance(attn_metadata, dict)
|
||||||
attn_metadata = attn_metadata[self.prefix]
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
mamba2_metadata = attn_metadata
|
|
||||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||||
ssm_state = self_kv_cache[1]
|
ssm_state = self_kv_cache[1]
|
||||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
else:
|
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||||
conv_state = mamba_cache_params.conv_state
|
prep_initial_states = attn_metadata.prep_initial_states
|
||||||
ssm_state = mamba_cache_params.ssm_state
|
chunk_size = attn_metadata.chunk_size
|
||||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
seq_idx_p = attn_metadata.seq_idx_p
|
||||||
|
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||||
# Common members between V1 metadata and V0 metadata
|
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||||
if mamba2_metadata is not None:
|
|
||||||
has_initial_states_p = mamba2_metadata.has_initial_states_p
|
|
||||||
prep_initial_states = mamba2_metadata.prep_initial_states
|
|
||||||
chunk_size = mamba2_metadata.chunk_size
|
|
||||||
seq_idx_p = mamba2_metadata.seq_idx_p
|
|
||||||
chunk_indices_p = mamba2_metadata.chunk_indices_p
|
|
||||||
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
|
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states, _ = self.in_proj(hidden_states)
|
projected_states, _ = self.in_proj(hidden_states)
|
||||||
@ -562,8 +536,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
if envs.VLLM_USE_V1 and attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# V1 profile run
|
# profile run
|
||||||
hidden_states_B_C = (hidden_states_B_C.transpose(
|
hidden_states_B_C = (hidden_states_B_C.transpose(
|
||||||
0, 1).clone().transpose(0, 1)).contiguous()
|
0, 1).clone().transpose(0, 1)).contiguous()
|
||||||
hidden_states, _B, _C = split_hidden_states_B_C_fn(
|
hidden_states, _B, _C = split_hidden_states_B_C_fn(
|
||||||
@ -579,10 +553,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
has_decode = num_decodes > 0
|
has_decode = num_decodes > 0
|
||||||
num_actual_tokens = num_prefill_tokens + num_decodes
|
num_actual_tokens = num_prefill_tokens + num_decodes
|
||||||
|
|
||||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
|
||||||
# Separate prefill and decode by splitting varlen input
|
# Separate prefill and decode by splitting varlen input
|
||||||
# Split along token dimension
|
# Split along token dimension
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||||
hidden_states_B_C[:num_actual_tokens],
|
hidden_states_B_C[:num_actual_tokens],
|
||||||
[num_decodes, num_prefill_tokens],
|
[num_decodes, num_prefill_tokens],
|
||||||
@ -602,26 +574,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
query_start_loc_p = (
|
query_start_loc_p = (
|
||||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||||
num_decodes if has_prefill else None)
|
num_decodes if has_prefill else None)
|
||||||
else:
|
|
||||||
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
|
|
||||||
hidden_states_B_C,
|
|
||||||
[num_prefill_tokens, num_decodes],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
dt_p, dt_d = torch.split(
|
|
||||||
dt,
|
|
||||||
[num_prefill_tokens, num_decodes],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
# Split along batch dimension
|
|
||||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
|
||||||
state_indices_tensor,
|
|
||||||
[num_prefills, num_decodes],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
|
|
||||||
1]
|
|
||||||
if has_prefill else None)
|
|
||||||
|
|
||||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||||
# and decode outputs
|
# and decode outputs
|
||||||
@ -633,18 +585,11 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
)
|
)
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
||||||
preallocated_ssm_out,
|
preallocated_ssm_out,
|
||||||
[num_decodes, num_prefill_tokens],
|
[num_decodes, num_prefill_tokens],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
|
||||||
preallocated_ssm_out,
|
|
||||||
[num_prefill_tokens, num_decodes],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process prefill requests
|
# Process prefill requests
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
@ -653,9 +598,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
# pointed to by "state_indices_tensor"
|
# pointed to by "state_indices_tensor"
|
||||||
x = hidden_states_B_C_p.transpose(
|
x = hidden_states_B_C_p.transpose(
|
||||||
0, 1) # this is the form that causal-conv see
|
0, 1) # this is the form that causal-conv see
|
||||||
if mamba2_metadata.cu_seqlen is None:
|
|
||||||
mamba2_metadata = update_metadata(x, query_start_loc_p,
|
|
||||||
mamba2_metadata)
|
|
||||||
hidden_states_B_C_p = causal_conv1d_fn(
|
hidden_states_B_C_p = causal_conv1d_fn(
|
||||||
x,
|
x,
|
||||||
conv_weights,
|
conv_weights,
|
||||||
@ -664,7 +606,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
conv_states=conv_state,
|
conv_states=conv_state,
|
||||||
has_initial_state=has_initial_states_p,
|
has_initial_state=has_initial_states_p,
|
||||||
cache_indices=state_indices_tensor_p,
|
cache_indices=state_indices_tensor_p,
|
||||||
metadata=mamba2_metadata,
|
metadata=attn_metadata,
|
||||||
query_start_loc=query_start_loc_p).transpose(
|
query_start_loc=query_start_loc_p).transpose(
|
||||||
0, 1)[:num_prefill_tokens]
|
0, 1)[:num_prefill_tokens]
|
||||||
|
|
||||||
@ -806,8 +748,6 @@ def mamba_mixer2(
|
|||||||
self = forward_context.no_compile_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
self.forward_cuda(hidden_states=hidden_states,
|
self.forward_cuda(hidden_states=hidden_states,
|
||||||
output=output,
|
output=output,
|
||||||
mamba_cache_params=None,
|
|
||||||
mamba2_metadata=None,
|
|
||||||
mup_vector=mup_vector)
|
mup_vector=mup_vector)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -100,7 +100,6 @@ class MambaStateShapeCalculator:
|
|||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
state_size: int,
|
state_size: int,
|
||||||
conv_kernel: int,
|
conv_kernel: int,
|
||||||
use_v1: bool = True,
|
|
||||||
) -> tuple[tuple[int, int], tuple[int, int]]:
|
) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||||
conv_state_shape = (divide(intermediate_size,
|
conv_state_shape = (divide(intermediate_size,
|
||||||
tp_world_size), conv_kernel - 1)
|
tp_world_size), conv_kernel - 1)
|
||||||
@ -108,10 +107,6 @@ class MambaStateShapeCalculator:
|
|||||||
temporal_state_shape = (divide(intermediate_size,
|
temporal_state_shape = (divide(intermediate_size,
|
||||||
tp_world_size), state_size)
|
tp_world_size), state_size)
|
||||||
|
|
||||||
# In V0, the conv_state shape was swapped during allocation in
|
|
||||||
# MambaCacheManager, but in V1 it needs to be determined here at the
|
|
||||||
# calculation level
|
|
||||||
if use_v1:
|
|
||||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||||
|
|
||||||
return conv_state_shape, temporal_state_shape
|
return conv_state_shape, temporal_state_shape
|
||||||
@ -126,7 +121,6 @@ class MambaStateShapeCalculator:
|
|||||||
head_dim: int,
|
head_dim: int,
|
||||||
state_size: int,
|
state_size: int,
|
||||||
conv_kernel: int,
|
conv_kernel: int,
|
||||||
use_v1: bool = True,
|
|
||||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
# if n_groups is not divisible by world_size, need to extend the shards
|
# if n_groups is not divisible by world_size, need to extend the shards
|
||||||
# to ensure all groups needed by a head is sharded along with it
|
# to ensure all groups needed by a head is sharded along with it
|
||||||
@ -137,8 +131,6 @@ class MambaStateShapeCalculator:
|
|||||||
|
|
||||||
# contiguous along 'dim' axis
|
# contiguous along 'dim' axis
|
||||||
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
|
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
|
||||||
if not use_v1:
|
|
||||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
|
||||||
|
|
||||||
# These are not TP-ed as they depend on A, dt_bias, D
|
# These are not TP-ed as they depend on A, dt_bias, D
|
||||||
# - they are typically small
|
# - they are typically small
|
||||||
@ -153,12 +145,9 @@ class MambaStateShapeCalculator:
|
|||||||
tp_world_size: int,
|
tp_world_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
conv_kernel: int,
|
conv_kernel: int,
|
||||||
use_v1: bool = True,
|
|
||||||
) -> tuple[tuple[int, int]]:
|
) -> tuple[tuple[int, int]]:
|
||||||
conv_dim = divide(intermediate_size, tp_world_size)
|
conv_dim = divide(intermediate_size, tp_world_size)
|
||||||
conv_state_shape = (conv_kernel - 1, conv_dim)
|
conv_state_shape = (conv_kernel - 1, conv_dim)
|
||||||
if not use_v1:
|
|
||||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
|
||||||
return (conv_state_shape, )
|
return (conv_state_shape, )
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -183,7 +172,6 @@ class MambaStateShapeCalculator:
|
|||||||
head_v_dim: int,
|
head_v_dim: int,
|
||||||
conv_kernel_size: int,
|
conv_kernel_size: int,
|
||||||
num_spec: int = 0,
|
num_spec: int = 0,
|
||||||
use_v1: bool = True,
|
|
||||||
):
|
):
|
||||||
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
|
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
|
||||||
conv_state_shape = (
|
conv_state_shape = (
|
||||||
@ -191,10 +179,6 @@ class MambaStateShapeCalculator:
|
|||||||
conv_kernel_size - 1 + num_spec,
|
conv_kernel_size - 1 + num_spec,
|
||||||
)
|
)
|
||||||
|
|
||||||
# In V0, the conv_state shape was swapped during allocation in
|
|
||||||
# MambaCacheManager, but in V1 it needs to be determined here at the
|
|
||||||
# calculation level
|
|
||||||
if use_v1:
|
|
||||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||||
|
|
||||||
temporal_state_shape = (divide(num_v_heads,
|
temporal_state_shape = (divide(num_v_heads,
|
||||||
|
|||||||
@ -420,9 +420,7 @@ def causal_conv1d_fn(
|
|||||||
x = x.to(conv_states.dtype)
|
x = x.to(conv_states.dtype)
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
cu_seqlen = metadata.cu_seqlen
|
|
||||||
nums_dict = metadata.nums_dict
|
nums_dict = metadata.nums_dict
|
||||||
#x = metadata.x
|
|
||||||
args = nums_dict
|
args = nums_dict
|
||||||
batch_ptr = metadata.batch_ptr
|
batch_ptr = metadata.batch_ptr
|
||||||
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
|
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
|
||||||
@ -926,7 +924,6 @@ def causal_conv1d_update(
|
|||||||
query_start_loc: Optional[torch.Tensor] = None,
|
query_start_loc: Optional[torch.Tensor] = None,
|
||||||
max_query_len: int = -1,
|
max_query_len: int = -1,
|
||||||
pad_slot_id: int = PAD_SLOT_ID,
|
pad_slot_id: int = PAD_SLOT_ID,
|
||||||
metadata=None,
|
|
||||||
validate_data=False,
|
validate_data=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -8,7 +8,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@ -18,7 +17,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
|
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
@ -71,15 +69,11 @@ class ShortConv(MambaBase, CustomOp):
|
|||||||
prefix=f"{prefix}.out_proj",
|
prefix=f"{prefix}.out_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert envs.VLLM_USE_V1, ("ShortConv layers are only supported in V1")
|
|
||||||
compilation_config = get_current_vllm_config().compilation_config
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
if prefix in compilation_config.static_forward_context:
|
if prefix in compilation_config.static_forward_context:
|
||||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
compilation_config.static_forward_context[prefix] = self
|
compilation_config.static_forward_context[prefix] = self
|
||||||
# The outer list is for v0 PP virtual engine. Though this code path
|
self.kv_cache = (torch.tensor([]), )
|
||||||
# only runs for v1, we have to do this to unify with the interface
|
|
||||||
# of Attention + v0 PP.
|
|
||||||
self.kv_cache = [(torch.tensor([]), )]
|
|
||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
@ -89,7 +83,6 @@ class ShortConv(MambaBase, CustomOp):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
conv_metadata: ShortConvAttentionMetadata,
|
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -97,7 +90,6 @@ class ShortConv(MambaBase, CustomOp):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
conv_metadata: ShortConvAttentionMetadata,
|
|
||||||
):
|
):
|
||||||
torch.ops.vllm.short_conv(
|
torch.ops.vllm.short_conv(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -109,7 +101,6 @@ class ShortConv(MambaBase, CustomOp):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
conv_metadata: ShortConvAttentionMetadata,
|
|
||||||
):
|
):
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
# ShortConvAttentionMetadata contains metadata necessary for the
|
# ShortConvAttentionMetadata contains metadata necessary for the
|
||||||
@ -121,7 +112,6 @@ class ShortConv(MambaBase, CustomOp):
|
|||||||
if attn_metadata is not None:
|
if attn_metadata is not None:
|
||||||
assert isinstance(attn_metadata, dict)
|
assert isinstance(attn_metadata, dict)
|
||||||
attn_metadata = attn_metadata[self.prefix]
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
conv_metadata = attn_metadata
|
|
||||||
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
|
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||||
@ -181,9 +171,6 @@ class ShortConv(MambaBase, CustomOp):
|
|||||||
|
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
Bx_p = (B_p * x_p).transpose(0, 1)
|
Bx_p = (B_p * x_p).transpose(0, 1)
|
||||||
if conv_metadata.cu_seqlen is None:
|
|
||||||
conv_metadata = update_metadata(Bx_p, query_start_loc_p,
|
|
||||||
conv_metadata)
|
|
||||||
Bx = causal_conv1d_fn(Bx_p,
|
Bx = causal_conv1d_fn(Bx_p,
|
||||||
conv_weights,
|
conv_weights,
|
||||||
self.conv.bias,
|
self.conv.bias,
|
||||||
@ -191,7 +178,7 @@ class ShortConv(MambaBase, CustomOp):
|
|||||||
conv_states=conv_state,
|
conv_states=conv_state,
|
||||||
has_initial_state=has_initial_states_p,
|
has_initial_state=has_initial_states_p,
|
||||||
cache_indices=state_indices_tensor_p,
|
cache_indices=state_indices_tensor_p,
|
||||||
metadata=conv_metadata,
|
metadata=attn_metadata,
|
||||||
query_start_loc=query_start_loc_p).transpose(
|
query_start_loc=query_start_loc_p).transpose(
|
||||||
0, 1)[:num_prefill_tokens]
|
0, 1)[:num_prefill_tokens]
|
||||||
|
|
||||||
@ -248,9 +235,7 @@ def short_conv(
|
|||||||
) -> None:
|
) -> None:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
self = forward_context.no_compile_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
self.forward_cuda(hidden_states=hidden_states,
|
self.forward_cuda(hidden_states=hidden_states, output=output)
|
||||||
output=output,
|
|
||||||
conv_metadata=None)
|
|
||||||
|
|
||||||
|
|
||||||
def short_conv_fake(
|
def short_conv_fake(
|
||||||
|
|||||||
@ -9,21 +9,17 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import BambaConfig
|
from transformers import BambaConfig
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|
||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
@ -32,10 +28,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
|
||||||
MambaCacheParams)
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import LayerBlockType
|
|
||||||
|
|
||||||
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
|
||||||
SupportsQuant)
|
SupportsQuant)
|
||||||
@ -115,8 +108,6 @@ class BambaMixerDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if residual is None:
|
if residual is None:
|
||||||
@ -127,7 +118,7 @@ class BambaMixerDecoderLayer(nn.Module):
|
|||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
|
|
||||||
output = torch.empty_like(hidden_states)
|
output = torch.empty_like(hidden_states)
|
||||||
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
self.mamba(hidden_states, output)
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.pre_ff_layernorm(output, residual)
|
hidden_states, residual = self.pre_ff_layernorm(output, residual)
|
||||||
hidden_states = self.feed_forward(hidden_states)
|
hidden_states = self.feed_forward(hidden_states)
|
||||||
@ -315,22 +306,10 @@ class BambaModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
|
||||||
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
mamba2_metadata = prepare_mamba2_metadata(
|
|
||||||
chunk_size=self.config.mamba_chunk_size,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# v1 get mamba2_metadata from forward_context
|
|
||||||
mamba2_metadata = None
|
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
@ -343,23 +322,11 @@ class BambaModel(nn.Module):
|
|||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
num_attn = 0
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
if isinstance(layer, BambaAttentionDecoderLayer):
|
|
||||||
num_attn += 1
|
|
||||||
|
|
||||||
layer_mamba_cache_params = None
|
|
||||||
if isinstance(layer,
|
|
||||||
BambaMixerDecoderLayer) and mamba_cache_params:
|
|
||||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
|
||||||
i - num_attn)
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
mamba_cache_params=layer_mamba_cache_params,
|
|
||||||
mamba2_metadata=mamba2_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@ -457,13 +424,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
use_v1: bool = True,
|
|
||||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vllm_config: vLLM config
|
vllm_config: vLLM config
|
||||||
use_v1: Get shapes for V1 (or V0)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Tuple containing:
|
||||||
@ -482,7 +447,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
head_dim=hf_config.mamba_d_head,
|
head_dim=hf_config.mamba_d_head,
|
||||||
state_size=hf_config.mamba_d_state,
|
state_size=hf_config.mamba_d_state,
|
||||||
conv_kernel=hf_config.mamba_d_conv,
|
conv_kernel=hf_config.mamba_d_conv,
|
||||||
use_v1=use_v1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
@ -515,8 +479,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
prefix=maybe_prefix(prefix, "lm_head"),
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
)
|
)
|
||||||
# Used to track and store by the Mamba cache between steps.
|
|
||||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
config.vocab_size)
|
config.vocab_size)
|
||||||
@ -534,39 +496,11 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
mamba_cache_params = None
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
if not envs.VLLM_USE_V1:
|
inputs_embeds)
|
||||||
if self.mamba_cache is None:
|
|
||||||
num_mamba_layers = \
|
|
||||||
self.model_config.get_num_layers_by_block_type(
|
|
||||||
self.vllm_config.parallel_config,
|
|
||||||
LayerBlockType.mamba
|
|
||||||
)
|
|
||||||
mamba_state_shape = \
|
|
||||||
self.get_mamba_state_shape_from_config(
|
|
||||||
self.vllm_config, use_v1=False)
|
|
||||||
mamba_state_dtype = \
|
|
||||||
self.get_mamba_state_dtype_from_config(
|
|
||||||
self.vllm_config)
|
|
||||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
|
||||||
num_mamba_layers,
|
|
||||||
*mamba_state_shape,
|
|
||||||
*mamba_state_dtype)
|
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
|
||||||
|
|
||||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
|
||||||
intermediate_tensors, inputs_embeds)
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
|
||||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
|
||||||
input_buffers, **kwargs)
|
|
||||||
|
|
||||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
|
||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -1,137 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
|
||||||
|
|
||||||
|
|
||||||
class ConstantSizeCache(ABC):
|
|
||||||
"""
|
|
||||||
Abstract base class for managing constant size caches
|
|
||||||
like Mamba and Minimax.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, max_batch_size: int):
|
|
||||||
# Maps between the request id and a dict that maps between the seq_id
|
|
||||||
# and its index inside the cache
|
|
||||||
self.cache_indices_mapping: dict[str, dict[int, int]] = {}
|
|
||||||
self.free_cache_indices = list(range(max_batch_size))
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def cache(self) -> Any:
|
|
||||||
"""Return the underlying cache tensor(s)"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _copy_cache(self, from_index: int, to_index: int):
|
|
||||||
"""Copy cache data from one index to another"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def current_run_tensors(self, **kwargs) -> tuple:
|
|
||||||
"""
|
|
||||||
Return the tensors for the current run's conv and ssm state.
|
|
||||||
"""
|
|
||||||
if "seqlen_agnostic_capture_inputs" not in kwargs:
|
|
||||||
# We get here only on Prefill/Eager mode runs
|
|
||||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
|
||||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
|
||||||
|
|
||||||
self._release_finished_requests(finished_requests_ids)
|
|
||||||
state_indices = self._prepare_current_run_cache(
|
|
||||||
request_ids_to_seq_ids, finished_requests_ids)
|
|
||||||
|
|
||||||
state_indices_tensor = torch.as_tensor(state_indices,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cuda")
|
|
||||||
cache_tensors = self.cache
|
|
||||||
else:
|
|
||||||
# CUDA graph capturing runs
|
|
||||||
cache_tensors, state_indices_tensor = kwargs[
|
|
||||||
"seqlen_agnostic_capture_inputs"]
|
|
||||||
|
|
||||||
return (cache_tensors, state_indices_tensor)
|
|
||||||
|
|
||||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
|
||||||
"""
|
|
||||||
Copy the relevant state_indices into the CUDA graph input buffer
|
|
||||||
"""
|
|
||||||
assert all(
|
|
||||||
key in kwargs
|
|
||||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
|
||||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
|
||||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
|
||||||
assert "seqlen_agnostic_capture_inputs" in input_buffers
|
|
||||||
_, input_state_indices_buffer = input_buffers[
|
|
||||||
"seqlen_agnostic_capture_inputs"]
|
|
||||||
|
|
||||||
self._release_finished_requests(finished_requests_ids)
|
|
||||||
state_indices = self._prepare_current_run_cache(
|
|
||||||
request_ids_to_seq_ids, finished_requests_ids)
|
|
||||||
cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
|
|
||||||
state_indices)
|
|
||||||
state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
|
|
||||||
|
|
||||||
input_state_indices_buffer.copy_(
|
|
||||||
torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
|
|
||||||
|
|
||||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
|
||||||
"""
|
|
||||||
Provide the CUDA graph capture runs with a buffer in adjusted size.
|
|
||||||
The buffer is used to maintain the Cache during the CUDA graph replay
|
|
||||||
runs.
|
|
||||||
"""
|
|
||||||
state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cuda")
|
|
||||||
return (self.cache, state_indices_tensor)
|
|
||||||
|
|
||||||
def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
|
|
||||||
finished_requests_ids) -> int:
|
|
||||||
"""
|
|
||||||
Assign (req_id,seq_id) pair to a `destination_index` index, if
|
|
||||||
already occupied, move the occupying index to a free index.
|
|
||||||
"""
|
|
||||||
if cur_rid in finished_requests_ids:
|
|
||||||
# set as pad, do not allocate destination index
|
|
||||||
return PAD_SLOT_ID
|
|
||||||
elif cur_rid not in self.cache_indices_mapping:
|
|
||||||
destination_index = self.free_cache_indices.pop()
|
|
||||||
self.cache_indices_mapping[cur_rid] = {seq_id: destination_index}
|
|
||||||
return destination_index
|
|
||||||
elif seq_id not in (seq_ids2indices :=
|
|
||||||
self.cache_indices_mapping[cur_rid]):
|
|
||||||
# parallel sampling , where n > 1, assume prefill have
|
|
||||||
# already happened, so we copy the
|
|
||||||
# existing cache into the siblings seq_ids caches
|
|
||||||
index_exists = next(iter(seq_ids2indices.values()))
|
|
||||||
# case of decoding n>1, copy prefill cache to decoding indices
|
|
||||||
destination_index = self.free_cache_indices.pop()
|
|
||||||
self._copy_cache(from_index=index_exists,
|
|
||||||
to_index=destination_index)
|
|
||||||
self.cache_indices_mapping[cur_rid][seq_id] = destination_index
|
|
||||||
return destination_index
|
|
||||||
else:
|
|
||||||
return self.cache_indices_mapping[cur_rid][seq_id]
|
|
||||||
|
|
||||||
def _prepare_current_run_cache(
|
|
||||||
self, request_ids_to_seq_ids: dict[str, list[int]],
|
|
||||||
finished_requests_ids: list[str]) -> list[int]:
|
|
||||||
return [
|
|
||||||
self._assign_seq_id_to_cache_index(req_id, seq_id,
|
|
||||||
finished_requests_ids)
|
|
||||||
for req_id, seq_ids in request_ids_to_seq_ids.items()
|
|
||||||
for seq_id in seq_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
def _release_finished_requests(self,
|
|
||||||
finished_seq_groups_req_ids: list[str]):
|
|
||||||
for req_id in finished_seq_groups_req_ids:
|
|
||||||
if req_id in self.cache_indices_mapping:
|
|
||||||
for seq_id in self.cache_indices_mapping[req_id]:
|
|
||||||
self.free_cache_indices.append(
|
|
||||||
self.cache_indices_mapping[req_id][seq_id])
|
|
||||||
self.cache_indices_mapping.pop(req_id)
|
|
||||||
@ -8,21 +8,17 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import FalconH1Config
|
from transformers import FalconH1Config
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|
||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
@ -31,8 +27,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
|
||||||
MambaCacheParams)
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
||||||
@ -179,16 +173,12 @@ class FalconH1SSMDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
output = torch.empty_like(hidden_states)
|
output = torch.empty_like(hidden_states)
|
||||||
self.mamba(
|
self.mamba(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
output,
|
output,
|
||||||
mamba_cache_params,
|
|
||||||
mamba2_metadata=mamba2_metadata,
|
|
||||||
mup_vector=self.mup_vector,
|
mup_vector=self.mup_vector,
|
||||||
)
|
)
|
||||||
return output, residual
|
return output, residual
|
||||||
@ -364,8 +354,6 @@ class FalconH1ParallelHybrid(nn.Module):
|
|||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -382,12 +370,10 @@ class FalconH1ParallelHybrid(nn.Module):
|
|||||||
|
|
||||||
# Process input through the SSM branch.
|
# Process input through the SSM branch.
|
||||||
# FalconH1SSMDecoderLayer expects hidden_states, attn_metadata,
|
# FalconH1SSMDecoderLayer expects hidden_states, attn_metadata,
|
||||||
# residual, mamba_cache_params, and sequence_idx.
|
# residual, and sequence_idx.
|
||||||
ssm_hidden, _ = self.mamba(
|
ssm_hidden, _ = self.mamba(
|
||||||
hidden_states=hidden_states * self.ssm_in_multiplier,
|
hidden_states=hidden_states * self.ssm_in_multiplier,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
mamba_cache_params=mamba_cache_params,
|
|
||||||
mamba2_metadata=mamba2_metadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
# Sum the outputs from both branches.
|
# Sum the outputs from both branches.
|
||||||
@ -464,25 +450,10 @@ class FalconH1Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
# pass a sequence index tensor, that is required for
|
|
||||||
# proper continuous batching computation including
|
|
||||||
# chunked prefill
|
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
|
||||||
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
mamba2_metadata = prepare_mamba2_metadata(
|
|
||||||
chunk_size=self.config.mamba_chunk_size,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# v1 get mamba2_metadata from forward_context
|
|
||||||
mamba2_metadata = None
|
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
hidden_states = inputs_embeds * self.embedding_multiplier
|
hidden_states = inputs_embeds * self.embedding_multiplier
|
||||||
@ -495,14 +466,9 @@ class FalconH1Model(nn.Module):
|
|||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for i in range(self.start_layer, self.end_layer):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
layer_mamba_cache_params = None
|
|
||||||
if mamba_cache_params:
|
|
||||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
|
|
||||||
hidden_states = layer(
|
hidden_states = layer(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
mamba_cache_params=layer_mamba_cache_params,
|
|
||||||
mamba2_metadata=mamba2_metadata,
|
|
||||||
)
|
)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
@ -541,13 +507,11 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
use_v1: bool = True,
|
|
||||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vllm_config: vLLM config
|
vllm_config: vLLM config
|
||||||
use_v1: Get shapes for V1 (or V0)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Tuple containing:
|
||||||
@ -570,7 +534,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
head_dim=hf_config.mamba_d_head,
|
head_dim=hf_config.mamba_d_head,
|
||||||
state_size=hf_config.mamba_d_state,
|
state_size=hf_config.mamba_d_state,
|
||||||
conv_kernel=hf_config.mamba_d_conv,
|
conv_kernel=hf_config.mamba_d_conv,
|
||||||
use_v1=use_v1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
@ -592,7 +555,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
self.tie_word_embeddings = config.tie_word_embeddings
|
self.tie_word_embeddings = config.tie_word_embeddings
|
||||||
self.unpadded_vocab_size = config.vocab_size
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
|
||||||
if lora_config:
|
if lora_config:
|
||||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
@ -637,40 +599,15 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
mamba_cache_params = None
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
if self.mamba_cache is None:
|
|
||||||
mamba_state_shape = \
|
|
||||||
self.get_mamba_state_shape_from_config(
|
|
||||||
self.vllm_config, use_v1=False)
|
|
||||||
mamba_state_dtype = \
|
|
||||||
self.get_mamba_state_dtype_from_config(
|
|
||||||
self.vllm_config)
|
|
||||||
self.mamba_cache = MambaCacheManager(
|
|
||||||
self.vllm_config,
|
|
||||||
self.config.num_hidden_layers,
|
|
||||||
*mamba_state_shape,
|
|
||||||
*mamba_state_dtype,
|
|
||||||
)
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
mamba_cache_params,
|
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
|
||||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
|
||||||
input_buffers, **kwargs)
|
|
||||||
|
|
||||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
|
||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -9,19 +9,15 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import GraniteMoeHybridConfig
|
from transformers import GraniteMoeHybridConfig
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|
||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
@ -30,10 +26,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
|
||||||
MambaCacheParams)
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import LayerBlockType
|
|
||||||
|
|
||||||
from .granitemoe import GraniteMoeMoE
|
from .granitemoe import GraniteMoeMoE
|
||||||
from .granitemoeshared import GraniteMoeSharedMLP
|
from .granitemoeshared import GraniteMoeSharedMLP
|
||||||
@ -102,14 +95,12 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
output = torch.empty_like(hidden_states)
|
output = torch.empty_like(hidden_states)
|
||||||
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
self.mamba(hidden_states, output)
|
||||||
hidden_states = residual + output * self.residual_multiplier
|
hidden_states = residual + output * self.residual_multiplier
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -182,8 +173,6 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
@ -366,22 +355,10 @@ class GraniteMoeHybridModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
|
||||||
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
mamba2_metadata = prepare_mamba2_metadata(
|
|
||||||
chunk_size=self.config.mamba_chunk_size,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# v1 get mamba2_metadata from forward_context
|
|
||||||
mamba2_metadata = None
|
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
@ -399,20 +376,9 @@ class GraniteMoeHybridModel(nn.Module):
|
|||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer):
|
if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer):
|
||||||
num_attn += 1
|
num_attn += 1
|
||||||
|
hidden_states, residual = layer(positions=positions,
|
||||||
layer_mamba_cache_params = None
|
|
||||||
if isinstance(
|
|
||||||
layer,
|
|
||||||
GraniteMoeHybridMambaDecoderLayer) and mamba_cache_params:
|
|
||||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
|
||||||
i - num_attn)
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
|
||||||
positions=positions,
|
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
residual=residual)
|
||||||
mamba_cache_params=layer_mamba_cache_params,
|
|
||||||
mamba2_metadata=mamba2_metadata)
|
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
@ -552,13 +518,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
|||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
use_v1: bool = True,
|
|
||||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vllm_config: vLLM config
|
vllm_config: vLLM config
|
||||||
use_v1: Get shapes for V1 (or V0)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Tuple containing:
|
||||||
@ -577,7 +541,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
|||||||
head_dim=hf_config.mamba_d_head,
|
head_dim=hf_config.mamba_d_head,
|
||||||
state_size=hf_config.mamba_d_state,
|
state_size=hf_config.mamba_d_state,
|
||||||
conv_kernel=hf_config.mamba_d_conv,
|
conv_kernel=hf_config.mamba_d_conv,
|
||||||
use_v1=use_v1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
@ -620,9 +583,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
|||||||
scale=1 /
|
scale=1 /
|
||||||
self.config.logits_scaling)
|
self.config.logits_scaling)
|
||||||
|
|
||||||
# Used to track and store by the Mamba cache between steps.
|
|
||||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.model.make_empty_intermediate_tensors)
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
@ -636,38 +596,11 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
|||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
mamba_cache_params = None
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
if not envs.VLLM_USE_V1:
|
inputs_embeds)
|
||||||
if self.mamba_cache is None:
|
|
||||||
num_mamba_layers = (
|
|
||||||
self.model_config.get_num_layers_by_block_type(
|
|
||||||
self.vllm_config.parallel_config,
|
|
||||||
LayerBlockType.mamba))
|
|
||||||
mamba_state_shape = \
|
|
||||||
self.get_mamba_state_shape_from_config(
|
|
||||||
self.vllm_config, use_v1=False)
|
|
||||||
mamba_state_dtype = \
|
|
||||||
self.get_mamba_state_dtype_from_config(
|
|
||||||
self.vllm_config)
|
|
||||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
|
||||||
num_mamba_layers,
|
|
||||||
*mamba_state_shape,
|
|
||||||
*mamba_state_dtype)
|
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
|
||||||
|
|
||||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
|
||||||
intermediate_tensors, inputs_embeds)
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
|
||||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
|
||||||
input_buffers, **kwargs)
|
|
||||||
|
|
||||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
|
||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -9,7 +9,6 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import JambaConfig
|
from transformers import JambaConfig
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
@ -30,10 +29,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.llama import LlamaMLP as JambaMLP
|
from vllm.model_executor.models.llama import LlamaMLP as JambaMLP
|
||||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
|
||||||
MambaCacheParams)
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import LayerBlockType
|
|
||||||
|
|
||||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||||
@ -145,7 +141,6 @@ class JambaMambaDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if residual is None:
|
if residual is None:
|
||||||
@ -156,7 +151,7 @@ class JambaMambaDecoderLayer(nn.Module):
|
|||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
|
|
||||||
output = torch.empty_like(hidden_states)
|
output = torch.empty_like(hidden_states)
|
||||||
self.mamba(hidden_states, output, mamba_cache_params)
|
self.mamba(hidden_states, output)
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.pre_ff_layernorm(output, residual)
|
hidden_states, residual = self.pre_ff_layernorm(output, residual)
|
||||||
hidden_states = self.feed_forward(hidden_states)
|
hidden_states = self.feed_forward(hidden_states)
|
||||||
@ -333,7 +328,6 @@ class JambaModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -348,24 +342,11 @@ class JambaModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
kv_cache_index = 0
|
|
||||||
mamba_cache_index = 0
|
|
||||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||||
layer_mamba_cache_params = None
|
hidden_states, residual = layer(positions=positions,
|
||||||
if isinstance(layer, JambaAttentionDecoderLayer):
|
|
||||||
kv_cache_index += 1
|
|
||||||
if isinstance(layer,
|
|
||||||
JambaMambaDecoderLayer) and mamba_cache_params:
|
|
||||||
current_state_layer = mamba_cache_index
|
|
||||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
|
||||||
current_state_layer)
|
|
||||||
mamba_cache_index += 1
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
|
||||||
positions=positions,
|
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
residual=residual)
|
||||||
mamba_cache_params=layer_mamba_cache_params)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
@ -503,8 +484,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
prefix=maybe_prefix(prefix, "lm_head"),
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
)
|
)
|
||||||
# Used to track and store by the Mamba cache between steps.
|
|
||||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
config.vocab_size)
|
config.vocab_size)
|
||||||
@ -521,24 +500,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
# NOTE: mamba_cache_params is not needed for v1
|
|
||||||
mamba_cache_params = None
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
if self.mamba_cache is None:
|
|
||||||
num_layers = self.model_config.get_num_layers_by_block_type(
|
|
||||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
|
||||||
state_shape = self.get_mamba_state_shape_from_config(
|
|
||||||
self.vllm_config)
|
|
||||||
state_dtype = self.get_mamba_state_dtype_from_config(
|
|
||||||
self.vllm_config)
|
|
||||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
|
||||||
num_layers, *state_shape,
|
|
||||||
*state_dtype)
|
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
|
inputs_embeds)
|
||||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
|
||||||
intermediate_tensors, inputs_embeds)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||||
@ -574,7 +538,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
intermediate_size=hf_config.mamba_expand * hidden_size,
|
intermediate_size=hf_config.mamba_expand * hidden_size,
|
||||||
state_size=hf_config.mamba_d_state,
|
state_size=hf_config.mamba_d_state,
|
||||||
conv_kernel=hf_config.mamba_d_conv,
|
conv_kernel=hf_config.mamba_d_conv,
|
||||||
use_v1=envs.VLLM_USE_V1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import Lfm2Config
|
from transformers import Lfm2Config
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
@ -297,7 +296,6 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
|
|||||||
self.conv(
|
self.conv(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
output,
|
output,
|
||||||
conv_metadata=None,
|
|
||||||
)
|
)
|
||||||
hidden_states, residual = self.ffn_norm(output, residual)
|
hidden_states, residual = self.ffn_norm(output, residual)
|
||||||
hidden_states = self.feed_forward(hidden_states)
|
hidden_states = self.feed_forward(hidden_states)
|
||||||
@ -459,13 +457,11 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
use_v1: bool = True,
|
|
||||||
) -> tuple[tuple[int, int]]:
|
) -> tuple[tuple[int, int]]:
|
||||||
""" Calculate shapes for LFM2's convolutional cache.
|
""" Calculate shapes for LFM2's convolutional cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vllm_config: vLLM config
|
vllm_config: vLLM config
|
||||||
use_v1: Get shapes for V1 (or V0)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Tuple containing:
|
||||||
@ -478,7 +474,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
tp_world_size=parallel_config.tensor_parallel_size,
|
tp_world_size=parallel_config.tensor_parallel_size,
|
||||||
intermediate_size=hf_config.conv_dim,
|
intermediate_size=hf_config.conv_dim,
|
||||||
conv_kernel=hf_config.conv_L_cache,
|
conv_kernel=hf_config.conv_L_cache,
|
||||||
use_v1=use_v1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
@ -489,8 +484,6 @@ class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
assert (not cache_config.enable_prefix_caching
|
assert (not cache_config.enable_prefix_caching
|
||||||
), "Lfm2 currently does not support prefix caching"
|
), "Lfm2 currently does not support prefix caching"
|
||||||
assert envs.VLLM_USE_V1, (
|
|
||||||
"Lfm2ForCausalLM doesn't support vLLM v0. Please enable v1")
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|||||||
@ -8,7 +8,6 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import MambaConfig
|
from transformers import MambaConfig
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
@ -24,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||||
IsAttentionFree, SupportsPP)
|
IsAttentionFree, SupportsPP)
|
||||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
|
||||||
MambaCacheParams)
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import LayerBlockType
|
|
||||||
|
|
||||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
@ -72,7 +68,6 @@ class MambaDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if residual is None:
|
if residual is None:
|
||||||
@ -82,7 +77,7 @@ class MambaDecoderLayer(nn.Module):
|
|||||||
hidden_states, residual = self.norm(hidden_states, residual)
|
hidden_states, residual = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
output = torch.empty_like(hidden_states)
|
output = torch.empty_like(hidden_states)
|
||||||
self.mixer(hidden_states, output, mamba_cache_params)
|
self.mixer(hidden_states, output)
|
||||||
return output, residual
|
return output, residual
|
||||||
|
|
||||||
|
|
||||||
@ -134,7 +129,6 @@ class MambaModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
mamba_cache_params: Optional[MambaCacheParams] = None,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -151,17 +145,9 @@ class MambaModel(nn.Module):
|
|||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for i in range(self.start_layer, self.end_layer):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
|
hidden_states, residual = layer(positions=positions,
|
||||||
layer_cache_params = None
|
|
||||||
if mamba_cache_params is not None:
|
|
||||||
layer_cache_params = mamba_cache_params.at_layer_idx(
|
|
||||||
i - self.start_layer)
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
|
||||||
positions=positions,
|
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
residual=residual)
|
||||||
mamba_cache_params=layer_cache_params)
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
@ -225,9 +211,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
|||||||
prefix=maybe_prefix(prefix, "lm_head"),
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Used to track and store by the Mamba cache between steps.
|
|
||||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
config.vocab_size)
|
config.vocab_size)
|
||||||
|
|
||||||
@ -244,22 +227,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
|||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
mamba_cache_params = None
|
hidden_states = self.backbone(input_ids, positions,
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
if self.mamba_cache is None:
|
|
||||||
num_layers = self.model_config.get_num_layers_by_block_type(
|
|
||||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
|
||||||
state_shape = self.get_mamba_state_shape_from_config(
|
|
||||||
self.vllm_config)
|
|
||||||
state_dtype = self.get_mamba_state_dtype_from_config(
|
|
||||||
self.vllm_config)
|
|
||||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
|
||||||
num_layers, *state_shape,
|
|
||||||
*state_dtype)
|
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
|
||||||
|
|
||||||
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
|
|
||||||
intermediate_tensors, inputs_embeds)
|
intermediate_tensors, inputs_embeds)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@ -288,8 +256,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
|
|||||||
tp_world_size=parallel_config.tensor_parallel_size,
|
tp_world_size=parallel_config.tensor_parallel_size,
|
||||||
intermediate_size=hf_config.intermediate_size,
|
intermediate_size=hf_config.intermediate_size,
|
||||||
state_size=hf_config.state_size,
|
state_size=hf_config.state_size,
|
||||||
conv_kernel=hf_config.conv_kernel,
|
conv_kernel=hf_config.conv_kernel)
|
||||||
use_v1=envs.VLLM_USE_V1)
|
|
||||||
|
|
||||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||||
|
|||||||
@ -8,16 +8,11 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import MambaConfig
|
from transformers import MambaConfig
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|
||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
@ -28,10 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||||
IsAttentionFree)
|
IsAttentionFree)
|
||||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
|
||||||
MambaCacheParams)
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import LayerBlockType
|
|
||||||
|
|
||||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
@ -74,8 +66,6 @@ class Mamba2DecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if residual is None:
|
if residual is None:
|
||||||
@ -85,7 +75,7 @@ class Mamba2DecoderLayer(nn.Module):
|
|||||||
hidden_states, residual = self.norm(hidden_states, residual)
|
hidden_states, residual = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
output = torch.empty_like(hidden_states)
|
output = torch.empty_like(hidden_states)
|
||||||
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
self.mixer(hidden_states, output)
|
||||||
return output, residual
|
return output, residual
|
||||||
|
|
||||||
|
|
||||||
@ -137,7 +127,6 @@ class Mamba2Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -152,25 +141,10 @@ class Mamba2Model(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
|
||||||
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
mamba2_metadata = prepare_mamba2_metadata(
|
|
||||||
chunk_size=self.config.chunk_size,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# v1 get mamba2_metadata from forward_context
|
|
||||||
mamba2_metadata = None
|
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(positions=positions,
|
||||||
positions=positions,
|
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
residual=residual)
|
||||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
|
||||||
i - self.start_layer) if mamba_cache_params else None,
|
|
||||||
mamba2_metadata=mamba2_metadata)
|
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
@ -222,13 +196,11 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
|||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
use_v1: bool = True,
|
|
||||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vllm_config: vLLM config
|
vllm_config: vLLM config
|
||||||
use_v1: Get shapes for V1 (or V0)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Tuple containing:
|
||||||
@ -247,7 +219,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
|||||||
head_dim=hf_config.head_dim,
|
head_dim=hf_config.head_dim,
|
||||||
state_size=hf_config.state_size,
|
state_size=hf_config.state_size,
|
||||||
conv_kernel=hf_config.conv_kernel,
|
conv_kernel=hf_config.conv_kernel,
|
||||||
use_v1=use_v1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
@ -282,9 +253,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
|||||||
if config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)
|
self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings)
|
||||||
|
|
||||||
# Used to track and store by the Mamba cache between steps.
|
|
||||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
config.vocab_size)
|
config.vocab_size)
|
||||||
|
|
||||||
@ -300,29 +268,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
|||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
if self.mamba_cache is None:
|
|
||||||
num_mamba_layers = (
|
|
||||||
self.model_config.get_num_layers_by_block_type(
|
|
||||||
self.vllm_config.parallel_config,
|
|
||||||
LayerBlockType.mamba))
|
|
||||||
mamba_state_shape = \
|
|
||||||
self.get_mamba_state_shape_from_config(
|
|
||||||
self.vllm_config, use_v1=False)
|
|
||||||
mamba_state_dtype = \
|
|
||||||
self.get_mamba_state_dtype_from_config(
|
|
||||||
self.vllm_config)
|
|
||||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
|
||||||
num_mamba_layers,
|
|
||||||
*mamba_state_shape,
|
|
||||||
*mamba_state_dtype)
|
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
hidden_states = self.backbone(input_ids, positions,
|
||||||
else:
|
|
||||||
# NOTE: mamba_cache_params is not needed for v1
|
|
||||||
mamba_cache_params = None
|
|
||||||
|
|
||||||
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
|
|
||||||
intermediate_tensors, inputs_embeds)
|
intermediate_tensors, inputs_embeds)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@ -1,83 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MambaCacheParams:
|
|
||||||
conv_state: torch.Tensor = torch.Tensor()
|
|
||||||
ssm_state: torch.Tensor = torch.Tensor()
|
|
||||||
state_indices_tensor: torch.Tensor = torch.Tensor()
|
|
||||||
|
|
||||||
def at_layer_idx(self, layer_idx):
|
|
||||||
return MambaCacheParams(self.conv_state[layer_idx],
|
|
||||||
self.ssm_state[layer_idx],
|
|
||||||
self.state_indices_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
class MambaCacheManager(ConstantSizeCache):
|
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int,
|
|
||||||
conv_state_shape: tuple[int, int],
|
|
||||||
temporal_state_shape: tuple[int, int],
|
|
||||||
conv_state_dtype: torch.dtype,
|
|
||||||
temporal_state_dtype: torch.dtype):
|
|
||||||
|
|
||||||
self.conv_state_dtype = conv_state_dtype
|
|
||||||
self.temporal_state_dtype = temporal_state_dtype
|
|
||||||
|
|
||||||
# Determine max batch size to set size of MambaCache
|
|
||||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
|
||||||
if not vllm_config.model_config.enforce_eager:
|
|
||||||
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
|
|
||||||
|
|
||||||
# Initialize parent class
|
|
||||||
super().__init__(max_batch_size)
|
|
||||||
|
|
||||||
# assume conv_state = (dim, state_len)
|
|
||||||
assert conv_state_shape[0] > conv_state_shape[1]
|
|
||||||
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
|
||||||
(conv_state_shape[1], conv_state_shape[0]),
|
|
||||||
dtype=self.conv_state_dtype,
|
|
||||||
device="cuda").transpose(-1, -2)
|
|
||||||
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
|
||||||
temporal_state_shape,
|
|
||||||
dtype=self.temporal_state_dtype,
|
|
||||||
device="cuda")
|
|
||||||
|
|
||||||
self._mamba_cache = (conv_state, temporal_state)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cache(self):
|
|
||||||
return self._mamba_cache
|
|
||||||
|
|
||||||
def _copy_cache(self, from_index: int, to_index: int):
|
|
||||||
for cache_t in self.cache:
|
|
||||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
|
||||||
non_blocking=True)
|
|
||||||
|
|
||||||
def current_run_tensors(self, **kwargs) -> MambaCacheParams:
|
|
||||||
"""
|
|
||||||
Return the tensors for the current run's conv and ssm state.
|
|
||||||
"""
|
|
||||||
cache_tensors, state_indices_tensor = super().current_run_tensors(
|
|
||||||
**kwargs)
|
|
||||||
return MambaCacheParams(cache_tensors[0], cache_tensors[1],
|
|
||||||
state_indices_tensor)
|
|
||||||
|
|
||||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
|
||||||
"""
|
|
||||||
Provide the CUDA graph capture runs with a buffer in adjusted size.
|
|
||||||
The buffer is used to maintain the Mamba Cache during the CUDA graph
|
|
||||||
replay runs.
|
|
||||||
"""
|
|
||||||
return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cuda")
|
|
||||||
@ -1,36 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MinimaxCacheParams:
|
|
||||||
minimax_cache: torch.Tensor = torch.Tensor()
|
|
||||||
state_indices_tensor: torch.Tensor = torch.Tensor()
|
|
||||||
|
|
||||||
def at_layer_idx(self, layer_idx):
|
|
||||||
return MinimaxCacheParams(self.minimax_cache[layer_idx, ...],
|
|
||||||
self.state_indices_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
class MinimaxCacheManager(ConstantSizeCache):
|
|
||||||
|
|
||||||
def __init__(self, dtype, cache_shape):
|
|
||||||
super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1]
|
|
||||||
self._minimax_cache = torch.empty(size=cache_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
device="cuda")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cache(self):
|
|
||||||
return self._minimax_cache
|
|
||||||
|
|
||||||
def _copy_cache(self, from_index: int, to_index: int):
|
|
||||||
assert len(self.cache) > 0
|
|
||||||
for cache_t in self.cache:
|
|
||||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
|
||||||
non_blocking=True)
|
|
||||||
@ -14,7 +14,6 @@ import torch.distributed
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import MiniMaxConfig
|
from transformers import MiniMaxConfig
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
@ -44,7 +43,6 @@ from vllm.model_executor.models.utils import maybe_prefix
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import HasInnerState, IsHybrid
|
from .interfaces import HasInnerState, IsHybrid
|
||||||
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
|
|
||||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||||
|
|
||||||
|
|
||||||
@ -404,7 +402,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
|||||||
def forward(self,
|
def forward(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: Union[list[dict], Optional[torch.Tensor]],
|
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
is_warmup: bool = False,
|
is_warmup: bool = False,
|
||||||
@ -418,7 +415,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
|||||||
hidden_states=layernorm_output,
|
hidden_states=layernorm_output,
|
||||||
output=self_attention_output,
|
output=self_attention_output,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
residual = residual * self.layernorm_attention_alpha
|
residual = residual * self.layernorm_attention_alpha
|
||||||
@ -563,10 +559,6 @@ class MiniMaxText01Model(nn.Module):
|
|||||||
self._dtype = _dummy.dtype
|
self._dtype = _dummy.dtype
|
||||||
del _dummy
|
del _dummy
|
||||||
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
self.minimax_cache = MinimaxCacheManager(
|
|
||||||
dtype=torch.float32, cache_shape=self.cache_shape)
|
|
||||||
|
|
||||||
norm_kwargs = {}
|
norm_kwargs = {}
|
||||||
if hasattr(config, "rms_norm_eps"):
|
if hasattr(config, "rms_norm_eps"):
|
||||||
norm_kwargs["eps"] = config.rms_norm_eps
|
norm_kwargs["eps"] = config.rms_norm_eps
|
||||||
@ -614,25 +606,6 @@ class MiniMaxText01Model(nn.Module):
|
|||||||
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
|
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
if not envs.VLLM_USE_V1 and attn_metadata is None:
|
|
||||||
return None
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
if "request_ids_to_seq_ids" not in kwargs:
|
|
||||||
kwargs["request_ids_to_seq_ids"] = {}
|
|
||||||
if "finished_requests_ids" not in kwargs:
|
|
||||||
kwargs["finished_requests_ids"] = []
|
|
||||||
(
|
|
||||||
minimax_cache_tensors,
|
|
||||||
state_indices_tensor,
|
|
||||||
) = self.minimax_cache.current_run_tensors(**kwargs)
|
|
||||||
if getattr(attn_metadata, "num_prefills", 0) > 0:
|
|
||||||
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
|
|
||||||
state_indices_tensor)
|
|
||||||
else:
|
|
||||||
minimax_cache_params = None
|
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
@ -645,20 +618,10 @@ class MiniMaxText01Model(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
minimax_cache_index = 0
|
|
||||||
|
|
||||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||||
_caches = None
|
|
||||||
if not envs.VLLM_USE_V1 and isinstance(
|
|
||||||
layer.self_attn, MiniMaxText01LinearAttention):
|
|
||||||
current_state_layer = minimax_cache_index
|
|
||||||
_caches = minimax_cache_params.at_layer_idx(
|
|
||||||
current_state_layer)
|
|
||||||
minimax_cache_index += 1
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
)
|
)
|
||||||
@ -1003,13 +966,11 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
use_v1: bool = True,
|
|
||||||
) -> tuple[tuple[int, ...], ...]:
|
) -> tuple[tuple[int, ...], ...]:
|
||||||
"""Calculate shape for MiniMaxText01LinearAttention cache.
|
"""Calculate shape for MiniMaxText01LinearAttention cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vllm_config: vLLM config
|
vllm_config: vLLM config
|
||||||
use_v1: Get shapes for V1 (or V0)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Tuple containing:
|
||||||
|
|||||||
@ -23,21 +23,17 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|
||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
@ -49,14 +45,11 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
||||||
SupportsLoRA, SupportsPP,
|
SupportsLoRA, SupportsPP,
|
||||||
SupportsQuant)
|
SupportsQuant)
|
||||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
|
||||||
MambaCacheParams)
|
|
||||||
from vllm.model_executor.models.utils import (
|
from vllm.model_executor.models.utils import (
|
||||||
AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory,
|
AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory,
|
||||||
make_layers, maybe_prefix)
|
make_layers, maybe_prefix)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs import NemotronHConfig
|
from vllm.transformers_utils.configs import NemotronHConfig
|
||||||
from vllm.utils import LayerBlockType
|
|
||||||
|
|
||||||
|
|
||||||
class NemotronHMLP(nn.Module):
|
class NemotronHMLP(nn.Module):
|
||||||
@ -181,8 +174,6 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if residual is None:
|
if residual is None:
|
||||||
@ -192,7 +183,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
|||||||
hidden_states, residual = self.norm(hidden_states, residual)
|
hidden_states, residual = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
output = torch.empty_like(hidden_states)
|
output = torch.empty_like(hidden_states)
|
||||||
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
self.mixer(hidden_states, output)
|
||||||
return output, residual
|
return output, residual
|
||||||
|
|
||||||
|
|
||||||
@ -370,22 +361,10 @@ class NemotronHModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
|
||||||
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
mamba2_metadata = prepare_mamba2_metadata(
|
|
||||||
chunk_size=self.config.chunk_size,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# v1 get mamba2_metadata from forward_context
|
|
||||||
mamba2_metadata = None
|
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
@ -398,22 +377,11 @@ class NemotronHModel(nn.Module):
|
|||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
num_non_mamba_layers = 0
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
layer_mamba_cache_params = None
|
|
||||||
if isinstance(layer,
|
|
||||||
NemotronHMambaDecoderLayer) and mamba_cache_params:
|
|
||||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
|
||||||
i - num_non_mamba_layers)
|
|
||||||
else:
|
|
||||||
num_non_mamba_layers += 1
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
mamba_cache_params=layer_mamba_cache_params,
|
|
||||||
mamba2_metadata=mamba2_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
@ -508,13 +476,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
use_v1: bool = True,
|
|
||||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vllm_config: vLLM config
|
vllm_config: vLLM config
|
||||||
use_v1: Get shapes for V1 (or V0)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Tuple containing:
|
||||||
@ -533,7 +499,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
head_dim=hf_config.mamba_head_dim,
|
head_dim=hf_config.mamba_head_dim,
|
||||||
state_size=hf_config.ssm_state_size,
|
state_size=hf_config.ssm_state_size,
|
||||||
conv_kernel=hf_config.conv_kernel,
|
conv_kernel=hf_config.conv_kernel,
|
||||||
use_v1=use_v1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
@ -566,8 +531,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||||
prefix=maybe_prefix(prefix, "lm_head"),
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
)
|
)
|
||||||
# Used to track and store by the Mamba cache between steps.
|
|
||||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
config.vocab_size)
|
config.vocab_size)
|
||||||
@ -584,40 +547,11 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
mamba_cache_params = None
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
if not envs.VLLM_USE_V1:
|
inputs_embeds)
|
||||||
if self.mamba_cache is None:
|
|
||||||
|
|
||||||
num_mamba_layers = \
|
|
||||||
self.model_config.get_num_layers_by_block_type(
|
|
||||||
self.vllm_config.parallel_config,
|
|
||||||
LayerBlockType.mamba
|
|
||||||
)
|
|
||||||
mamba_state_shape = \
|
|
||||||
self.get_mamba_state_shape_from_config(
|
|
||||||
self.vllm_config, use_v1=False)
|
|
||||||
mamba_state_dtype = \
|
|
||||||
self.get_mamba_state_dtype_from_config(
|
|
||||||
self.vllm_config)
|
|
||||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
|
||||||
num_mamba_layers,
|
|
||||||
*mamba_state_shape,
|
|
||||||
*mamba_state_dtype)
|
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
|
||||||
|
|
||||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
|
||||||
intermediate_tensors, inputs_embeds)
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
|
||||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
|
||||||
input_buffers, **kwargs)
|
|
||||||
|
|
||||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
|
||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -1,731 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import math
|
|
||||||
from collections.abc import Iterable
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers.activations import ACT2FN
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
|
||||||
from vllm.attention.selector import _Backend
|
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
||||||
MergedColumnParallelLinear,
|
|
||||||
RowParallelLinear)
|
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
|
||||||
causal_conv1d_fn, causal_conv1d_update)
|
|
||||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
|
||||||
selective_scan_fn, selective_state_update)
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
|
||||||
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
|
||||||
SupportsV0Only)
|
|
||||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
|
||||||
MambaCacheParams)
|
|
||||||
from vllm.sequence import IntermediateTensors
|
|
||||||
|
|
||||||
from .utils import make_layers, maybe_prefix
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SwiGLUActivation(nn.Module):
|
|
||||||
|
|
||||||
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
|
||||||
return x1 * nn.functional.silu(x2)
|
|
||||||
|
|
||||||
|
|
||||||
class SambaYMLP(nn.Module):
|
|
||||||
"""Gated Linear Unit.
|
|
||||||
|
|
||||||
Reference:
|
|
||||||
Language Modeling with Gated Convolutional Networks.
|
|
||||||
https://arxiv.org/pdf/1612.08083v3.pdf.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.fc1 = nn.Linear(config.hidden_size,
|
|
||||||
2 * config.intermediate_size,
|
|
||||||
bias=False)
|
|
||||||
self.fc2 = nn.Linear(config.intermediate_size,
|
|
||||||
config.hidden_size,
|
|
||||||
bias=False)
|
|
||||||
|
|
||||||
self.activation_fn = ACT2FN[config.hidden_act]
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
y = self.fc1(hidden_states)
|
|
||||||
gate, y = y.chunk(2, dim=-1)
|
|
||||||
y = y * self.activation_fn(gate)
|
|
||||||
return self.fc2(y)
|
|
||||||
|
|
||||||
|
|
||||||
def get_virtual_engine():
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
|
||||||
return forward_context.virtual_engine
|
|
||||||
|
|
||||||
|
|
||||||
class SambaYAttention(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
config,
|
|
||||||
layer_idx: Optional[int] = None,
|
|
||||||
yoco_cross: bool = False,
|
|
||||||
cache_config: Optional[CacheConfig] = None,
|
|
||||||
prefix: str = ""):
|
|
||||||
super().__init__()
|
|
||||||
if layer_idx is None:
|
|
||||||
logger.warning_once(
|
|
||||||
f"Instantiating {self.__class__.__name__} without passing "
|
|
||||||
"a `layer_idx` is not recommended and will lead to errors "
|
|
||||||
"during the forward call if caching is used. Please make "
|
|
||||||
"sure to provide a `layer_idx` when creating this class.")
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.num_heads = config.num_attention_heads
|
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
|
||||||
self.num_key_value_heads = config.num_key_value_heads
|
|
||||||
self.yoco_cross = yoco_cross
|
|
||||||
|
|
||||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
||||||
raise ValueError("hidden_size must be divisible by num_heads "
|
|
||||||
f"(got `hidden_size`: {self.hidden_size} and "
|
|
||||||
f"`num_heads`: {self.num_heads}).")
|
|
||||||
|
|
||||||
op_size = self.num_heads * self.head_dim + 2 * (
|
|
||||||
self.num_key_value_heads * self.head_dim)
|
|
||||||
self.out_proj = nn.Linear(self.num_heads * self.head_dim,
|
|
||||||
self.hidden_size,
|
|
||||||
bias=True)
|
|
||||||
if yoco_cross:
|
|
||||||
self.Wqkv = nn.Linear(self.hidden_size,
|
|
||||||
self.num_heads * self.head_dim,
|
|
||||||
bias=True)
|
|
||||||
else:
|
|
||||||
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
|
|
||||||
|
|
||||||
# disable sliding window for the second half of the model
|
|
||||||
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
|
|
||||||
sliding_window = config.sliding_window if is_sliding else None
|
|
||||||
|
|
||||||
assert self.num_heads % 2 == 0, 'num_heads should be even'
|
|
||||||
assert self.num_key_value_heads % 2 == 0, 'num_heads should be even'
|
|
||||||
|
|
||||||
self.lambda_init = self.lambda_init_fn(layer_idx)
|
|
||||||
self.lambda_q1 = nn.Parameter(
|
|
||||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
|
||||||
std=0.1))
|
|
||||||
self.lambda_k1 = nn.Parameter(
|
|
||||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
|
||||||
std=0.1))
|
|
||||||
self.lambda_q2 = nn.Parameter(
|
|
||||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
|
||||||
std=0.1))
|
|
||||||
self.lambda_k2 = nn.Parameter(
|
|
||||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
|
||||||
std=0.1))
|
|
||||||
self.subln = nn.RMSNorm(2 * self.head_dim,
|
|
||||||
eps=1e-5,
|
|
||||||
elementwise_affine=True)
|
|
||||||
|
|
||||||
params = {
|
|
||||||
'differential_flash_attention_config': {
|
|
||||||
'lambda_init': self.lambda_init,
|
|
||||||
'lambda_q1': self.lambda_q1,
|
|
||||||
'lambda_k1': self.lambda_k1,
|
|
||||||
'lambda_q2': self.lambda_q2,
|
|
||||||
'lambda_k2': self.lambda_k2,
|
|
||||||
"subln": self.subln,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if yoco_cross:
|
|
||||||
kv_shared_layer_index = config.num_hidden_layers // 2 + 1
|
|
||||||
kv_sharing_target_layer_name = \
|
|
||||||
f"model.layers.{kv_shared_layer_index}.self_attn.attn"
|
|
||||||
else:
|
|
||||||
kv_sharing_target_layer_name = None
|
|
||||||
|
|
||||||
self.attn = Attention(
|
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
self.head_dim**-0.5,
|
|
||||||
num_kv_heads=self.num_key_value_heads,
|
|
||||||
cache_config=cache_config,
|
|
||||||
per_layer_sliding_window=sliding_window,
|
|
||||||
prefix=f"{prefix}.attn",
|
|
||||||
attn_type=AttentionType.DECODER,
|
|
||||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
|
||||||
**params)
|
|
||||||
assert self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN,\
|
|
||||||
"DIFFERENTIAL_FLASH_ATTN required"
|
|
||||||
|
|
||||||
def lambda_init_fn(self, depth):
|
|
||||||
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
):
|
|
||||||
|
|
||||||
if not self.yoco_cross: # need to generate kv-cache
|
|
||||||
qkv = self.Wqkv(hidden_states)
|
|
||||||
q, k, v = qkv.split([
|
|
||||||
self.hidden_size, self.num_key_value_heads * self.head_dim,
|
|
||||||
self.num_key_value_heads * self.head_dim
|
|
||||||
],
|
|
||||||
dim=-1)
|
|
||||||
attn_output = self.attn(q, k, v)
|
|
||||||
else: # reuse the kv cache, full attention
|
|
||||||
q = self.Wqkv(hidden_states)
|
|
||||||
attn_output = self.attn(q, None, None)
|
|
||||||
attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
|
|
||||||
return self.out_proj(attn_output)
|
|
||||||
|
|
||||||
|
|
||||||
class Phi4Mamba(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
d_model,
|
|
||||||
d_state=16,
|
|
||||||
d_conv=4,
|
|
||||||
expand=2,
|
|
||||||
dt_rank="auto",
|
|
||||||
dt_min=0.001,
|
|
||||||
dt_max=0.1,
|
|
||||||
dt_init="random", # difference
|
|
||||||
dt_scale=1.0, # difference
|
|
||||||
dt_init_floor=1e-4,
|
|
||||||
conv_bias=True,
|
|
||||||
bias=False,
|
|
||||||
use_fast_path=True, # Fused kernel options
|
|
||||||
layer_idx=None,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
yoco_cross=False,
|
|
||||||
yoco_kv=False,
|
|
||||||
):
|
|
||||||
factory_kwargs = {"params_dtype": dtype} # difference
|
|
||||||
super().__init__()
|
|
||||||
self.yoco_cross = yoco_cross
|
|
||||||
self.yoco_kv = yoco_kv
|
|
||||||
self.d_model = d_model
|
|
||||||
self.d_state = d_state
|
|
||||||
self.d_conv = d_conv
|
|
||||||
self.expand = expand
|
|
||||||
self.d_inner = int(self.expand * self.d_model)
|
|
||||||
self.dt_rank = math.ceil(self.d_model /
|
|
||||||
16) if dt_rank == "auto" else dt_rank
|
|
||||||
self.use_fast_path = use_fast_path
|
|
||||||
self.layer_idx = layer_idx
|
|
||||||
self.swiGluActivation = SwiGLUActivation()
|
|
||||||
if self.yoco_cross:
|
|
||||||
self.in_proj = MergedColumnParallelLinear(self.d_model,
|
|
||||||
[self.d_inner],
|
|
||||||
bias=bias,
|
|
||||||
**factory_kwargs)
|
|
||||||
self.out_proj = RowParallelLinear(self.d_inner,
|
|
||||||
self.d_model,
|
|
||||||
bias=bias,
|
|
||||||
**factory_kwargs)
|
|
||||||
return
|
|
||||||
self.conv1d = ColumnParallelLinear(
|
|
||||||
input_size=d_conv,
|
|
||||||
output_size=self.d_inner,
|
|
||||||
bias=conv_bias,
|
|
||||||
params_dtype=dtype,
|
|
||||||
)
|
|
||||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
|
||||||
# Can't do this in `weight_loader` since it already exists in
|
|
||||||
# `ColumnParallelLinear` and `set_weight_attrs`
|
|
||||||
# doesn't allow to override it
|
|
||||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
|
||||||
|
|
||||||
self.in_proj = MergedColumnParallelLinear(
|
|
||||||
self.d_model,
|
|
||||||
[self.d_inner] * 2,
|
|
||||||
bias=bias,
|
|
||||||
params_dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# selective projection used to make dt, B and C input dependent
|
|
||||||
self.x_proj = RowParallelLinear(
|
|
||||||
self.d_inner,
|
|
||||||
self.dt_rank + self.d_state * 2,
|
|
||||||
bias=False,
|
|
||||||
params_dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# time step projection (discretization) -
|
|
||||||
# In the forward we need to apply dt_proj without the bias,
|
|
||||||
# as the bias is added in the selective scan kernel.
|
|
||||||
self.dt_proj = ColumnParallelLinear(
|
|
||||||
self.dt_rank,
|
|
||||||
self.d_inner,
|
|
||||||
bias=True,
|
|
||||||
skip_bias_add=True,
|
|
||||||
params_dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# # D "skip" parameter
|
|
||||||
# self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32
|
|
||||||
self.A = nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
self.d_inner,
|
|
||||||
self.d_state,
|
|
||||||
dtype=torch.float32,
|
|
||||||
))
|
|
||||||
self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32))
|
|
||||||
|
|
||||||
self.out_proj = RowParallelLinear(
|
|
||||||
self.d_inner,
|
|
||||||
self.d_model,
|
|
||||||
bias=bias,
|
|
||||||
input_is_parallel=True,
|
|
||||||
params_dtype=dtype,
|
|
||||||
)
|
|
||||||
self.activation = "silu"
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
yoco_key_values=None) -> torch.Tensor:
|
|
||||||
|
|
||||||
if self.yoco_cross:
|
|
||||||
out = self.in_proj(hidden_states)[0]
|
|
||||||
out = self.swiGluActivation(yoco_key_values, out)
|
|
||||||
out = self.out_proj(out)
|
|
||||||
return out[0], yoco_key_values
|
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
|
||||||
# projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
|
||||||
projected_states = self.in_proj(
|
|
||||||
hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1)
|
|
||||||
hidden_states, gate = projected_states.chunk(2, dim=-2)
|
|
||||||
|
|
||||||
# 2. Convolution sequence transformation
|
|
||||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
|
||||||
self.conv1d.weight.size(2))
|
|
||||||
|
|
||||||
if attn_metadata.query_start_loc is not None \
|
|
||||||
and attn_metadata.context_lens_tensor is not None:
|
|
||||||
# |---------- N-1 iteration --------|
|
|
||||||
# |---------------- N iteration ---------------------|
|
|
||||||
# |- tokenA -|......................|-- newTokens ---|
|
|
||||||
# |---------- context_len ----------|
|
|
||||||
# |-------------------- seq_len ---------------------|
|
|
||||||
# |-- query_len ---|
|
|
||||||
hidden_states = causal_conv1d_fn(
|
|
||||||
hidden_states,
|
|
||||||
conv_weights,
|
|
||||||
self.conv1d.bias,
|
|
||||||
activation=self.activation,
|
|
||||||
conv_states=mamba_cache_params.conv_state,
|
|
||||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
|
||||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
|
||||||
query_start_loc=attn_metadata.query_start_loc)
|
|
||||||
else:
|
|
||||||
hidden_states = causal_conv1d_update(
|
|
||||||
hidden_states.transpose(0, 1),
|
|
||||||
mamba_cache_params.conv_state,
|
|
||||||
conv_weights,
|
|
||||||
self.conv1d.bias,
|
|
||||||
self.activation,
|
|
||||||
conv_state_indices=mamba_cache_params.state_indices_tensor)
|
|
||||||
hidden_states = hidden_states.transpose(0, 1)
|
|
||||||
|
|
||||||
# 3. State Space Model sequence transformation
|
|
||||||
# 3.a. input varying initialization of time_step, B and C
|
|
||||||
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
|
|
||||||
|
|
||||||
time_step, B, C = torch.split(
|
|
||||||
ssm_parameters,
|
|
||||||
[self.dt_rank, self.d_state, self.d_state],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
|
|
||||||
|
|
||||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
|
||||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
|
||||||
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
|
|
||||||
self.dt_proj, "bias") else None)
|
|
||||||
|
|
||||||
if attn_metadata.query_start_loc is not None \
|
|
||||||
and attn_metadata.context_lens_tensor is not None:
|
|
||||||
scan_outputs = selective_scan_fn(
|
|
||||||
hidden_states,
|
|
||||||
mamba_cache_params.ssm_state,
|
|
||||||
discrete_time_step,
|
|
||||||
self.A,
|
|
||||||
B.transpose(-2, -1),
|
|
||||||
C.transpose(-2, -1),
|
|
||||||
self.D.float(),
|
|
||||||
# z,
|
|
||||||
None if self.yoco_kv else gate,
|
|
||||||
time_proj_bias,
|
|
||||||
delta_softplus=True,
|
|
||||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
|
||||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
|
||||||
query_start_loc=attn_metadata.query_start_loc)
|
|
||||||
else:
|
|
||||||
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
|
|
||||||
selective_state_update(
|
|
||||||
mamba_cache_params.ssm_state,
|
|
||||||
hidden_states.transpose(0, 1),
|
|
||||||
discrete_time_step.transpose(0, 1),
|
|
||||||
self.A,
|
|
||||||
B,
|
|
||||||
C,
|
|
||||||
self.D,
|
|
||||||
# z
|
|
||||||
# gate.transpose(0, 1),
|
|
||||||
None if self.yoco_kv else gate.transpose(0, 1),
|
|
||||||
time_proj_bias,
|
|
||||||
dt_softplus=True,
|
|
||||||
state_batch_indices=mamba_cache_params.state_indices_tensor,
|
|
||||||
out=scan_outputs)
|
|
||||||
scan_outputs = scan_outputs.transpose(0, 1)
|
|
||||||
|
|
||||||
# 4. Final linear projection
|
|
||||||
if self.yoco_kv:
|
|
||||||
# gate = gate.transpose(-1,-2).contiguous()
|
|
||||||
yoco_key_values = scan_outputs.transpose(-2, -1)
|
|
||||||
scan_outputs = self.swiGluActivation(scan_outputs, gate)
|
|
||||||
|
|
||||||
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
|
|
||||||
-1))[0]
|
|
||||||
|
|
||||||
return contextualized_states, yoco_key_values
|
|
||||||
|
|
||||||
|
|
||||||
class SambaYDecoderLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
layer_idx,
|
|
||||||
cache_config,
|
|
||||||
prefix: str = "",
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.layer_idx = layer_idx
|
|
||||||
|
|
||||||
self.mlp = SambaYMLP(config)
|
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
|
||||||
eps=config.layer_norm_eps)
|
|
||||||
|
|
||||||
self.yoco_mb = False
|
|
||||||
self.yoco_cross = False
|
|
||||||
if layer_idx >= config.num_hidden_layers // 2:
|
|
||||||
self.yoco_mb = True
|
|
||||||
self.yoco_cross = (layer_idx
|
|
||||||
>= (config.num_hidden_layers // 2 + 2))
|
|
||||||
self.use_mamba = config.mb_per_layer > 0 and \
|
|
||||||
layer_idx % config.mb_per_layer == 0
|
|
||||||
if self.use_mamba:
|
|
||||||
factory_kwargs = {"dtype": None}
|
|
||||||
self.attn = Phi4Mamba(config.hidden_size,
|
|
||||||
layer_idx=layer_idx,
|
|
||||||
yoco_cross=self.yoco_cross,
|
|
||||||
yoco_kv=self.yoco_mb,
|
|
||||||
**factory_kwargs)
|
|
||||||
else:
|
|
||||||
self.attn = SambaYAttention(config,
|
|
||||||
layer_idx=layer_idx,
|
|
||||||
yoco_cross=self.yoco_cross,
|
|
||||||
cache_config=cache_config,
|
|
||||||
prefix=f"{prefix}.self_attn")
|
|
||||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
|
||||||
eps=config.layer_norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
ssm_output: Optional[torch.LongTensor] = None,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
if self.use_mamba:
|
|
||||||
assert mamba_cache_params is not None
|
|
||||||
else:
|
|
||||||
assert mamba_cache_params is None
|
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.input_layernorm(
|
|
||||||
hidden_states.to(dtype=self.input_layernorm.weight.dtype))
|
|
||||||
|
|
||||||
if self.use_mamba:
|
|
||||||
attn_outputs, ssm_output = self.attn(hidden_states,
|
|
||||||
attn_metadata,
|
|
||||||
mamba_cache_params,
|
|
||||||
yoco_key_values=ssm_output)
|
|
||||||
residual = residual.to(torch.float32)
|
|
||||||
else:
|
|
||||||
attn_outputs = self.attn(hidden_states, )
|
|
||||||
hidden_states = residual + attn_outputs
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.post_attention_layernorm(
|
|
||||||
hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype))
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
return hidden_states, ssm_output
|
|
||||||
|
|
||||||
|
|
||||||
class SambaYModel(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
config,
|
|
||||||
cache_config=None,
|
|
||||||
quant_config=None,
|
|
||||||
lora_config=None,
|
|
||||||
prefix: str = "") -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
|
||||||
self.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
org_num_embeddings=config.vocab_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pipeline parallel is not supported since the second half of
|
|
||||||
# the layers share the kv cache.
|
|
||||||
if get_pp_group().world_size != 1:
|
|
||||||
raise ValueError("Pipeline Parallel not supported")
|
|
||||||
|
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
||||||
config.num_hidden_layers,
|
|
||||||
lambda prefix: SambaYDecoderLayer(config,
|
|
||||||
int(prefix.split('.')[-1]),
|
|
||||||
cache_config,
|
|
||||||
prefix=prefix),
|
|
||||||
prefix=f"{prefix}.layers")
|
|
||||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
|
||||||
eps=config.layer_norm_eps)
|
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.Tensor],
|
|
||||||
positions: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
|
||||||
if inputs_embeds is not None:
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
else:
|
|
||||||
hidden_states = self.get_input_embeddings(input_ids)
|
|
||||||
else:
|
|
||||||
assert intermediate_tensors is not None
|
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
|
||||||
|
|
||||||
mamba_state_idx = 0
|
|
||||||
ssm_output = None
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
|
||||||
layer = self.layers[i]
|
|
||||||
if i == self.config.num_hidden_layers // 2 + 2:
|
|
||||||
# profile run
|
|
||||||
kv_cache_idx = self.config.num_hidden_layers // 2 + 1
|
|
||||||
cache_layer = self.layers[kv_cache_idx]
|
|
||||||
kv_cache = cache_layer.attn.attn.kv_cache
|
|
||||||
if kv_cache[0].numel() == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Starting from this layer, we do not need to calculate
|
|
||||||
# the kv cache since we reuse the kv cache from last layer.
|
|
||||||
# If in prefill phase, we can <s>prune></s> truncate
|
|
||||||
# the hidden state to save computation cost.
|
|
||||||
if attn_metadata.prefill_metadata and not envs.VLLM_USE_V1:
|
|
||||||
selected_token_indices = torch.cumsum(
|
|
||||||
attn_metadata.seq_lens_tensor, dim=0) - 1
|
|
||||||
hidden_states = hidden_states.index_select(
|
|
||||||
0, selected_token_indices)
|
|
||||||
ssm_output = ssm_output.index_select(
|
|
||||||
0, selected_token_indices)
|
|
||||||
|
|
||||||
if layer.use_mamba:
|
|
||||||
if i < self.config.num_hidden_layers // 2 or \
|
|
||||||
not layer.yoco_cross:
|
|
||||||
mamba_cache = mamba_cache_params.at_layer_idx(
|
|
||||||
mamba_state_idx)
|
|
||||||
mamba_state_idx += 1
|
|
||||||
else:
|
|
||||||
mamba_cache = mamba_cache_params.at_layer_idx(
|
|
||||||
mamba_state_idx - 1)
|
|
||||||
|
|
||||||
hidden_states, ssm_output = layer(hidden_states,
|
|
||||||
positions,
|
|
||||||
attn_metadata,
|
|
||||||
mamba_cache,
|
|
||||||
ssm_output=ssm_output)
|
|
||||||
else:
|
|
||||||
hidden_states, ssm_output = layer(
|
|
||||||
hidden_states,
|
|
||||||
positions,
|
|
||||||
attn_metadata,
|
|
||||||
None, # mamba_cache_params
|
|
||||||
ssm_output=ssm_output)
|
|
||||||
|
|
||||||
hidden_states = self.final_layernorm(
|
|
||||||
hidden_states.to(dtype=self.final_layernorm.weight.dtype))
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
lora_config = vllm_config.lora_config
|
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
scheduler_config = vllm_config.scheduler_config
|
|
||||||
self.compilation_config = vllm_config.compilation_config
|
|
||||||
self.vllm_config = vllm_config
|
|
||||||
# Prefix caching and chunked prefill is not supported for this model.
|
|
||||||
assert not cache_config.enable_prefix_caching, \
|
|
||||||
"Phi4flash currently does not support prefix caching"
|
|
||||||
assert not scheduler_config.chunked_prefill_enabled, \
|
|
||||||
"Phi4Flash currently does not support prefix caching"
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.model_config = vllm_config.model_config
|
|
||||||
self.scheduler_config = scheduler_config
|
|
||||||
self.model = SambaYModel(config,
|
|
||||||
cache_config=cache_config,
|
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
|
||||||
self.unpadded_vocab_size = config.vocab_size
|
|
||||||
if lora_config:
|
|
||||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
|
||||||
self.lm_head = ParallelLMHead(
|
|
||||||
self.unpadded_vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
org_num_embeddings=config.vocab_size,
|
|
||||||
padding_size=(
|
|
||||||
DEFAULT_VOCAB_PADDING_SIZE
|
|
||||||
# We need bigger padding if using lora for kernel
|
|
||||||
# compatibility
|
|
||||||
if not lora_config else lora_config.lora_vocab_padding_size),
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=maybe_prefix(prefix, "lm_head"),
|
|
||||||
)
|
|
||||||
self.embedding_bias = None
|
|
||||||
# Used to track and store by the Mamba cache between steps.
|
|
||||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
|
||||||
config.vocab_size,
|
|
||||||
logits_as_input=False)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
if self.mamba_cache is None:
|
|
||||||
num_mamba_layers = self.config.num_hidden_layers \
|
|
||||||
// 2 // self.config.mb_per_layer + 1
|
|
||||||
self.mamba_cache = MambaCacheManager(
|
|
||||||
self.vllm_config,
|
|
||||||
num_mamba_layers,
|
|
||||||
*self._get_mamba_cache_shape(),
|
|
||||||
self.lm_head.weight.dtype,
|
|
||||||
self.lm_head.weight.dtype,
|
|
||||||
)
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
|
||||||
|
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
|
||||||
# input_ids and hidden_states isn't a one-to-one mapping in prefill
|
|
||||||
# stage due to YOCO optimization.
|
|
||||||
hidden_states = self.model(input_ids, positions, attn_metadata,
|
|
||||||
mamba_cache_params, intermediate_tensors,
|
|
||||||
inputs_embeds)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def _get_mamba_cache_shape(
|
|
||||||
self
|
|
||||||
) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]:
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
hidden_size = self.config.hidden_size
|
|
||||||
mamba_expand = self.config.mamba_expand # 2
|
|
||||||
mamba_d_conv = self.config.mamba_d_conv # 4
|
|
||||||
mamba_d_state = self.config.mamba_d_state # 16
|
|
||||||
conv_state_shape = (
|
|
||||||
mamba_expand * hidden_size // world_size,
|
|
||||||
mamba_d_conv - 1,
|
|
||||||
)
|
|
||||||
temporal_state_shape = (
|
|
||||||
mamba_expand * hidden_size // world_size,
|
|
||||||
mamba_d_state,
|
|
||||||
)
|
|
||||||
return conv_state_shape, temporal_state_shape
|
|
||||||
|
|
||||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
|
||||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
|
||||||
input_buffers, **kwargs)
|
|
||||||
|
|
||||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
|
||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
|
||||||
|
|
||||||
def compute_logits(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
processed_logits = self.logits_processor(
|
|
||||||
self.lm_head,
|
|
||||||
hidden_states,
|
|
||||||
self.embedding_bias,
|
|
||||||
)
|
|
||||||
return processed_logits
|
|
||||||
|
|
||||||
def load_weights(
|
|
||||||
self,
|
|
||||||
weights: Iterable[tuple[str, torch.Tensor]],
|
|
||||||
):
|
|
||||||
weights = {name: weight for name, weight in weights}
|
|
||||||
adjusted_weights = {}
|
|
||||||
for name, weight in weights.items():
|
|
||||||
if "A_log" in name:
|
|
||||||
name = name.replace("A_log", "A")
|
|
||||||
weight = -torch.exp(weight.float())
|
|
||||||
if "inner_cross_attn." in name:
|
|
||||||
name = name.replace("inner_cross_attn.", "")
|
|
||||||
adjusted_weights[name] = weight
|
|
||||||
adjusted_weights["lm_head.weight"] = weights[
|
|
||||||
"model.embed_tokens.weight"]
|
|
||||||
loaded_params: set[str] = set()
|
|
||||||
for name, param in self.named_parameters():
|
|
||||||
weight = adjusted_weights.get(name)
|
|
||||||
if weight is not None and weight.shape != param.shape:
|
|
||||||
logger.warning("Shape mismatch: %s %s %s", name, weight.shape,
|
|
||||||
param.shape)
|
|
||||||
loaded_params.add(name)
|
|
||||||
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
|
|
||||||
strict=False)
|
|
||||||
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
|
|
||||||
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
|
|
||||||
return loaded_params
|
|
||||||
@ -12,7 +12,6 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
@ -29,8 +28,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|
||||||
Mamba2Metadata, prepare_mamba2_metadata, update_metadata)
|
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
@ -47,15 +44,13 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
||||||
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
||||||
SupportsPP)
|
SupportsPP)
|
||||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
|
||||||
MambaCacheParams)
|
|
||||||
from vllm.model_executor.models.utils import (
|
from vllm.model_executor.models.utils import (
|
||||||
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
|
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
|
||||||
make_layers, maybe_prefix)
|
make_layers, maybe_prefix)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import LayerBlockType, direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||||
|
|
||||||
|
|
||||||
@ -194,16 +189,12 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
|
|
||||||
self.chunk_size = self.config.mamba_chunk_size
|
self.chunk_size = self.config.mamba_chunk_size
|
||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
compilation_config = get_current_vllm_config().compilation_config
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
if prefix in compilation_config.static_forward_context:
|
if prefix in compilation_config.static_forward_context:
|
||||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
compilation_config.static_forward_context[prefix] = self
|
compilation_config.static_forward_context[prefix] = self
|
||||||
# The outer list is for v0 PP virtual engine. Though this code path
|
# The tuple is (conv_state, ssm_state)
|
||||||
# only runs for v1, we have to do this to unify with the interface
|
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||||
# of Attention + v0 PP.
|
|
||||||
# The inner tuple is (conv_state, ssm_state)
|
|
||||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
|
||||||
assert self.chunk_size != -1, "chunk_size must be set for v1"
|
assert self.chunk_size != -1, "chunk_size must be set for v1"
|
||||||
|
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
@ -227,8 +218,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
@ -237,14 +226,8 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
|
|
||||||
mamba2_metadata)
|
|
||||||
else:
|
|
||||||
torch.ops.vllm.plamo2_mamba_mixer(
|
torch.ops.vllm.plamo2_mamba_mixer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
output,
|
output,
|
||||||
@ -255,41 +238,31 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
# attn_metadata contains metadata necessary for the mamba2 triton
|
||||||
# kernels to operate in continuous batching and in chunked prefill
|
# kernels to operate in continuous batching and in chunked prefill
|
||||||
# modes; they are computed at top-level model forward since they
|
# modes; they are computed at top-level model forward since they
|
||||||
# stay the same and reused for all mamba layers in the same iteration
|
# stay the same and reused for all mamba layers in the same iteration
|
||||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
if attn_metadata is not None:
|
if attn_metadata is not None:
|
||||||
assert isinstance(attn_metadata, dict)
|
assert isinstance(attn_metadata, dict)
|
||||||
attn_metadata = attn_metadata[self.prefix]
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
mamba2_metadata = attn_metadata
|
|
||||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||||
ssm_state = self_kv_cache[1]
|
ssm_state = self_kv_cache[1]
|
||||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
else:
|
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||||
conv_state = mamba_cache_params.conv_state
|
prep_initial_states = attn_metadata.prep_initial_states
|
||||||
ssm_state = mamba_cache_params.ssm_state
|
chunk_size = attn_metadata.chunk_size
|
||||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
seq_idx_p = attn_metadata.seq_idx_p
|
||||||
|
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||||
# Common members between V1 metadata and V0 metadata
|
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||||
if mamba2_metadata is not None:
|
|
||||||
has_initial_states_p = mamba2_metadata.has_initial_states_p
|
|
||||||
prep_initial_states = mamba2_metadata.prep_initial_states
|
|
||||||
chunk_size = mamba2_metadata.chunk_size
|
|
||||||
seq_idx_p = mamba2_metadata.seq_idx_p
|
|
||||||
chunk_indices_p = mamba2_metadata.chunk_indices_p
|
|
||||||
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
|
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states = self.in_proj(hidden_states)
|
projected_states = self.in_proj(hidden_states)
|
||||||
@ -299,8 +272,8 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||||
self.conv1d.weight.size(2))
|
self.conv1d.weight.size(2))
|
||||||
|
|
||||||
if envs.VLLM_USE_V1 and attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# V1 profile run
|
# profile run
|
||||||
hidden_states = (hidden_states.transpose(0, 1).clone().transpose(
|
hidden_states = (hidden_states.transpose(0, 1).clone().transpose(
|
||||||
0, 1)).contiguous()
|
0, 1)).contiguous()
|
||||||
output[:] = self.out_proj(hidden_states)
|
output[:] = self.out_proj(hidden_states)
|
||||||
@ -316,7 +289,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||||
# Separate prefill and decode by splitting varlen input
|
# Separate prefill and decode by splitting varlen input
|
||||||
# Split along token dimension
|
# Split along token dimension
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
hidden_states_d, hidden_states_p = torch.split(
|
hidden_states_d, hidden_states_p = torch.split(
|
||||||
hidden_states[:num_actual_tokens],
|
hidden_states[:num_actual_tokens],
|
||||||
[num_decodes, num_prefill_tokens],
|
[num_decodes, num_prefill_tokens],
|
||||||
@ -334,24 +306,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
query_start_loc_p = (
|
query_start_loc_p = (
|
||||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||||
num_decodes if has_prefill else None)
|
num_decodes if has_prefill else None)
|
||||||
else:
|
|
||||||
hidden_states_p, hidden_states_d = torch.split(
|
|
||||||
hidden_states,
|
|
||||||
[num_prefill_tokens, num_decodes],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
gate_p, gate_d = torch.split(gate,
|
|
||||||
[num_prefill_tokens, num_decodes],
|
|
||||||
dim=0)
|
|
||||||
# Split along batch dimension
|
|
||||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
|
||||||
state_indices_tensor,
|
|
||||||
[num_prefills, num_decodes],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
|
|
||||||
1]
|
|
||||||
if has_prefill else None)
|
|
||||||
|
|
||||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||||
# and decode outputs
|
# and decode outputs
|
||||||
@ -363,18 +317,11 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
)
|
)
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
||||||
preallocated_ssm_out,
|
preallocated_ssm_out,
|
||||||
[num_decodes, num_prefill_tokens],
|
[num_decodes, num_prefill_tokens],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
|
||||||
preallocated_ssm_out,
|
|
||||||
[num_prefill_tokens, num_decodes],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process prefill requests
|
# Process prefill requests
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
@ -383,9 +330,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
# pointed to by "state_indices_tensor"
|
# pointed to by "state_indices_tensor"
|
||||||
x = hidden_states_p.transpose(
|
x = hidden_states_p.transpose(
|
||||||
0, 1) # this is the form that causal-conv see
|
0, 1) # this is the form that causal-conv see
|
||||||
if mamba2_metadata.cu_seqlen is None:
|
|
||||||
mamba2_metadata = update_metadata(x, query_start_loc_p,
|
|
||||||
mamba2_metadata)
|
|
||||||
hidden_states_p = causal_conv1d_fn(
|
hidden_states_p = causal_conv1d_fn(
|
||||||
x,
|
x,
|
||||||
conv_weights,
|
conv_weights,
|
||||||
@ -394,7 +338,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
conv_states=conv_state,
|
conv_states=conv_state,
|
||||||
has_initial_state=has_initial_states_p,
|
has_initial_state=has_initial_states_p,
|
||||||
cache_indices=state_indices_tensor_p,
|
cache_indices=state_indices_tensor_p,
|
||||||
metadata=mamba2_metadata,
|
metadata=attn_metadata,
|
||||||
query_start_loc=query_start_loc_p)
|
query_start_loc=query_start_loc_p)
|
||||||
hidden_states_p = hidden_states_p.transpose(0, 1)
|
hidden_states_p = hidden_states_p.transpose(0, 1)
|
||||||
hidden_states_p = hidden_states_p[:num_prefill_tokens]
|
hidden_states_p = hidden_states_p[:num_prefill_tokens]
|
||||||
@ -470,7 +414,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
-1, self.num_heads // self.tp_size, self.head_dim)
|
-1, self.num_heads // self.tp_size, self.head_dim)
|
||||||
|
|
||||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||||
# - mamba_cache_params.ssm_state's slots will be selected
|
# - ssm_state's slots will be selected
|
||||||
# using state_indices_tensor_d
|
# using state_indices_tensor_d
|
||||||
|
|
||||||
# NOTE: final output is an in-place update of out tensor
|
# NOTE: final output is an in-place update of out tensor
|
||||||
@ -530,10 +474,7 @@ def plamo2_mamba_mixer(
|
|||||||
) -> None:
|
) -> None:
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
self = forward_context.no_compile_layers[layer_name]
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
self.forward_cuda(hidden_states=hidden_states,
|
self.forward_cuda(hidden_states=hidden_states, output=output)
|
||||||
output=output,
|
|
||||||
mamba_cache_params=None,
|
|
||||||
mamba2_metadata=None)
|
|
||||||
|
|
||||||
|
|
||||||
def plamo2_mamba_mixer_fake(
|
def plamo2_mamba_mixer_fake(
|
||||||
@ -731,8 +672,6 @@ class Plamo2DecoderLayer(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if residual is None:
|
if residual is None:
|
||||||
@ -747,8 +686,6 @@ class Plamo2DecoderLayer(nn.Module):
|
|||||||
output = torch.empty_like(hidden_states)
|
output = torch.empty_like(hidden_states)
|
||||||
mixer_kwargs = {
|
mixer_kwargs = {
|
||||||
"output": output,
|
"output": output,
|
||||||
"mamba_cache_params": mamba_cache_params,
|
|
||||||
"mamba2_metadata": mamba2_metadata,
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
mixer_kwargs = {
|
mixer_kwargs = {
|
||||||
@ -790,23 +727,12 @@ class Plamo2Decoder(torch.nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
mamba_cache_index = 0
|
|
||||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||||
layer_mamba_cache_params = None
|
|
||||||
if layer.is_mamba and mamba_cache_params is not None:
|
|
||||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
|
||||||
mamba_cache_index)
|
|
||||||
mamba_cache_index += 1
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
mamba_cache_params=layer_mamba_cache_params,
|
|
||||||
mamba2_metadata=mamba2_metadata,
|
|
||||||
)
|
)
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
@ -844,7 +770,6 @@ class Plamo2Model(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -859,23 +784,10 @@ class Plamo2Model(torch.nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
attn_metadata: AttentionMetadata = get_forward_context(
|
|
||||||
).attn_metadata
|
|
||||||
mamba2_metadata = prepare_mamba2_metadata(
|
|
||||||
chunk_size=self.config.mamba_chunk_size,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# v1 get mamba2_metadata from forward_context
|
|
||||||
mamba2_metadata = None
|
|
||||||
|
|
||||||
hidden_states, residual = self.layers(
|
hidden_states, residual = self.layers(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
mamba_cache_params=mamba_cache_params,
|
|
||||||
mamba2_metadata=mamba2_metadata,
|
|
||||||
)
|
)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
@ -925,9 +837,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
|
|||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||||
|
|
||||||
# Used to track and store by the Mamba cache between steps.
|
|
||||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
self.config.vocab_size)
|
self.config.vocab_size)
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
@ -942,39 +851,11 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
|
|||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
if self.mamba_cache is None:
|
|
||||||
num_mamba_layers = (
|
|
||||||
self.model_config.get_num_layers_by_block_type(
|
|
||||||
self.vllm_config.parallel_config,
|
|
||||||
LayerBlockType.mamba))
|
|
||||||
|
|
||||||
mamba_state_shape = self.get_mamba_state_shape_from_config(
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||||
self.vllm_config, use_v1=False)
|
inputs_embeds)
|
||||||
mamba_state_dtype = \
|
|
||||||
self.get_mamba_state_dtype_from_config(
|
|
||||||
self.vllm_config)
|
|
||||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
|
||||||
num_mamba_layers,
|
|
||||||
*mamba_state_shape,
|
|
||||||
*mamba_state_dtype)
|
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
|
||||||
else:
|
|
||||||
# NOTE: mamba_cache_params is not needed for v1
|
|
||||||
mamba_cache_params = None
|
|
||||||
|
|
||||||
hidden_states = self.model(input_ids, positions, mamba_cache_params,
|
|
||||||
intermediate_tensors, inputs_embeds)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
|
||||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
|
||||||
input_buffers, **kwargs)
|
|
||||||
|
|
||||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
|
||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_mamba_state_dtype_from_config(
|
def get_mamba_state_dtype_from_config(
|
||||||
cls,
|
cls,
|
||||||
@ -991,12 +872,10 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
|
|||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
use_v1: bool = True,
|
|
||||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||||
Args:
|
Args:
|
||||||
vllm_config: vLLM config
|
vllm_config: vLLM config
|
||||||
use_v1: Get shapes for V1 (or V0)
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Tuple containing:
|
||||||
- conv_state_shape: Shape for convolutional state cache
|
- conv_state_shape: Shape for convolutional state cache
|
||||||
@ -1015,7 +894,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
|
|||||||
head_dim=hf_config.hidden_size_per_head,
|
head_dim=hf_config.hidden_size_per_head,
|
||||||
state_size=hf_config.mamba_d_state,
|
state_size=hf_config.mamba_d_state,
|
||||||
conv_kernel=hf_config.mamba_d_conv,
|
conv_kernel=hf_config.mamba_d_conv,
|
||||||
use_v1=use_v1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from einops import rearrange
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention import Attention, AttentionBackend, AttentionMetadata
|
from vllm.attention import Attention, AttentionBackend, AttentionMetadata
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
|
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
|
||||||
@ -35,7 +34,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
|
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||||
mamba_v2_sharded_weight_loader)
|
mamba_v2_sharded_weight_loader)
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
@ -51,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader, sharded_weight_loader)
|
default_weight_loader, sharded_weight_loader)
|
||||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
|
||||||
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
|
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -198,14 +195,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
|
|
||||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||||
return MambaStateShapeCalculator.gated_delta_net_state_shape(
|
return MambaStateShapeCalculator.gated_delta_net_state_shape(
|
||||||
self.tp_size,
|
self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim,
|
||||||
self.num_k_heads,
|
self.head_v_dim, self.conv_kernel_size, self.num_spec)
|
||||||
self.num_v_heads,
|
|
||||||
self.head_k_dim,
|
|
||||||
self.head_v_dim,
|
|
||||||
self.conv_kernel_size,
|
|
||||||
self.num_spec,
|
|
||||||
use_v1=True)
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -394,7 +385,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
cache_params: Optional[MambaCacheParams] = None,
|
|
||||||
):
|
):
|
||||||
return torch.ops.vllm.gdn_attention(
|
return torch.ops.vllm.gdn_attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -416,7 +406,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
|
|
||||||
assert isinstance(attn_metadata, dict)
|
assert isinstance(attn_metadata, dict)
|
||||||
attn_metadata = attn_metadata[self.prefix]
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
conv_metadata = attn_metadata
|
|
||||||
assert isinstance(attn_metadata, GDNAttentionMetadata)
|
assert isinstance(attn_metadata, GDNAttentionMetadata)
|
||||||
has_initial_state = attn_metadata.has_initial_state
|
has_initial_state = attn_metadata.has_initial_state
|
||||||
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
||||||
@ -479,12 +468,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
# 2.2: process the remaining part
|
# 2.2: process the remaining part
|
||||||
if attn_metadata.num_prefills > 0:
|
if attn_metadata.num_prefills > 0:
|
||||||
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
|
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
|
||||||
if conv_metadata.cu_seqlen is None:
|
|
||||||
conv_metadata = update_metadata(mixed_qkv_non_spec_T,
|
|
||||||
non_spec_query_start_loc,
|
|
||||||
conv_metadata)
|
|
||||||
# - "cache_indices" updates the conv_state cache in positions
|
# - "cache_indices" updates the conv_state cache in positions
|
||||||
# pointed to by "mamba_cache_params.state_indices_tensor"
|
# pointed to by "state_indices_tensor"
|
||||||
mixed_qkv_non_spec = causal_conv1d_fn(
|
mixed_qkv_non_spec = causal_conv1d_fn(
|
||||||
mixed_qkv_non_spec_T,
|
mixed_qkv_non_spec_T,
|
||||||
conv_weights,
|
conv_weights,
|
||||||
@ -494,7 +479,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
has_initial_state=has_initial_state,
|
has_initial_state=has_initial_state,
|
||||||
cache_indices=non_spec_state_indices_tensor,
|
cache_indices=non_spec_state_indices_tensor,
|
||||||
query_start_loc=non_spec_query_start_loc,
|
query_start_loc=non_spec_query_start_loc,
|
||||||
metadata=conv_metadata,
|
metadata=attn_metadata,
|
||||||
).transpose(0, 1)
|
).transpose(0, 1)
|
||||||
elif attn_metadata.num_decodes > 0:
|
elif attn_metadata.num_decodes > 0:
|
||||||
mixed_qkv_non_spec = causal_conv1d_update(
|
mixed_qkv_non_spec = causal_conv1d_update(
|
||||||
@ -1075,7 +1060,6 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
assert not cache_config.enable_prefix_caching, \
|
assert not cache_config.enable_prefix_caching, \
|
||||||
"Qwen3Next currently does not support prefix caching"
|
"Qwen3Next currently does not support prefix caching"
|
||||||
assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1"
|
|
||||||
self.quant_config = vllm_config.quant_config
|
self.quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1195,14 +1179,10 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
num_spec = (vllm_config.speculative_config.num_speculative_tokens
|
num_spec = (vllm_config.speculative_config.num_speculative_tokens
|
||||||
if vllm_config.speculative_config else 0)
|
if vllm_config.speculative_config else 0)
|
||||||
return MambaStateShapeCalculator.gated_delta_net_state_shape(
|
return MambaStateShapeCalculator.gated_delta_net_state_shape(
|
||||||
tp_size,
|
tp_size, hf_config.linear_num_key_heads,
|
||||||
hf_config.linear_num_key_heads,
|
hf_config.linear_num_value_heads, hf_config.linear_key_head_dim,
|
||||||
hf_config.linear_num_value_heads,
|
hf_config.linear_value_head_dim, hf_config.linear_conv_kernel_dim,
|
||||||
hf_config.linear_key_head_dim,
|
num_spec)
|
||||||
hf_config.linear_value_head_dim,
|
|
||||||
hf_config.linear_conv_kernel_dim,
|
|
||||||
num_spec,
|
|
||||||
use_v1=True)
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -134,7 +134,6 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||||
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
||||||
"Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
|
|
||||||
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
|
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
|
||||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||||
|
|||||||
@ -15,12 +15,10 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import Zamba2Config
|
from transformers import Zamba2Config
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.model_executor.layers.activation import GeluAndMul
|
from vllm.model_executor.layers.activation import GeluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -29,8 +27,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
|
||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
@ -39,8 +35,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
|
||||||
MambaCacheParams)
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import HasInnerState, IsHybrid
|
from .interfaces import HasInnerState, IsHybrid
|
||||||
@ -515,8 +509,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
transformer_hidden_states: Optional[torch.Tensor] = None,
|
transformer_hidden_states: Optional[torch.Tensor] = None,
|
||||||
positions: Optional[torch.Tensor] = None,
|
positions: Optional[torch.Tensor] = None,
|
||||||
original_hidden_states: Optional[torch.Tensor] = None,
|
original_hidden_states: Optional[torch.Tensor] = None,
|
||||||
@ -525,8 +517,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
hidden_states: Input tensor [batch_size, seq_len, hidden_size]
|
hidden_states: Input tensor [batch_size, seq_len, hidden_size]
|
||||||
mamba_cache_params: Parameters for Mamba's state caches
|
|
||||||
(one for conv, one for ssm)
|
|
||||||
transformer_hidden_states: Optional output from transformer path
|
transformer_hidden_states: Optional output from transformer path
|
||||||
Added to input if provided (used in hybrid architecture)
|
Added to input if provided (used in hybrid architecture)
|
||||||
positions: Optional position IDs (unused in Mamba)
|
positions: Optional position IDs (unused in Mamba)
|
||||||
@ -555,8 +545,6 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
|||||||
self.mamba(
|
self.mamba(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
output,
|
output,
|
||||||
mamba_cache_params=mamba_cache_params,
|
|
||||||
mamba2_metadata=mamba2_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# residual connection after mamba
|
# residual connection after mamba
|
||||||
@ -607,8 +595,6 @@ class Zamba2HybridLayer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
original_hidden_states: torch.Tensor,
|
original_hidden_states: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
mamba2_metadata: Mamba2Metadata,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass through the hybrid layer.
|
"""Forward pass through the hybrid layer.
|
||||||
|
|
||||||
@ -623,8 +609,6 @@ class Zamba2HybridLayer(nn.Module):
|
|||||||
original_hidden_states: Original input for transformer residual
|
original_hidden_states: Original input for transformer residual
|
||||||
connection
|
connection
|
||||||
positions: Position IDs for positional embeddings
|
positions: Position IDs for positional embeddings
|
||||||
mamba_cache_params: Parameters for Mamba's state caches
|
|
||||||
(one for conv, one for ssm)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Output tensor combining transformer and Mamba representations
|
Output tensor combining transformer and Mamba representations
|
||||||
@ -644,8 +628,6 @@ class Zamba2HybridLayer(nn.Module):
|
|||||||
layer_outputs = self.mamba_decoder(
|
layer_outputs = self.mamba_decoder(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
transformer_hidden_states=transformer_hidden_states,
|
transformer_hidden_states=transformer_hidden_states,
|
||||||
mamba_cache_params=mamba_cache_params,
|
|
||||||
mamba2_metadata=mamba2_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return layer_outputs
|
return layer_outputs
|
||||||
@ -752,7 +734,6 @@ class Zamba2Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
mamba_cache_params: MambaCacheParams,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
"""Forward pass through the model.
|
"""Forward pass through the model.
|
||||||
@ -760,8 +741,6 @@ class Zamba2Model(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
input_ids: Input token IDs
|
input_ids: Input token IDs
|
||||||
positions: Position IDs for embeddings
|
positions: Position IDs for embeddings
|
||||||
mamba_cache_params: Parameters for Mamba's state caches
|
|
||||||
(one for conv, one for ssm)
|
|
||||||
inputs_embeds: Optional pre-computed input embeddings
|
inputs_embeds: Optional pre-computed input embeddings
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -773,33 +752,13 @@ class Zamba2Model(nn.Module):
|
|||||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
|
||||||
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
mamba2_metadata = prepare_mamba2_metadata(
|
|
||||||
chunk_size=self.config.chunk_size,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# v1 get mamba2_metadata from forward_context
|
|
||||||
mamba2_metadata = None
|
|
||||||
|
|
||||||
# Process through layers
|
# Process through layers
|
||||||
original_hidden_states = torch.clone(hidden_states)
|
original_hidden_states = torch.clone(hidden_states)
|
||||||
for layer_idx, layer in enumerate(self.layers):
|
for layer_idx, layer in enumerate(self.layers):
|
||||||
|
|
||||||
layer_mamba_cache_params = None
|
|
||||||
if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer))
|
|
||||||
and mamba_cache_params):
|
|
||||||
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
|
|
||||||
layer_idx)
|
|
||||||
|
|
||||||
layer_outputs = layer(
|
layer_outputs = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
original_hidden_states=original_hidden_states,
|
original_hidden_states=original_hidden_states,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
mamba_cache_params=layer_mamba_cache_params,
|
|
||||||
mamba2_metadata=mamba2_metadata,
|
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs
|
hidden_states = layer_outputs
|
||||||
|
|
||||||
@ -870,13 +829,11 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
def get_mamba_state_shape_from_config(
|
def get_mamba_state_shape_from_config(
|
||||||
cls,
|
cls,
|
||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
use_v1: bool = True,
|
|
||||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vllm_config: vLLM config
|
vllm_config: vLLM config
|
||||||
use_v1: Get shapes for V1 (or V0)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Tuple containing:
|
||||||
@ -896,7 +853,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
head_dim=hf_config.mamba_headdim,
|
head_dim=hf_config.mamba_headdim,
|
||||||
state_size=hf_config.mamba_d_state,
|
state_size=hf_config.mamba_d_state,
|
||||||
conv_kernel=hf_config.mamba_d_conv,
|
conv_kernel=hf_config.mamba_d_conv,
|
||||||
use_v1=use_v1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
@ -945,9 +901,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
# Tie weights with input embeddings if using same dimensions
|
# Tie weights with input embeddings if using same dimensions
|
||||||
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||||
|
|
||||||
# Used to track and store by the Mamba cache between steps.
|
|
||||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
|
||||||
|
|
||||||
# Initialize logits processing and sampling
|
# Initialize logits processing and sampling
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
config.vocab_size)
|
config.vocab_size)
|
||||||
@ -977,61 +930,15 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
Returns:
|
Returns:
|
||||||
Output hidden states
|
Output hidden states
|
||||||
"""
|
"""
|
||||||
# Initialize Mamba cache if needed
|
|
||||||
mamba_cache_params = None
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
if self.mamba_cache is None:
|
|
||||||
num_mamba_layers = self.config.num_hidden_layers
|
|
||||||
mamba_state_shape = \
|
|
||||||
self.get_mamba_state_shape_from_config(
|
|
||||||
self.vllm_config, use_v1=False)
|
|
||||||
mamba_state_dtype = \
|
|
||||||
self.get_mamba_state_dtype_from_config(
|
|
||||||
self.vllm_config)
|
|
||||||
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
|
||||||
num_mamba_layers,
|
|
||||||
*mamba_state_shape,
|
|
||||||
*mamba_state_dtype)
|
|
||||||
|
|
||||||
# Get cache parameters for current run
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
|
||||||
|
|
||||||
# Forward pass through model
|
# Forward pass through model
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
mamba_cache_params,
|
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def copy_inputs_before_cuda_graphs(
|
|
||||||
self, input_buffers: dict[str, torch.Tensor],
|
|
||||||
**kwargs: Any) -> dict[str, torch.Tensor]:
|
|
||||||
"""Copy inputs before CUDA graph capture.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_buffers: Dictionary of input tensors
|
|
||||||
**kwargs: Additional arguments passed to cache manager
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated input buffers
|
|
||||||
"""
|
|
||||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
|
||||||
input_buffers, **kwargs)
|
|
||||||
|
|
||||||
def get_seqlen_agnostic_capture_inputs(
|
|
||||||
self, batch_size: int) -> dict[str, torch.Tensor]:
|
|
||||||
"""Get inputs for sequence-length-agnostic graph capture.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch_size: Size of batch to capture
|
|
||||||
Returns:
|
|
||||||
Dictionary of capture inputs
|
|
||||||
"""
|
|
||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
|
compute_causal_conv1d_metadata,
|
||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||||
|
|
||||||
@ -52,7 +53,6 @@ class GDNAttentionMetadata:
|
|||||||
|
|
||||||
# The following attributes are for triton implementation of causal_conv1d
|
# The following attributes are for triton implementation of causal_conv1d
|
||||||
nums_dict: Optional[dict] = None
|
nums_dict: Optional[dict] = None
|
||||||
cu_seqlen: Optional[int] = None
|
|
||||||
batch_ptr: Optional[torch.Tensor] = None
|
batch_ptr: Optional[torch.Tensor] = None
|
||||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
@ -134,6 +134,7 @@ class GDNAttentionMetadataBuilder(
|
|||||||
context_lens = m.num_computed_tokens_cpu
|
context_lens = m.num_computed_tokens_cpu
|
||||||
context_lens_tensor = context_lens.to(query_start_loc.device)
|
context_lens_tensor = context_lens.to(query_start_loc.device)
|
||||||
seq_lens_tensor = m.seq_lens
|
seq_lens_tensor = m.seq_lens
|
||||||
|
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||||
|
|
||||||
if (not self.use_spec_decode or num_draft_tokens is None
|
if (not self.use_spec_decode or num_draft_tokens is None
|
||||||
or num_draft_tokens.sum().item() == 0):
|
or num_draft_tokens.sum().item() == 0):
|
||||||
@ -210,6 +211,8 @@ class GDNAttentionMetadataBuilder(
|
|||||||
has_initial_state = context_lens_tensor > 0
|
has_initial_state = context_lens_tensor > 0
|
||||||
if spec_sequence_masks is not None:
|
if spec_sequence_masks is not None:
|
||||||
has_initial_state = has_initial_state[~spec_sequence_masks]
|
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||||
|
nums_dict, batch_ptr, token_chunk_offset_ptr = \
|
||||||
|
compute_causal_conv1d_metadata(non_spec_query_start_loc)
|
||||||
else:
|
else:
|
||||||
has_initial_state = None
|
has_initial_state = None
|
||||||
num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
|
num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
|
||||||
@ -297,6 +300,9 @@ class GDNAttentionMetadataBuilder(
|
|||||||
spec_sequence_masks=spec_sequence_masks,
|
spec_sequence_masks=spec_sequence_masks,
|
||||||
spec_token_masks=spec_token_masks,
|
spec_token_masks=spec_token_masks,
|
||||||
num_accepted_tokens=num_accepted_tokens,
|
num_accepted_tokens=num_accepted_tokens,
|
||||||
|
nums_dict=nums_dict,
|
||||||
|
batch_ptr=batch_ptr,
|
||||||
|
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
|
|||||||
@ -7,11 +7,12 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.v1.attention.backends.mamba_attn import (
|
from vllm.v1.attention.backends.mamba_attn import (
|
||||||
BaseMambaAttentionMetadataBuilder)
|
BaseMambaAttentionMetadataBuilder)
|
||||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
|
||||||
|
CommonAttentionMetadata,
|
||||||
|
compute_causal_conv1d_metadata,
|
||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
@ -131,7 +132,6 @@ class Mamba2AttentionMetadata:
|
|||||||
|
|
||||||
# The following attributes are for triton implementation of causal_conv1d
|
# The following attributes are for triton implementation of causal_conv1d
|
||||||
nums_dict: Optional[dict] = None
|
nums_dict: Optional[dict] = None
|
||||||
cu_seqlen: Optional[int] = None
|
|
||||||
batch_ptr: Optional[torch.Tensor] = None
|
batch_ptr: Optional[torch.Tensor] = None
|
||||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
@ -161,6 +161,9 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
has_initial_states_p = None
|
has_initial_states_p = None
|
||||||
prep_initial_states = False
|
prep_initial_states = False
|
||||||
|
|
||||||
|
# for causal_conv1d
|
||||||
|
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||||
|
|
||||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||||
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
@ -198,6 +201,9 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
query_start_loc_p, self.chunk_size,
|
query_start_loc_p, self.chunk_size,
|
||||||
num_prefill_tokens))
|
num_prefill_tokens))
|
||||||
|
|
||||||
|
nums_dict, batch_ptr, token_chunk_offset_ptr = \
|
||||||
|
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||||
|
|
||||||
elif num_decodes <= self.decode_cudagraph_max_bs:
|
elif num_decodes <= self.decode_cudagraph_max_bs:
|
||||||
# Pad state tensor for CUDA graph
|
# Pad state tensor for CUDA graph
|
||||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||||
@ -220,5 +226,8 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
chunk_indices_p=chunk_indices_p,
|
chunk_indices_p=chunk_indices_p,
|
||||||
chunk_offsets_p=chunk_offsets_p,
|
chunk_offsets_p=chunk_offsets_p,
|
||||||
state_indices_tensor=state_indices_tensor,
|
state_indices_tensor=state_indices_tensor,
|
||||||
|
nums_dict=nums_dict,
|
||||||
|
batch_ptr=batch_ptr,
|
||||||
|
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionBackend
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
|
compute_causal_conv1d_metadata,
|
||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||||
|
|
||||||
@ -33,7 +34,6 @@ class ShortConvAttentionMetadata:
|
|||||||
|
|
||||||
# For causal_conv1d
|
# For causal_conv1d
|
||||||
nums_dict: Optional[dict] = None
|
nums_dict: Optional[dict] = None
|
||||||
cu_seqlen: Optional[int] = None
|
|
||||||
batch_ptr: Optional[torch.Tensor] = None
|
batch_ptr: Optional[torch.Tensor] = None
|
||||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
@ -57,6 +57,9 @@ class ShortConvAttentionMetadataBuilder(
|
|||||||
|
|
||||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||||
|
|
||||||
|
# for causal_conv1d
|
||||||
|
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||||
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
split_decodes_and_prefills(
|
split_decodes_and_prefills(
|
||||||
common_attn_metadata,
|
common_attn_metadata,
|
||||||
@ -70,6 +73,12 @@ class ShortConvAttentionMetadataBuilder(
|
|||||||
has_initial_states = has_initial_states_cpu.to(
|
has_initial_states = has_initial_states_cpu.to(
|
||||||
query_start_loc.device)
|
query_start_loc.device)
|
||||||
|
|
||||||
|
query_start_loc_p = common_attn_metadata.query_start_loc[
|
||||||
|
-num_prefills - 1:] - num_decode_tokens
|
||||||
|
|
||||||
|
nums_dict, batch_ptr, token_chunk_offset_ptr = \
|
||||||
|
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||||
|
|
||||||
attn_metadata = ShortConvAttentionMetadata(
|
attn_metadata = ShortConvAttentionMetadata(
|
||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
@ -78,5 +87,8 @@ class ShortConvAttentionMetadataBuilder(
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
has_initial_states=has_initial_states,
|
has_initial_states=has_initial_states,
|
||||||
state_indices_tensor=state_indices_tensor,
|
state_indices_tensor=state_indices_tensor,
|
||||||
|
nums_dict=nums_dict,
|
||||||
|
batch_ptr=batch_ptr,
|
||||||
|
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|||||||
@ -34,6 +34,8 @@ logger = init_logger(__name__)
|
|||||||
KVCacheLayoutType = Literal["NHD", "HND"]
|
KVCacheLayoutType = Literal["NHD", "HND"]
|
||||||
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None
|
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None
|
||||||
|
|
||||||
|
PAD_SLOT_ID = -1
|
||||||
|
|
||||||
|
|
||||||
def is_valid_kv_cache_layout(value: str) -> bool:
|
def is_valid_kv_cache_layout(value: str) -> bool:
|
||||||
return value in get_args(KVCacheLayoutType)
|
return value in get_args(KVCacheLayoutType)
|
||||||
@ -838,3 +840,52 @@ def create_fast_prefill_custom_backend(
|
|||||||
builder_cls=FastPrefillAttentionBuilder)
|
builder_cls=FastPrefillAttentionBuilder)
|
||||||
|
|
||||||
return attn_backend
|
return attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
|
||||||
|
|
||||||
|
# Needed for causal_conv1d
|
||||||
|
seqlens = query_start_loc_p.diff().to('cpu')
|
||||||
|
nums_dict = {} # type: ignore
|
||||||
|
batch_ptr = None
|
||||||
|
token_chunk_offset_ptr = None
|
||||||
|
for BLOCK_M in [8]: # cover all BLOCK_M values
|
||||||
|
nums = -(-seqlens // BLOCK_M)
|
||||||
|
nums_dict[BLOCK_M] = {}
|
||||||
|
nums_dict[BLOCK_M]['nums'] = nums
|
||||||
|
nums_dict[BLOCK_M]['tot'] = nums.sum().item()
|
||||||
|
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
|
||||||
|
nums_dict[BLOCK_M]['mlist'] = mlist
|
||||||
|
mlist_len = len(nums_dict[BLOCK_M]['mlist'])
|
||||||
|
nums_dict[BLOCK_M]['mlist_len'] = mlist_len
|
||||||
|
MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2
|
||||||
|
offsetlist = [] # type: ignore
|
||||||
|
for idx, num in enumerate(nums):
|
||||||
|
offsetlist.extend(range(num))
|
||||||
|
offsetlist = torch.tensor(offsetlist, dtype=torch.int32)
|
||||||
|
nums_dict[BLOCK_M]['offsetlist'] = offsetlist
|
||||||
|
|
||||||
|
if batch_ptr is None:
|
||||||
|
# Update default value after class definition
|
||||||
|
batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
||||||
|
PAD_SLOT_ID,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device='cuda')
|
||||||
|
token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
||||||
|
PAD_SLOT_ID,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device='cuda')
|
||||||
|
else:
|
||||||
|
if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
|
||||||
|
batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
|
||||||
|
token_chunk_offset_ptr.resize_( # type: ignore
|
||||||
|
MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
|
||||||
|
|
||||||
|
batch_ptr[0:mlist_len].copy_(mlist)
|
||||||
|
token_chunk_offset_ptr[ # type: ignore
|
||||||
|
0:mlist_len].copy_(offsetlist)
|
||||||
|
nums_dict[BLOCK_M]['batch_ptr'] = batch_ptr
|
||||||
|
nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (token_chunk_offset_ptr
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
return nums_dict, batch_ptr, token_chunk_offset_ptr
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user