mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:45:00 +08:00
[Model] Support Mamba (#6484)
This commit is contained in:
parent
df3dcdf49d
commit
7342a7d7f8
@ -18,7 +18,13 @@ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/hugg
|
||||
# Run basic model test
|
||||
docker exec cpu-test bash -c "
|
||||
pip install pytest matplotlib einops transformers_stream_generator
|
||||
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
||||
pytest -v -s tests/models -m \"not vlm\" \
|
||||
--ignore=tests/models/test_embedding.py \
|
||||
--ignore=tests/models/test_oot_registration.py \
|
||||
--ignore=tests/models/test_registry.py \
|
||||
--ignore=tests/models/test_jamba.py \
|
||||
--ignore=tests/models/test_mamba.py \
|
||||
--ignore=tests/models/test_danube3_4b.py" # Mamba kernels and Danube3-4B on CPU is not supported
|
||||
|
||||
# online inference
|
||||
docker exec cpu-test bash -c "
|
||||
|
||||
@ -27,6 +27,7 @@ docker exec cpu-test bash -c "
|
||||
pytest -v -s tests/models/decoder_only/language \
|
||||
--ignore=tests/models/test_fp8.py \
|
||||
--ignore=tests/models/decoder_only/language/test_jamba.py \
|
||||
--ignore=tests/models/decoder_only/language/test_mamba.py \
|
||||
--ignore=tests/models/decoder_only/language/test_granitemoe.py \
|
||||
--ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
||||
|
||||
|
||||
@ -152,6 +152,11 @@ Text Generation
|
||||
- :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`MambaForCausalLM`
|
||||
- Mamba
|
||||
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
|
||||
- ✅︎
|
||||
-
|
||||
* - :code:`MiniCPMForCausalLM`
|
||||
- MiniCPM
|
||||
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
|
||||
|
||||
@ -20,22 +20,22 @@ def test_env(name: str, device: str, monkeypatch):
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.is_cpu", return_value=True):
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
|
||||
torch.float16, 16)
|
||||
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
|
||||
16, False)
|
||||
assert backend.name == "TORCH_SDPA"
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.is_hip", return_value=True):
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
|
||||
torch.float16, 16)
|
||||
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
|
||||
16, False)
|
||||
assert backend.name == "ROCM_FLASH"
|
||||
elif device == "openvino":
|
||||
with patch("vllm.attention.selector.is_openvino", return_value=True):
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
|
||||
torch.float16, 16)
|
||||
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
|
||||
16, False)
|
||||
assert backend.name == "OPENVINO"
|
||||
else:
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
|
||||
torch.float16, 16)
|
||||
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == name
|
||||
|
||||
|
||||
@ -46,32 +46,37 @@ def test_flash_attn(monkeypatch):
|
||||
|
||||
# Unsupported CUDA arch
|
||||
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
|
||||
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported data type
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
|
||||
backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported kv cache data type
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
|
||||
backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported block size
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
|
||||
backend = which_attn_to_use(16, None, torch.float16, None, 8, False)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported sliding window
|
||||
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
|
||||
backend = which_attn_to_use(16, 1, torch.float16, None, 16, False)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# flash-attn is not installed
|
||||
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
|
||||
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported head size
|
||||
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
|
||||
backend = which_attn_to_use(17, None, torch.float16, None, 16, False)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Attention-free models should bypass env and use PlaceholderAttention
|
||||
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
|
||||
True)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
|
||||
@ -79,4 +84,4 @@ def test_invalid_env(monkeypatch):
|
||||
"""Throw an exception if the backend name is invalid."""
|
||||
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
|
||||
with pytest.raises(ValueError):
|
||||
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
|
||||
which_attn_to_use(16, None, torch.float16, None, 16, False)
|
||||
|
||||
295
tests/models/decoder_only/language/test_mamba.py
Normal file
295
tests/models/decoder_only/language/test_mamba.py
Normal file
@ -0,0 +1,295 @@
|
||||
"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba.
|
||||
|
||||
Run `pytest tests/models/test_mamba.py`.
|
||||
"""
|
||||
import pytest
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.worker.model_runner import _get_graph_batch_size
|
||||
|
||||
from ...utils import check_outputs_equal
|
||||
|
||||
MODELS = ["state-spaces/mamba-130m-hf"]
|
||||
|
||||
|
||||
# Use lower-level interfaces to create this greedy generator, as mamba will
|
||||
# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used.
|
||||
def generate_greedy(model_name, example_prompts, max_tokens):
|
||||
# Create a text generation pipeline
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
|
||||
# Generate texts from the prompts
|
||||
outputs = []
|
||||
for prompt in example_prompts:
|
||||
# Tokenize the input prompt with truncation
|
||||
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
||||
input_ids = inputs["input_ids"].to(model.device)
|
||||
|
||||
# Generate text using the model's generate method directly
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=max_tokens)
|
||||
generated_text = tokenizer.decode(generated_ids[0],
|
||||
skip_special_tokens=True)
|
||||
|
||||
outputs.append((generated_ids[0].tolist(), generated_text))
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
hf_outputs = generate_greedy(model, example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
def test_batching(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
# To pass the small model tests, we need full precision.
|
||||
for_loop_outputs = []
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
for prompt in example_prompts:
|
||||
for_loop_outputs.append(
|
||||
vllm_model.generate_greedy([prompt], max_tokens)[0])
|
||||
|
||||
batched_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=for_loop_outputs,
|
||||
outputs_1_lst=batched_outputs,
|
||||
name_0="for_loop_vllm",
|
||||
name_1="batched_vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts,
|
||||
model: str, dtype: str,
|
||||
max_tokens: int) -> None:
|
||||
# Tests chunked prefill in conjunction with n>1. In this case, prefill is
|
||||
# populated with decoding tokens and we test that it doesn't fail.
|
||||
# This test might fail if cache is not allocated correctly for n > 1
|
||||
# decoding steps inside a chunked prefill forward pass (where we have both
|
||||
# prefill and decode together )
|
||||
sampling_params = SamplingParams(n=3,
|
||||
temperature=1,
|
||||
seed=0,
|
||||
max_tokens=max_tokens)
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=30,
|
||||
max_num_seqs=10 # forces prefill chunks with decoding
|
||||
) as vllm_model:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
|
||||
def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int) -> None:
|
||||
"""
|
||||
Checks exact match decode between huggingface model and vllm runner with
|
||||
chunked prefill.
|
||||
"""
|
||||
max_num_seqs = chunked_prefill_token_size
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
non_chunked = generate_greedy(model, example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs) as vllm_model:
|
||||
chunked = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens=max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=chunked,
|
||||
outputs_1_lst=non_chunked,
|
||||
name_0="chunked",
|
||||
name_1="non_chunked",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [15])
|
||||
def test_parallel_sampling(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
for_loop_outputs = []
|
||||
for _ in range(10):
|
||||
for_loop_outputs.append(
|
||||
# using example_prompts index 1 instead of 0 since with 0 the
|
||||
# logprobs get really close and the test doesn't pass
|
||||
vllm_model.generate_greedy([example_prompts[1]], max_tokens)
|
||||
[0])
|
||||
sampling_params = SamplingParams(n=10,
|
||||
temperature=0.001,
|
||||
seed=0,
|
||||
max_tokens=max_tokens)
|
||||
n_lt_1_outputs = vllm_model.generate([example_prompts[1]],
|
||||
sampling_params)
|
||||
token_ids, texts = n_lt_1_outputs[0]
|
||||
n_lt_1_outputs = [(token_id, text)
|
||||
for token_id, text in zip(token_ids, texts)]
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=n_lt_1_outputs,
|
||||
outputs_1_lst=for_loop_outputs,
|
||||
name_0="vllm_n_lt_1_outputs",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [20])
|
||||
def test_mamba_cache_cg_padding(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
# This test is for verifying that mamba cache is padded to CG captured
|
||||
# batch size. If it's not, a torch RuntimeError will be raised because
|
||||
# tensor dimensions aren't compatible
|
||||
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
|
||||
example_prompts.append(example_prompts[0])
|
||||
|
||||
try:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
except RuntimeError:
|
||||
pytest.fail(
|
||||
"Couldn't run batch size which is not equal to a Cuda Graph "
|
||||
"captured batch size. "
|
||||
"Could be related to mamba cache not padded correctly")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [20])
|
||||
def test_models_preemption_recompute(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
# Tests that outputs are identical with and w/o preemtions (recompute)
|
||||
assert dtype == "float"
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_model.model.llm_engine.scheduler[
|
||||
0].ENABLE_ARTIFICIAL_PREEMPT = True
|
||||
preempt_vllm_outputs = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
vllm_model.model.llm_engine.scheduler[
|
||||
0].ENABLE_ARTIFICIAL_PREEMPT = False
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=preempt_vllm_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="vllm_preepmtions",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
example_prompts,
|
||||
) -> None:
|
||||
# This test is for verifying that the Mamba inner state management doesn't
|
||||
# collapse in case where the number of incoming requests and
|
||||
# finished_requests_ids is larger than the maximum Mamba block capacity.
|
||||
# This could generally happen due to the fact that Mamba does support
|
||||
# statelessness mechanism where it can cleanup new incoming requests in
|
||||
# a single step.
|
||||
try:
|
||||
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
|
||||
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
|
||||
except ValueError:
|
||||
pytest.fail("Mamba inner state wasn't cleaned up properly between"
|
||||
"steps finished requests registered unnecessarily ")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_state_cleanup(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
example_prompts,
|
||||
) -> None:
|
||||
# This test is for verifying that the Mamba state is cleaned up between
|
||||
# steps, If its not cleaned, an error would be expected.
|
||||
try:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
for _ in range(10):
|
||||
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
|
||||
except ValueError:
|
||||
pytest.fail("Mamba inner state wasn't cleaned up between states, "
|
||||
"could be related to finished_requests_ids")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_model_print(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
# This test is for verifying whether the model's extra_repr
|
||||
# can be printed correctly.
|
||||
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
324
vllm/attention/backends/placeholder_attn.py
Normal file
324
vllm/attention/backends/placeholder_attn.py
Normal file
@ -0,0 +1,324 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||
|
||||
# Placeholder attention backend for models like Mamba and embedding models that
|
||||
# lack attention.
|
||||
|
||||
|
||||
class PlaceholderAttentionBackend(AttentionBackend):
|
||||
"""Placeholder backend for when no attention is needed."""
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "placeholder-attn"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
|
||||
return PlaceholderAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
|
||||
return PlaceholderAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
|
||||
return PlaceholderAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (1, 1, 1, 1, 1)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlaceholderAttentionMetadata(AttentionMetadata):
|
||||
"""Attention metadata for prefill and decode batched together."""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int]
|
||||
|
||||
# Number of query tokens for each request in the batch.
|
||||
# Currently, we require that all requests have the same number of query
|
||||
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||
decode_query_len: Optional[int]
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
_cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.query_start_loc is not None
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.seq_start_loc is not None
|
||||
|
||||
# Placeholders
|
||||
slot_mapping = torch.empty(0)
|
||||
block_tables = torch.empty(0)
|
||||
|
||||
self._cached_prefill_metadata = PlaceholderAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
decode_query_len=0,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
# Placeholders
|
||||
slot_mapping = torch.empty(0)
|
||||
block_tables = torch.empty(0)
|
||||
|
||||
self._cached_decode_metadata = PlaceholderAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
decode_query_len=self.decode_query_len,
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class PlaceholderAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
"""
|
||||
is_prompt = inter_data.is_prompt
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
|
||||
if is_prompt:
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
assert query_len == 1, (
|
||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len))
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
logits_soft_cap = getattr(self.runner.model_config.hf_config,
|
||||
"attn_logit_softcapping", None)
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"Please use Flashinfer backend for models with logits_soft_cap"
|
||||
" (i.e., Gemma-2). Otherwise, the output might be wrong."
|
||||
" Set Flashinfer backend by "
|
||||
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
decode_query_lens = query_lens[self.num_prefills:]
|
||||
if len(decode_query_lens) > 0:
|
||||
decode_query_len = max(decode_query_lens)
|
||||
else:
|
||||
decode_query_len = 1
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
|
||||
if use_captured_graph:
|
||||
num_decode_tokens = batch_size
|
||||
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
context_lens_tensor = torch.tensor(self.context_lens,
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
query_lens_tensor = torch.tensor(query_lens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
torch.cumsum(seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=query_start_loc.dtype,
|
||||
out=query_start_loc[1:])
|
||||
|
||||
# Placeholders
|
||||
slot_mapping = torch.empty(0)
|
||||
block_tables = torch.empty(0)
|
||||
|
||||
return PlaceholderAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
decode_query_len=decode_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
|
||||
|
||||
class PlaceholderAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
return
|
||||
|
||||
def forward(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
@ -42,10 +42,12 @@ class Attention(nn.Module):
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
sliding_window = cache_config.sliding_window
|
||||
is_attention_free = cache_config.is_attention_free
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
sliding_window = None
|
||||
is_attention_free = False
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
|
||||
@ -76,9 +78,9 @@ class Attention(nn.Module):
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
|
||||
sliding_window, dtype, kv_cache_dtype,
|
||||
block_size, blocksparse_params
|
||||
attn_backend = get_attn_backend(head_size, sliding_window, dtype,
|
||||
kv_cache_dtype, block_size,
|
||||
is_attention_free, blocksparse_params
|
||||
is not None)
|
||||
impl_cls = attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
|
||||
@ -24,6 +24,7 @@ class _Backend(enum.Enum):
|
||||
FLASHINFER = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
|
||||
|
||||
def backend_name_to_enum(backend_name: str) -> _Backend:
|
||||
@ -88,13 +89,12 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_attn_backend(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
sliding_window: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
) -> Type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
@ -105,9 +105,8 @@ def get_attn_backend(
|
||||
BlocksparseFlashAttentionBackend)
|
||||
return BlocksparseFlashAttentionBackend
|
||||
|
||||
backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
|
||||
sliding_window, dtype, kv_cache_dtype,
|
||||
block_size)
|
||||
backend = which_attn_to_use(head_size, sliding_window, dtype,
|
||||
kv_cache_dtype, block_size, is_attention_free)
|
||||
if backend == _Backend.FLASH_ATTN:
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
@ -146,23 +145,31 @@ def get_attn_backend(
|
||||
logger.info("Using Pallas backend.")
|
||||
from vllm.attention.backends.pallas import PallasAttentionBackend
|
||||
return PallasAttentionBackend
|
||||
elif backend == _Backend.NO_ATTENTION:
|
||||
from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionBackend)
|
||||
return PlaceholderAttentionBackend
|
||||
else:
|
||||
raise ValueError("Invalid attention backend.")
|
||||
|
||||
|
||||
def which_attn_to_use(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
sliding_window: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
) -> _Backend:
|
||||
"""Returns which flash attention backend to use."""
|
||||
# Default case.
|
||||
selected_backend = _Backend.FLASH_ATTN
|
||||
|
||||
# If there are no attention layers (e.g. we are running Mamba),
|
||||
# use the placeholder NO_ATTENTION
|
||||
if is_attention_free:
|
||||
return _Backend.NO_ATTENTION
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
|
||||
@ -196,6 +196,9 @@ class ModelConfig:
|
||||
if not self.skip_tokenizer_init:
|
||||
self._verify_tokenizer_mode()
|
||||
|
||||
self.is_attention_free = self._init_attention_free()
|
||||
self.has_inner_state = self._init_has_inner_state()
|
||||
|
||||
self.override_neuron_config = override_neuron_config if is_neuron(
|
||||
) else None
|
||||
self._verify_embedding_mode()
|
||||
@ -216,6 +219,14 @@ class ModelConfig:
|
||||
|
||||
return None
|
||||
|
||||
def _init_attention_free(self) -> bool:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.is_attention_free_model(architectures)
|
||||
|
||||
def _init_has_inner_state(self) -> bool:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
return ModelRegistry.model_has_inner_state(architectures)
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
tokenizer_mode = self.tokenizer_mode.lower()
|
||||
if tokenizer_mode not in ["auto", "slow", "mistral"]:
|
||||
@ -438,6 +449,10 @@ class ModelConfig:
|
||||
# FlashAttention supports only head_size 32, 64, 128, 256,
|
||||
# we need to pad head_size 192 to 256
|
||||
return 256
|
||||
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
|
||||
if hasattr(self.hf_text_config, "head_dim"):
|
||||
return self.hf_text_config.head_dim
|
||||
# FIXME(woosuk): This may not be true for all models.
|
||||
@ -469,6 +484,9 @@ class ModelConfig:
|
||||
return getattr(self.hf_config.attn_config, "kv_n_heads",
|
||||
self.hf_config.num_attention_heads)
|
||||
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
|
||||
attributes = [
|
||||
# For Falcon:
|
||||
"n_head_kv",
|
||||
@ -511,31 +529,17 @@ class ModelConfig:
|
||||
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
|
||||
return end - start
|
||||
|
||||
def contains_seqlen_agnostic_layers(
|
||||
self, parallel_config: "ParallelConfig") -> bool:
|
||||
"""True for Mamba/SSM models (Jamba)"""
|
||||
return self._get_num_seqlen_agnostic_layers(parallel_config) > 0
|
||||
|
||||
def get_layers_block_type(self,
|
||||
parallel_config: "ParallelConfig") -> List[str]:
|
||||
num_layers = self.get_num_layers(parallel_config)
|
||||
# Transformers supports layers_block_type @property
|
||||
return getattr(self.hf_config, "layers_block_type",
|
||||
["attention"] * num_layers)
|
||||
|
||||
def get_num_attention_layers(self,
|
||||
parallel_config: "ParallelConfig") -> int:
|
||||
return len([
|
||||
t for t in self.get_layers_block_type(parallel_config)
|
||||
if t == "attention"
|
||||
])
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
|
||||
def _get_num_seqlen_agnostic_layers(
|
||||
self, parallel_config: "ParallelConfig") -> int:
|
||||
return len([
|
||||
t for t in self.get_layers_block_type(parallel_config)
|
||||
if t != "attention"
|
||||
])
|
||||
num_layers = self.get_num_layers(parallel_config)
|
||||
|
||||
# Transformers supports layers_block_type @property
|
||||
layers = getattr(self.hf_config, "layers_block_type",
|
||||
["attention"] * num_layers)
|
||||
return len([t for t in layers if t == "attention"])
|
||||
|
||||
def get_multimodal_config(self) -> "MultiModalConfig":
|
||||
"""
|
||||
@ -585,6 +589,7 @@ class CacheConfig:
|
||||
gpu_memory_utilization: float,
|
||||
swap_space: float,
|
||||
cache_dtype: str,
|
||||
is_attention_free: bool = False,
|
||||
num_gpu_blocks_override: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_prefix_caching: bool = False,
|
||||
@ -595,6 +600,7 @@ class CacheConfig:
|
||||
self.swap_space_bytes = swap_space * GiB_bytes
|
||||
self.num_gpu_blocks_override = num_gpu_blocks_override
|
||||
self.cache_dtype = cache_dtype
|
||||
self.is_attention_free = is_attention_free
|
||||
self.sliding_window = sliding_window
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self.cpu_offload_gb = cpu_offload_gb
|
||||
|
||||
@ -36,10 +36,10 @@ class BlockSpaceManager(ABC):
|
||||
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
|
||||
return BlockSpaceManagerV2
|
||||
|
||||
if version == "embedding":
|
||||
from vllm.core.embedding_model_block_manager import (
|
||||
EmbeddingModelBlockSpaceManager)
|
||||
return EmbeddingModelBlockSpaceManager
|
||||
if version == "placeholder":
|
||||
from vllm.core.placeholder_block_space_manager import (
|
||||
PlaceholderBlockSpaceManager)
|
||||
return PlaceholderBlockSpaceManager
|
||||
|
||||
raise ValueError(f"Unknown version {version=}")
|
||||
|
||||
|
||||
@ -5,9 +5,10 @@ from vllm.sequence import Sequence, SequenceGroup
|
||||
from vllm.utils import Device
|
||||
|
||||
|
||||
class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
|
||||
"""An embedding version of BlockSpaceManager for use in environments
|
||||
with embedding models where block management is not required.
|
||||
class PlaceholderBlockSpaceManager(BlockSpaceManager):
|
||||
"""A version of BlockSpaceManager for use in environments
|
||||
where block management is not required.
|
||||
For example: embedding models or attention-free models like Mamba.
|
||||
|
||||
This class provides the same interface as BlockSpaceManager, but its
|
||||
methods perform no actions or return simple values like True in specific
|
||||
@ -40,7 +41,7 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
|
||||
seq: Sequence,
|
||||
num_lookahead_slots: int,
|
||||
) -> List[Tuple[int, int]]:
|
||||
return None # type: ignore
|
||||
return []
|
||||
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
pass
|
||||
@ -314,8 +314,9 @@ class Scheduler:
|
||||
version = "v1"
|
||||
if self.scheduler_config.use_v2_block_manager:
|
||||
version = "v2"
|
||||
if self.scheduler_config.embedding_mode:
|
||||
version = "embedding"
|
||||
if (self.scheduler_config.embedding_mode
|
||||
or self.cache_config.is_attention_free):
|
||||
version = "placeholder"
|
||||
|
||||
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
|
||||
version)
|
||||
|
||||
@ -912,6 +912,7 @@ class EngineArgs:
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
swap_space=self.swap_space,
|
||||
cache_dtype=self.kv_cache_dtype,
|
||||
is_attention_free=model_config.is_attention_free,
|
||||
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
||||
sliding_window=model_config.get_sliding_window(),
|
||||
enable_prefix_caching=self.enable_prefix_caching,
|
||||
@ -945,13 +946,9 @@ class EngineArgs:
|
||||
use_sliding_window = (model_config.get_sliding_window()
|
||||
is not None)
|
||||
use_spec_decode = self.speculative_model is not None
|
||||
has_seqlen_agnostic_layers = (
|
||||
model_config.contains_seqlen_agnostic_layers(
|
||||
parallel_config))
|
||||
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||
and not self.enable_lora
|
||||
and not self.enable_prompt_adapter
|
||||
and not has_seqlen_agnostic_layers):
|
||||
and not self.enable_prompt_adapter):
|
||||
self.enable_chunked_prefill = True
|
||||
logger.warning(
|
||||
"Chunked prefill is enabled by default for models with "
|
||||
|
||||
@ -6,7 +6,8 @@ import json
|
||||
import os
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
|
||||
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
|
||||
Tuple, Union)
|
||||
|
||||
import filelock
|
||||
import gguf
|
||||
@ -559,6 +560,38 @@ def row_parallel_weight_loader(param: torch.Tensor,
|
||||
return default_weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
|
||||
|
||||
def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
|
||||
"""Create a weight loader that shards the weights along the given axis"""
|
||||
|
||||
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
shard_size = param.data.shape[shard_axis]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size)
|
||||
|
||||
return default_weight_loader(param, loaded_weight)
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def composed_weight_loader(
|
||||
loader: LoaderFunction, fn: Callable[[torch.Tensor],
|
||||
torch.Tensor]) -> LoaderFunction:
|
||||
"""Create a weight loader that post-processes the weights after loading"""
|
||||
|
||||
def composed_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
loader(param, loaded_weight)
|
||||
param.data.copy_(fn(param))
|
||||
return
|
||||
|
||||
return composed_loader
|
||||
|
||||
|
||||
def initialize_dummy_weights(
|
||||
model: torch.nn.Module,
|
||||
low: float = -1e-3,
|
||||
|
||||
@ -271,7 +271,7 @@ class HasInnerState(Protocol):
|
||||
"""
|
||||
A flag that indicates this model has inner state.
|
||||
Models that has inner state usually need access to the scheduler_config
|
||||
for max_num_seqs ,etc... (Currently only used by Jamba)
|
||||
for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -307,3 +307,46 @@ def has_inner_state(
|
||||
return isinstance(model, _HasInnerStateType)
|
||||
|
||||
return isinstance(model, HasInnerState)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IsAttentionFree(Protocol):
|
||||
"""The interface required for all models like Mamba that lack attention,
|
||||
but do have state whose size is constant wrt the number of tokens."""
|
||||
|
||||
is_attention_free: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model has no attention.
|
||||
Used for block manager and attention backend selection.
|
||||
True for Mamba but not Jamba.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _IsAttentionFreeType(Protocol):
|
||||
is_attention_free: ClassVar[Literal[True]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]:
|
||||
...
|
||||
|
||||
|
||||
def is_attention_free(
|
||||
model: Union[Type[object], object]
|
||||
) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _IsAttentionFreeType)
|
||||
|
||||
return isinstance(model, IsAttentionFree)
|
||||
|
||||
@ -1,18 +1,16 @@
|
||||
# coding=utf-8
|
||||
"""Inference-only Jamba model."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from transformers import JambaConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -29,7 +27,9 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheManager
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -99,16 +99,6 @@ class JambaMambaMixer(nn.Module):
|
||||
bias=True,
|
||||
skip_bias_add=True)
|
||||
|
||||
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
param.data.copy_(
|
||||
loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
|
||||
dim=0)[tp_rank])
|
||||
|
||||
def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
|
||||
weight_loader(param, -torch.exp(loaded_weight.float()))
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.A = nn.Parameter(
|
||||
torch.empty(
|
||||
@ -118,8 +108,10 @@ class JambaMambaMixer(nn.Module):
|
||||
))
|
||||
self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
|
||||
|
||||
set_weight_attrs(self.D, {"weight_loader": weight_loader})
|
||||
set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
|
||||
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
|
||||
a_weight_loader = composed_weight_loader(
|
||||
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
|
||||
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.intermediate_size,
|
||||
@ -571,10 +563,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
)
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
|
||||
# Maps between the request id and a dict that maps between the seq_id
|
||||
# and its index inside the self.mamba_cache
|
||||
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
@ -586,203 +576,36 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs):
|
||||
if not self.mamba_cache:
|
||||
self._prepare_mamba_cache()
|
||||
if self.mamba_cache is None:
|
||||
max_batch_size = (_get_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs) if self.scheduler_config
|
||||
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
|
||||
|
||||
if "seqlen_agnostic_capture_inputs" not in kwargs:
|
||||
# We get here only on Prefill/Eager mode runs
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
mamba_cache = self._release_finished_and_prepare_mamba_cache(
|
||||
finished_requests_ids, request_ids_to_seq_ids)
|
||||
else:
|
||||
# CUDA graph capturing runs
|
||||
mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]
|
||||
layers_type = self.config.layers_block_type
|
||||
num_mamba_layers = sum(
|
||||
[layer_type == "mamba" for layer_type in layers_type])
|
||||
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
|
||||
*self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
|
||||
input_ids, attn_metadata, **kwargs)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, mamba_cache[0],
|
||||
mamba_cache[1])
|
||||
attn_metadata, mamba_cache_tensors[0],
|
||||
mamba_cache_tensors[1])
|
||||
return hidden_states
|
||||
|
||||
def _swap_mamba_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.mamba_cache) > 0
|
||||
for cache_t in self.mamba_cache:
|
||||
cache_t[:, [to_index,from_index]] = \
|
||||
cache_t[:, [from_index,to_index]]
|
||||
|
||||
def _copy_mamba_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.mamba_cache) > 0
|
||||
for cache_t in self.mamba_cache:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
||||
|
||||
def _move_out_if_already_occupied(self, index: int,
|
||||
all_occupied_indices: List[int]):
|
||||
if index in all_occupied_indices:
|
||||
first_free_index = self._first_free_index_in_mamba_cache()
|
||||
# In case occupied, move the occupied to a new empty block
|
||||
self._move_cache_index_and_mappings(from_index=index,
|
||||
to_index=first_free_index)
|
||||
|
||||
def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
|
||||
seq_id: int,
|
||||
destination_index: int):
|
||||
"""
|
||||
Assign (req_id,seq_id) pair to a `destination_index` index, if
|
||||
already occupied, move the occupying index to a free index.
|
||||
"""
|
||||
all_occupied_indices = self._get_all_occupied_indices()
|
||||
if cur_rid not in self.mamba_cache_indices_mapping:
|
||||
self._move_out_if_already_occupied(
|
||||
index=destination_index,
|
||||
all_occupied_indices=all_occupied_indices)
|
||||
self.mamba_cache_indices_mapping[cur_rid] = {
|
||||
seq_id: destination_index
|
||||
}
|
||||
elif seq_id not in (seq_ids2indices :=
|
||||
self.mamba_cache_indices_mapping[cur_rid]):
|
||||
# parallel sampling , where n > 1, assume prefill have
|
||||
# already happened now we only need to copy the already
|
||||
# existing cache into the siblings seq_ids caches
|
||||
self._move_out_if_already_occupied(
|
||||
index=destination_index,
|
||||
all_occupied_indices=all_occupied_indices)
|
||||
index_exists = list(seq_ids2indices.values())[0]
|
||||
# case of decoding n>1, copy prefill cache to decoding indices
|
||||
self._copy_mamba_cache(from_index=index_exists,
|
||||
to_index=destination_index)
|
||||
self.mamba_cache_indices_mapping[cur_rid][
|
||||
seq_id] = destination_index
|
||||
else:
|
||||
# already exists
|
||||
cache_index_already_exists = self.mamba_cache_indices_mapping[
|
||||
cur_rid][seq_id]
|
||||
if cache_index_already_exists != destination_index:
|
||||
# In case the seq id already exists but not in
|
||||
# the right destination, swap it with what's occupying it
|
||||
self._swap_pair_indices_and_mappings(
|
||||
from_index=cache_index_already_exists,
|
||||
to_index=destination_index)
|
||||
|
||||
def _prepare_current_run_mamba_cache(
|
||||
self, request_ids_to_seq_ids: Dict[str, list[int]],
|
||||
finished_requests_ids: List[str]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
running_indices = []
|
||||
request_ids_to_seq_ids_flatten = [
|
||||
(req_id, seq_id)
|
||||
for req_id, seq_ids in request_ids_to_seq_ids.items()
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
batch_size = len(request_ids_to_seq_ids_flatten)
|
||||
for dest_index, (request_id,
|
||||
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
|
||||
if request_id in finished_requests_ids:
|
||||
# Do not allocate cache index for requests that run
|
||||
# and finish right after
|
||||
continue
|
||||
self._assign_seq_id_to_mamba_cache_in_specific_dest(
|
||||
request_id, seq_id, dest_index)
|
||||
running_indices.append(dest_index)
|
||||
|
||||
self._clean_up_first_bs_blocks(batch_size, running_indices)
|
||||
conv_state = self.mamba_cache[0][:, :batch_size]
|
||||
temporal_state = self.mamba_cache[1][:, :batch_size]
|
||||
|
||||
return (conv_state, temporal_state)
|
||||
|
||||
def _get_all_occupied_indices(self):
|
||||
return [
|
||||
cache_idx
|
||||
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
|
||||
for cache_idx in seq_ids2indices.values()
|
||||
]
|
||||
|
||||
def _clean_up_first_bs_blocks(self, batch_size: int,
|
||||
indices_for_current_run: List[int]):
|
||||
# move out all of the occupied but currently not running blocks
|
||||
# outside of the first n blocks
|
||||
destination_indices = range(batch_size)
|
||||
max_possible_batch_size = self.mamba_cache[0].shape[1]
|
||||
for destination_index in destination_indices:
|
||||
if destination_index in self._get_all_occupied_indices() and \
|
||||
destination_index not in indices_for_current_run:
|
||||
# move not running indices outside of the batch
|
||||
all_other_indices = list(
|
||||
range(batch_size, max_possible_batch_size))
|
||||
first_avail_index = self._first_free_index_in_mamba_cache(
|
||||
all_other_indices)
|
||||
self._swap_indices(from_index=destination_index,
|
||||
to_index=first_avail_index)
|
||||
|
||||
def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
|
||||
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
|
||||
self._update_mapping_index(from_index=from_index, to_index=to_index)
|
||||
|
||||
def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
|
||||
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
|
||||
self._swap_mapping_index(from_index=from_index, to_index=to_index)
|
||||
|
||||
def _swap_mapping_index(self, from_index: int, to_index: int):
|
||||
for seq_ids2index in self.mamba_cache_indices_mapping.values():
|
||||
for seq_id, index in seq_ids2index.items():
|
||||
if from_index == index:
|
||||
seq_ids2index.update({seq_id: to_index})
|
||||
elif to_index == index:
|
||||
seq_ids2index.update({seq_id: from_index})
|
||||
|
||||
def _update_mapping_index(self, from_index: int, to_index: int):
|
||||
for seq_ids2index in self.mamba_cache_indices_mapping.values():
|
||||
for seq_id, index in seq_ids2index.items():
|
||||
if from_index == index:
|
||||
seq_ids2index.update({seq_id: to_index})
|
||||
return
|
||||
|
||||
def _release_finished_and_prepare_mamba_cache(
|
||||
self, finished_requests_ids,
|
||||
request_ids_to_seq_ids) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self._release_mamba_cache(finished_requests_ids)
|
||||
return self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
||||
finished_requests_ids)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
"""
|
||||
Copy the relevant Mamba cache into the CUDA graph input buffer
|
||||
that was provided during the capture runs
|
||||
(JambaForCausalLM.mamba_gc_cache_buffer).
|
||||
"""
|
||||
self._release_finished_and_prepare_mamba_cache(
|
||||
kwargs["finished_requests_ids"], kwargs["request_ids_to_seq_ids"])
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
"""
|
||||
Provide the CUDA graph capture runs with a buffer in adjusted size.
|
||||
The buffer is used to maintain the Mamba Cache during the CUDA graph
|
||||
replay runs.
|
||||
"""
|
||||
return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
|
||||
|
||||
def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
|
||||
for req_id in finished_seq_groups_req_ids:
|
||||
if req_id in self.mamba_cache_indices_mapping:
|
||||
self.mamba_cache_indices_mapping.pop(req_id)
|
||||
|
||||
def _first_free_index_in_mamba_cache(
|
||||
self, indices_range: Optional[List[int]] = None) -> int:
|
||||
assert self.mamba_cache is not None
|
||||
if indices_range is None:
|
||||
max_possible_batch_size = self.mamba_cache[0].shape[1]
|
||||
indices_range = list(range(max_possible_batch_size))
|
||||
all_occupied_indices = self._get_all_occupied_indices()
|
||||
for i in indices_range:
|
||||
if i not in all_occupied_indices:
|
||||
return i
|
||||
raise Exception("Couldn't find a free spot in the mamba cache! This"
|
||||
"should never happen")
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def _get_mamba_cache_shape(
|
||||
self
|
||||
) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]:
|
||||
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
hidden_size = self.config.hidden_size
|
||||
conv_state_shape = (
|
||||
@ -790,31 +613,11 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
self.config.mamba_d_conv - 1,
|
||||
)
|
||||
temporal_state_shape = (
|
||||
self.config.mamba_expand * self.config.hidden_size // world_size,
|
||||
self.config.mamba_expand * hidden_size // world_size,
|
||||
self.config.mamba_d_state,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
def _prepare_mamba_cache(self):
|
||||
dtype = self.lm_head.weight.dtype
|
||||
layers_type = self.config.layers_block_type
|
||||
mamba_layers = sum(
|
||||
[layer_type == "mamba" for layer_type in layers_type])
|
||||
max_batch_size = (_get_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs) if self.scheduler_config else
|
||||
max(_BATCH_SIZES_TO_CAPTURE) + 2)
|
||||
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
|
||||
assert conv_state_shape is not None and temporal_state_shape is not None
|
||||
|
||||
self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) +
|
||||
conv_state_shape,
|
||||
dtype=dtype,
|
||||
device="cuda"),
|
||||
torch.empty(size=(mamba_layers, max_batch_size) +
|
||||
temporal_state_shape,
|
||||
dtype=dtype,
|
||||
device="cuda"))
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
499
vllm/model_executor/models/mamba.py
Normal file
499
vllm/model_executor/models/mamba.py
Normal file
@ -0,0 +1,499 @@
|
||||
# coding=utf-8
|
||||
"""PyTorch MAMBA model."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import MambaConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
composed_weight_loader, default_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||
IsAttentionFree)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheManager
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
|
||||
_get_graph_batch_size)
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaCacheParams:
|
||||
is_prompt: bool = False
|
||||
conv_state: torch.Tensor = torch.Tensor()
|
||||
ssm_state: torch.Tensor = torch.Tensor()
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
class MambaMixer(nn.Module):
|
||||
"""
|
||||
Compute ∆, A, B, C, and D the state space parameters and compute
|
||||
the `contextualized_states`. A, D are input independent
|
||||
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
|
||||
for why A isn't selective) ∆, B, C are input-dependent
|
||||
(this is a key difference between Mamba and the linear time
|
||||
invariant S4, and is why Mamba is called
|
||||
**selective** state spaces)
|
||||
"""
|
||||
|
||||
def __init__(self, config: MambaConfig, layer_idx):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.ssm_state_size = config.state_size
|
||||
self.conv_kernel_size = config.conv_kernel
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.time_step_rank = int(config.time_step_rank)
|
||||
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=self.conv_kernel_size,
|
||||
output_size=self.intermediate_size,
|
||||
bias=config.use_conv_bias,
|
||||
)
|
||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||
# Can't do this in `weight_loader` since it already exists in
|
||||
# `ColumnParallelLinear` and `set_weight_attrs`
|
||||
# doesn't allow to override it
|
||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(self.hidden_size,
|
||||
[self.intermediate_size] * 2,
|
||||
bias=config.use_bias)
|
||||
# selective projection used to make dt, B and C input dependent
|
||||
self.x_proj = RowParallelLinear(
|
||||
self.intermediate_size,
|
||||
self.time_step_rank + self.ssm_state_size * 2,
|
||||
bias=False,
|
||||
)
|
||||
# time step projection (discretization) -
|
||||
# In the forward we need to apply dt_proj without the bias,
|
||||
# as the bias is added in the selective scan kernel.
|
||||
self.dt_proj = ColumnParallelLinear(self.time_step_rank,
|
||||
self.intermediate_size,
|
||||
bias=True,
|
||||
skip_bias_add=True)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.A = nn.Parameter(
|
||||
torch.empty(
|
||||
self.intermediate_size // tp_size,
|
||||
self.ssm_state_size,
|
||||
dtype=torch.float32,
|
||||
))
|
||||
self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
|
||||
|
||||
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
|
||||
a_weight_loader = composed_weight_loader(
|
||||
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
|
||||
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=config.use_bias,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
self.activation = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor):
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
hidden_states, gate = projected_states.chunk(2, dim=-2)
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if attn_metadata.query_start_loc is not None \
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
hidden_states = causal_conv1d_fn(
|
||||
hidden_states,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
hidden_states = causal_conv1d_update(
|
||||
hidden_states.transpose(0, 1),
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
# 3.a. input varying initialization of time_step, B and C
|
||||
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
|
||||
|
||||
time_step, B, C = torch.split(
|
||||
ssm_parameters,
|
||||
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
|
||||
|
||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
||||
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
|
||||
self.dt_proj, "bias") else None)
|
||||
|
||||
if attn_metadata.query_start_loc is not None \
|
||||
and attn_metadata.context_lens_tensor is not None:
|
||||
scan_outputs = selective_scan_fn(
|
||||
hidden_states,
|
||||
ssm_state,
|
||||
discrete_time_step,
|
||||
self.A,
|
||||
B.transpose(-2, -1),
|
||||
C.transpose(-2, -1),
|
||||
self.D.float(),
|
||||
gate,
|
||||
time_proj_bias,
|
||||
delta_softplus=True,
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
scan_outputs = selective_state_update(
|
||||
ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
self.A,
|
||||
B,
|
||||
C,
|
||||
self.D,
|
||||
gate.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
|
||||
# 4. Final linear projection
|
||||
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
|
||||
-1))[0]
|
||||
return contextualized_states
|
||||
|
||||
|
||||
class MambaMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MambaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
intermediate_size = config.intermediate_size
|
||||
hidden_act = config.hidden_act
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class MambaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: MambaConfig,
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.config = config
|
||||
self.mixer = MambaMixer(config, layer_idx)
|
||||
|
||||
self.feed_forward = MambaMLP(config, quant_config=quant_config)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer(hidden_states, attn_metadata, conv_state,
|
||||
ssm_state)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class MambaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MambaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
self.embeddings = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
decoder_layers = []
|
||||
for i in range(config.num_hidden_layers):
|
||||
decoder_layers.append(
|
||||
MambaDecoderLayer(config,
|
||||
layer_idx=i,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config))
|
||||
self.layers = nn.ModuleList(decoder_layers)
|
||||
self.norm_f = RMSNorm(config.hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embeddings(input_ids)
|
||||
residual = None
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
current_ssm_state = ssm_state[i]
|
||||
current_conv_state = conv_state[i]
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
conv_state=current_conv_state,
|
||||
ssm_state=current_ssm_state,
|
||||
)
|
||||
hidden_states, _ = self.norm_f(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"qkv_proj",
|
||||
"o_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
embedding_modules = {
|
||||
"embeddings": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MambaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
scheduler_config: Optional[SchedulerConfig] = None,
|
||||
) -> None:
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Mamba does not support prefix caching"
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.backbone = MambaModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
lora_config=lora_config)
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
self.lm_head = self.backbone.embeddings
|
||||
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Optional[MambaCacheManager] = None
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
max_batch_size = (_get_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs) if self.scheduler_config
|
||||
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.lm_head.weight.dtype, self.config.num_hidden_layers,
|
||||
max_batch_size, *self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_tensors = self.mamba_cache.current_run_tensors(
|
||||
input_ids, attn_metadata, **kwargs)
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, kv_caches,
|
||||
attn_metadata, mamba_cache_tensors[0],
|
||||
mamba_cache_tensors[1])
|
||||
|
||||
return hidden_states
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def _get_mamba_cache_shape(
|
||||
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
conv_state_shape = (
|
||||
self.config.intermediate_size // world_size,
|
||||
self.config.conv_kernel - 1,
|
||||
)
|
||||
temporal_state_shape = (
|
||||
self.config.intermediate_size // world_size,
|
||||
self.config.state_size,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if "A_log" in name:
|
||||
name = name.replace("A_log", "A")
|
||||
|
||||
if ".self_attn." in name:
|
||||
name = name.replace(".self_attn", "")
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
222
vllm/model_executor/models/mamba_cache.py
Normal file
222
vllm/model_executor/models/mamba_cache.py
Normal file
@ -0,0 +1,222 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
|
||||
|
||||
class MambaCacheManager:
|
||||
|
||||
def __init__(self, dtype, num_mamba_layers, max_batch_size,
|
||||
conv_state_shape, temporal_state_shape):
|
||||
|
||||
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||
conv_state_shape,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
|
||||
temporal_state_shape,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
self.mamba_cache = (conv_state, temporal_state)
|
||||
|
||||
# Maps between the request id and a dict that maps between the seq_id
|
||||
# and its index inside the self.mamba_cache
|
||||
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
|
||||
|
||||
def current_run_tensors(self, input_ids: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata, **kwargs):
|
||||
"""
|
||||
Return the tensors for the current run's conv and ssm state.
|
||||
"""
|
||||
if "seqlen_agnostic_capture_inputs" not in kwargs:
|
||||
# We get here only on Prefill/Eager mode runs
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
mamba_cache_tensors = self._prepare_current_run_mamba_cache(
|
||||
request_ids_to_seq_ids, finished_requests_ids)
|
||||
|
||||
else:
|
||||
# CUDA graph capturing runs
|
||||
mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"]
|
||||
|
||||
return mamba_cache_tensors
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
"""
|
||||
Copy the relevant Mamba cache into the CUDA graph input buffer
|
||||
that was provided during the capture runs
|
||||
(JambaForCausalLM.mamba_gc_cache_buffer).
|
||||
"""
|
||||
assert all(
|
||||
key in kwargs
|
||||
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
|
||||
finished_requests_ids = kwargs["finished_requests_ids"]
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
|
||||
self._release_finished_requests(finished_requests_ids)
|
||||
self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
||||
finished_requests_ids)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
"""
|
||||
Provide the CUDA graph capture runs with a buffer in adjusted size.
|
||||
The buffer is used to maintain the Mamba Cache during the CUDA graph
|
||||
replay runs.
|
||||
"""
|
||||
return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
|
||||
|
||||
def _swap_mamba_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.mamba_cache) > 0
|
||||
for cache_t in self.mamba_cache:
|
||||
cache_t[:, [to_index,from_index]] = \
|
||||
cache_t[:, [from_index,to_index]]
|
||||
|
||||
def _copy_mamba_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.mamba_cache) > 0
|
||||
for cache_t in self.mamba_cache:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
||||
|
||||
def _move_out_if_already_occupied(self, index: int,
|
||||
all_occupied_indices: List[int]):
|
||||
if index in all_occupied_indices:
|
||||
first_free_index = self._first_free_index_in_mamba_cache()
|
||||
# In case occupied, move the occupied to a new empty block
|
||||
self._move_cache_index_and_mappings(from_index=index,
|
||||
to_index=first_free_index)
|
||||
|
||||
def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
|
||||
seq_id: int,
|
||||
destination_index: int):
|
||||
"""
|
||||
Assign (req_id,seq_id) pair to a `destination_index` index, if
|
||||
already occupied, move the occupying index to a free index.
|
||||
"""
|
||||
all_occupied_indices = self._get_all_occupied_indices()
|
||||
if cur_rid not in self.mamba_cache_indices_mapping:
|
||||
self._move_out_if_already_occupied(
|
||||
index=destination_index,
|
||||
all_occupied_indices=all_occupied_indices)
|
||||
self.mamba_cache_indices_mapping[cur_rid] = {
|
||||
seq_id: destination_index
|
||||
}
|
||||
elif seq_id not in (seq_ids2indices :=
|
||||
self.mamba_cache_indices_mapping[cur_rid]):
|
||||
# parallel sampling , where n > 1, assume prefill have
|
||||
# already happened now we only need to copy the already
|
||||
# existing cache into the siblings seq_ids caches
|
||||
self._move_out_if_already_occupied(
|
||||
index=destination_index,
|
||||
all_occupied_indices=all_occupied_indices)
|
||||
index_exists = list(seq_ids2indices.values())[0]
|
||||
# case of decoding n>1, copy prefill cache to decoding indices
|
||||
self._copy_mamba_cache(from_index=index_exists,
|
||||
to_index=destination_index)
|
||||
self.mamba_cache_indices_mapping[cur_rid][
|
||||
seq_id] = destination_index
|
||||
else:
|
||||
# already exists
|
||||
cache_index_already_exists = self.mamba_cache_indices_mapping[
|
||||
cur_rid][seq_id]
|
||||
if cache_index_already_exists != destination_index:
|
||||
# In case the seq id already exists but not in
|
||||
# the right destination, swap it with what's occupying it
|
||||
self._swap_pair_indices_and_mappings(
|
||||
from_index=cache_index_already_exists,
|
||||
to_index=destination_index)
|
||||
|
||||
def _prepare_current_run_mamba_cache(
|
||||
self, request_ids_to_seq_ids: Dict[str, list[int]],
|
||||
finished_requests_ids: List[str]):
|
||||
running_indices = []
|
||||
request_ids_to_seq_ids_flatten = [
|
||||
(req_id, seq_id)
|
||||
for req_id, seq_ids in request_ids_to_seq_ids.items()
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
batch_size = len(request_ids_to_seq_ids_flatten)
|
||||
for dest_index, (request_id,
|
||||
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
|
||||
if request_id in finished_requests_ids:
|
||||
# Do not allocate cache index for requests that run
|
||||
# and finish right after
|
||||
continue
|
||||
self._assign_seq_id_to_mamba_cache_in_specific_dest(
|
||||
request_id, seq_id, dest_index)
|
||||
running_indices.append(dest_index)
|
||||
|
||||
self._clean_up_first_bs_blocks(batch_size, running_indices)
|
||||
conv_state = self.mamba_cache[0][:, :batch_size]
|
||||
temporal_state = self.mamba_cache[1][:, :batch_size]
|
||||
|
||||
return (conv_state, temporal_state)
|
||||
|
||||
def _get_all_occupied_indices(self):
|
||||
return [
|
||||
cache_idx
|
||||
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
|
||||
for cache_idx in seq_ids2indices.values()
|
||||
]
|
||||
|
||||
def _clean_up_first_bs_blocks(self, batch_size: int,
|
||||
indices_for_current_run: List[int]):
|
||||
# move out all of the occupied but currently not running blocks
|
||||
# outside of the first n blocks
|
||||
destination_indices = range(batch_size)
|
||||
max_possible_batch_size = self.mamba_cache[0].shape[1]
|
||||
for destination_index in destination_indices:
|
||||
if destination_index in self._get_all_occupied_indices() and \
|
||||
destination_index not in indices_for_current_run:
|
||||
# move not running indices outside of the batch
|
||||
all_other_indices = list(
|
||||
range(batch_size, max_possible_batch_size))
|
||||
first_avail_index = self._first_free_index_in_mamba_cache(
|
||||
all_other_indices)
|
||||
self._swap_indices(from_index=destination_index,
|
||||
to_index=first_avail_index)
|
||||
|
||||
def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
|
||||
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
|
||||
self._update_mapping_index(from_index=from_index, to_index=to_index)
|
||||
|
||||
def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
|
||||
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
|
||||
self._swap_mapping_index(from_index=from_index, to_index=to_index)
|
||||
|
||||
def _swap_mapping_index(self, from_index: int, to_index: int):
|
||||
for seq_ids2index in self.mamba_cache_indices_mapping.values():
|
||||
for seq_id, index in seq_ids2index.items():
|
||||
if from_index == index:
|
||||
seq_ids2index.update({seq_id: to_index})
|
||||
elif to_index == index:
|
||||
seq_ids2index.update({seq_id: from_index})
|
||||
|
||||
def _update_mapping_index(self, from_index: int, to_index: int):
|
||||
for seq_ids2index in self.mamba_cache_indices_mapping.values():
|
||||
for seq_id, index in seq_ids2index.items():
|
||||
if from_index == index:
|
||||
seq_ids2index.update({seq_id: to_index})
|
||||
return
|
||||
|
||||
def _release_finished_requests(self,
|
||||
finished_seq_groups_req_ids: List[str]):
|
||||
for req_id in finished_seq_groups_req_ids:
|
||||
if req_id in self.mamba_cache_indices_mapping:
|
||||
self.mamba_cache_indices_mapping.pop(req_id)
|
||||
|
||||
def _first_free_index_in_mamba_cache(
|
||||
self, indices_range: Optional[List[int]] = None) -> int:
|
||||
assert self.mamba_cache is not None
|
||||
if indices_range is None:
|
||||
max_possible_batch_size = self.mamba_cache[0].shape[1]
|
||||
indices_range = list(range(max_possible_batch_size))
|
||||
all_occupied_indices = self._get_all_occupied_indices()
|
||||
for i in indices_range:
|
||||
if i not in all_occupied_indices:
|
||||
return i
|
||||
raise Exception("Couldn't find a free spot in the mamba cache! This"
|
||||
"should never happen")
|
||||
@ -14,7 +14,8 @@ import torch.nn as nn
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import supports_multimodal, supports_pp
|
||||
from .interfaces import (has_inner_state, is_attention_free,
|
||||
supports_multimodal, supports_pp)
|
||||
from .interfaces_base import is_embedding_model, is_text_generation_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -52,6 +53,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
# For decapoda-research/llama-*
|
||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MambaForCausalLM": ("mamba", "MambaForCausalLM"),
|
||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
||||
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
|
||||
@ -157,6 +159,8 @@ class _ModelInfo:
|
||||
is_embedding_model: bool
|
||||
supports_multimodal: bool
|
||||
supports_pp: bool
|
||||
has_inner_state: bool
|
||||
is_attention_free: bool
|
||||
|
||||
@staticmethod
|
||||
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
|
||||
@ -165,6 +169,8 @@ class _ModelInfo:
|
||||
is_embedding_model=is_embedding_model(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_pp=supports_pp(model),
|
||||
has_inner_state=has_inner_state(model),
|
||||
is_attention_free=is_attention_free(model),
|
||||
)
|
||||
|
||||
|
||||
@ -380,6 +386,14 @@ class _ModelRegistry:
|
||||
) -> bool:
|
||||
return self.inspect_model_cls(architectures).supports_pp
|
||||
|
||||
def model_has_inner_state(self, architectures: Union[str,
|
||||
List[str]]) -> bool:
|
||||
return self.inspect_model_cls(architectures).has_inner_state
|
||||
|
||||
def is_attention_free_model(self, architectures: Union[str,
|
||||
List[str]]) -> bool:
|
||||
return self.inspect_model_cls(architectures).is_attention_free
|
||||
|
||||
|
||||
ModelRegistry = _ModelRegistry({
|
||||
model_arch: _LazyRegisteredModel(
|
||||
|
||||
@ -52,15 +52,12 @@ class CacheEngine:
|
||||
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
# Get attention backend.
|
||||
self.attn_backend = get_attn_backend(
|
||||
model_config.get_num_attention_heads(parallel_config),
|
||||
self.head_size,
|
||||
self.num_kv_heads,
|
||||
model_config.get_sliding_window(),
|
||||
model_config.dtype,
|
||||
cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
)
|
||||
self.attn_backend = get_attn_backend(self.head_size,
|
||||
model_config.get_sliding_window(),
|
||||
model_config.dtype,
|
||||
cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
model_config.is_attention_free)
|
||||
|
||||
# Initialize the cache.
|
||||
self.gpu_cache = self._allocate_kv_cache(
|
||||
|
||||
@ -418,13 +418,12 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_num_attention_heads(self.parallel_config),
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.get_num_kv_heads(self.parallel_config),
|
||||
self.model_config.get_sliding_window(),
|
||||
self.model_config.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
)
|
||||
|
||||
# Multi-modal data support
|
||||
|
||||
@ -56,13 +56,12 @@ class CPUCacheEngine:
|
||||
|
||||
# Get attention backend.
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_num_attention_heads(self.parallel_config),
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.get_num_kv_heads(self.parallel_config),
|
||||
self.model_config.get_sliding_window(),
|
||||
self.model_config.dtype,
|
||||
cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
)
|
||||
|
||||
# Initialize the cache.
|
||||
|
||||
@ -196,7 +196,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
seqlen_agnostic_kwargs = {
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_seqlen_agnostic else {}
|
||||
} if self.has_inner_state else {}
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
with set_forward_context(model_input.attn_metadata):
|
||||
|
||||
@ -17,7 +17,6 @@ import torch.nn as nn
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.attention.backends.abstract import AttentionState
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
@ -991,8 +990,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.graph_memory_pool: Optional[Tuple[
|
||||
int, int]] = None # Set during graph capture.
|
||||
|
||||
self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
|
||||
parallel_config)
|
||||
self.has_inner_state = model_config.has_inner_state
|
||||
|
||||
# When using CUDA graph, the input block tables must be padded to
|
||||
# max_seq_len_to_capture. However, creating the block table in
|
||||
@ -1003,22 +1001,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.graph_block_tables = np.zeros(
|
||||
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
|
||||
dtype=np.int32)
|
||||
num_attn_heads = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
self.attn_backend = get_attn_backend(
|
||||
num_attn_heads,
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.get_num_kv_heads(self.parallel_config),
|
||||
self.model_config.get_sliding_window(),
|
||||
self.model_config.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
) if num_attn_heads else None
|
||||
if self.attn_backend:
|
||||
self.attn_state = self.attn_backend.get_state_cls()(
|
||||
weakref.proxy(self))
|
||||
else:
|
||||
self.attn_state = CommonAttentionState(weakref.proxy(self))
|
||||
self.model_config.is_attention_free,
|
||||
)
|
||||
self.attn_state = self.attn_backend.get_state_cls()(
|
||||
weakref.proxy(self))
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = input_registry
|
||||
@ -1498,7 +1490,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
"previous_hidden_states"] = previous_hidden_states[:
|
||||
batch_size]
|
||||
|
||||
if self.has_seqlen_agnostic:
|
||||
if self.has_inner_state:
|
||||
# Only used by Mamba-based models CUDA graph atm (Jamba)
|
||||
capture_inputs.update({
|
||||
"seqlen_agnostic_capture_inputs":
|
||||
@ -1647,7 +1639,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
seqlen_agnostic_kwargs = {
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_seqlen_agnostic else {}
|
||||
} if self.has_inner_state else {}
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_start = torch.cuda.Event(enable_timing=True)
|
||||
@ -1852,10 +1844,14 @@ class CUDAGraphRunner:
|
||||
# Copy the input tensors to the input buffers.
|
||||
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
||||
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
||||
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
|
||||
non_blocking=True)
|
||||
|
||||
if self.backend_name != "placeholder-attn":
|
||||
self.input_buffers["slot_mapping"].copy_(
|
||||
attn_metadata.slot_mapping, non_blocking=True)
|
||||
|
||||
self.attn_state.prepare_graph_input_buffers(
|
||||
self.input_buffers, attn_metadata, self._is_encoder_decoder_model)
|
||||
|
||||
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
|
||||
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
|
||||
**kwargs)
|
||||
|
||||
@ -74,13 +74,12 @@ class OpenVINOModelRunner:
|
||||
self.block_size = cache_config.block_size
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_num_attention_heads(self.parallel_config),
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.get_num_kv_heads(self.parallel_config),
|
||||
self.model_config.get_sliding_window(),
|
||||
self.model_config.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
)
|
||||
|
||||
# Multi-modal data support
|
||||
|
||||
@ -70,13 +70,12 @@ class OpenVINOCacheEngine:
|
||||
|
||||
# Get attention backend.
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_num_attention_heads(self.parallel_config),
|
||||
self.head_size,
|
||||
self.model_config.get_num_kv_heads(self.parallel_config),
|
||||
self.model_config.get_sliding_window(),
|
||||
self.model_config.dtype,
|
||||
self.cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
)
|
||||
|
||||
# Initialize the cache.
|
||||
|
||||
@ -113,13 +113,12 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
(self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
|
||||
dtype=np.int32)
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_num_attention_heads(self.parallel_config),
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.get_num_kv_heads(self.parallel_config),
|
||||
self.model_config.get_sliding_window(),
|
||||
self.model_config.dtype,
|
||||
self.cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
False,
|
||||
)
|
||||
self.cached_step_outputs: List[torch.Tensor] = []
|
||||
|
||||
@ -236,11 +236,15 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
cache_block_size = self.get_cache_block_size_bytes()
|
||||
num_gpu_blocks = int(
|
||||
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory) // cache_block_size)
|
||||
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||
cache_block_size)
|
||||
if cache_block_size == 0:
|
||||
num_gpu_blocks = 0
|
||||
num_cpu_blocks = 0
|
||||
else:
|
||||
num_gpu_blocks = int(
|
||||
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory) // cache_block_size)
|
||||
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||
cache_block_size)
|
||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
if self.model_runner.lora_manager:
|
||||
@ -257,6 +261,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
"""
|
||||
raise_if_cache_size_invalid(num_gpu_blocks,
|
||||
self.cache_config.block_size,
|
||||
self.cache_config.is_attention_free,
|
||||
self.model_config.max_model_len)
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
@ -472,14 +477,18 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
"`dtype` flag in CLI, for example: --dtype=half.")
|
||||
|
||||
|
||||
def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
|
||||
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
|
||||
max_model_len) -> None:
|
||||
if num_gpu_blocks <= 0:
|
||||
if is_attention_free and num_gpu_blocks != 0:
|
||||
raise ValueError("No memory should be allocated for the cache blocks "
|
||||
f"for an attention-free model, but {num_gpu_blocks}"
|
||||
"blocks are allocated.")
|
||||
if not is_attention_free and num_gpu_blocks <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine.")
|
||||
max_seq_len = block_size * num_gpu_blocks
|
||||
if max_model_len > max_seq_len:
|
||||
if not is_attention_free and max_model_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The model's max seq len ({max_model_len}) "
|
||||
"is larger than the maximum number of tokens that can be "
|
||||
|
||||
@ -372,13 +372,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
self.block_size = cache_config.block_size
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_num_attention_heads(self.parallel_config),
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.get_num_kv_heads(self.parallel_config),
|
||||
self.model_config.get_sliding_window(),
|
||||
self.model_config.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
)
|
||||
|
||||
# Multi-modal data support
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user