[V1] [Hybrid] Mamba2 Automatic Prefix Caching (#25752)

Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Thomas Ortner <boh@zurich.ibm.com>
Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Stan Wozniak 2025-10-04 06:34:22 +02:00 committed by GitHub
parent 9705fba7b7
commit ea507c3a93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 917 additions and 147 deletions

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable
import pytest
from tests.models.registry import HF_EXAMPLE_MODELS
@ -8,7 +10,7 @@ from tests.utils import multi_gpu_test
from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams
from ...utils import check_logprobs_close
from ...utils import check_logprobs_close, check_outputs_equal
# Mark all tests as hybrid
pytestmark = pytest.mark.hybrid_model
@ -332,3 +334,413 @@ def test_fp32_cache_state(
name_0="hf",
name_1="vllm",
)
# Helper functions for the APC tests
def _get_vllm_runner_params(model, max_model_len, tensor_parallel_size=1):
return {
'model_name': model,
'enable_prefix_caching': False,
'max_model_len': max_model_len,
'tensor_parallel_size': tensor_parallel_size,
'gpu_memory_utilization': 0.4
}
def _get_vLLM_output(vllm_runner,
kwargs,
prompts,
max_tokens,
num_logprobs,
num_repetitions=1,
vllm_model=None):
outs = []
if vllm_model is None:
vllm_model = vllm_runner(**kwargs)
for _ in range(num_repetitions):
if num_logprobs < 0:
vllm_output = vllm_model.generate_greedy(prompts, max_tokens)
else:
vllm_output = vllm_model.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs)
outs.append(vllm_output)
return outs, vllm_model
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_single_prompt(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
n_repetitions: int,
num_logprobs: int,
tensor_parallel_size: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
compare_operator: Callable = check_logprobs_close \
if num_logprobs > 0 else check_outputs_equal # type: ignore
MULTIPLE = 300
# Sample prompts.
generated_prompts = [MULTIPLE * example_prompts[0]]
max_model_len = max(
len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts, max_tokens,
num_logprobs)
vllm_runner_kwargs['enable_prefix_caching'] = True
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts, max_tokens,
num_logprobs, n_repetitions)
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
# In the first repetition, the caches are filled
# In the second repetition, these caches are reused
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0],
outputs_1_lst=vllm_outputs_cache_itn,
name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}",
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_single_prompt_block_align_alignment(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
n_repetitions: int,
num_logprobs: int,
tensor_parallel_size: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
compare_operator: Callable = check_logprobs_close \
if num_logprobs > 0 else check_outputs_equal # type: ignore
MULTIPLE = 300
# Sample prompts. This custom prompt is used, as it causes the most issues
generated_prompts = ["The president of the United States is " * MULTIPLE]
max_model_len = max(
len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts, max_tokens,
num_logprobs)
vllm_runner_kwargs['enable_prefix_caching'] = True
with vllm_runner(**vllm_runner_kwargs) as vllm_model:
# Retrieve the default mamba state block size
mamba_block_size = vllm_model.llm.llm_engine.cache_config. \
mamba_block_size
# In case the hybrid model does not have the
# "mamba_block_size" assume a fixed constant
if mamba_block_size is None:
mamba_block_size = 512
mamba_block_size_multiplier = 10
for offsets in [
-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3
]:
vllm_runner_kwargs[
'max_num_batched_tokens'] = mamba_block_size_multiplier * \
mamba_block_size - offsets
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts,
max_tokens, num_logprobs,
n_repetitions)
# Check alignment of the output logits when using APC
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
# In the first repetition, the caches are filled
# In the second repetition, these caches are reused
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0],
outputs_1_lst=vllm_outputs_cache_itn,
name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}",
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_multiple_prompts_all_cached_outputs(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
n_repetitions: int,
num_logprobs: int,
tensor_parallel_size: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
compare_operator: Callable = check_logprobs_close \
if num_logprobs > 0 else check_outputs_equal # type: ignore
MULTIPLE = 300
# Sample prompts.
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
max_model_len = max(
len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts, max_tokens,
num_logprobs)
vllm_runner_kwargs['enable_prefix_caching'] = True
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts, max_tokens,
num_logprobs, n_repetitions)
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
# In the first repetition, the caches are filled
# In the second repetition, these caches are reused
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0],
outputs_1_lst=vllm_outputs_cache_itn,
name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}",
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_multiple_prompts_block_align_alignment(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
n_repetitions: int,
num_logprobs: int,
tensor_parallel_size: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
compare_operator: Callable = check_logprobs_close \
if num_logprobs > 0 else check_outputs_equal # type: ignore
MULTIPLE = 300
# Sample prompts. This custom prompt is used, as it causes the most issues
prompt_text = "The president of the United States is "
prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31]
generated_prompts = [
prompt_text[offset:] * MULTIPLE for offset in prompt_offsets
]
max_model_len = max(
len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(model, max_model_len,
tensor_parallel_size)
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts, max_tokens,
num_logprobs)
vllm_runner_kwargs['enable_prefix_caching'] = True
with vllm_runner(**vllm_runner_kwargs) as vllm_model:
# Retrieve the default mamba state block size
mamba_block_size = vllm_model.llm.llm_engine.cache_config. \
mamba_block_size
# In case the hybrid model does not have the
# "mamba_block_size" assume a fixed constant
if mamba_block_size is None:
mamba_block_size = 512
mamba_block_size_multiplier = 10
for offsets in [
-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3
]:
vllm_runner_kwargs[
'max_num_batched_tokens'] = mamba_block_size_multiplier * \
mamba_block_size - offsets
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts,
max_tokens, num_logprobs,
n_repetitions)
# Check alignment of the output logits when using APC
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
# In the first repetition, the caches are filled
# In the second repetition, these caches are reused
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0],
outputs_1_lst=vllm_outputs_cache_itn,
name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}",
)
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_multiple_prompts_partial_cached_outputs(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
n_repetitions: int,
num_logprobs: int,
tensor_parallel_size: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
compare_operator: Callable = check_logprobs_close \
if num_logprobs > 0 else check_outputs_equal # type: ignore
MULTIPLE = 300
# Sample prompts.
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
max_model_len = max(
len(prompt) + max_tokens for prompt in generated_prompts)
vllm_runner_kwargs = _get_vllm_runner_params(
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts, max_tokens,
num_logprobs)
# Cache only part of all the prompts
vllm_runner_kwargs['enable_prefix_caching'] = True
vllm_outputs_partial_cache, vllm_model = _get_vLLM_output(
vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens,
num_logprobs)
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0][:3],
outputs_1_lst=vllm_outputs_partial_cache[0],
name_0="vllm_no_cache",
name_1="vllm_partial_cache",
)
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
vllm_runner_kwargs,
generated_prompts,
max_tokens,
num_logprobs,
n_repetitions,
vllm_model=vllm_model)
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
# In the first repetition, the caches are filled
# In the second repetition, these caches are reused
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0],
outputs_1_lst=vllm_outputs_cache_itn,
name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}",
)

View File

@ -92,7 +92,8 @@ class CacheConfig:
mamba_page_size_padded: Optional[int] = None
""" Optional override for mamba page size; used by hybrid mamba/attention
models to ensure exact alignment with attention page size."""
mamba_block_size: Optional[int] = None
"""Size of a contiguous cache block in number of tokens for mamba cache."""
mamba_cache_dtype: MambaDType = "auto"
"""The data type to use for the Mamba cache (both the conv as well as the
ssm state). If set to 'auto', the data type will be inferred from the model

View File

@ -1563,7 +1563,12 @@ class EngineArgs:
self.enable_prefix_caching = False
if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
# Disable prefix caching default for hybrid models
# since the feature is still experimental.
if model_config.is_hybrid:
self.enable_prefix_caching = False
else:
self.enable_prefix_caching = True
else:
pooling_type = model_config.pooler_config.pooling_type

View File

@ -489,6 +489,9 @@ class MambaMixer2(MambaBase, CustomOp):
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = forward_context.attn_metadata
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
@ -573,6 +576,25 @@ class MambaMixer2(MambaBase, CustomOp):
dim=0,
)
if prefix_caching_enabled:
# If prefix caching is enabled, retrieve the relevant variables
# for prefill and decode
last_state_idx_d, last_state_idx_p = torch.split(
attn_metadata.last_state_idx, [num_decodes, num_prefills],
dim=0)
current_last_idx_d, current_last_idx_p = torch.split(
attn_metadata.current_last_idx, [num_decodes, num_prefills],
dim=0)
# Prefill-only variables:
current_first_idx_p = attn_metadata.current_first_idx_p
context_lens_p = attn_metadata.context_lens_p
last_computed_offset_p = attn_metadata.last_computed_offset_p
else:
last_state_idx_d, last_state_idx_p = None, None
current_last_idx_d, current_last_idx_p = None, None
current_first_idx_p = None
context_lens_p = None
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
@ -592,8 +614,17 @@ class MambaMixer2(MambaBase, CustomOp):
# Process prefill requests
if has_prefill:
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "state_indices_tensor"
# - It will read the initial states for every sequence,
# that has "has_initial_states_p" == True,
# from "cache_indices", using "state_indices_tensor_p".
# - It updates the "conv_state" cache in positions pointed
# to by "state_indices_tensor_p".
# In particular, it will always write the state at the
# sequence end.
# In addition, "current_first_idx_p" and "current_last_idx_p"
# are provided (which are pointers into
# "state_indices_tensor_p"), it will write additional cache
# states aligned at "block_size_to_align".
x = hidden_states_B_C_p.transpose(
0, 1) # this is the form that causal-conv see
hidden_states_B_C_p = causal_conv1d_fn(
@ -604,6 +635,11 @@ class MambaMixer2(MambaBase, CustomOp):
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
current_first_idx=current_first_idx_p,
current_last_idx=current_last_idx_p,
initial_state_idx=last_state_idx_p,
context_lens=context_lens_p,
block_size_to_align=mamba_block_size,
metadata=attn_metadata,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
@ -614,9 +650,13 @@ class MambaMixer2(MambaBase, CustomOp):
# 3. State Space Model sequence transformation
initial_states = None
if (has_initial_states_p is not None and prep_initial_states):
kernel_ssm_indices = state_indices_tensor_p
if prefix_caching_enabled:
kernel_ssm_indices = state_indices_tensor_p.gather(
1, last_state_idx_p.unsqueeze(1)).squeeze(1)
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)
ssm_state[kernel_ssm_indices], 0)
# NOTE: final output is an in-place update of out tensor
varlen_states = mamba_chunk_scan_combined_varlen(
@ -638,18 +678,71 @@ class MambaMixer2(MambaBase, CustomOp):
cu_chunk_seqlens=cu_chunk_seqlen_p,
last_chunk_indices=last_chunk_indices_p,
initial_states=initial_states,
return_intermediate_states=prefix_caching_enabled,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
self.head_dim),
state_dtype=ssm_state.dtype)
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
ssm_state[state_indices_tensor_p] = varlen_states
if prefix_caching_enabled:
# Save states for sequences with more than just the final state:
n_blocks_to_fill = current_last_idx_p - current_first_idx_p
for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1):
cache_blocks_to_fill = state_indices_tensor_p[
seq_idx, current_first_idx_p[seq_idx]:
current_first_idx_p[seq_idx] +
n_blocks_to_fill[seq_idx]]
# chunks = [0 1 2 3 4 5 6 ...]
# First aligned chunk would typically be:
# mamba_block_size = 1024, chunk_size = 256
# 1024 // 256 - 1 --> chunks[3]
# But when last chunk wasn't block aligned:
# - last_computed_offset_p[seq_idx] // chunk_size
# e.g. 1000 // 256 -> 3 completed --> store chunk[0]
# e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
chunk_stride = mamba_block_size // chunk_size
first_aligned_chunk = \
torch.concat([torch.zeros(1, \
dtype=last_chunk_indices_p.dtype, \
device=last_chunk_indices_p.device), \
last_chunk_indices_p + 1])[seq_idx] \
+ chunk_stride - 1 \
- last_computed_offset_p[seq_idx] // chunk_size
from_where = varlen_states[
first_aligned_chunk:first_aligned_chunk +
n_blocks_to_fill[seq_idx] * chunk_stride:chunk_stride]
ssm_state[cache_blocks_to_fill] = from_where
#For all seqs, store the last state (Note: might be partial):
ssm_state[state_indices_tensor_p.gather(1,
current_last_idx_p.unsqueeze(1)).squeeze(1)] = \
varlen_states[last_chunk_indices_p]
else:
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate)
# tensor
ssm_state[state_indices_tensor_p] = varlen_states
# Process decode requests
if has_decode:
if prefix_caching_enabled:
state_indices_tensor_d_input = \
state_indices_tensor_d.gather(1,
last_state_idx_d.unsqueeze(1)).squeeze(1)
state_indices_tensor_d_output = \
state_indices_tensor_d.gather(1,
current_last_idx_d.unsqueeze(1)).squeeze(1)
#Note:
# for decode always: current_first_idx_d == current_last_idx_d
# at block boundaries: current_first_idx_d > last_state_idx_d
else:
# Without caching, read and write in-place to the same blocks:
state_indices_tensor_d_input = state_indices_tensor_d
state_indices_tensor_d_output = state_indices_tensor_d
# 2. Convolution sequence transformation
hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d,
@ -657,7 +750,10 @@ class MambaMixer2(MambaBase, CustomOp):
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d)
conv_state_indices=state_indices_tensor_d,
current_last_idx=current_last_idx_d,
initial_state_idx=last_state_idx_d,
)
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
hidden_states_B_C_d)
@ -689,7 +785,8 @@ class MambaMixer2(MambaBase, CustomOp):
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1,
self.head_dim),
)

View File

@ -20,19 +20,23 @@ def _causal_conv1d_fwd_kernel( # continuous batching
w_ptr, # (dim, width)
bias_ptr,
initial_states_ptr, # conv_states_ptr
cache_indices_ptr, # conv_state_indices_ptr
cache_indices_ptr, # (batch, n_blocks + padding) The second dimension contains
# the block indices relevant for each sequence
# plus potential 0-padding at the beginning and at the end
has_initial_states_ptr,
query_start_loc_ptr,
batch_ptr,
token_chunk_offset_ptr,
current_first_idx, # (batch,)
current_last_idx, # (batch,)
initial_state_idx, # (batch,)
context_lens, # (batch,)
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
# Matrix dimensions
batch: tl.int32, # actually padded_batch
dim: tl.constexpr,
seqlen: tl.int32, # cu_seqlen
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
# Strides
stride_x_seq: tl.constexpr, # stride to get to next sequence,
stride_x_dim: tl.constexpr, # stride to get to next feature-value,
stride_x_token: tl.
constexpr, # stride to get to next token (same feature-index, same sequence-index)
@ -42,18 +46,16 @@ def _causal_conv1d_fwd_kernel( # continuous batching
stride_istate_dim: tl.constexpr,
stride_istate_token: tl.constexpr,
stride_cache_indices: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M
# others
pad_slot_id: tl.constexpr,
# Meta-parameters
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
HAS_INITIAL_STATES: tl.constexpr,
HAS_CACHE: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_APC_ENABLED: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
NP2_STATELEN: tl.constexpr,
BLOCK_M: tl.constexpr,
@ -84,26 +86,57 @@ def _causal_conv1d_fwd_kernel( # continuous batching
# find the actual sequence length
seqlen = sequence_end_index - sequence_start_index
B_size: tl.constexpr = (stride_block_m * BLOCK_M)
if IS_APC_ENABLED:
# Handle the case if prefix caching is enabled.
# In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr"
# Get the length of the completed sequence so far and compute the offset.
current_first_index = tl.load(current_first_idx + idx_seq)
current_last_index = tl.load(current_last_idx + idx_seq)
sequence_completed_index = tl.load(context_lens + idx_seq)
# Compute the offset where the first stride_block_m-aligned first full block is
# Value in "token-space"
sequence_completed_offset_token = sequence_completed_index % B_size
seq_completed_offset = B_size - sequence_completed_offset_token
seq_end_offset = (seqlen - seq_completed_offset) % B_size
last_full_block_token_index = sequence_end_index - seq_end_offset
# If the sequence without the sequence_offset_index is stride_cache_chunk-aligned, then the last full chunk is the second-to-last one
if seq_end_offset == 0:
last_full_block_token_index = last_full_block_token_index - B_size
# Get the number of blocks to be filled for the current sequence
# If n_block_to_fill = 0, then only the state at the sequence end is stored
n_block_to_fill = current_last_index - current_first_index
# Get the index of the init block
conv_state_init_index = tl.load(initial_state_idx + idx_seq)
else:
n_block_to_fill = 0
current_last_index = 0
conv_state_init_index = 0
current_first_index = 0
last_full_block_token_index = 0
token_offset = BLOCK_M * chunk_offset
segment_len = min(BLOCK_M, seqlen - token_offset)
# base of the sequence
x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,]
if IS_CONTINUOUS_BATCHING:
# cache_idx
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices).to(
tl.int64)
else:
# cache_idx
conv_state_batch_coord = idx_seq
# cache_idx
conv_states_input_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices +
conv_state_init_index).to(tl.int64)
if USE_PAD_SLOT: # noqa
if conv_state_batch_coord == pad_slot_id:
if conv_states_input_coord == pad_slot_id:
# not processing as this is not the actual sequence
return
conv_states_base = (conv_states_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(conv_states_input_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
@ -113,10 +146,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
# 2. update conv_state with new data [only by the Triton program handles chunk_offset=0]
if chunk_offset == 0:
# read from conv_states
load_init_state = False
if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES
load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(
tl.int1)
load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1)
if load_init_state:
# load from conv_states
prior_tokens = conv_states_base + (state_len -
@ -175,15 +205,23 @@ def _causal_conv1d_fwd_kernel( # continuous batching
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
conv_states_ptrs_target = conv_states_base[None, :] + (
idx_tokens_conv * stride_conv_state_tok)[:, None]
# Compute the offset where the last block should be written in the conv_states
conv_states_output_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices +
current_last_index).to(tl.int64)
conv_states_ptrs_target = (
conv_states_ptr + (conv_states_output_coord *
stride_conv_state_seq) + # Offset from seq
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
idx_tokens_conv * stride_conv_state_tok)[:, None]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
< dim)[None, :]
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
tl.store(conv_states_ptrs_target, new_conv_state, mask)
tl.store(conv_states_ptrs_target, loaded_x, mask)
else:
if load_init_state:
@ -192,12 +230,12 @@ def _causal_conv1d_fwd_kernel( # continuous batching
conv_states_ptrs_source = (
conv_states_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(conv_states_input_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:,
None]
) # [BLOCK_M, BLOCK_N]
mask = ((conv_state_batch_coord < num_cache_lines)
mask = ((conv_states_input_coord < num_cache_lines)
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
@ -280,6 +318,45 @@ def _causal_conv1d_fwd_kernel( # continuous batching
conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
# Store intermediate states aligned with stride_block_m
# The additional states are cached starting from the last stride_block_m.
# For example:
# If n_block_to_fill = 0, then only the state at the sequence end is cached and the process below is not involved.
# If n_block_to_fill > 0, then the states at the sequence end and at the n_block_to_fill-last
# stride_block_m are cached.
# For example chunk_offset = n_block_to_fill stores the state at last_full_block
if (chunk_offset - 1) < n_block_to_fill:
# Store the states at the chunk boundaries from the start of the sequence
idx_tokens_last = (last_full_block_token_index -
(n_block_to_fill - chunk_offset) * B_size -
state_len) + tl.arange(
0, NP2_STATELEN) # [BLOCK_M]
x_ptrs = x_ptr + (idx_tokens_last * stride_x_token)[:, None] + (
idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,]
mask_x = (
(idx_tokens_last >= 0)[:, None] & (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
# cache_idx
conv_states_output_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices +
current_first_index +
(chunk_offset - 1)).to(tl.int64)
conv_states_ptrs_target = (
conv_states_ptr + (conv_states_output_coord *
stride_conv_state_seq) + # Offset from seq
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
idx_tokens_conv * stride_conv_state_tok)[:, None]
mask = (idx_tokens_conv < state_len)[:, None] & \
(idx_feats < dim)[None, :]
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
tl.store(conv_states_ptrs_target, loaded_x, mask)
if HAS_BIAS:
bias = bias_ptr + idx_feats
mask_bias = idx_feats < dim
@ -368,6 +445,11 @@ def causal_conv1d_fn(
has_initial_state: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
current_first_idx: Optional[torch.Tensor] = None,
current_last_idx: Optional[torch.Tensor] = None,
initial_state_idx: Optional[torch.Tensor] = None,
context_lens: Optional[torch.Tensor] = None,
block_size_to_align=0,
metadata=None,
validate_data=False,
):
@ -378,7 +460,7 @@ def causal_conv1d_fn(
sequences are concatenated from left to right for varlen
weight: (dim, width)
conv_states: (...,dim,width - 1) itype
updated inplace if provided
updated inplace if cache_indices are not provided
[it use `cache_indices` to get the index to the cache of conv_state for that sequence
conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True
@ -410,7 +492,16 @@ def causal_conv1d_fn(
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
current_first_idx: (batch,), dtype int32
The pointer into cache_indices, where the first cache block to be filled is located.
current_last_idx: (batch,), dtype int32
The pointer into cache_indices, where the last cache block to be filled is located.
initial_state_idx: (batch,), dtype int32
The pointer into cache_indices, where the cache block containing the initial state is located.
context_lens: (batch,), dtype int32
The number of tokens already completed for each sequence
block_size_to_align: int
The block size to align the cached states to
out: same shape as `x`
"""
if isinstance(activation, bool) and activation:
@ -451,7 +542,6 @@ def causal_conv1d_fn(
np2_statelen = triton.next_power_of_2(state_len)
padded_batch = query_start_loc.size(0) - 1
stride_x_seq = 0
stride_x_dim = x.stride(0)
stride_x_token = x.stride(1)
stride_w_dim = weight.stride(0)
@ -460,6 +550,7 @@ def causal_conv1d_fn(
stride_istate_dim = 0
stride_istate_token = 0
num_cache_lines = 0
BLOCK_M = 8
if conv_states is not None:
# extensions to support vLLM:
# 1. conv_states is used to replaced initial_states
@ -475,11 +566,9 @@ def causal_conv1d_fn(
stride_istate_token = conv_states.stride(2)
assert stride_istate_dim == 1
if out.dim() == 2:
stride_o_seq = 0
stride_o_dim = out.stride(0)
stride_o_token = out.stride(1)
else:
stride_o_seq = out.stride(0)
stride_o_dim = out.stride(1)
stride_o_token = out.stride(2)
stride_cache_indices = cache_indices.stride(
@ -502,6 +591,12 @@ def causal_conv1d_fn(
assert weight.stride(1) == 1
assert (dim, width) == weight.shape
assert is_channel_last, "Need to run in channel-last layout"
if block_size_to_align is not None and block_size_to_align > 0:
assert (
block_size_to_align % BLOCK_M
) == 0, "The mamba block size needs to be divisible by the BLOCK_M"
else:
block_size_to_align = BLOCK_M
if metadata is None:
@ -584,14 +679,16 @@ def causal_conv1d_fn(
query_start_loc,
batch_ptr,
token_chunk_offset_ptr,
current_first_idx,
current_last_idx,
initial_state_idx,
context_lens,
out,
# Matrix dimensions
padded_batch,
dim,
cu_seqlen,
num_cache_lines,
# stride
stride_x_seq,
stride_x_dim,
stride_x_token,
stride_w_dim,
@ -600,22 +697,20 @@ def causal_conv1d_fn(
stride_istate_dim,
stride_istate_token,
stride_cache_indices,
stride_o_seq,
stride_o_dim,
stride_o_token,
block_size_to_align // BLOCK_M,
# others
pad_slot_id,
# META
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
HAS_INITIAL_STATES=has_initial_state is not None,
HAS_CACHE=conv_states is not None,
IS_CONTINUOUS_BATCHING=cache_indices is not None,
IS_APC_ENABLED=current_last_idx is not None,
USE_PAD_SLOT=pad_slot_id is not None,
NP2_STATELEN=np2_statelen,
#launch_cooperative_grid=True
BLOCK_M=8,
BLOCK_M=BLOCK_M,
BLOCK_N=256,
num_stages=2,
)
@ -629,10 +724,11 @@ def _causal_conv1d_update_kernel(
w_ptr, # (dim, width)
bias_ptr,
conv_state_ptr,
cache_seqlens_ptr, # circular buffer
conv_state_indices_ptr,
num_accepted_tokens_ptr,
query_start_loc_ptr, # (batch + 1)
current_last_idx, # (batch,)
initial_state_idx, #(batch,)
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
batch: int,
@ -660,7 +756,7 @@ def _causal_conv1d_update_kernel(
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_APC_ENABLED: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
@ -674,15 +770,21 @@ def _causal_conv1d_update_kernel(
# [BLOCK_N,] elements along the feature-dimension (channel)
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
if IS_CONTINUOUS_BATCHING:
# mask = idx_seq < batch
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_state_indices).to(
tl.int64)
if IS_APC_ENABLED:
# Get the state from the initial_state_idx
conv_state_init = tl.load(initial_state_idx + idx_seq)
current_last_index = tl.load(current_last_idx + idx_seq)
else:
conv_state_batch_coord = idx_seq
conv_state_init = 0
current_last_index = 0
# cache_idx
conv_states_input_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_state_indices +
conv_state_init).to(tl.int64)
if USE_PAD_SLOT: # noqa
if conv_state_batch_coord == pad_slot_id:
if conv_states_input_coord == pad_slot_id:
# not processing as this is not the actual sequence
return
@ -726,7 +828,7 @@ def _causal_conv1d_update_kernel(
# STEP 1: READ init_state data
conv_states_base = (conv_state_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(conv_states_input_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim))
mask_w = idx_feats < dim
@ -754,12 +856,12 @@ def _causal_conv1d_update_kernel(
# window manner, at each forward pass, the tokens are shift by 1, so we
# load since idx_tokens + 1.
conv_state_ptrs_source = (
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) +
conv_state_token_offset * stride_conv_state_tok +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) *
stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N]
mask = ((conv_state_batch_coord < num_cache_lines)
mask = ((conv_states_input_coord < num_cache_lines)
& ((idx_tokens + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
@ -778,11 +880,16 @@ def _causal_conv1d_update_kernel(
new_conv_state = tl.where(mask, conv_state, loaded_x)
conv_state_base = (conv_state_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
conv_state_ptrs_target = conv_state_base + (
idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
# Get the state from the initial_state_idx
# cache_idx
conv_states_offset = tl.load(conv_state_indices_ptr +
idx_seq * stride_state_indices +
current_last_index).to(tl.int64)
conv_state_ptrs_target = (
conv_state_ptr +
(conv_states_offset * stride_conv_state_seq) + # Offset from seq
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
idx_tokens * stride_conv_state_tok)[:, None]
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
tl.store(conv_state_ptrs_target, new_conv_state, mask)
@ -923,12 +1030,13 @@ def causal_conv1d_update(
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Union[bool, str, None] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
max_query_len: int = -1,
pad_slot_id: int = PAD_SLOT_ID,
current_last_idx: Optional[torch.Tensor] = None,
initial_state_idx: Optional[torch.Tensor] = None,
validate_data=False,
):
"""
@ -942,15 +1050,14 @@ def causal_conv1d_update(
conv_state: (..., dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state
starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
current_last_idx: (batch,), dtype int32
The pointer into conv_state_indices, where the last cache block to be filled is located.
initial_state_idx: (batch,), dtype int32
The pointer into conv_state_indices, where the cache block containing the initial state is located.
num_accepted_tokens: (batch,), dtype int32
If not None, it indicates the number of accepted tokens for each
sequence in the batch.
@ -963,15 +1070,14 @@ def causal_conv1d_update(
If query_start_loc is not None, this indicates the maximum query
length in the batch.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
if conv_state_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
"""
if validate_data:
assert cache_seqlens is None # not implemented yet - ok for vLLM
assert pad_slot_id is not None
assert x.stride(1) == 1
if isinstance(activation, bool):
@ -1011,7 +1117,6 @@ def causal_conv1d_update(
assert num_cache_lines >= batch
assert weight.stride(1) == 1 # Need this
assert cache_seqlens is None # not needed for vLLM - circular buffer
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
out = x
@ -1050,10 +1155,11 @@ def causal_conv1d_update(
weight,
bias,
conv_state,
cache_seqlens,
conv_state_indices,
num_accepted_tokens,
query_start_loc,
current_last_idx,
initial_state_idx,
out,
# Matrix dimensions
batch,
@ -1081,7 +1187,7 @@ def causal_conv1d_update(
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_VARLEN=query_start_loc is not None,
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
IS_APC_ENABLED=current_last_idx is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen,
USE_PAD_SLOT=pad_slot_id is not None,

View File

@ -52,6 +52,7 @@ def _selective_scan_update_kernel(
z_ptr,
out_ptr,
state_batch_indices_ptr,
dst_state_batch_indices_ptr,
pad_slot_id,
# Matrix dimensions
batch,
@ -107,11 +108,17 @@ def _selective_scan_update_kernel(
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
# is the same as the batch id.
if HAS_STATE_BATCH_INDICES:
dst_state_batch_indices_ptr += pid_b
dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64)
dst_state_ptr = state_ptr + (dst_state_batch_idx * stride_state_batch +
pid_h * stride_state_head)
state_batch_indices_ptr += pid_b
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
state_ptr += (state_batch_idx * stride_state_batch +
pid_h * stride_state_head)
else:
dst_state_ptr = state_ptr + pid_b * stride_state_batch + \
pid_h * stride_state_head
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
@ -131,6 +138,8 @@ def _selective_scan_update_kernel(
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +
offs_n[None, :] * stride_state_dstate)
dst_state_ptrs = dst_state_ptr + (offs_m[:, None] * stride_state_dim +
offs_n[None, :] * stride_state_dstate)
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
if HAS_DT_BIAS:
@ -185,7 +194,7 @@ def _selective_scan_update_kernel(
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= (state_batch_idx != pad_slot_id)
tl.store(state_ptrs, state, mask=mask)
tl.store(dst_state_ptrs, state, mask=mask)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
@ -205,6 +214,7 @@ def selective_state_update(state,
dt_bias=None,
dt_softplus=False,
state_batch_indices=None,
dst_state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID,
out=None):
"""
@ -266,6 +276,11 @@ def selective_state_update(state,
assert dt_bias.shape == (nheads, dim)
if state_batch_indices is not None:
assert state_batch_indices.shape == (batch, )
if dst_state_batch_indices is not None:
assert dst_state_batch_indices.shape == (batch, )
else:
# revert to the default behavior of in-place state updates
dst_state_batch_indices = state_batch_indices
assert out.shape == x.shape
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
@ -292,6 +307,7 @@ def selective_state_update(state,
z,
out,
state_batch_indices,
dst_state_batch_indices,
pad_slot_id,
batch,
nheads,

View File

@ -35,6 +35,7 @@ def _mamba_chunk_scan_combined_fwd(x,
z=None,
dt_bias=None,
initial_states=None,
return_intermediate_states=False,
seq_idx=None,
cu_seqlens=None,
cu_chunk_seqlens=None,
@ -151,28 +152,32 @@ def _mamba_chunk_scan_combined_fwd(x,
initial_states=initial_states,
)
return states[last_chunk_indices]
if return_intermediate_states:
return states
else:
return states[last_chunk_indices]
def mamba_chunk_scan_combined_varlen(
x,
dt,
A,
B,
C,
chunk_size,
cu_seqlens,
cu_chunk_seqlens,
last_chunk_indices,
seq_idx,
out,
D=None,
z=None,
dt_bias=None,
initial_states=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
state_dtype=None,
x,
dt,
A,
B,
C,
chunk_size,
cu_seqlens,
cu_chunk_seqlens,
last_chunk_indices,
seq_idx,
out,
D=None,
z=None,
dt_bias=None,
initial_states=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
return_intermediate_states=False,
state_dtype=None,
):
"""
Argument:
@ -213,6 +218,7 @@ def mamba_chunk_scan_combined_varlen(
z=z,
dt_bias=dt_bias,
initial_states=initial_states,
return_intermediate_states=return_intermediate_states,
seq_idx=seq_idx,
cu_seqlens=cu_seqlens,
cu_chunk_seqlens=cu_chunk_seqlens,

View File

@ -453,12 +453,8 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"Bamba currently does not support prefix caching"
self.quant_config = vllm_config.quant_config
super().__init__()

View File

@ -292,10 +292,33 @@ class MambaModelConfig(VerifyAndUpdateConfig):
cache_config = vllm_config.cache_config
compilation_config = vllm_config.compilation_config
# TODO(tdoublep): remove once prefix caching is enabled
cache_config.enable_prefix_caching = False
logger.info("Hybrid or mamba-based model detected: disabling prefix "
"caching since it is not yet supported.")
# Set mamba block size to max_model_len (this may get
# override by prefix caching logic later)
cache_config.mamba_block_size = model_config.max_model_len
# TODO(@tdoublep) find a better way to do this than whitelist
MAMBA2_MODELS = [
"BambaForCausalLM",
"FalconH1ForCausalLM",
"GraniteMoeHybridForCausalLM",
"Mamba2ForCausalLM",
"NemotronHForCausalLM",
"Zamba2ForCausalLM",
]
if cache_config.enable_prefix_caching:
if model_config.architecture in MAMBA2_MODELS:
logger.info("Warning: Prefix caching is currently enabled. "
"Its support for Mamba2 layers is experimental. "
"Please report any issues you may observe.")
else:
logger.info("Hybrid or mamba-based model detected without "
"support for prefix caching: disabling.")
cache_config.enable_prefix_caching = False
# TODO(tdoublep): remove once cascade attention is supported
logger.info("Disabling cascade attention since it is not supported "
"for hybrid models.")
model_config.disable_cascade_attn = True
# TODO(tdoublep): remove as full cuda graph support is added
FCG_NOT_SUPPORTED_MODELS = [
@ -360,12 +383,38 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
block_size=model_config.max_model_len,
).page_size_bytes
# some attention backends (e.g. FA) only support setting
# block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead).
attn_block_size = 16 * cdiv(mamba_page_size,
16 * attn_page_size_1_token)
if cache_config.enable_prefix_caching:
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
# mamba SSD kernel uses a chunk_size, e.g. 256
# Align the block to the kernel: use lowest multiple of chunk_size
# of attention tokens that would fit mamba_page_size:
# e.g. for mamba page size = 788kB
# attn_1_token = 2kB -> fits ~394 tokens
# then round up to a mulitple of 256 -> 512 tokens
# End result:
# attn_block_size = 512
# mamba_block_size = 512 (aligned to a multiple of chunk_size)
# TODO(tdoublep): this constraint can be relaxed fairly
# easily by changing the way we layout chunks in the
# mamba2 kernels.
chunk_size = model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = \
cdiv(mamba_page_size, attn_page_size_1_token)
attn_block_size = chunk_size * \
cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size
else:
# Without prefix caching, select minimum valid attention block size
# to minimize mamba state padding
# some attention backends (e.g. FA) only support setting
# block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead).
attn_block_size = 16 * cdiv(mamba_page_size,
16 * attn_page_size_1_token)
# override attention block size if either (a) the
# user has not set it or (b) the user has set it

View File

@ -540,11 +540,8 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert (not cache_config.enable_prefix_caching
), "FalconH1 currently does not support prefix caching"
self.quant_config = vllm_config.quant_config

View File

@ -549,13 +549,8 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
if cache_config.enable_prefix_caching:
raise RuntimeError(
"GraniteMoeHybrid currently does not support prefix caching")
self.quant_config = vllm_config.quant_config
self.config = config
self.scheduler_config = scheduler_config

View File

@ -222,11 +222,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"Mamba does not support prefix caching"
super().__init__()
self.config = config

View File

@ -505,11 +505,8 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"NemotronH currently does not support prefix caching"
self.quant_config = vllm_config.quant_config

View File

@ -868,11 +868,8 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
(not supported by Mamba)
"""
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"Mamba does not support prefix caching"
super().__init__()
self.config = config

View File

@ -122,6 +122,11 @@ class Mamba2AttentionMetadata:
last_chunk_indices_p: Optional[torch.Tensor]
state_indices_tensor: torch.Tensor # shape: [batch,]
current_last_idx: torch.Tensor
current_first_idx_p: torch.Tensor
last_state_idx: torch.Tensor
context_lens_p: torch.Tensor
last_computed_offset_p: torch.Tensor
# The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None
@ -138,6 +143,24 @@ class Mamba2AttentionMetadataBuilder(
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
assert self.chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models")
if self.vllm_config.cache_config.enable_prefix_caching:
self.state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs,
cdiv(vllm_config.model_config.max_model_len,
kv_cache_spec.block_size)),
dtype=torch.int32,
device=device,
)
self.current_last_idx = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.int32,
device=device,
)
self.last_state_idx = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.int32,
device=device,
)
def build(self,
common_prefix_len: int,
@ -158,7 +181,45 @@ class Mamba2AttentionMetadataBuilder(
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
context_lens, context_lens_p = None, None
current_first_idx, current_first_idx_p = None, None
last_computed_offset, last_computed_offset_p = None, None
if self.vllm_config.cache_config.enable_prefix_caching:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
# Additional cache-related varaiables:
mamba_block_size = self.kv_cache_spec.block_size
seq_lens_pending = (
torch.roll(common_attn_metadata.query_start_loc, -1, -1) -
common_attn_metadata.query_start_loc)[:-1]
context_lens = common_attn_metadata.seq_lens - \
seq_lens_pending
last_computed_offset = \
context_lens % mamba_block_size
# Indices: last_computed <= current_first <= current_last
# Cases:
# last_computed == current_first if last state was partially
# computed and needs to be updated
# current_first == current_last if no block crossing occurs, and
# only one state will be stored
# 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]:
current_last_idx = cdiv(context_lens + seq_lens_pending,
mamba_block_size) - 1
current_first_idx = cdiv(context_lens + 1, mamba_block_size) - 1
last_state_idx = cdiv(context_lens, mamba_block_size) - 1
# -1 in case it's non-computed and causes later issues with indexing
last_state_idx = \
last_state_idx.clamp(min=0)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:,
0]
# Additional cache-related varaiables:
current_last_idx = None
last_state_idx = None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
@ -178,6 +239,16 @@ class Mamba2AttentionMetadataBuilder(
query_start_loc_p = common_attn_metadata.query_start_loc[
-num_prefills - 1:] - num_decode_tokens
if self.vllm_config.cache_config.enable_prefix_caching:
assert context_lens is not None
context_lens_p = context_lens[num_reqs - num_prefills:num_reqs]
assert last_computed_offset is not None
last_computed_offset_p = last_computed_offset[
num_reqs - num_prefills:num_reqs]
assert current_first_idx is not None
current_first_idx_p = current_first_idx[num_reqs -
num_prefills:num_reqs]
num_computed_tokens_p = \
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills:num_reqs]
@ -252,6 +323,19 @@ class Mamba2AttentionMetadataBuilder(
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.enable_prefix_caching:
self.current_last_idx[:num_decodes].copy_(current_last_idx,
non_blocking=True)
current_last_idx = \
self.current_last_idx[:num_input_tokens]
current_last_idx[num_decodes:] = 0
self.last_state_idx[:num_decodes].copy_(last_state_idx,
non_blocking=True)
last_state_idx = \
self.last_state_idx[:num_input_tokens]
last_state_idx[num_decodes:] = 0
attn_metadata = Mamba2AttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
@ -269,5 +353,10 @@ class Mamba2AttentionMetadataBuilder(
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
current_last_idx=current_last_idx,
current_first_idx_p=current_first_idx_p,
last_state_idx=last_state_idx,
context_lens_p=context_lens_p,
last_computed_offset_p=last_computed_offset_p,
)
return attn_metadata

View File

@ -546,20 +546,38 @@ class MambaManager(SingleTypeKVCacheManager):
kv_cache_spec,
MambaSpec), ("MambaManager can only be used for mamba groups")
assert dcp_world_size == 1, "DCP not support mamba now."
# Prefix caching is not supported for mamba now. Always return empty
# list.
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids)))
max_num_blocks = max_length // kv_cache_spec.block_size
# Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1):
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids):
for computed, cached in zip(computed_blocks, cached_block):
# the hit length logic later assumes:
# hit_length = len(hit_blocks_other_attn[0])
# * self.other_block_size
# so we insert dummy blocks at the beginning:
if i > 0:
computed.extend([block_pool.null_block] * i)
computed.append(cached)
break # we just need the last match - early stopping
return computed_blocks
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
# Each request will always have 1 block at this moment, so no need to
# remove blocks.
# Here unused blocks may be freed up for running requests.
# TODO(@s3woz) Free up all blocks that aren't needed by Mamba2
# (for which find_longest_cache_hit returns block_pool.null_block)
pass
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
"""
cascade attention is not supported by mamba
"""
return 0
def get_num_blocks_to_allocate(

View File

@ -233,10 +233,8 @@ class MambaSpec(KVCacheSpec):
return page_size
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
# We allocate 1 block for each request now, so max_memory_usage_bytes is
# the same as page_size_bytes.
# Need to update this when supporting prefix caching.
return self.page_size_bytes
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@dataclass(frozen=True)

View File

@ -4240,21 +4240,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
not in ["qwen3_next"]):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet.")
if self.vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError(
"Prefix caching is not supported for Mamba yet.")
max_model_len = self.vllm_config.model_config.max_model_len
mamba_block_size = self.vllm_config.cache_config.mamba_block_size
page_size_padded = (
self.vllm_config.cache_config.mamba_page_size_padded)
# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
for layer_name, mamba_module in mamba_layers.items():
kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(),
dtypes=mamba_module.get_state_dtype(),
block_size=max_model_len,
block_size=mamba_block_size,
page_size_padded=page_size_padded,
mamba_type=mamba_module.mamba_type,
num_speculative_blocks=(