[Model] Support Mamba (#6484)

This commit is contained in:
Tyler Michael Smith 2024-10-11 11:40:06 -04:00 committed by GitHub
parent df3dcdf49d
commit 7342a7d7f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 1603 additions and 343 deletions

View File

@ -18,7 +18,13 @@ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/hugg
# Run basic model test
docker exec cpu-test bash -c "
pip install pytest matplotlib einops transformers_stream_generator
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
pytest -v -s tests/models -m \"not vlm\" \
--ignore=tests/models/test_embedding.py \
--ignore=tests/models/test_oot_registration.py \
--ignore=tests/models/test_registry.py \
--ignore=tests/models/test_jamba.py \
--ignore=tests/models/test_mamba.py \
--ignore=tests/models/test_danube3_4b.py" # Mamba kernels and Danube3-4B on CPU is not supported
# online inference
docker exec cpu-test bash -c "

View File

@ -27,6 +27,7 @@ docker exec cpu-test bash -c "
pytest -v -s tests/models/decoder_only/language \
--ignore=tests/models/test_fp8.py \
--ignore=tests/models/decoder_only/language/test_jamba.py \
--ignore=tests/models/decoder_only/language/test_mamba.py \
--ignore=tests/models/decoder_only/language/test_granitemoe.py \
--ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported

View File

@ -152,6 +152,11 @@ Text Generation
- :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc.
- ✅︎
- ✅︎
* - :code:`MambaForCausalLM`
- Mamba
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
- ✅︎
-
* - :code:`MiniCPMForCausalLM`
- MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.

View File

@ -20,22 +20,22 @@ def test_env(name: str, device: str, monkeypatch):
if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.is_openvino", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
assert backend.name == "OPENVINO"
else:
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
False)
assert backend.name == name
@ -46,32 +46,37 @@ def test_flash_attn(monkeypatch):
# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported data type
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported block size
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
backend = which_attn_to_use(16, None, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported sliding window
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
backend = which_attn_to_use(16, 1, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported head size
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
backend = which_attn_to_use(17, None, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
True)
assert backend.name != STR_FLASH_ATTN_VAL
@ -79,4 +84,4 @@ def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
which_attn_to_use(16, None, torch.float16, None, 16, False)

View File

@ -0,0 +1,295 @@
"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba.
Run `pytest tests/models/test_mamba.py`.
"""
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm.sampling_params import SamplingParams
from vllm.worker.model_runner import _get_graph_batch_size
from ...utils import check_outputs_equal
MODELS = ["state-spaces/mamba-130m-hf"]
# Use lower-level interfaces to create this greedy generator, as mamba will
# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used.
def generate_greedy(model_name, example_prompts, max_tokens):
# Create a text generation pipeline
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Generate texts from the prompts
outputs = []
for prompt in example_prompts:
# Tokenize the input prompt with truncation
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
input_ids = inputs["input_ids"].to(model.device)
# Generate text using the model's generate method directly
generated_ids = model.generate(input_ids, max_new_tokens=max_tokens)
generated_text = tokenizer.decode(generated_ids[0],
skip_special_tokens=True)
outputs.append((generated_ids[0].tolist(), generated_text))
return outputs
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_models(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
hf_outputs = generate_greedy(model, example_prompts, max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_batching(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
for_loop_outputs = []
with vllm_runner(model, dtype=dtype) as vllm_model:
for prompt in example_prompts:
for_loop_outputs.append(
vllm_model.generate_greedy([prompt], max_tokens)[0])
batched_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
check_outputs_equal(
outputs_0_lst=for_loop_outputs,
outputs_1_lst=batched_outputs,
name_0="for_loop_vllm",
name_1="batched_vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [10])
def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts,
model: str, dtype: str,
max_tokens: int) -> None:
# Tests chunked prefill in conjunction with n>1. In this case, prefill is
# populated with decoding tokens and we test that it doesn't fail.
# This test might fail if cache is not allocated correctly for n > 1
# decoding steps inside a chunked prefill forward pass (where we have both
# prefill and decode together )
sampling_params = SamplingParams(n=3,
temperature=1,
seed=0,
max_tokens=max_tokens)
with vllm_runner(
model,
dtype=dtype,
enable_chunked_prefill=True,
max_num_batched_tokens=30,
max_num_seqs=10 # forces prefill chunks with decoding
) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str,
max_tokens: int,
chunked_prefill_token_size: int) -> None:
"""
Checks exact match decode between huggingface model and vllm runner with
chunked prefill.
"""
max_num_seqs = chunked_prefill_token_size
max_num_batched_tokens = chunked_prefill_token_size
non_chunked = generate_greedy(model, example_prompts, max_tokens)
with vllm_runner(model,
dtype=dtype,
enable_chunked_prefill=True,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs) as vllm_model:
chunked = vllm_model.generate_greedy(example_prompts,
max_tokens=max_tokens)
check_outputs_equal(
outputs_0_lst=chunked,
outputs_1_lst=non_chunked,
name_0="chunked",
name_1="non_chunked",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [15])
def test_parallel_sampling(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
for_loop_outputs = []
for _ in range(10):
for_loop_outputs.append(
# using example_prompts index 1 instead of 0 since with 0 the
# logprobs get really close and the test doesn't pass
vllm_model.generate_greedy([example_prompts[1]], max_tokens)
[0])
sampling_params = SamplingParams(n=10,
temperature=0.001,
seed=0,
max_tokens=max_tokens)
n_lt_1_outputs = vllm_model.generate([example_prompts[1]],
sampling_params)
token_ids, texts = n_lt_1_outputs[0]
n_lt_1_outputs = [(token_id, text)
for token_id, text in zip(token_ids, texts)]
check_outputs_equal(
outputs_0_lst=n_lt_1_outputs,
outputs_1_lst=for_loop_outputs,
name_0="vllm_n_lt_1_outputs",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20])
def test_mamba_cache_cg_padding(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
example_prompts.append(example_prompts[0])
try:
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
except RuntimeError:
pytest.fail(
"Couldn't run batch size which is not equal to a Cuda Graph "
"captured batch size. "
"Could be related to mamba cache not padded correctly")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_models_preemption_recompute(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# Tests that outputs are identical with and w/o preemtions (recompute)
assert dtype == "float"
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = True
preempt_vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = False
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=preempt_vllm_outputs,
outputs_1_lst=vllm_outputs,
name_0="vllm_preepmtions",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Mamba inner state management doesn't
# collapse in case where the number of incoming requests and
# finished_requests_ids is larger than the maximum Mamba block capacity.
# This could generally happen due to the fact that Mamba does support
# statelessness mechanism where it can cleanup new incoming requests in
# a single step.
try:
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
except ValueError:
pytest.fail("Mamba inner state wasn't cleaned up properly between"
"steps finished requests registered unnecessarily ")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Mamba state is cleaned up between
# steps, If its not cleaned, an error would be expected.
try:
with vllm_runner(model, dtype=dtype) as vllm_model:
for _ in range(10):
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
except ValueError:
pytest.fail("Mamba inner state wasn't cleaned up between states, "
"could be related to finished_requests_ids")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)

View File

@ -0,0 +1,324 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder)
from vllm.attention.backends.utils import CommonAttentionState
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
# Placeholder attention backend for models like Mamba and embedding models that
# lack attention.
class PlaceholderAttentionBackend(AttentionBackend):
"""Placeholder backend for when no attention is needed."""
@staticmethod
def get_name() -> str:
return "placeholder-attn"
@staticmethod
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
return PlaceholderAttentionImpl
@staticmethod
def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
return PlaceholderAttentionMetadataBuilder
@staticmethod
def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
return PlaceholderAttentionMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (1, 1, 1, 1, 1)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
return
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
return
@dataclass
class PlaceholderAttentionMetadata(AttentionMetadata):
"""Attention metadata for prefill and decode batched together."""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum query length in the batch.
max_query_len: Optional[int]
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
_cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
_cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.seq_start_loc is not None
# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)
self._cached_prefill_metadata = PlaceholderAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
decode_query_len=0,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=block_tables,
use_cuda_graph=False,
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.seq_lens_tensor is not None
# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)
self._cached_decode_metadata = PlaceholderAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
decode_query_len=self.decode_query_len,
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata
class PlaceholderAttentionMetadataBuilder(
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.input_builder = input_builder
self.runner = input_builder.runner
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
"""
is_prompt = inter_data.is_prompt
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
decode_query_len = max(decode_query_lens)
else:
decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
if use_captured_graph:
num_decode_tokens = batch_size
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
# Placeholders
slot_mapping = torch.empty(0)
block_tables = torch.empty(0)
return PlaceholderAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
decode_query_len=decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
class PlaceholderAttentionImpl(AttentionImpl):
def __init__(self, *args, **kwargs) -> None:
return
def forward(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError

View File

@ -42,10 +42,12 @@ class Attention(nn.Module):
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
sliding_window = cache_config.sliding_window
is_attention_free = cache_config.is_attention_free
else:
kv_cache_dtype = "auto"
block_size = 16
sliding_window = None
is_attention_free = False
if num_kv_heads is None:
num_kv_heads = num_heads
@ -76,9 +78,9 @@ class Attention(nn.Module):
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size, blocksparse_params
attn_backend = get_attn_backend(head_size, sliding_window, dtype,
kv_cache_dtype, block_size,
is_attention_free, blocksparse_params
is not None)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,

View File

@ -24,6 +24,7 @@ class _Backend(enum.Enum):
FLASHINFER = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
def backend_name_to_enum(backend_name: str) -> _Backend:
@ -88,13 +89,12 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
@lru_cache(maxsize=None)
def get_attn_backend(
num_heads: int,
head_size: int,
num_kv_heads: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
is_blocksparse: bool = False,
) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
@ -105,9 +105,8 @@ def get_attn_backend(
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend
backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
backend = which_attn_to_use(head_size, sliding_window, dtype,
kv_cache_dtype, block_size, is_attention_free)
if backend == _Backend.FLASH_ATTN:
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
@ -146,23 +145,31 @@ def get_attn_backend(
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
return PallasAttentionBackend
elif backend == _Backend.NO_ATTENTION:
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
return PlaceholderAttentionBackend
else:
raise ValueError("Invalid attention backend.")
def which_attn_to_use(
num_heads: int,
head_size: int,
num_kv_heads: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
) -> _Backend:
"""Returns which flash attention backend to use."""
# Default case.
selected_backend = _Backend.FLASH_ATTN
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if is_attention_free:
return _Backend.NO_ATTENTION
# Check whether a particular choice of backend was
# previously forced.
#

View File

@ -196,6 +196,9 @@ class ModelConfig:
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
self.is_attention_free = self._init_attention_free()
self.has_inner_state = self._init_has_inner_state()
self.override_neuron_config = override_neuron_config if is_neuron(
) else None
self._verify_embedding_mode()
@ -216,6 +219,14 @@ class ModelConfig:
return None
def _init_attention_free(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_attention_free_model(architectures)
def _init_has_inner_state(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.model_has_inner_state(architectures)
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow", "mistral"]:
@ -438,6 +449,10 @@ class ModelConfig:
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return 256
if self.is_attention_free:
return 0
if hasattr(self.hf_text_config, "head_dim"):
return self.hf_text_config.head_dim
# FIXME(woosuk): This may not be true for all models.
@ -469,6 +484,9 @@ class ModelConfig:
return getattr(self.hf_config.attn_config, "kv_n_heads",
self.hf_config.num_attention_heads)
if self.is_attention_free:
return 0
attributes = [
# For Falcon:
"n_head_kv",
@ -511,31 +529,17 @@ class ModelConfig:
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
return end - start
def contains_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> bool:
"""True for Mamba/SSM models (Jamba)"""
return self._get_num_seqlen_agnostic_layers(parallel_config) > 0
def get_layers_block_type(self,
parallel_config: "ParallelConfig") -> List[str]:
num_layers = self.get_num_layers(parallel_config)
# Transformers supports layers_block_type @property
return getattr(self.hf_config, "layers_block_type",
["attention"] * num_layers)
def get_num_attention_layers(self,
parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t == "attention"
])
if self.is_attention_free:
return 0
def _get_num_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t != "attention"
])
num_layers = self.get_num_layers(parallel_config)
# Transformers supports layers_block_type @property
layers = getattr(self.hf_config, "layers_block_type",
["attention"] * num_layers)
return len([t for t in layers if t == "attention"])
def get_multimodal_config(self) -> "MultiModalConfig":
"""
@ -585,6 +589,7 @@ class CacheConfig:
gpu_memory_utilization: float,
swap_space: float,
cache_dtype: str,
is_attention_free: bool = False,
num_gpu_blocks_override: Optional[int] = None,
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
@ -595,6 +600,7 @@ class CacheConfig:
self.swap_space_bytes = swap_space * GiB_bytes
self.num_gpu_blocks_override = num_gpu_blocks_override
self.cache_dtype = cache_dtype
self.is_attention_free = is_attention_free
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self.cpu_offload_gb = cpu_offload_gb

View File

@ -36,10 +36,10 @@ class BlockSpaceManager(ABC):
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
return BlockSpaceManagerV2
if version == "embedding":
from vllm.core.embedding_model_block_manager import (
EmbeddingModelBlockSpaceManager)
return EmbeddingModelBlockSpaceManager
if version == "placeholder":
from vllm.core.placeholder_block_space_manager import (
PlaceholderBlockSpaceManager)
return PlaceholderBlockSpaceManager
raise ValueError(f"Unknown version {version=}")

View File

@ -5,9 +5,10 @@ from vllm.sequence import Sequence, SequenceGroup
from vllm.utils import Device
class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
"""An embedding version of BlockSpaceManager for use in environments
with embedding models where block management is not required.
class PlaceholderBlockSpaceManager(BlockSpaceManager):
"""A version of BlockSpaceManager for use in environments
where block management is not required.
For example: embedding models or attention-free models like Mamba.
This class provides the same interface as BlockSpaceManager, but its
methods perform no actions or return simple values like True in specific
@ -40,7 +41,7 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
seq: Sequence,
num_lookahead_slots: int,
) -> List[Tuple[int, int]]:
return None # type: ignore
return []
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
pass

View File

@ -314,8 +314,9 @@ class Scheduler:
version = "v1"
if self.scheduler_config.use_v2_block_manager:
version = "v2"
if self.scheduler_config.embedding_mode:
version = "embedding"
if (self.scheduler_config.embedding_mode
or self.cache_config.is_attention_free):
version = "placeholder"
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
version)

View File

@ -912,6 +912,7 @@ class EngineArgs:
gpu_memory_utilization=self.gpu_memory_utilization,
swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype,
is_attention_free=model_config.is_attention_free,
num_gpu_blocks_override=self.num_gpu_blocks_override,
sliding_window=model_config.get_sliding_window(),
enable_prefix_caching=self.enable_prefix_caching,
@ -945,13 +946,9 @@ class EngineArgs:
use_sliding_window = (model_config.get_sliding_window()
is not None)
use_spec_decode = self.speculative_model is not None
has_seqlen_agnostic_layers = (
model_config.contains_seqlen_agnostic_layers(
parallel_config))
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter
and not has_seqlen_agnostic_layers):
and not self.enable_prompt_adapter):
self.enable_chunked_prefill = True
logger.warning(
"Chunked prefill is enabled by default for models with "

View File

@ -6,7 +6,8 @@ import json
import os
import tempfile
from collections import defaultdict
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
Tuple, Union)
import filelock
import gguf
@ -559,6 +560,38 @@ def row_parallel_weight_loader(param: torch.Tensor,
return default_weight_loader(param, loaded_weight)
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
"""Create a weight loader that shards the weights along the given axis"""
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
tp_rank = get_tensor_model_parallel_rank()
shard_size = param.data.shape[shard_axis]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size)
return default_weight_loader(param, loaded_weight)
return loader
def composed_weight_loader(
loader: LoaderFunction, fn: Callable[[torch.Tensor],
torch.Tensor]) -> LoaderFunction:
"""Create a weight loader that post-processes the weights after loading"""
def composed_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
loader(param, loaded_weight)
param.data.copy_(fn(param))
return
return composed_loader
def initialize_dummy_weights(
model: torch.nn.Module,
low: float = -1e-3,

View File

@ -271,7 +271,7 @@ class HasInnerState(Protocol):
"""
A flag that indicates this model has inner state.
Models that has inner state usually need access to the scheduler_config
for max_num_seqs ,etc... (Currently only used by Jamba)
for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
"""
def __init__(self,
@ -307,3 +307,46 @@ def has_inner_state(
return isinstance(model, _HasInnerStateType)
return isinstance(model, HasInnerState)
@runtime_checkable
class IsAttentionFree(Protocol):
"""The interface required for all models like Mamba that lack attention,
but do have state whose size is constant wrt the number of tokens."""
is_attention_free: ClassVar[Literal[True]] = True
"""
A flag that indicates this model has no attention.
Used for block manager and attention backend selection.
True for Mamba but not Jamba.
"""
def __init__(self) -> None:
...
@runtime_checkable
class _IsAttentionFreeType(Protocol):
is_attention_free: ClassVar[Literal[True]]
def __init__(self) -> None:
...
@overload
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
...
@overload
def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]:
...
def is_attention_free(
model: Union[Type[object], object]
) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
if isinstance(model, type):
return isinstance(model, _IsAttentionFreeType)
return isinstance(model, IsAttentionFree)

View File

@ -1,18 +1,16 @@
# coding=utf-8
"""Inference-only Jamba model."""
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from torch.nn.parameter import Parameter
from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -29,7 +27,9 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheManager
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
@ -99,16 +99,6 @@ class JambaMambaMixer(nn.Module):
bias=True,
skip_bias_add=True)
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
param.data.copy_(
loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
dim=0)[tp_rank])
def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
weight_loader(param, -torch.exp(loaded_weight.float()))
tp_size = get_tensor_model_parallel_world_size()
self.A = nn.Parameter(
torch.empty(
@ -118,8 +108,10 @@ class JambaMambaMixer(nn.Module):
))
self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
set_weight_attrs(self.D, {"weight_loader": weight_loader})
set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader(
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
self.out_proj = RowParallelLinear(
self.intermediate_size,
@ -571,10 +563,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
if not lora_config else lora_config.lora_vocab_padding_size,
)
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
@ -586,203 +576,36 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs):
if not self.mamba_cache:
self._prepare_mamba_cache()
if self.mamba_cache is None:
max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
mamba_cache = self._release_finished_and_prepare_mamba_cache(
finished_requests_ids, request_ids_to_seq_ids)
else:
# CUDA graph capturing runs
mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]
layers_type = self.config.layers_block_type
num_mamba_layers = sum(
[layer_type == "mamba" for layer_type in layers_type])
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
*self._get_mamba_cache_shape())
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
input_ids, attn_metadata, **kwargs)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache[0],
mamba_cache[1])
attn_metadata, mamba_cache_tensors[0],
mamba_cache_tensors[1])
return hidden_states
def _swap_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, [to_index,from_index]] = \
cache_t[:, [from_index,to_index]]
def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def _move_out_if_already_occupied(self, index: int,
all_occupied_indices: List[int]):
if index in all_occupied_indices:
first_free_index = self._first_free_index_in_mamba_cache()
# In case occupied, move the occupied to a new empty block
self._move_cache_index_and_mappings(from_index=index,
to_index=first_free_index)
def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
seq_id: int,
destination_index: int):
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
all_occupied_indices = self._get_all_occupied_indices()
if cur_rid not in self.mamba_cache_indices_mapping:
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index
}
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened now we only need to copy the already
# existing cache into the siblings seq_ids caches
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
index_exists = list(seq_ids2indices.values())[0]
# case of decoding n>1, copy prefill cache to decoding indices
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
else:
# already exists
cache_index_already_exists = self.mamba_cache_indices_mapping[
cur_rid][seq_id]
if cache_index_already_exists != destination_index:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self._swap_pair_indices_and_mappings(
from_index=cache_index_already_exists,
to_index=destination_index)
def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]
) -> Tuple[torch.Tensor, torch.Tensor]:
running_indices = []
request_ids_to_seq_ids_flatten = [
(req_id, seq_id)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
batch_size = len(request_ids_to_seq_ids_flatten)
for dest_index, (request_id,
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
if request_id in finished_requests_ids:
# Do not allocate cache index for requests that run
# and finish right after
continue
self._assign_seq_id_to_mamba_cache_in_specific_dest(
request_id, seq_id, dest_index)
running_indices.append(dest_index)
self._clean_up_first_bs_blocks(batch_size, running_indices)
conv_state = self.mamba_cache[0][:, :batch_size]
temporal_state = self.mamba_cache[1][:, :batch_size]
return (conv_state, temporal_state)
def _get_all_occupied_indices(self):
return [
cache_idx
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
for cache_idx in seq_ids2indices.values()
]
def _clean_up_first_bs_blocks(self, batch_size: int,
indices_for_current_run: List[int]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices = range(batch_size)
max_possible_batch_size = self.mamba_cache[0].shape[1]
for destination_index in destination_indices:
if destination_index in self._get_all_occupied_indices() and \
destination_index not in indices_for_current_run:
# move not running indices outside of the batch
all_other_indices = list(
range(batch_size, max_possible_batch_size))
first_avail_index = self._first_free_index_in_mamba_cache(
all_other_indices)
self._swap_indices(from_index=destination_index,
to_index=first_avail_index)
def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
self._update_mapping_index(from_index=from_index, to_index=to_index)
def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
self._swap_mapping_index(from_index=from_index, to_index=to_index)
def _swap_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
elif to_index == index:
seq_ids2index.update({seq_id: from_index})
def _update_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
return
def _release_finished_and_prepare_mamba_cache(
self, finished_requests_ids,
request_ids_to_seq_ids) -> Tuple[torch.Tensor, torch.Tensor]:
self._release_mamba_cache(finished_requests_ids)
return self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
finished_requests_ids)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant Mamba cache into the CUDA graph input buffer
that was provided during the capture runs
(JambaForCausalLM.mamba_gc_cache_buffer).
"""
self._release_finished_and_prepare_mamba_cache(
kwargs["finished_requests_ids"], kwargs["request_ids_to_seq_ids"])
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping:
self.mamba_cache_indices_mapping.pop(req_id)
def _first_free_index_in_mamba_cache(
self, indices_range: Optional[List[int]] = None) -> int:
assert self.mamba_cache is not None
if indices_range is None:
max_possible_batch_size = self.mamba_cache[0].shape[1]
indices_range = list(range(max_possible_batch_size))
all_occupied_indices = self._get_all_occupied_indices()
for i in indices_range:
if i not in all_occupied_indices:
return i
raise Exception("Couldn't find a free spot in the mamba cache! This"
"should never happen")
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self
) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]:
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
conv_state_shape = (
@ -790,31 +613,11 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
self.config.mamba_d_conv - 1,
)
temporal_state_shape = (
self.config.mamba_expand * self.config.hidden_size // world_size,
self.config.mamba_expand * hidden_size // world_size,
self.config.mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def _prepare_mamba_cache(self):
dtype = self.lm_head.weight.dtype
layers_type = self.config.layers_block_type
mamba_layers = sum(
[layer_type == "mamba" for layer_type in layers_type])
max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config else
max(_BATCH_SIZES_TO_CAPTURE) + 2)
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
assert conv_state_shape is not None and temporal_state_shape is not None
self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) +
conv_state_shape,
dtype=dtype,
device="cuda"),
torch.empty(size=(mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=dtype,
device="cuda"))
def compute_logits(
self,
hidden_states: torch.Tensor,

View File

@ -0,0 +1,499 @@
# coding=utf-8
"""PyTorch MAMBA model."""
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import MambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
composed_weight_loader, default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree)
from vllm.model_executor.models.mamba_cache import MambaCacheManager
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)
KVCache = Tuple[torch.Tensor, torch.Tensor]
@dataclass
class MambaCacheParams:
is_prompt: bool = False
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class MambaMixer(nn.Module):
"""
Compute , A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
for why A isn't selective) ∆, B, C are input-dependent
(this is a key difference between Mamba and the linear time
invariant S4, and is why Mamba is called
**selective** state spaces)
"""
def __init__(self, config: MambaConfig, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = config.intermediate_size
self.time_step_rank = int(config.time_step_rank)
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.intermediate_size,
bias=config.use_conv_bias,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(self.hidden_size,
[self.intermediate_size] * 2,
bias=config.use_bias)
# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
self.intermediate_size,
self.time_step_rank + self.ssm_state_size * 2,
bias=False,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(self.time_step_rank,
self.intermediate_size,
bias=True,
skip_bias_add=True)
tp_size = get_tensor_model_parallel_world_size()
self.A = nn.Parameter(
torch.empty(
self.intermediate_size // tp_size,
self.ssm_state_size,
dtype=torch.float32,
))
self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader(
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
self.out_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=config.use_bias,
input_is_parallel=True,
)
self.activation = config.hidden_act
def forward(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
ssm_state: torch.Tensor):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split(
ssm_parameters,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1,
)
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
return contextualized_states
class MambaMLP(nn.Module):
def __init__(
self,
config: MambaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
hidden_act = config.hidden_act
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class MambaDecoderLayer(nn.Module):
def __init__(self,
config: MambaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.mixer = MambaMixer(config, layer_idx)
self.feed_forward = MambaMLP(config, quant_config=quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, attn_metadata, conv_state,
ssm_state)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
class MambaModel(nn.Module):
def __init__(
self,
config: MambaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embeddings = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
decoder_layers = []
for i in range(config.num_hidden_layers):
decoder_layers.append(
MambaDecoderLayer(config,
layer_idx=i,
cache_config=cache_config,
quant_config=quant_config))
self.layers = nn.ModuleList(decoder_layers)
self.norm_f = RMSNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
current_ssm_state = ssm_state[i]
current_conv_state = conv_state[i]
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
conv_state=current_conv_state,
ssm_state=current_ssm_state,
)
hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embeddings": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
config: MambaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
) -> None:
assert not cache_config.enable_prefix_caching, \
"Mamba does not support prefix caching"
super().__init__()
self.config = config
self.scheduler_config = scheduler_config
self.backbone = MambaModel(config,
cache_config=cache_config,
quant_config=quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = self.backbone.embeddings
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs):
if self.mamba_cache is None:
max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, self.config.num_hidden_layers,
max_batch_size, *self._get_mamba_cache_shape())
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
input_ids, attn_metadata, **kwargs)
hidden_states = self.backbone(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_tensors[0],
mamba_cache_tensors[1])
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
conv_state_shape = (
self.config.intermediate_size // world_size,
self.config.conv_kernel - 1,
)
temporal_state_shape = (
self.config.intermediate_size // world_size,
self.config.state_size,
)
return conv_state_shape, temporal_state_shape
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "A_log" in name:
name = name.replace("A_log", "A")
if ".self_attn." in name:
name = name.replace(".self_attn", "")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -0,0 +1,222 @@
from typing import Dict, List, Optional
import torch
from vllm.attention.backends.abstract import AttentionMetadata
class MambaCacheManager:
def __init__(self, dtype, num_mamba_layers, max_batch_size,
conv_state_shape, temporal_state_shape):
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
conv_state_shape,
dtype=dtype,
device="cuda")
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=dtype,
device="cuda")
self.mamba_cache = (conv_state, temporal_state)
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
def current_run_tensors(self, input_ids: torch.Tensor,
attn_metadata: AttentionMetadata, **kwargs):
"""
Return the tensors for the current run's conv and ssm state.
"""
if "seqlen_agnostic_capture_inputs" not in kwargs:
# We get here only on Prefill/Eager mode runs
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
finished_requests_ids = kwargs["finished_requests_ids"]
self._release_finished_requests(finished_requests_ids)
mamba_cache_tensors = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, finished_requests_ids)
else:
# CUDA graph capturing runs
mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"]
return mamba_cache_tensors
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant Mamba cache into the CUDA graph input buffer
that was provided during the capture runs
(JambaForCausalLM.mamba_gc_cache_buffer).
"""
assert all(
key in kwargs
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
finished_requests_ids = kwargs["finished_requests_ids"]
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
self._release_finished_requests(finished_requests_ids)
self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
finished_requests_ids)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
def _swap_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, [to_index,from_index]] = \
cache_t[:, [from_index,to_index]]
def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def _move_out_if_already_occupied(self, index: int,
all_occupied_indices: List[int]):
if index in all_occupied_indices:
first_free_index = self._first_free_index_in_mamba_cache()
# In case occupied, move the occupied to a new empty block
self._move_cache_index_and_mappings(from_index=index,
to_index=first_free_index)
def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
seq_id: int,
destination_index: int):
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
all_occupied_indices = self._get_all_occupied_indices()
if cur_rid not in self.mamba_cache_indices_mapping:
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index
}
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened now we only need to copy the already
# existing cache into the siblings seq_ids caches
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
index_exists = list(seq_ids2indices.values())[0]
# case of decoding n>1, copy prefill cache to decoding indices
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
else:
# already exists
cache_index_already_exists = self.mamba_cache_indices_mapping[
cur_rid][seq_id]
if cache_index_already_exists != destination_index:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self._swap_pair_indices_and_mappings(
from_index=cache_index_already_exists,
to_index=destination_index)
def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]):
running_indices = []
request_ids_to_seq_ids_flatten = [
(req_id, seq_id)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
batch_size = len(request_ids_to_seq_ids_flatten)
for dest_index, (request_id,
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
if request_id in finished_requests_ids:
# Do not allocate cache index for requests that run
# and finish right after
continue
self._assign_seq_id_to_mamba_cache_in_specific_dest(
request_id, seq_id, dest_index)
running_indices.append(dest_index)
self._clean_up_first_bs_blocks(batch_size, running_indices)
conv_state = self.mamba_cache[0][:, :batch_size]
temporal_state = self.mamba_cache[1][:, :batch_size]
return (conv_state, temporal_state)
def _get_all_occupied_indices(self):
return [
cache_idx
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
for cache_idx in seq_ids2indices.values()
]
def _clean_up_first_bs_blocks(self, batch_size: int,
indices_for_current_run: List[int]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices = range(batch_size)
max_possible_batch_size = self.mamba_cache[0].shape[1]
for destination_index in destination_indices:
if destination_index in self._get_all_occupied_indices() and \
destination_index not in indices_for_current_run:
# move not running indices outside of the batch
all_other_indices = list(
range(batch_size, max_possible_batch_size))
first_avail_index = self._first_free_index_in_mamba_cache(
all_other_indices)
self._swap_indices(from_index=destination_index,
to_index=first_avail_index)
def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
self._update_mapping_index(from_index=from_index, to_index=to_index)
def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
self._swap_mapping_index(from_index=from_index, to_index=to_index)
def _swap_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
elif to_index == index:
seq_ids2index.update({seq_id: from_index})
def _update_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
return
def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping:
self.mamba_cache_indices_mapping.pop(req_id)
def _first_free_index_in_mamba_cache(
self, indices_range: Optional[List[int]] = None) -> int:
assert self.mamba_cache is not None
if indices_range is None:
max_possible_batch_size = self.mamba_cache[0].shape[1]
indices_range = list(range(max_possible_batch_size))
all_occupied_indices = self._get_all_occupied_indices()
for i in indices_range:
if i not in all_occupied_indices:
return i
raise Exception("Couldn't find a free spot in the mamba cache! This"
"should never happen")

View File

@ -14,7 +14,8 @@ import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_hip
from .interfaces import supports_multimodal, supports_pp
from .interfaces import (has_inner_state, is_attention_free,
supports_multimodal, supports_pp)
from .interfaces_base import is_embedding_model, is_text_generation_model
logger = init_logger(__name__)
@ -52,6 +53,7 @@ _TEXT_GENERATION_MODELS = {
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
@ -157,6 +159,8 @@ class _ModelInfo:
is_embedding_model: bool
supports_multimodal: bool
supports_pp: bool
has_inner_state: bool
is_attention_free: bool
@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
@ -165,6 +169,8 @@ class _ModelInfo:
is_embedding_model=is_embedding_model(model),
supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model),
has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model),
)
@ -380,6 +386,14 @@ class _ModelRegistry:
) -> bool:
return self.inspect_model_cls(architectures).supports_pp
def model_has_inner_state(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).has_inner_state
def is_attention_free_model(self, architectures: Union[str,
List[str]]) -> bool:
return self.inspect_model_cls(architectures).is_attention_free
ModelRegistry = _ModelRegistry({
model_arch: _LazyRegisteredModel(

View File

@ -52,15 +52,12 @@ class CacheEngine:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Get attention backend.
self.attn_backend = get_attn_backend(
model_config.get_num_attention_heads(parallel_config),
self.head_size,
self.num_kv_heads,
model_config.get_sliding_window(),
model_config.dtype,
cache_config.cache_dtype,
self.block_size,
)
self.attn_backend = get_attn_backend(self.head_size,
model_config.get_sliding_window(),
model_config.dtype,
cache_config.cache_dtype,
self.block_size,
model_config.is_attention_free)
# Initialize the cache.
self.gpu_cache = self._allocate_kv_cache(

View File

@ -418,13 +418,12 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
# Multi-modal data support

View File

@ -56,13 +56,12 @@ class CPUCacheEngine:
# Get attention backend.
self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
# Initialize the cache.

View File

@ -196,7 +196,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_seqlen_agnostic else {}
} if self.has_inner_state else {}
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
with set_forward_context(model_input.attn_metadata):

View File

@ -17,7 +17,6 @@ import torch.nn as nn
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.levels import CompilationLevel
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
@ -991,8 +990,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture.
self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
parallel_config)
self.has_inner_state = model_config.has_inner_state
# When using CUDA graph, the input block tables must be padded to
# max_seq_len_to_capture. However, creating the block table in
@ -1003,22 +1001,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.graph_block_tables = np.zeros(
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
dtype=np.int32)
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
self.attn_backend = get_attn_backend(
num_attn_heads,
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
) if num_attn_heads else None
if self.attn_backend:
self.attn_state = self.attn_backend.get_state_cls()(
weakref.proxy(self))
else:
self.attn_state = CommonAttentionState(weakref.proxy(self))
self.model_config.is_attention_free,
)
self.attn_state = self.attn_backend.get_state_cls()(
weakref.proxy(self))
# Multi-modal data support
self.input_registry = input_registry
@ -1498,7 +1490,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"previous_hidden_states"] = previous_hidden_states[:
batch_size]
if self.has_seqlen_agnostic:
if self.has_inner_state:
# Only used by Mamba-based models CUDA graph atm (Jamba)
capture_inputs.update({
"seqlen_agnostic_capture_inputs":
@ -1647,7 +1639,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_seqlen_agnostic else {}
} if self.has_inner_state else {}
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start = torch.cuda.Event(enable_timing=True)
@ -1852,10 +1844,14 @@ class CUDAGraphRunner:
# Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True)
if self.backend_name != "placeholder-attn":
self.input_buffers["slot_mapping"].copy_(
attn_metadata.slot_mapping, non_blocking=True)
self.attn_state.prepare_graph_input_buffers(
self.input_buffers, attn_metadata, self._is_encoder_decoder_model)
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
**kwargs)

View File

@ -74,13 +74,12 @@ class OpenVINOModelRunner:
self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
# Multi-modal data support

View File

@ -70,13 +70,12 @@ class OpenVINOCacheEngine:
# Get attention backend.
self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.head_size,
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
# Initialize the cache.

View File

@ -113,13 +113,12 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
(self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
dtype=np.int32)
self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
False,
)
self.cached_step_outputs: List[torch.Tensor] = []

View File

@ -236,11 +236,15 @@ class Worker(LocalOrDistributedWorkerBase):
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes()
num_gpu_blocks = int(
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
if cache_block_size == 0:
num_gpu_blocks = 0
num_cpu_blocks = 0
else:
num_gpu_blocks = int(
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
@ -257,6 +261,7 @@ class Worker(LocalOrDistributedWorkerBase):
"""
raise_if_cache_size_invalid(num_gpu_blocks,
self.cache_config.block_size,
self.cache_config.is_attention_free,
self.model_config.max_model_len)
self.cache_config.num_gpu_blocks = num_gpu_blocks
@ -472,14 +477,18 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
"`dtype` flag in CLI, for example: --dtype=half.")
def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
max_model_len) -> None:
if num_gpu_blocks <= 0:
if is_attention_free and num_gpu_blocks != 0:
raise ValueError("No memory should be allocated for the cache blocks "
f"for an attention-free model, but {num_gpu_blocks}"
"blocks are allocated.")
if not is_attention_free and num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = block_size * num_gpu_blocks
if max_model_len > max_seq_len:
if not is_attention_free and max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "

View File

@ -372,13 +372,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
# Multi-modal data support