mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 00:15:01 +08:00
[V1][Mamba1] - Full CUDA and Piecewise CUDA Graphs Support (#23035)
Signed-off-by: asafg <asafg@ai21.com> Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com> Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
parent
2461d9e562
commit
3663870c72
@ -107,7 +107,7 @@ to enable simultaneous generation and embedding using the same engine instance i
|
|||||||
#### Mamba Models
|
#### Mamba Models
|
||||||
|
|
||||||
Models using selective state-space mechanisms instead of standard transformer attention are supported.
|
Models using selective state-space mechanisms instead of standard transformer attention are supported.
|
||||||
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`.
|
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1.
|
||||||
|
|
||||||
Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
||||||
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that
|
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that
|
||||||
|
|||||||
@ -54,16 +54,14 @@ V1_SUPPORTED_MODELS = [
|
|||||||
"tiiuae/Falcon-H1-0.5B-Base",
|
"tiiuae/Falcon-H1-0.5B-Base",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
FULL_CUDA_GRAPH_MODELS = [
|
||||||
|
"ai21labs/Jamba-tiny-dev",
|
||||||
|
"Zyphra/Zamba2-1.2B-instruct",
|
||||||
|
]
|
||||||
|
|
||||||
# Avoid OOM
|
# Avoid OOM
|
||||||
MAX_NUM_SEQS = 4
|
MAX_NUM_SEQS = 4
|
||||||
|
|
||||||
# Once we add support for FCG in Mamba1, this list will be removed and tests
|
|
||||||
# all test cases will use enforce_eager=False
|
|
||||||
ENFORCE_EAGER_MODELS_V1 = [
|
|
||||||
"state-spaces/mamba-130m-hf",
|
|
||||||
"ai21labs/Jamba-tiny-dev",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
@ -101,19 +99,13 @@ def test_models(
|
|||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
if model in V1_SUPPORTED_MODELS:
|
if model in V1_SUPPORTED_MODELS:
|
||||||
enforce_eager = False
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
if model in HYBRID_MODELS:
|
if model in HYBRID_MODELS:
|
||||||
# required due to reorder_batch behaviour
|
# required due to reorder_batch behaviour
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||||
|
|
||||||
if model in ENFORCE_EAGER_MODELS_V1:
|
|
||||||
enforce_eager = True
|
|
||||||
|
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
max_num_seqs=MAX_NUM_SEQS,
|
max_num_seqs=MAX_NUM_SEQS,
|
||||||
enforce_eager=enforce_eager,
|
|
||||||
enable_prefix_caching=False) as vllm_model:
|
enable_prefix_caching=False) as vllm_model:
|
||||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
@ -373,7 +365,7 @@ def test_distributed_correctness(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
|
@pytest.mark.parametrize("model", FULL_CUDA_GRAPH_MODELS)
|
||||||
@pytest.mark.parametrize("max_tokens", [64])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
@pytest.mark.parametrize("num_logprobs", [5])
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
def test_full_cuda_graph(
|
def test_full_cuda_graph(
|
||||||
|
|||||||
@ -336,6 +336,7 @@ class CompilationConfig:
|
|||||||
"vllm.unified_attention",
|
"vllm.unified_attention",
|
||||||
"vllm.unified_attention_with_output",
|
"vllm.unified_attention_with_output",
|
||||||
"vllm.mamba_mixer2",
|
"vllm.mamba_mixer2",
|
||||||
|
"vllm.mamba_mixer",
|
||||||
]
|
]
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
|
|||||||
@ -27,6 +27,8 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
|||||||
selective_scan_fn, selective_state_update)
|
selective_scan_fn, selective_state_update)
|
||||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
||||||
|
|
||||||
|
|
||||||
@ -183,22 +185,26 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||||
if not envs.VLLM_USE_V1:
|
if not envs.VLLM_USE_V1:
|
||||||
return CustomOp.forward(self, hidden_states, mamba_cache_params)
|
CustomOp.forward(self, hidden_states, output, mamba_cache_params)
|
||||||
else:
|
else:
|
||||||
return self.forward_cuda(
|
torch.ops.vllm.mamba_mixer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
mamba_cache_params,
|
output,
|
||||||
|
self.prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_native(self,
|
def forward_native(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def forward_cuda(self,
|
def forward_cuda(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||||
"""
|
"""
|
||||||
Run the Mamba-1 SSM pipeline.
|
Run the Mamba-1 SSM pipeline.
|
||||||
@ -237,6 +243,7 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||||
ssm_state = self_kv_cache[1]
|
ssm_state = self_kv_cache[1]
|
||||||
has_initial_states = mamba1_metadata.has_initial_states
|
has_initial_states = mamba1_metadata.has_initial_states
|
||||||
|
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
||||||
else:
|
else:
|
||||||
assert isinstance(attn_metadata, AttentionMetadata)
|
assert isinstance(attn_metadata, AttentionMetadata)
|
||||||
assert mamba_cache_params is not None
|
assert mamba_cache_params is not None
|
||||||
@ -248,6 +255,7 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
has_initial_states = None
|
has_initial_states = None
|
||||||
if context_lens_tensor is not None:
|
if context_lens_tensor is not None:
|
||||||
has_initial_states = context_lens_tensor > 0
|
has_initial_states = context_lens_tensor > 0
|
||||||
|
num_padded_decodes = attn_metadata.num_decode_tokens
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||||
@ -267,6 +275,7 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||||
has_prefill = num_prefill_tokens > 0
|
has_prefill = num_prefill_tokens > 0
|
||||||
has_decode = num_decode_tokens > 0
|
has_decode = num_decode_tokens > 0
|
||||||
|
num_actual_tokens = num_prefill_tokens + num_decode_tokens
|
||||||
|
|
||||||
prefill_decode_split = split_batch_to_prefill_and_decode(
|
prefill_decode_split = split_batch_to_prefill_and_decode(
|
||||||
hidden_states_BC,
|
hidden_states_BC,
|
||||||
@ -278,6 +287,7 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
num_decode_tokens,
|
num_decode_tokens,
|
||||||
num_prefills,
|
num_prefills,
|
||||||
num_decodes,
|
num_decodes,
|
||||||
|
num_padded_decodes,
|
||||||
)
|
)
|
||||||
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
|
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
|
||||||
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
|
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
|
||||||
@ -371,7 +381,7 @@ class MambaMixer(MambaBase, CustomOp):
|
|||||||
else:
|
else:
|
||||||
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
|
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
|
||||||
|
|
||||||
return out
|
output[:num_actual_tokens] = out
|
||||||
|
|
||||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||||
assert self.model_config is not None
|
assert self.model_config is not None
|
||||||
@ -421,18 +431,27 @@ def split_batch_to_prefill_and_decode(
|
|||||||
num_decode_tokens: int,
|
num_decode_tokens: int,
|
||||||
num_prefills: int,
|
num_prefills: int,
|
||||||
num_decodes: int,
|
num_decodes: int,
|
||||||
|
num_padded_decodes: int,
|
||||||
) -> PrefillDecodeSplit:
|
) -> PrefillDecodeSplit:
|
||||||
|
num_actual_tokens = num_prefill_tokens + num_padded_decodes
|
||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
# In v1, decode tokens come first, then prefill tokens.
|
# In v1, decode tokens come first, then prefill tokens.
|
||||||
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||||
hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1)
|
hidden_states_BC[..., :num_actual_tokens],
|
||||||
gate_d, gate_p = torch.split(gate,
|
[num_padded_decodes, num_prefill_tokens],
|
||||||
[num_decode_tokens, num_prefill_tokens],
|
|
||||||
dim=-1)
|
dim=-1)
|
||||||
|
gate_d, gate_p = torch.split(gate[..., :num_actual_tokens],
|
||||||
|
[num_padded_decodes, num_prefill_tokens],
|
||||||
|
dim=-1)
|
||||||
|
|
||||||
|
# num_padded_decodes accounts for CUDA graph padding when applicable
|
||||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||||
state_indices_tensor, [num_decodes, num_prefills], dim=0)
|
state_indices_tensor[:num_padded_decodes + num_prefills],
|
||||||
|
[num_padded_decodes, num_prefills],
|
||||||
|
dim=0)
|
||||||
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
|
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
|
||||||
num_decodes if num_prefills > 0 else None)
|
num_padded_decodes if num_prefills > 0 else None)
|
||||||
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
||||||
has_initial_states is not None and num_prefills > 0) else None
|
has_initial_states is not None and num_prefills > 0) else None
|
||||||
else:
|
else:
|
||||||
@ -459,3 +478,32 @@ def split_batch_to_prefill_and_decode(
|
|||||||
query_start_loc_p=query_start_loc_p,
|
query_start_loc_p=query_start_loc_p,
|
||||||
has_initial_states_p=has_initial_states_p,
|
has_initial_states_p=has_initial_states_p,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mamba_mixer(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
|
self.forward_cuda(hidden_states=hidden_states,
|
||||||
|
output=output,
|
||||||
|
mamba_cache_params=None)
|
||||||
|
|
||||||
|
|
||||||
|
def mamba_mixer_fake(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="mamba_mixer",
|
||||||
|
op_func=mamba_mixer,
|
||||||
|
mutates_args=["output"],
|
||||||
|
fake_impl=mamba_mixer_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from transformers import JambaConfig
|
|||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
@ -154,10 +155,10 @@ class JambaMambaDecoderLayer(nn.Module):
|
|||||||
hidden_states, residual = self.input_layernorm(
|
hidden_states, residual = self.input_layernorm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.mamba(hidden_states, mamba_cache_params)
|
output = torch.empty_like(hidden_states)
|
||||||
|
self.mamba(hidden_states, output, mamba_cache_params)
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.pre_ff_layernorm(
|
hidden_states, residual = self.pre_ff_layernorm(output, residual)
|
||||||
hidden_states, residual)
|
|
||||||
hidden_states = self.feed_forward(hidden_states)
|
hidden_states = self.feed_forward(hidden_states)
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
@ -278,6 +279,7 @@ ALL_DECODER_LAYER_TYPES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class JambaModel(nn.Module):
|
class JambaModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from torch import nn
|
|||||||
from transformers import MambaConfig
|
from transformers import MambaConfig
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -81,10 +82,12 @@ class MambaDecoderLayer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden_states, residual = self.norm(hidden_states, residual)
|
hidden_states, residual = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.mixer(hidden_states, mamba_cache_params)
|
output = torch.empty_like(hidden_states)
|
||||||
return hidden_states, residual
|
self.mixer(hidden_states, output, mamba_cache_params)
|
||||||
|
return output, residual
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class MambaModel(nn.Module):
|
class MambaModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
|||||||
@ -2,16 +2,16 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import ClassVar, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.config import VllmConfig
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.mamba_attn import (
|
||||||
CommonAttentionMetadata,
|
BaseMambaAttentionMetadataBuilder)
|
||||||
|
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
|
||||||
|
|
||||||
|
|
||||||
class Mamba1AttentionBackend(AttentionBackend):
|
class Mamba1AttentionBackend(AttentionBackend):
|
||||||
@ -31,24 +31,11 @@ class Mamba1AttentionMetadata:
|
|||||||
num_prefill_tokens: int
|
num_prefill_tokens: int
|
||||||
num_decodes: int
|
num_decodes: int
|
||||||
num_decode_tokens: int
|
num_decode_tokens: int
|
||||||
|
num_padded_decodes: int
|
||||||
|
|
||||||
|
|
||||||
class Mamba1AttentionMetadataBuilder(
|
class Mamba1AttentionMetadataBuilder(
|
||||||
AttentionMetadataBuilder[Mamba1AttentionMetadata]):
|
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]):
|
||||||
reorder_batch_threshold: ClassVar[int] = 1
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
kv_cache_spec: AttentionSpec,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
device: torch.device,
|
|
||||||
layer_names: list[str],
|
|
||||||
):
|
|
||||||
assert isinstance(kv_cache_spec, MambaSpec)
|
|
||||||
self.kv_cache_spec = kv_cache_spec
|
|
||||||
self.device = device
|
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.layer_names = layer_names
|
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
@ -67,9 +54,18 @@ class Mamba1AttentionMetadataBuilder(
|
|||||||
decode_threshold=1))
|
decode_threshold=1))
|
||||||
|
|
||||||
has_initial_states = None
|
has_initial_states = None
|
||||||
|
padded_decodes = num_decodes
|
||||||
|
|
||||||
if num_prefills > 0:
|
if num_prefills > 0:
|
||||||
has_initial_states = context_lens_tensor > 0
|
has_initial_states = context_lens_tensor > 0
|
||||||
|
elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs
|
||||||
|
and self.compilation_config.full_cuda_graph):
|
||||||
|
state_indices_for_decode = state_indices_tensor[:num_decodes]
|
||||||
|
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||||
|
self.state_indices_tensor[:num_decodes].copy_(
|
||||||
|
state_indices_for_decode, non_blocking=True)
|
||||||
|
state_indices_tensor = self.state_indices_tensor[:padded_decodes]
|
||||||
|
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||||
|
|
||||||
return Mamba1AttentionMetadata(
|
return Mamba1AttentionMetadata(
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
@ -80,4 +76,5 @@ class Mamba1AttentionMetadataBuilder(
|
|||||||
num_prefill_tokens=num_prefill_tokens,
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
num_decodes=num_decodes,
|
num_decodes=num_decodes,
|
||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
num_padded_decodes=padded_decodes,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -2,18 +2,18 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import ClassVar, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
from vllm.v1.attention.backends.mamba_attn import (
|
||||||
AttentionMetadataBuilder,
|
BaseMambaAttentionMetadataBuilder)
|
||||||
CommonAttentionMetadata,
|
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
|
|
||||||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
||||||
@ -88,29 +88,14 @@ class Mamba2AttentionMetadata:
|
|||||||
|
|
||||||
|
|
||||||
class Mamba2AttentionMetadataBuilder(
|
class Mamba2AttentionMetadataBuilder(
|
||||||
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
||||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
|
||||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
|
||||||
|
|
||||||
reorder_batch_threshold: ClassVar[int] = 1
|
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
vllm_config: VllmConfig, device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
assert isinstance(kv_cache_spec, MambaSpec)
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
self.kv_cache_spec = kv_cache_spec
|
|
||||||
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.compilation_config = vllm_config.compilation_config
|
|
||||||
assert self.chunk_size is not None, (
|
assert self.chunk_size is not None, (
|
||||||
"chunk_size needs to be set in the model config for Mamba2 models")
|
"chunk_size needs to be set in the model config for Mamba2 models")
|
||||||
self.decode_cudagraph_max_bs = min(
|
|
||||||
self.vllm_config.scheduler_config.max_num_seqs,
|
|
||||||
self.compilation_config.max_capture_size)
|
|
||||||
self.state_indices_tensor = torch.empty(
|
|
||||||
(self.decode_cudagraph_max_bs, ),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
def build(self,
|
def build(self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
@ -187,19 +172,3 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
state_indices_tensor=state_indices_tensor,
|
state_indices_tensor=state_indices_tensor,
|
||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
def build_for_cudagraph_capture(
|
|
||||||
self, common_attn_metadata: CommonAttentionMetadata):
|
|
||||||
"""
|
|
||||||
This method builds the metadata for full cudagraph capture.
|
|
||||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
|
||||||
"""
|
|
||||||
m = common_attn_metadata
|
|
||||||
|
|
||||||
assert m.num_reqs == m.num_actual_tokens, \
|
|
||||||
"Mamba only supports decode-only full CUDAGraph capture. " \
|
|
||||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
|
||||||
|
|
||||||
m.max_query_len = 1 # decode-only
|
|
||||||
|
|
||||||
return self.build(0, m)
|
|
||||||
|
|||||||
55
vllm/v1/attention/backends/mamba_attn.py
Normal file
55
vllm/v1/attention/backends/mamba_attn.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from typing import ClassVar, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||||
|
AttentionMetadataBuilder,
|
||||||
|
CommonAttentionMetadata)
|
||||||
|
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||||
|
|
||||||
|
M = TypeVar("M")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||||
|
reorder_batch_threshold: ClassVar[int] = 1
|
||||||
|
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||||
|
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
|
|
||||||
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
|
assert isinstance(kv_cache_spec, MambaSpec)
|
||||||
|
self.kv_cache_spec = kv_cache_spec
|
||||||
|
self.device = device
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.layer_names = layer_names
|
||||||
|
|
||||||
|
self.compilation_config = vllm_config.compilation_config
|
||||||
|
self.decode_cudagraph_max_bs = min(
|
||||||
|
self.vllm_config.scheduler_config.max_num_seqs,
|
||||||
|
self.compilation_config.max_capture_size)
|
||||||
|
self.state_indices_tensor = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_for_cudagraph_capture(
|
||||||
|
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||||
|
"""
|
||||||
|
This method builds the metadata for full cudagraph capture.
|
||||||
|
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||||
|
"""
|
||||||
|
m = common_attn_metadata
|
||||||
|
|
||||||
|
assert m.num_reqs == m.num_actual_tokens, \
|
||||||
|
"Mamba only supports decode-only full CUDAGraph capture. " \
|
||||||
|
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||||
|
|
||||||
|
m.max_query_len = 1 # decode-only
|
||||||
|
|
||||||
|
return self.build(0, m)
|
||||||
Loading…
x
Reference in New Issue
Block a user