mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 12:25:01 +08:00
[V1] [Hybrid] Mamba2 Automatic Prefix Caching (#25752)
Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com> Signed-off-by: Thomas Ortner <boh@zurich.ibm.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Thomas Ortner <boh@zurich.ibm.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
9705fba7b7
commit
ea507c3a93
@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.registry import HF_EXAMPLE_MODELS
|
||||
@ -8,7 +10,7 @@ from tests.utils import multi_gpu_test
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from ...utils import check_logprobs_close
|
||||
from ...utils import check_logprobs_close, check_outputs_equal
|
||||
|
||||
# Mark all tests as hybrid
|
||||
pytestmark = pytest.mark.hybrid_model
|
||||
@ -332,3 +334,413 @@ def test_fp32_cache_state(
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
# Helper functions for the APC tests
|
||||
def _get_vllm_runner_params(model, max_model_len, tensor_parallel_size=1):
|
||||
return {
|
||||
'model_name': model,
|
||||
'enable_prefix_caching': False,
|
||||
'max_model_len': max_model_len,
|
||||
'tensor_parallel_size': tensor_parallel_size,
|
||||
'gpu_memory_utilization': 0.4
|
||||
}
|
||||
|
||||
|
||||
def _get_vLLM_output(vllm_runner,
|
||||
kwargs,
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
num_repetitions=1,
|
||||
vllm_model=None):
|
||||
outs = []
|
||||
if vllm_model is None:
|
||||
vllm_model = vllm_runner(**kwargs)
|
||||
for _ in range(num_repetitions):
|
||||
if num_logprobs < 0:
|
||||
vllm_output = vllm_model.generate_greedy(prompts, max_tokens)
|
||||
else:
|
||||
vllm_output = vllm_model.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs)
|
||||
outs.append(vllm_output)
|
||||
|
||||
return outs, vllm_model
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("n_repetitions", [2])
|
||||
# If num_logprobs is set to -1, then the stringent version
|
||||
# of the test is executed using `check_outputs_equal`
|
||||
# instead of `check_logprobs_close`
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
def test_apc_single_prompt(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
monkeypatch,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
n_repetitions: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
compare_operator: Callable = check_logprobs_close \
|
||||
if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts.
|
||||
generated_prompts = [MULTIPLE * example_prompts[0]]
|
||||
|
||||
max_model_len = max(
|
||||
len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
|
||||
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts, max_tokens,
|
||||
num_logprobs)
|
||||
|
||||
vllm_runner_kwargs['enable_prefix_caching'] = True
|
||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts, max_tokens,
|
||||
num_logprobs, n_repetitions)
|
||||
|
||||
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
|
||||
# In the first repetition, the caches are filled
|
||||
# In the second repetition, these caches are reused
|
||||
|
||||
compare_operator(
|
||||
outputs_0_lst=vllm_outputs_no_cache[0],
|
||||
outputs_1_lst=vllm_outputs_cache_itn,
|
||||
name_0="vllm_no_cache",
|
||||
name_1=f"vllm_cache_it_{r_idx + 1}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("n_repetitions", [2])
|
||||
# If num_logprobs is set to -1, then the stringent version
|
||||
# of the test is executed using `check_outputs_equal`
|
||||
# instead of `check_logprobs_close`
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
def test_apc_single_prompt_block_align_alignment(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
monkeypatch,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
n_repetitions: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
compare_operator: Callable = check_logprobs_close \
|
||||
if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts. This custom prompt is used, as it causes the most issues
|
||||
generated_prompts = ["The president of the United States is " * MULTIPLE]
|
||||
|
||||
max_model_len = max(
|
||||
len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
|
||||
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
|
||||
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts, max_tokens,
|
||||
num_logprobs)
|
||||
|
||||
vllm_runner_kwargs['enable_prefix_caching'] = True
|
||||
with vllm_runner(**vllm_runner_kwargs) as vllm_model:
|
||||
# Retrieve the default mamba state block size
|
||||
mamba_block_size = vllm_model.llm.llm_engine.cache_config. \
|
||||
mamba_block_size
|
||||
|
||||
# In case the hybrid model does not have the
|
||||
# "mamba_block_size" assume a fixed constant
|
||||
if mamba_block_size is None:
|
||||
mamba_block_size = 512
|
||||
|
||||
mamba_block_size_multiplier = 10
|
||||
for offsets in [
|
||||
-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3
|
||||
]:
|
||||
|
||||
vllm_runner_kwargs[
|
||||
'max_num_batched_tokens'] = mamba_block_size_multiplier * \
|
||||
mamba_block_size - offsets
|
||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts,
|
||||
max_tokens, num_logprobs,
|
||||
n_repetitions)
|
||||
|
||||
# Check alignment of the output logits when using APC
|
||||
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
|
||||
# In the first repetition, the caches are filled
|
||||
# In the second repetition, these caches are reused
|
||||
|
||||
compare_operator(
|
||||
outputs_0_lst=vllm_outputs_no_cache[0],
|
||||
outputs_1_lst=vllm_outputs_cache_itn,
|
||||
name_0="vllm_no_cache",
|
||||
name_1=f"vllm_cache_it_{r_idx + 1}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("n_repetitions", [2])
|
||||
# If num_logprobs is set to -1, then the stringent version
|
||||
# of the test is executed using `check_outputs_equal`
|
||||
# instead of `check_logprobs_close`
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
def test_apc_multiple_prompts_all_cached_outputs(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
monkeypatch,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
n_repetitions: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
compare_operator: Callable = check_logprobs_close \
|
||||
if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts.
|
||||
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
|
||||
|
||||
max_model_len = max(
|
||||
len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
|
||||
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
|
||||
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts, max_tokens,
|
||||
num_logprobs)
|
||||
|
||||
vllm_runner_kwargs['enable_prefix_caching'] = True
|
||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts, max_tokens,
|
||||
num_logprobs, n_repetitions)
|
||||
|
||||
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
|
||||
# In the first repetition, the caches are filled
|
||||
# In the second repetition, these caches are reused
|
||||
|
||||
compare_operator(
|
||||
outputs_0_lst=vllm_outputs_no_cache[0],
|
||||
outputs_1_lst=vllm_outputs_cache_itn,
|
||||
name_0="vllm_no_cache",
|
||||
name_1=f"vllm_cache_it_{r_idx + 1}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("n_repetitions", [2])
|
||||
# If num_logprobs is set to -1, then the stringent version
|
||||
# of the test is executed using `check_outputs_equal`
|
||||
# instead of `check_logprobs_close`
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
def test_apc_multiple_prompts_block_align_alignment(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
monkeypatch,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
n_repetitions: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
compare_operator: Callable = check_logprobs_close \
|
||||
if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts. This custom prompt is used, as it causes the most issues
|
||||
prompt_text = "The president of the United States is "
|
||||
prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31]
|
||||
generated_prompts = [
|
||||
prompt_text[offset:] * MULTIPLE for offset in prompt_offsets
|
||||
]
|
||||
|
||||
max_model_len = max(
|
||||
len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(model, max_model_len,
|
||||
tensor_parallel_size)
|
||||
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
|
||||
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts, max_tokens,
|
||||
num_logprobs)
|
||||
|
||||
vllm_runner_kwargs['enable_prefix_caching'] = True
|
||||
with vllm_runner(**vllm_runner_kwargs) as vllm_model:
|
||||
# Retrieve the default mamba state block size
|
||||
mamba_block_size = vllm_model.llm.llm_engine.cache_config. \
|
||||
mamba_block_size
|
||||
|
||||
# In case the hybrid model does not have the
|
||||
# "mamba_block_size" assume a fixed constant
|
||||
if mamba_block_size is None:
|
||||
mamba_block_size = 512
|
||||
|
||||
mamba_block_size_multiplier = 10
|
||||
for offsets in [
|
||||
-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3
|
||||
]:
|
||||
|
||||
vllm_runner_kwargs[
|
||||
'max_num_batched_tokens'] = mamba_block_size_multiplier * \
|
||||
mamba_block_size - offsets
|
||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts,
|
||||
max_tokens, num_logprobs,
|
||||
n_repetitions)
|
||||
|
||||
# Check alignment of the output logits when using APC
|
||||
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
|
||||
# In the first repetition, the caches are filled
|
||||
# In the second repetition, these caches are reused
|
||||
|
||||
compare_operator(
|
||||
outputs_0_lst=vllm_outputs_no_cache[0],
|
||||
outputs_1_lst=vllm_outputs_cache_itn,
|
||||
name_0="vllm_no_cache",
|
||||
name_1=f"vllm_cache_it_{r_idx + 1}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("n_repetitions", [2])
|
||||
# If num_logprobs is set to -1, then the stringent version
|
||||
# of the test is executed using `check_outputs_equal`
|
||||
# instead of `check_logprobs_close`
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
def test_apc_multiple_prompts_partial_cached_outputs(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
monkeypatch,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
n_repetitions: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
compare_operator: Callable = check_logprobs_close \
|
||||
if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts.
|
||||
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
|
||||
|
||||
max_model_len = max(
|
||||
len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
|
||||
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
|
||||
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts, max_tokens,
|
||||
num_logprobs)
|
||||
|
||||
# Cache only part of all the prompts
|
||||
vllm_runner_kwargs['enable_prefix_caching'] = True
|
||||
vllm_outputs_partial_cache, vllm_model = _get_vLLM_output(
|
||||
vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens,
|
||||
num_logprobs)
|
||||
|
||||
compare_operator(
|
||||
outputs_0_lst=vllm_outputs_no_cache[0][:3],
|
||||
outputs_1_lst=vllm_outputs_partial_cache[0],
|
||||
name_0="vllm_no_cache",
|
||||
name_1="vllm_partial_cache",
|
||||
)
|
||||
|
||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
n_repetitions,
|
||||
vllm_model=vllm_model)
|
||||
|
||||
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
|
||||
# In the first repetition, the caches are filled
|
||||
# In the second repetition, these caches are reused
|
||||
|
||||
compare_operator(
|
||||
outputs_0_lst=vllm_outputs_no_cache[0],
|
||||
outputs_1_lst=vllm_outputs_cache_itn,
|
||||
name_0="vllm_no_cache",
|
||||
name_1=f"vllm_cache_it_{r_idx + 1}",
|
||||
)
|
||||
|
||||
@ -92,7 +92,8 @@ class CacheConfig:
|
||||
mamba_page_size_padded: Optional[int] = None
|
||||
""" Optional override for mamba page size; used by hybrid mamba/attention
|
||||
models to ensure exact alignment with attention page size."""
|
||||
|
||||
mamba_block_size: Optional[int] = None
|
||||
"""Size of a contiguous cache block in number of tokens for mamba cache."""
|
||||
mamba_cache_dtype: MambaDType = "auto"
|
||||
"""The data type to use for the Mamba cache (both the conv as well as the
|
||||
ssm state). If set to 'auto', the data type will be inferred from the model
|
||||
|
||||
@ -1563,7 +1563,12 @@ class EngineArgs:
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = True
|
||||
# Disable prefix caching default for hybrid models
|
||||
# since the feature is still experimental.
|
||||
if model_config.is_hybrid:
|
||||
self.enable_prefix_caching = False
|
||||
else:
|
||||
self.enable_prefix_caching = True
|
||||
else:
|
||||
|
||||
pooling_type = model_config.pooler_config.pooling_type
|
||||
|
||||
@ -489,6 +489,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# stay the same and reused for all mamba layers in the same iteration
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
|
||||
assert self.cache_config is not None
|
||||
mamba_block_size = self.cache_config.mamba_block_size
|
||||
prefix_caching_enabled = self.cache_config.enable_prefix_caching
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
@ -573,6 +576,25 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if prefix_caching_enabled:
|
||||
# If prefix caching is enabled, retrieve the relevant variables
|
||||
# for prefill and decode
|
||||
last_state_idx_d, last_state_idx_p = torch.split(
|
||||
attn_metadata.last_state_idx, [num_decodes, num_prefills],
|
||||
dim=0)
|
||||
current_last_idx_d, current_last_idx_p = torch.split(
|
||||
attn_metadata.current_last_idx, [num_decodes, num_prefills],
|
||||
dim=0)
|
||||
# Prefill-only variables:
|
||||
current_first_idx_p = attn_metadata.current_first_idx_p
|
||||
context_lens_p = attn_metadata.context_lens_p
|
||||
last_computed_offset_p = attn_metadata.last_computed_offset_p
|
||||
else:
|
||||
last_state_idx_d, last_state_idx_p = None, None
|
||||
current_last_idx_d, current_last_idx_p = None, None
|
||||
current_first_idx_p = None
|
||||
context_lens_p = None
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
preallocated_ssm_out = torch.empty(
|
||||
@ -592,8 +614,17 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# Process prefill requests
|
||||
if has_prefill:
|
||||
# 2. Convolution sequence transformation
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "state_indices_tensor"
|
||||
# - It will read the initial states for every sequence,
|
||||
# that has "has_initial_states_p" == True,
|
||||
# from "cache_indices", using "state_indices_tensor_p".
|
||||
# - It updates the "conv_state" cache in positions pointed
|
||||
# to by "state_indices_tensor_p".
|
||||
# In particular, it will always write the state at the
|
||||
# sequence end.
|
||||
# In addition, "current_first_idx_p" and "current_last_idx_p"
|
||||
# are provided (which are pointers into
|
||||
# "state_indices_tensor_p"), it will write additional cache
|
||||
# states aligned at "block_size_to_align".
|
||||
x = hidden_states_B_C_p.transpose(
|
||||
0, 1) # this is the form that causal-conv see
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
@ -604,6 +635,11 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
current_first_idx=current_first_idx_p,
|
||||
current_last_idx=current_last_idx_p,
|
||||
initial_state_idx=last_state_idx_p,
|
||||
context_lens=context_lens_p,
|
||||
block_size_to_align=mamba_block_size,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
@ -614,9 +650,13 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
if (has_initial_states_p is not None and prep_initial_states):
|
||||
kernel_ssm_indices = state_indices_tensor_p
|
||||
if prefix_caching_enabled:
|
||||
kernel_ssm_indices = state_indices_tensor_p.gather(
|
||||
1, last_state_idx_p.unsqueeze(1)).squeeze(1)
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
ssm_state[kernel_ssm_indices], 0)
|
||||
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
varlen_states = mamba_chunk_scan_combined_varlen(
|
||||
@ -638,18 +678,71 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
cu_chunk_seqlens=cu_chunk_seqlen_p,
|
||||
last_chunk_indices=last_chunk_indices_p,
|
||||
initial_states=initial_states,
|
||||
return_intermediate_states=prefix_caching_enabled,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
|
||||
self.head_dim),
|
||||
state_dtype=ssm_state.dtype)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||
ssm_state[state_indices_tensor_p] = varlen_states
|
||||
if prefix_caching_enabled:
|
||||
# Save states for sequences with more than just the final state:
|
||||
n_blocks_to_fill = current_last_idx_p - current_first_idx_p
|
||||
for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1):
|
||||
cache_blocks_to_fill = state_indices_tensor_p[
|
||||
seq_idx, current_first_idx_p[seq_idx]:
|
||||
current_first_idx_p[seq_idx] +
|
||||
n_blocks_to_fill[seq_idx]]
|
||||
# chunks = [0 1 2 3 4 5 6 ...]
|
||||
# First aligned chunk would typically be:
|
||||
# mamba_block_size = 1024, chunk_size = 256
|
||||
# 1024 // 256 - 1 --> chunks[3]
|
||||
# But when last chunk wasn't block aligned:
|
||||
# - last_computed_offset_p[seq_idx] // chunk_size
|
||||
# e.g. 1000 // 256 -> 3 completed --> store chunk[0]
|
||||
# e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
|
||||
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
|
||||
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
|
||||
chunk_stride = mamba_block_size // chunk_size
|
||||
first_aligned_chunk = \
|
||||
torch.concat([torch.zeros(1, \
|
||||
dtype=last_chunk_indices_p.dtype, \
|
||||
device=last_chunk_indices_p.device), \
|
||||
last_chunk_indices_p + 1])[seq_idx] \
|
||||
+ chunk_stride - 1 \
|
||||
- last_computed_offset_p[seq_idx] // chunk_size
|
||||
from_where = varlen_states[
|
||||
first_aligned_chunk:first_aligned_chunk +
|
||||
n_blocks_to_fill[seq_idx] * chunk_stride:chunk_stride]
|
||||
ssm_state[cache_blocks_to_fill] = from_where
|
||||
|
||||
#For all seqs, store the last state (Note: might be partial):
|
||||
ssm_state[state_indices_tensor_p.gather(1,
|
||||
current_last_idx_p.unsqueeze(1)).squeeze(1)] = \
|
||||
varlen_states[last_chunk_indices_p]
|
||||
else:
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate)
|
||||
# tensor
|
||||
ssm_state[state_indices_tensor_p] = varlen_states
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
if prefix_caching_enabled:
|
||||
state_indices_tensor_d_input = \
|
||||
state_indices_tensor_d.gather(1,
|
||||
last_state_idx_d.unsqueeze(1)).squeeze(1)
|
||||
state_indices_tensor_d_output = \
|
||||
state_indices_tensor_d.gather(1,
|
||||
current_last_idx_d.unsqueeze(1)).squeeze(1)
|
||||
#Note:
|
||||
# for decode always: current_first_idx_d == current_last_idx_d
|
||||
# at block boundaries: current_first_idx_d > last_state_idx_d
|
||||
else:
|
||||
# Without caching, read and write in-place to the same blocks:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d
|
||||
state_indices_tensor_d_output = state_indices_tensor_d
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
hidden_states_B_C_d = causal_conv1d_update(
|
||||
hidden_states_B_C_d,
|
||||
@ -657,7 +750,10 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor_d)
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
current_last_idx=current_last_idx_d,
|
||||
initial_state_idx=last_state_idx_d,
|
||||
)
|
||||
|
||||
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C_d)
|
||||
@ -689,7 +785,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
z=None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
state_batch_indices=state_indices_tensor_d_input,
|
||||
dst_state_batch_indices=state_indices_tensor_d_output,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1,
|
||||
self.head_dim),
|
||||
)
|
||||
|
||||
@ -20,19 +20,23 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
w_ptr, # (dim, width)
|
||||
bias_ptr,
|
||||
initial_states_ptr, # conv_states_ptr
|
||||
cache_indices_ptr, # conv_state_indices_ptr
|
||||
cache_indices_ptr, # (batch, n_blocks + padding) The second dimension contains
|
||||
# the block indices relevant for each sequence
|
||||
# plus potential 0-padding at the beginning and at the end
|
||||
has_initial_states_ptr,
|
||||
query_start_loc_ptr,
|
||||
batch_ptr,
|
||||
token_chunk_offset_ptr,
|
||||
current_first_idx, # (batch,)
|
||||
current_last_idx, # (batch,)
|
||||
initial_state_idx, # (batch,)
|
||||
context_lens, # (batch,)
|
||||
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
|
||||
# Matrix dimensions
|
||||
batch: tl.int32, # actually padded_batch
|
||||
dim: tl.constexpr,
|
||||
seqlen: tl.int32, # cu_seqlen
|
||||
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
|
||||
# Strides
|
||||
stride_x_seq: tl.constexpr, # stride to get to next sequence,
|
||||
stride_x_dim: tl.constexpr, # stride to get to next feature-value,
|
||||
stride_x_token: tl.
|
||||
constexpr, # stride to get to next token (same feature-index, same sequence-index)
|
||||
@ -42,18 +46,16 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
stride_istate_dim: tl.constexpr,
|
||||
stride_istate_token: tl.constexpr,
|
||||
stride_cache_indices: tl.constexpr,
|
||||
stride_o_seq: tl.constexpr,
|
||||
stride_o_dim: tl.constexpr,
|
||||
stride_o_token: tl.constexpr,
|
||||
stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M
|
||||
# others
|
||||
pad_slot_id: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_BIAS: tl.constexpr,
|
||||
KERNEL_WIDTH: tl.constexpr,
|
||||
SILU_ACTIVATION: tl.constexpr,
|
||||
HAS_INITIAL_STATES: tl.constexpr,
|
||||
HAS_CACHE: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_APC_ENABLED: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
@ -84,26 +86,57 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
# find the actual sequence length
|
||||
seqlen = sequence_end_index - sequence_start_index
|
||||
|
||||
B_size: tl.constexpr = (stride_block_m * BLOCK_M)
|
||||
|
||||
if IS_APC_ENABLED:
|
||||
# Handle the case if prefix caching is enabled.
|
||||
# In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr"
|
||||
|
||||
# Get the length of the completed sequence so far and compute the offset.
|
||||
current_first_index = tl.load(current_first_idx + idx_seq)
|
||||
current_last_index = tl.load(current_last_idx + idx_seq)
|
||||
sequence_completed_index = tl.load(context_lens + idx_seq)
|
||||
|
||||
# Compute the offset where the first stride_block_m-aligned first full block is
|
||||
# Value in "token-space"
|
||||
sequence_completed_offset_token = sequence_completed_index % B_size
|
||||
seq_completed_offset = B_size - sequence_completed_offset_token
|
||||
seq_end_offset = (seqlen - seq_completed_offset) % B_size
|
||||
last_full_block_token_index = sequence_end_index - seq_end_offset
|
||||
# If the sequence without the sequence_offset_index is stride_cache_chunk-aligned, then the last full chunk is the second-to-last one
|
||||
if seq_end_offset == 0:
|
||||
last_full_block_token_index = last_full_block_token_index - B_size
|
||||
|
||||
# Get the number of blocks to be filled for the current sequence
|
||||
# If n_block_to_fill = 0, then only the state at the sequence end is stored
|
||||
n_block_to_fill = current_last_index - current_first_index
|
||||
|
||||
# Get the index of the init block
|
||||
conv_state_init_index = tl.load(initial_state_idx + idx_seq)
|
||||
else:
|
||||
n_block_to_fill = 0
|
||||
current_last_index = 0
|
||||
conv_state_init_index = 0
|
||||
current_first_index = 0
|
||||
last_full_block_token_index = 0
|
||||
|
||||
token_offset = BLOCK_M * chunk_offset
|
||||
segment_len = min(BLOCK_M, seqlen - token_offset)
|
||||
|
||||
# base of the sequence
|
||||
x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,]
|
||||
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
# cache_idx
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_cache_indices).to(
|
||||
tl.int64)
|
||||
else:
|
||||
# cache_idx
|
||||
conv_state_batch_coord = idx_seq
|
||||
# cache_idx
|
||||
conv_states_input_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_cache_indices +
|
||||
conv_state_init_index).to(tl.int64)
|
||||
|
||||
if USE_PAD_SLOT: # noqa
|
||||
if conv_state_batch_coord == pad_slot_id:
|
||||
if conv_states_input_coord == pad_slot_id:
|
||||
# not processing as this is not the actual sequence
|
||||
return
|
||||
conv_states_base = (conv_states_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(conv_states_input_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
|
||||
|
||||
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
|
||||
@ -113,10 +146,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
# 2. update conv_state with new data [only by the Triton program handles chunk_offset=0]
|
||||
if chunk_offset == 0:
|
||||
# read from conv_states
|
||||
load_init_state = False
|
||||
if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES
|
||||
load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(
|
||||
tl.int1)
|
||||
load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1)
|
||||
if load_init_state:
|
||||
# load from conv_states
|
||||
prior_tokens = conv_states_base + (state_len -
|
||||
@ -175,15 +205,23 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
(idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
||||
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
|
||||
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
conv_states_ptrs_target = conv_states_base[None, :] + (
|
||||
idx_tokens_conv * stride_conv_state_tok)[:, None]
|
||||
|
||||
# Compute the offset where the last block should be written in the conv_states
|
||||
conv_states_output_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_cache_indices +
|
||||
current_last_index).to(tl.int64)
|
||||
|
||||
conv_states_ptrs_target = (
|
||||
conv_states_ptr + (conv_states_output_coord *
|
||||
stride_conv_state_seq) + # Offset from seq
|
||||
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
|
||||
idx_tokens_conv * stride_conv_state_tok)[:, None]
|
||||
|
||||
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
|
||||
< dim)[None, :]
|
||||
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
|
||||
tl.store(conv_states_ptrs_target, new_conv_state, mask)
|
||||
tl.store(conv_states_ptrs_target, loaded_x, mask)
|
||||
|
||||
else:
|
||||
if load_init_state:
|
||||
@ -192,12 +230,12 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
|
||||
conv_states_ptrs_source = (
|
||||
conv_states_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(conv_states_input_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||
((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:,
|
||||
None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = ((conv_state_batch_coord < num_cache_lines)
|
||||
mask = ((conv_states_input_coord < num_cache_lines)
|
||||
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
|
||||
& (idx_feats < dim)[None, :])
|
||||
conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
|
||||
@ -280,6 +318,45 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N]
|
||||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||
|
||||
# Store intermediate states aligned with stride_block_m
|
||||
# The additional states are cached starting from the last stride_block_m.
|
||||
# For example:
|
||||
# If n_block_to_fill = 0, then only the state at the sequence end is cached and the process below is not involved.
|
||||
# If n_block_to_fill > 0, then the states at the sequence end and at the n_block_to_fill-last
|
||||
# stride_block_m are cached.
|
||||
# For example chunk_offset = n_block_to_fill stores the state at last_full_block
|
||||
if (chunk_offset - 1) < n_block_to_fill:
|
||||
# Store the states at the chunk boundaries from the start of the sequence
|
||||
idx_tokens_last = (last_full_block_token_index -
|
||||
(n_block_to_fill - chunk_offset) * B_size -
|
||||
state_len) + tl.arange(
|
||||
0, NP2_STATELEN) # [BLOCK_M]
|
||||
x_ptrs = x_ptr + (idx_tokens_last * stride_x_token)[:, None] + (
|
||||
idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,]
|
||||
|
||||
mask_x = (
|
||||
(idx_tokens_last >= 0)[:, None] & (idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
||||
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
|
||||
# cache_idx
|
||||
conv_states_output_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_cache_indices +
|
||||
current_first_index +
|
||||
(chunk_offset - 1)).to(tl.int64)
|
||||
|
||||
conv_states_ptrs_target = (
|
||||
conv_states_ptr + (conv_states_output_coord *
|
||||
stride_conv_state_seq) + # Offset from seq
|
||||
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
|
||||
idx_tokens_conv * stride_conv_state_tok)[:, None]
|
||||
|
||||
mask = (idx_tokens_conv < state_len)[:, None] & \
|
||||
(idx_feats < dim)[None, :]
|
||||
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
|
||||
tl.store(conv_states_ptrs_target, loaded_x, mask)
|
||||
|
||||
if HAS_BIAS:
|
||||
bias = bias_ptr + idx_feats
|
||||
mask_bias = idx_feats < dim
|
||||
@ -368,6 +445,11 @@ def causal_conv1d_fn(
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
current_first_idx: Optional[torch.Tensor] = None,
|
||||
current_last_idx: Optional[torch.Tensor] = None,
|
||||
initial_state_idx: Optional[torch.Tensor] = None,
|
||||
context_lens: Optional[torch.Tensor] = None,
|
||||
block_size_to_align=0,
|
||||
metadata=None,
|
||||
validate_data=False,
|
||||
):
|
||||
@ -378,7 +460,7 @@ def causal_conv1d_fn(
|
||||
sequences are concatenated from left to right for varlen
|
||||
weight: (dim, width)
|
||||
conv_states: (...,dim,width - 1) itype
|
||||
updated inplace if provided
|
||||
updated inplace if cache_indices are not provided
|
||||
[it use `cache_indices` to get the index to the cache of conv_state for that sequence
|
||||
|
||||
conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True
|
||||
@ -410,7 +492,16 @@ def causal_conv1d_fn(
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
|
||||
current_first_idx: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the first cache block to be filled is located.
|
||||
current_last_idx: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the last cache block to be filled is located.
|
||||
initial_state_idx: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the cache block containing the initial state is located.
|
||||
context_lens: (batch,), dtype int32
|
||||
The number of tokens already completed for each sequence
|
||||
block_size_to_align: int
|
||||
The block size to align the cached states to
|
||||
out: same shape as `x`
|
||||
"""
|
||||
if isinstance(activation, bool) and activation:
|
||||
@ -451,7 +542,6 @@ def causal_conv1d_fn(
|
||||
np2_statelen = triton.next_power_of_2(state_len)
|
||||
|
||||
padded_batch = query_start_loc.size(0) - 1
|
||||
stride_x_seq = 0
|
||||
stride_x_dim = x.stride(0)
|
||||
stride_x_token = x.stride(1)
|
||||
stride_w_dim = weight.stride(0)
|
||||
@ -460,6 +550,7 @@ def causal_conv1d_fn(
|
||||
stride_istate_dim = 0
|
||||
stride_istate_token = 0
|
||||
num_cache_lines = 0
|
||||
BLOCK_M = 8
|
||||
if conv_states is not None:
|
||||
# extensions to support vLLM:
|
||||
# 1. conv_states is used to replaced initial_states
|
||||
@ -475,11 +566,9 @@ def causal_conv1d_fn(
|
||||
stride_istate_token = conv_states.stride(2)
|
||||
assert stride_istate_dim == 1
|
||||
if out.dim() == 2:
|
||||
stride_o_seq = 0
|
||||
stride_o_dim = out.stride(0)
|
||||
stride_o_token = out.stride(1)
|
||||
else:
|
||||
stride_o_seq = out.stride(0)
|
||||
stride_o_dim = out.stride(1)
|
||||
stride_o_token = out.stride(2)
|
||||
stride_cache_indices = cache_indices.stride(
|
||||
@ -502,6 +591,12 @@ def causal_conv1d_fn(
|
||||
assert weight.stride(1) == 1
|
||||
assert (dim, width) == weight.shape
|
||||
assert is_channel_last, "Need to run in channel-last layout"
|
||||
if block_size_to_align is not None and block_size_to_align > 0:
|
||||
assert (
|
||||
block_size_to_align % BLOCK_M
|
||||
) == 0, "The mamba block size needs to be divisible by the BLOCK_M"
|
||||
else:
|
||||
block_size_to_align = BLOCK_M
|
||||
|
||||
if metadata is None:
|
||||
|
||||
@ -584,14 +679,16 @@ def causal_conv1d_fn(
|
||||
query_start_loc,
|
||||
batch_ptr,
|
||||
token_chunk_offset_ptr,
|
||||
current_first_idx,
|
||||
current_last_idx,
|
||||
initial_state_idx,
|
||||
context_lens,
|
||||
out,
|
||||
# Matrix dimensions
|
||||
padded_batch,
|
||||
dim,
|
||||
cu_seqlen,
|
||||
num_cache_lines,
|
||||
# stride
|
||||
stride_x_seq,
|
||||
stride_x_dim,
|
||||
stride_x_token,
|
||||
stride_w_dim,
|
||||
@ -600,22 +697,20 @@ def causal_conv1d_fn(
|
||||
stride_istate_dim,
|
||||
stride_istate_token,
|
||||
stride_cache_indices,
|
||||
stride_o_seq,
|
||||
stride_o_dim,
|
||||
stride_o_token,
|
||||
block_size_to_align // BLOCK_M,
|
||||
# others
|
||||
pad_slot_id,
|
||||
# META
|
||||
HAS_BIAS=bias is not None,
|
||||
KERNEL_WIDTH=width,
|
||||
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||
HAS_INITIAL_STATES=has_initial_state is not None,
|
||||
HAS_CACHE=conv_states is not None,
|
||||
IS_CONTINUOUS_BATCHING=cache_indices is not None,
|
||||
IS_APC_ENABLED=current_last_idx is not None,
|
||||
USE_PAD_SLOT=pad_slot_id is not None,
|
||||
NP2_STATELEN=np2_statelen,
|
||||
#launch_cooperative_grid=True
|
||||
BLOCK_M=8,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=256,
|
||||
num_stages=2,
|
||||
)
|
||||
@ -629,10 +724,11 @@ def _causal_conv1d_update_kernel(
|
||||
w_ptr, # (dim, width)
|
||||
bias_ptr,
|
||||
conv_state_ptr,
|
||||
cache_seqlens_ptr, # circular buffer
|
||||
conv_state_indices_ptr,
|
||||
num_accepted_tokens_ptr,
|
||||
query_start_loc_ptr, # (batch + 1)
|
||||
current_last_idx, # (batch,)
|
||||
initial_state_idx, #(batch,)
|
||||
o_ptr, # (batch, dim, seqlen)
|
||||
# Matrix dimensions
|
||||
batch: int,
|
||||
@ -660,7 +756,7 @@ def _causal_conv1d_update_kernel(
|
||||
KERNEL_WIDTH: tl.constexpr,
|
||||
SILU_ACTIVATION: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_APC_ENABLED: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
@ -674,15 +770,21 @@ def _causal_conv1d_update_kernel(
|
||||
# [BLOCK_N,] elements along the feature-dimension (channel)
|
||||
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
# mask = idx_seq < batch
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_state_indices).to(
|
||||
tl.int64)
|
||||
if IS_APC_ENABLED:
|
||||
# Get the state from the initial_state_idx
|
||||
conv_state_init = tl.load(initial_state_idx + idx_seq)
|
||||
current_last_index = tl.load(current_last_idx + idx_seq)
|
||||
else:
|
||||
conv_state_batch_coord = idx_seq
|
||||
conv_state_init = 0
|
||||
current_last_index = 0
|
||||
|
||||
# cache_idx
|
||||
conv_states_input_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_state_indices +
|
||||
conv_state_init).to(tl.int64)
|
||||
|
||||
if USE_PAD_SLOT: # noqa
|
||||
if conv_state_batch_coord == pad_slot_id:
|
||||
if conv_states_input_coord == pad_slot_id:
|
||||
# not processing as this is not the actual sequence
|
||||
return
|
||||
|
||||
@ -726,7 +828,7 @@ def _causal_conv1d_update_kernel(
|
||||
|
||||
# STEP 1: READ init_state data
|
||||
conv_states_base = (conv_state_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(conv_states_input_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim))
|
||||
mask_w = idx_feats < dim
|
||||
|
||||
@ -754,12 +856,12 @@ def _causal_conv1d_update_kernel(
|
||||
# window manner, at each forward pass, the tokens are shift by 1, so we
|
||||
# load since idx_tokens + 1.
|
||||
conv_state_ptrs_source = (
|
||||
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
|
||||
conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) +
|
||||
conv_state_token_offset * stride_conv_state_tok +
|
||||
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||
((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) *
|
||||
stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N]
|
||||
mask = ((conv_state_batch_coord < num_cache_lines)
|
||||
mask = ((conv_states_input_coord < num_cache_lines)
|
||||
& ((idx_tokens + seqlen) < state_len)[:, None]
|
||||
& (idx_feats < dim)[None, :])
|
||||
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
|
||||
@ -778,11 +880,16 @@ def _causal_conv1d_update_kernel(
|
||||
|
||||
new_conv_state = tl.where(mask, conv_state, loaded_x)
|
||||
|
||||
conv_state_base = (conv_state_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
|
||||
conv_state_ptrs_target = conv_state_base + (
|
||||
idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
|
||||
# Get the state from the initial_state_idx
|
||||
# cache_idx
|
||||
conv_states_offset = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_state_indices +
|
||||
current_last_index).to(tl.int64)
|
||||
conv_state_ptrs_target = (
|
||||
conv_state_ptr +
|
||||
(conv_states_offset * stride_conv_state_seq) + # Offset from seq
|
||||
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
|
||||
idx_tokens * stride_conv_state_tok)[:, None]
|
||||
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
|
||||
tl.store(conv_state_ptrs_target, new_conv_state, mask)
|
||||
|
||||
@ -923,12 +1030,13 @@ def causal_conv1d_update(
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Union[bool, str, None] = None,
|
||||
cache_seqlens: Optional[torch.Tensor] = None,
|
||||
conv_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
max_query_len: int = -1,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
current_last_idx: Optional[torch.Tensor] = None,
|
||||
initial_state_idx: Optional[torch.Tensor] = None,
|
||||
validate_data=False,
|
||||
):
|
||||
"""
|
||||
@ -942,15 +1050,14 @@ def causal_conv1d_update(
|
||||
conv_state: (..., dim, state_len), where state_len >= width - 1
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the conv_state
|
||||
starting at the index
|
||||
@cache_seqlens % state_len.
|
||||
conv_state_indices: (batch,), dtype int32
|
||||
If not None, the conv_state is a larger tensor along the batch dim,
|
||||
and we are selecting the batch coords specified by conv_state_indices.
|
||||
Useful for a continuous batching scenario.
|
||||
current_last_idx: (batch,), dtype int32
|
||||
The pointer into conv_state_indices, where the last cache block to be filled is located.
|
||||
initial_state_idx: (batch,), dtype int32
|
||||
The pointer into conv_state_indices, where the cache block containing the initial state is located.
|
||||
num_accepted_tokens: (batch,), dtype int32
|
||||
If not None, it indicates the number of accepted tokens for each
|
||||
sequence in the batch.
|
||||
@ -963,15 +1070,14 @@ def causal_conv1d_update(
|
||||
If query_start_loc is not None, this indicates the maximum query
|
||||
length in the batch.
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
if conv_state_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
|
||||
"""
|
||||
if validate_data:
|
||||
assert cache_seqlens is None # not implemented yet - ok for vLLM
|
||||
assert pad_slot_id is not None
|
||||
assert x.stride(1) == 1
|
||||
if isinstance(activation, bool):
|
||||
@ -1011,7 +1117,6 @@ def causal_conv1d_update(
|
||||
|
||||
assert num_cache_lines >= batch
|
||||
assert weight.stride(1) == 1 # Need this
|
||||
assert cache_seqlens is None # not needed for vLLM - circular buffer
|
||||
|
||||
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
|
||||
out = x
|
||||
@ -1050,10 +1155,11 @@ def causal_conv1d_update(
|
||||
weight,
|
||||
bias,
|
||||
conv_state,
|
||||
cache_seqlens,
|
||||
conv_state_indices,
|
||||
num_accepted_tokens,
|
||||
query_start_loc,
|
||||
current_last_idx,
|
||||
initial_state_idx,
|
||||
out,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
@ -1081,7 +1187,7 @@ def causal_conv1d_update(
|
||||
KERNEL_WIDTH=width,
|
||||
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||
IS_VARLEN=query_start_loc is not None,
|
||||
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
|
||||
IS_APC_ENABLED=current_last_idx is not None,
|
||||
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
||||
NP2_STATELEN=np2_statelen,
|
||||
USE_PAD_SLOT=pad_slot_id is not None,
|
||||
|
||||
@ -52,6 +52,7 @@ def _selective_scan_update_kernel(
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
state_batch_indices_ptr,
|
||||
dst_state_batch_indices_ptr,
|
||||
pad_slot_id,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
@ -107,11 +108,17 @@ def _selective_scan_update_kernel(
|
||||
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
|
||||
# is the same as the batch id.
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
dst_state_batch_indices_ptr += pid_b
|
||||
dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64)
|
||||
dst_state_ptr = state_ptr + (dst_state_batch_idx * stride_state_batch +
|
||||
pid_h * stride_state_head)
|
||||
state_batch_indices_ptr += pid_b
|
||||
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
|
||||
state_ptr += (state_batch_idx * stride_state_batch +
|
||||
pid_h * stride_state_head)
|
||||
else:
|
||||
dst_state_ptr = state_ptr + pid_b * stride_state_batch + \
|
||||
pid_h * stride_state_head
|
||||
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
|
||||
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
||||
@ -131,6 +138,8 @@ def _selective_scan_update_kernel(
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
||||
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +
|
||||
offs_n[None, :] * stride_state_dstate)
|
||||
dst_state_ptrs = dst_state_ptr + (offs_m[:, None] * stride_state_dim +
|
||||
offs_n[None, :] * stride_state_dstate)
|
||||
x_ptrs = x_ptr + offs_m * stride_x_dim
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
||||
if HAS_DT_BIAS:
|
||||
@ -185,7 +194,7 @@ def _selective_scan_update_kernel(
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= (state_batch_idx != pad_slot_id)
|
||||
tl.store(state_ptrs, state, mask=mask)
|
||||
tl.store(dst_state_ptrs, state, mask=mask)
|
||||
out = tl.sum(state * C[None, :], axis=1)
|
||||
if HAS_D:
|
||||
out += x * D
|
||||
@ -205,6 +214,7 @@ def selective_state_update(state,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
state_batch_indices=None,
|
||||
dst_state_batch_indices=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=None):
|
||||
"""
|
||||
@ -266,6 +276,11 @@ def selective_state_update(state,
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
if state_batch_indices is not None:
|
||||
assert state_batch_indices.shape == (batch, )
|
||||
if dst_state_batch_indices is not None:
|
||||
assert dst_state_batch_indices.shape == (batch, )
|
||||
else:
|
||||
# revert to the default behavior of in-place state updates
|
||||
dst_state_batch_indices = state_batch_indices
|
||||
assert out.shape == x.shape
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
||||
@ -292,6 +307,7 @@ def selective_state_update(state,
|
||||
z,
|
||||
out,
|
||||
state_batch_indices,
|
||||
dst_state_batch_indices,
|
||||
pad_slot_id,
|
||||
batch,
|
||||
nheads,
|
||||
|
||||
@ -35,6 +35,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
return_intermediate_states=False,
|
||||
seq_idx=None,
|
||||
cu_seqlens=None,
|
||||
cu_chunk_seqlens=None,
|
||||
@ -151,28 +152,32 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
initial_states=initial_states,
|
||||
)
|
||||
|
||||
return states[last_chunk_indices]
|
||||
if return_intermediate_states:
|
||||
return states
|
||||
else:
|
||||
return states[last_chunk_indices]
|
||||
|
||||
|
||||
def mamba_chunk_scan_combined_varlen(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens,
|
||||
cu_chunk_seqlens,
|
||||
last_chunk_indices,
|
||||
seq_idx,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens,
|
||||
cu_chunk_seqlens,
|
||||
last_chunk_indices,
|
||||
seq_idx,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
return_intermediate_states=False,
|
||||
state_dtype=None,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
@ -213,6 +218,7 @@ def mamba_chunk_scan_combined_varlen(
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
initial_states=initial_states,
|
||||
return_intermediate_states=return_intermediate_states,
|
||||
seq_idx=seq_idx,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
|
||||
@ -453,12 +453,8 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Bamba currently does not support prefix caching"
|
||||
|
||||
self.quant_config = vllm_config.quant_config
|
||||
|
||||
super().__init__()
|
||||
|
||||
@ -292,10 +292,33 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
cache_config = vllm_config.cache_config
|
||||
compilation_config = vllm_config.compilation_config
|
||||
|
||||
# TODO(tdoublep): remove once prefix caching is enabled
|
||||
cache_config.enable_prefix_caching = False
|
||||
logger.info("Hybrid or mamba-based model detected: disabling prefix "
|
||||
"caching since it is not yet supported.")
|
||||
# Set mamba block size to max_model_len (this may get
|
||||
# override by prefix caching logic later)
|
||||
cache_config.mamba_block_size = model_config.max_model_len
|
||||
|
||||
# TODO(@tdoublep) find a better way to do this than whitelist
|
||||
MAMBA2_MODELS = [
|
||||
"BambaForCausalLM",
|
||||
"FalconH1ForCausalLM",
|
||||
"GraniteMoeHybridForCausalLM",
|
||||
"Mamba2ForCausalLM",
|
||||
"NemotronHForCausalLM",
|
||||
"Zamba2ForCausalLM",
|
||||
]
|
||||
if cache_config.enable_prefix_caching:
|
||||
if model_config.architecture in MAMBA2_MODELS:
|
||||
logger.info("Warning: Prefix caching is currently enabled. "
|
||||
"Its support for Mamba2 layers is experimental. "
|
||||
"Please report any issues you may observe.")
|
||||
else:
|
||||
logger.info("Hybrid or mamba-based model detected without "
|
||||
"support for prefix caching: disabling.")
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
# TODO(tdoublep): remove once cascade attention is supported
|
||||
logger.info("Disabling cascade attention since it is not supported "
|
||||
"for hybrid models.")
|
||||
model_config.disable_cascade_attn = True
|
||||
|
||||
# TODO(tdoublep): remove as full cuda graph support is added
|
||||
FCG_NOT_SUPPORTED_MODELS = [
|
||||
@ -360,12 +383,38 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
block_size=model_config.max_model_len,
|
||||
).page_size_bytes
|
||||
|
||||
# some attention backends (e.g. FA) only support setting
|
||||
# block size to multiple of 16, so let's suggest a value
|
||||
# that would work (note: FA is currently not compatible
|
||||
# with mamba layers, use FlashInfer instead).
|
||||
attn_block_size = 16 * cdiv(mamba_page_size,
|
||||
16 * attn_page_size_1_token)
|
||||
if cache_config.enable_prefix_caching:
|
||||
# With prefix caching, select attention block size to
|
||||
# optimize for mamba kernel performance
|
||||
|
||||
# mamba SSD kernel uses a chunk_size, e.g. 256
|
||||
# Align the block to the kernel: use lowest multiple of chunk_size
|
||||
# of attention tokens that would fit mamba_page_size:
|
||||
# e.g. for mamba page size = 788kB
|
||||
# attn_1_token = 2kB -> fits ~394 tokens
|
||||
# then round up to a mulitple of 256 -> 512 tokens
|
||||
# End result:
|
||||
# attn_block_size = 512
|
||||
# mamba_block_size = 512 (aligned to a multiple of chunk_size)
|
||||
# TODO(tdoublep): this constraint can be relaxed fairly
|
||||
# easily by changing the way we layout chunks in the
|
||||
# mamba2 kernels.
|
||||
chunk_size = model_config.get_mamba_chunk_size()
|
||||
attn_tokens_per_mamba_state = \
|
||||
cdiv(mamba_page_size, attn_page_size_1_token)
|
||||
attn_block_size = chunk_size * \
|
||||
cdiv(attn_tokens_per_mamba_state, chunk_size)
|
||||
cache_config.mamba_block_size = attn_block_size
|
||||
else:
|
||||
# Without prefix caching, select minimum valid attention block size
|
||||
# to minimize mamba state padding
|
||||
|
||||
# some attention backends (e.g. FA) only support setting
|
||||
# block size to multiple of 16, so let's suggest a value
|
||||
# that would work (note: FA is currently not compatible
|
||||
# with mamba layers, use FlashInfer instead).
|
||||
attn_block_size = 16 * cdiv(mamba_page_size,
|
||||
16 * attn_page_size_1_token)
|
||||
|
||||
# override attention block size if either (a) the
|
||||
# user has not set it or (b) the user has set it
|
||||
|
||||
@ -540,11 +540,8 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert (not cache_config.enable_prefix_caching
|
||||
), "FalconH1 currently does not support prefix caching"
|
||||
|
||||
self.quant_config = vllm_config.quant_config
|
||||
|
||||
|
||||
@ -549,13 +549,8 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if cache_config.enable_prefix_caching:
|
||||
raise RuntimeError(
|
||||
"GraniteMoeHybrid currently does not support prefix caching")
|
||||
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.scheduler_config = scheduler_config
|
||||
|
||||
@ -222,11 +222,8 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Mamba does not support prefix caching"
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@ -505,11 +505,8 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"NemotronH currently does not support prefix caching"
|
||||
|
||||
self.quant_config = vllm_config.quant_config
|
||||
|
||||
|
||||
@ -868,11 +868,8 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
(not supported by Mamba)
|
||||
"""
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Mamba does not support prefix caching"
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@ -122,6 +122,11 @@ class Mamba2AttentionMetadata:
|
||||
last_chunk_indices_p: Optional[torch.Tensor]
|
||||
|
||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||
current_last_idx: torch.Tensor
|
||||
current_first_idx_p: torch.Tensor
|
||||
last_state_idx: torch.Tensor
|
||||
context_lens_p: torch.Tensor
|
||||
last_computed_offset_p: torch.Tensor
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: Optional[dict] = None
|
||||
@ -138,6 +143,24 @@ class Mamba2AttentionMetadataBuilder(
|
||||
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
||||
assert self.chunk_size is not None, (
|
||||
"chunk_size needs to be set in the model config for Mamba2 models")
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,
|
||||
cdiv(vllm_config.model_config.max_model_len,
|
||||
kv_cache_spec.block_size)),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.current_last_idx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, ),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.last_state_idx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, ),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
@ -158,7 +181,45 @@ class Mamba2AttentionMetadataBuilder(
|
||||
# for causal_conv1d
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
context_lens, context_lens_p = None, None
|
||||
current_first_idx, current_first_idx_p = None, None
|
||||
last_computed_offset, last_computed_offset_p = None, None
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
# Return a tensor of shape (#requests, #max blocks)
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
||||
|
||||
# Additional cache-related varaiables:
|
||||
mamba_block_size = self.kv_cache_spec.block_size
|
||||
seq_lens_pending = (
|
||||
torch.roll(common_attn_metadata.query_start_loc, -1, -1) -
|
||||
common_attn_metadata.query_start_loc)[:-1]
|
||||
context_lens = common_attn_metadata.seq_lens - \
|
||||
seq_lens_pending
|
||||
last_computed_offset = \
|
||||
context_lens % mamba_block_size
|
||||
# Indices: last_computed <= current_first <= current_last
|
||||
# Cases:
|
||||
# last_computed == current_first if last state was partially
|
||||
# computed and needs to be updated
|
||||
# current_first == current_last if no block crossing occurs, and
|
||||
# only one state will be stored
|
||||
# 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]:
|
||||
current_last_idx = cdiv(context_lens + seq_lens_pending,
|
||||
mamba_block_size) - 1
|
||||
current_first_idx = cdiv(context_lens + 1, mamba_block_size) - 1
|
||||
last_state_idx = cdiv(context_lens, mamba_block_size) - 1
|
||||
# -1 in case it's non-computed and causes later issues with indexing
|
||||
last_state_idx = \
|
||||
last_state_idx.clamp(min=0)
|
||||
|
||||
else:
|
||||
# Always return just a single block per each request:
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:,
|
||||
0]
|
||||
# Additional cache-related varaiables:
|
||||
current_last_idx = None
|
||||
last_state_idx = None
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
@ -178,6 +239,16 @@ class Mamba2AttentionMetadataBuilder(
|
||||
query_start_loc_p = common_attn_metadata.query_start_loc[
|
||||
-num_prefills - 1:] - num_decode_tokens
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
assert context_lens is not None
|
||||
context_lens_p = context_lens[num_reqs - num_prefills:num_reqs]
|
||||
assert last_computed_offset is not None
|
||||
last_computed_offset_p = last_computed_offset[
|
||||
num_reqs - num_prefills:num_reqs]
|
||||
assert current_first_idx is not None
|
||||
current_first_idx_p = current_first_idx[num_reqs -
|
||||
num_prefills:num_reqs]
|
||||
|
||||
num_computed_tokens_p = \
|
||||
common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills:num_reqs]
|
||||
@ -252,6 +323,19 @@ class Mamba2AttentionMetadataBuilder(
|
||||
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.current_last_idx[:num_decodes].copy_(current_last_idx,
|
||||
non_blocking=True)
|
||||
current_last_idx = \
|
||||
self.current_last_idx[:num_input_tokens]
|
||||
current_last_idx[num_decodes:] = 0
|
||||
|
||||
self.last_state_idx[:num_decodes].copy_(last_state_idx,
|
||||
non_blocking=True)
|
||||
last_state_idx = \
|
||||
self.last_state_idx[:num_input_tokens]
|
||||
last_state_idx[num_decodes:] = 0
|
||||
|
||||
attn_metadata = Mamba2AttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
@ -269,5 +353,10 @@ class Mamba2AttentionMetadataBuilder(
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
current_last_idx=current_last_idx,
|
||||
current_first_idx_p=current_first_idx_p,
|
||||
last_state_idx=last_state_idx,
|
||||
context_lens_p=context_lens_p,
|
||||
last_computed_offset_p=last_computed_offset_p,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
@ -546,20 +546,38 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
kv_cache_spec,
|
||||
MambaSpec), ("MambaManager can only be used for mamba groups")
|
||||
assert dcp_world_size == 1, "DCP not support mamba now."
|
||||
# Prefix caching is not supported for mamba now. Always return empty
|
||||
# list.
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[] for _ in range(len(kv_cache_group_ids)))
|
||||
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
# Search from right to left and early stop when a match is found.
|
||||
for i in range(max_num_blocks - 1, -1, -1):
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hashes[i], kv_cache_group_ids):
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
# the hit length logic later assumes:
|
||||
# hit_length = len(hit_blocks_other_attn[0])
|
||||
# * self.other_block_size
|
||||
# so we insert dummy blocks at the beginning:
|
||||
if i > 0:
|
||||
computed.extend([block_pool.null_block] * i)
|
||||
computed.append(cached)
|
||||
break # we just need the last match - early stopping
|
||||
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
# Each request will always have 1 block at this moment, so no need to
|
||||
# remove blocks.
|
||||
# Here unused blocks may be freed up for running requests.
|
||||
# TODO(@s3woz) Free up all blocks that aren't needed by Mamba2
|
||||
# (for which find_longest_cache_hit returns block_pool.null_block)
|
||||
pass
|
||||
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> int:
|
||||
"""
|
||||
cascade attention is not supported by mamba
|
||||
"""
|
||||
return 0
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
|
||||
@ -233,10 +233,8 @@ class MambaSpec(KVCacheSpec):
|
||||
return page_size
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
# We allocate 1 block for each request now, so max_memory_usage_bytes is
|
||||
# the same as page_size_bytes.
|
||||
# Need to update this when supporting prefix caching.
|
||||
return self.page_size_bytes
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@ -4240,21 +4240,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
not in ["qwen3_next"]):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
max_model_len = self.vllm_config.model_config.max_model_len
|
||||
|
||||
mamba_block_size = self.vllm_config.cache_config.mamba_block_size
|
||||
page_size_padded = (
|
||||
self.vllm_config.cache_config.mamba_page_size_padded)
|
||||
|
||||
# Set block_size to max_model_len, so that mamba model will always
|
||||
# have only one block in the KV cache.
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=mamba_module.get_state_shape(),
|
||||
dtypes=mamba_module.get_state_dtype(),
|
||||
block_size=max_model_len,
|
||||
block_size=mamba_block_size,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=mamba_module.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user