mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:16:06 +08:00
[V1] [Hybrid] Refactor mamba state shape calculation; enable V1 via cli (#20840)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
c586b55667
commit
3534c39a20
@ -112,8 +112,7 @@ enforcing eager mode and disabling prefix caching in V1.
|
|||||||
Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
||||||
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that
|
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that
|
||||||
these models currently require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention
|
these models currently require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention
|
||||||
backend in V1. It is also necessary to pass a non-standard block size for attention layers (this is not possible
|
backend in V1.
|
||||||
using the `vllm serve` CLI yet).
|
|
||||||
|
|
||||||
#### Encoder-Decoder Models
|
#### Encoder-Decoder Models
|
||||||
|
|
||||||
|
|||||||
@ -61,14 +61,6 @@ V1_SUPPORTED_MODELS = [
|
|||||||
"tiiuae/Falcon-H1-0.5B-Base",
|
"tiiuae/Falcon-H1-0.5B-Base",
|
||||||
]
|
]
|
||||||
|
|
||||||
ATTN_BLOCK_SIZES = {
|
|
||||||
"ibm-ai-platform/Bamba-9B-v1": 528,
|
|
||||||
"Zyphra/Zamba2-1.2B-instruct": 80,
|
|
||||||
"nvidia/Nemotron-H-8B-Base-8K": 528,
|
|
||||||
"ibm-granite/granite-4.0-tiny-preview": 400,
|
|
||||||
"tiiuae/Falcon-H1-0.5B-Base": 800,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Avoid OOM
|
# Avoid OOM
|
||||||
MAX_NUM_SEQS = 4
|
MAX_NUM_SEQS = 4
|
||||||
|
|
||||||
@ -105,11 +97,6 @@ def test_models(
|
|||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
if model in V1_SUPPORTED_MODELS:
|
if model in V1_SUPPORTED_MODELS:
|
||||||
if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES:
|
|
||||||
block_size = ATTN_BLOCK_SIZES[model]
|
|
||||||
else:
|
|
||||||
block_size = 16
|
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
if model in HYBRID_MODELS:
|
if model in HYBRID_MODELS:
|
||||||
@ -118,8 +105,7 @@ def test_models(
|
|||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
max_num_seqs=MAX_NUM_SEQS,
|
max_num_seqs=MAX_NUM_SEQS,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
enable_prefix_caching=False,
|
enable_prefix_caching=False) as vllm_model:
|
||||||
block_size=block_size) as vllm_model:
|
|
||||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1630,6 +1630,9 @@ class CacheConfig:
|
|||||||
checkpoint if available. Otherwise, the scales will default to 1.0."""
|
checkpoint if available. Otherwise, the scales will default to 1.0."""
|
||||||
cpu_kvcache_space_bytes: Optional[int] = None
|
cpu_kvcache_space_bytes: Optional[int] = None
|
||||||
"""(CPU backend only) CPU key-value cache space."""
|
"""(CPU backend only) CPU key-value cache space."""
|
||||||
|
mamba_page_size_padded: Optional[int] = None
|
||||||
|
""" Optional override for mamba page size; used by hybrid mamba/attention
|
||||||
|
models to ensure exact alignment with attention page size."""
|
||||||
|
|
||||||
# Will be set after profiling.
|
# Will be set after profiling.
|
||||||
num_gpu_blocks: Optional[int] = field(default=None, init=False)
|
num_gpu_blocks: Optional[int] = field(default=None, init=False)
|
||||||
@ -4882,11 +4885,15 @@ class VllmConfig:
|
|||||||
if architecture is None:
|
if architecture is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
from vllm.model_executor.models.config import MODELS_CONFIG_MAP
|
from vllm.model_executor.models.config import (
|
||||||
|
MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig)
|
||||||
cls = MODELS_CONFIG_MAP.get(architecture, None)
|
cls = MODELS_CONFIG_MAP.get(architecture, None)
|
||||||
if cls is not None:
|
if cls is not None:
|
||||||
cls.verify_and_update_config(self)
|
cls.verify_and_update_config(self)
|
||||||
|
|
||||||
|
if self.model_config.is_hybrid:
|
||||||
|
HybridAttentionMambaModelConfig.verify_and_update_config(self)
|
||||||
|
|
||||||
if self.model_config.task == "classify":
|
if self.model_config.task == "classify":
|
||||||
# Maybe convert ForCausalLM into ForSequenceClassification model.
|
# Maybe convert ForCausalLM into ForSequenceClassification model.
|
||||||
from vllm.model_executor.models.adapters import (
|
from vllm.model_executor.models.adapters import (
|
||||||
|
|||||||
@ -20,6 +20,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||||
update_metadata)
|
update_metadata)
|
||||||
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
|
extra_groups_for_head_shards, get_mamba_state_shape)
|
||||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||||
causal_conv1d_fn, causal_conv1d_update)
|
causal_conv1d_fn, causal_conv1d_update)
|
||||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||||
@ -146,18 +148,6 @@ class Mixer2RMSNormGated(CustomOp):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
|
|
||||||
"""Compute the increase in group numbers to account for
|
|
||||||
replication in order to accompany the head shards."""
|
|
||||||
|
|
||||||
# in the case ngoups % tp_size == 0, this will be zero
|
|
||||||
if ngroups % tp_size == 0:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
|
||||||
return tp_size - ngroups
|
|
||||||
|
|
||||||
|
|
||||||
def mamba_v2_sharded_weight_loader(
|
def mamba_v2_sharded_weight_loader(
|
||||||
shard_spec: list[tuple[int, int, float]],
|
shard_spec: list[tuple[int, int, float]],
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
@ -707,30 +697,12 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
return get_mamba_state_shape(
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
conv_state_shape, temporal_state_shape = None, None
|
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||||
|
n_groups=self.n_groups,
|
||||||
# if n_groups is not divisible by world_size, need to extend the shards
|
num_heads=self.num_heads,
|
||||||
# to ensure all groups needed by a head is sharded along with it
|
head_dim=self.head_dim,
|
||||||
n_groups = (self.n_groups +
|
state_size=self.ssm_state_size,
|
||||||
extra_groups_for_head_shards(self.n_groups, world_size))
|
conv_kernel=self.conv_kernel_size,
|
||||||
|
|
||||||
# - heads and n_groups are TP-ed
|
|
||||||
conv_dim = (self.intermediate_size +
|
|
||||||
2 * n_groups * self.ssm_state_size)
|
|
||||||
# contiguous along 'dim' axis
|
|
||||||
conv_state_shape = (
|
|
||||||
self.conv_kernel_size - 1,
|
|
||||||
divide(conv_dim, world_size),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# These are not TP-ed as they depend on A, dt_bias, D
|
|
||||||
# - they are typically small
|
|
||||||
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
|
||||||
temporal_state_shape = (
|
|
||||||
divide(self.num_heads, world_size),
|
|
||||||
self.head_dim,
|
|
||||||
self.ssm_state_size,
|
|
||||||
)
|
|
||||||
return conv_state_shape, temporal_state_shape
|
|
||||||
|
|||||||
55
vllm/model_executor/layers/mamba/mamba_utils.py
Normal file
55
vllm/model_executor/layers/mamba/mamba_utils.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from vllm.distributed import divide
|
||||||
|
|
||||||
|
|
||||||
|
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
|
||||||
|
"""Compute the increase in group numbers to account for
|
||||||
|
replication in order to accompany the head shards."""
|
||||||
|
|
||||||
|
# in the case ngoups % tp_size == 0, this will be zero
|
||||||
|
if ngroups % tp_size == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||||
|
return tp_size - ngroups
|
||||||
|
|
||||||
|
|
||||||
|
def get_mamba_state_shape(
|
||||||
|
intermediate_size: int,
|
||||||
|
tp_world_size: int,
|
||||||
|
n_groups: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
state_size: int,
|
||||||
|
conv_kernel: int,
|
||||||
|
use_v1: bool = True,
|
||||||
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
|
""" Get the shape of mamba state."""
|
||||||
|
|
||||||
|
# if n_groups is not divisible by world_size, need to extend the shards
|
||||||
|
# to ensure all groups needed by a head is sharded along with it
|
||||||
|
n_groups = (n_groups +
|
||||||
|
extra_groups_for_head_shards(n_groups, tp_world_size))
|
||||||
|
|
||||||
|
# - heads and n_groups are TP-ed
|
||||||
|
conv_dim = (intermediate_size + 2 * n_groups * state_size)
|
||||||
|
# contiguous along 'dim' axis
|
||||||
|
conv_state_shape = (
|
||||||
|
conv_kernel - 1,
|
||||||
|
divide(conv_dim, tp_world_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not use_v1:
|
||||||
|
conv_state_shape = (conv_state_shape[1], conv_state_shape[0])
|
||||||
|
|
||||||
|
# These are not TP-ed as they depend on A, dt_bias, D
|
||||||
|
# - they are typically small
|
||||||
|
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
|
||||||
|
temporal_state_shape = (
|
||||||
|
divide(num_heads, tp_world_size),
|
||||||
|
head_dim,
|
||||||
|
state_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return conv_state_shape, temporal_state_shape
|
||||||
@ -12,7 +12,7 @@ from transformers import BambaConfig
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
@ -23,8 +23,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
MambaMixer2, extra_groups_for_head_shards)
|
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -435,6 +435,38 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
}
|
}
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_shape_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
use_v1: bool = True,
|
||||||
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vllm_config: vLLM config
|
||||||
|
use_v1: Get shapes for V1 (or V0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- conv_state_shape: Shape for convolutional state cache
|
||||||
|
- temporal_state_shape: Shape for state space model cache
|
||||||
|
"""
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
hf_config = vllm_config.model_config.hf_config
|
||||||
|
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
|
||||||
|
|
||||||
|
return get_mamba_state_shape(
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
tp_world_size=parallel_config.tensor_parallel_size,
|
||||||
|
n_groups=hf_config.mamba_n_groups,
|
||||||
|
num_heads=hf_config.mamba_n_heads,
|
||||||
|
head_dim=hf_config.mamba_d_head,
|
||||||
|
state_size=hf_config.mamba_d_state,
|
||||||
|
conv_kernel=hf_config.mamba_d_conv,
|
||||||
|
use_v1=use_v1,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
@ -491,10 +523,13 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
self.vllm_config.parallel_config,
|
self.vllm_config.parallel_config,
|
||||||
LayerBlockType.mamba
|
LayerBlockType.mamba
|
||||||
)
|
)
|
||||||
|
mamba_state_shape = \
|
||||||
self.mamba_cache = MambaCacheManager(
|
self.get_mamba_state_shape_from_config(
|
||||||
self.vllm_config, self.lm_head.weight.dtype,
|
self.vllm_config, use_v1=False)
|
||||||
num_mamba_layers, *self._get_mamba_cache_shape())
|
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||||
|
self.lm_head.weight.dtype,
|
||||||
|
num_mamba_layers,
|
||||||
|
*mamba_state_shape)
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
@ -510,38 +545,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||||
|
|
||||||
def _get_mamba_cache_shape(
|
|
||||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
hidden_size = self.config.hidden_size
|
|
||||||
|
|
||||||
conv_state_shape, temporal_state_shape = None, None
|
|
||||||
|
|
||||||
intermediate_size = self.config.mamba_expand * hidden_size
|
|
||||||
|
|
||||||
# if n_groups is not divisible by world_size, need to extend the shards
|
|
||||||
# to ensure all groups needed by a head is sharded along with it
|
|
||||||
n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards(
|
|
||||||
self.config.mamba_n_groups, world_size))
|
|
||||||
|
|
||||||
# - heads and n_groups are TP-ed
|
|
||||||
conv_dim = (intermediate_size +
|
|
||||||
2 * n_groups * self.config.mamba_d_state)
|
|
||||||
conv_state_shape = (
|
|
||||||
divide(conv_dim, world_size),
|
|
||||||
self.config.mamba_d_conv - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# These are not TP-ed as they depend on A, dt_bias, D
|
|
||||||
# - they are typically small
|
|
||||||
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
|
||||||
temporal_state_shape = (
|
|
||||||
divide(self.config.mamba_n_heads, world_size),
|
|
||||||
self.config.mamba_d_head,
|
|
||||||
self.config.mamba_d_state,
|
|
||||||
)
|
|
||||||
return conv_state_shape, temporal_state_shape
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -3,9 +3,14 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||||
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -200,6 +205,91 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||||
|
"""
|
||||||
|
Ensure that page size of attention layers is greater than or
|
||||||
|
equal to the mamba layers. If not, automatically set the attention
|
||||||
|
block size to ensure that it is. If the attention page size is
|
||||||
|
strictly greater than the mamba page size, we pad the mamba page size
|
||||||
|
to make them equal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vllm_config: vLLM Config
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
return
|
||||||
|
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
|
||||||
|
if cache_config.cache_dtype == "auto":
|
||||||
|
kv_cache_dtype = model_config.dtype
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||||
|
|
||||||
|
# get attention page size (for 1 token)
|
||||||
|
attn_page_size_1_token = FullAttentionSpec(
|
||||||
|
block_size=1,
|
||||||
|
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||||
|
head_size=model_config.get_head_size(),
|
||||||
|
dtype=kv_cache_dtype,
|
||||||
|
use_mla=model_config.use_mla).page_size_bytes
|
||||||
|
|
||||||
|
model_cls = ModelRegistry.resolve_model_cls(
|
||||||
|
model_config._model_info.architecture)[0]
|
||||||
|
|
||||||
|
# get mamba page size
|
||||||
|
mamba_page_size = MambaSpec(
|
||||||
|
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
||||||
|
dtype=kv_cache_dtype,
|
||||||
|
block_size=model_config.max_model_len,
|
||||||
|
).page_size_bytes
|
||||||
|
|
||||||
|
# some attention backends (e.g. FA) only support setting
|
||||||
|
# block size to multiple of 16, so let's suggest a value
|
||||||
|
# that would work (note: FA is currently not compatible
|
||||||
|
# with mamba layers, use FlashInfer instead).
|
||||||
|
attn_block_size = 16 * cdiv(mamba_page_size,
|
||||||
|
16 * attn_page_size_1_token)
|
||||||
|
|
||||||
|
# override attention block size if either (a) the
|
||||||
|
# user has not set it or (b) the user has set it
|
||||||
|
# too small.
|
||||||
|
if (cache_config.block_size is None
|
||||||
|
or cache_config.block_size < attn_block_size):
|
||||||
|
cache_config.block_size = attn_block_size
|
||||||
|
logger.info(
|
||||||
|
"Setting attention block size to %d tokens "
|
||||||
|
"to ensure that attention page size is >= mamba page size.",
|
||||||
|
attn_block_size)
|
||||||
|
|
||||||
|
# compute new attention page size
|
||||||
|
attn_page_size = \
|
||||||
|
cache_config.block_size * attn_page_size_1_token
|
||||||
|
|
||||||
|
assert attn_page_size >= mamba_page_size
|
||||||
|
|
||||||
|
if attn_page_size == mamba_page_size:
|
||||||
|
# don't need to pad mamba page size
|
||||||
|
return
|
||||||
|
|
||||||
|
# pad mamba page size to exactly match attention
|
||||||
|
if (cache_config.mamba_page_size_padded is None
|
||||||
|
or cache_config.mamba_page_size_padded != attn_page_size):
|
||||||
|
cache_config.mamba_page_size_padded = (attn_page_size)
|
||||||
|
mamba_padding_pct = 100 * (attn_page_size -
|
||||||
|
mamba_page_size) / mamba_page_size
|
||||||
|
logger.info(
|
||||||
|
"Padding mamba page size by %.2f%% to ensure "
|
||||||
|
"that mamba page size and attention page size are "
|
||||||
|
"exactly equal.", mamba_padding_pct)
|
||||||
|
|
||||||
|
|
||||||
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||||
"GteModel": SnowflakeGteNewModelConfig,
|
"GteModel": SnowflakeGteNewModelConfig,
|
||||||
"GteNewModel": GteNewModelConfig,
|
"GteNewModel": GteNewModelConfig,
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from transformers import FalconH1Config
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
@ -22,8 +22,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
MambaMixer2, extra_groups_for_head_shards)
|
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -514,6 +514,42 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
}
|
}
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_shape_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
use_v1: bool = True,
|
||||||
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vllm_config: vLLM config
|
||||||
|
use_v1: Get shapes for V1 (or V0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- conv_state_shape: Shape for convolutional state cache
|
||||||
|
- temporal_state_shape: Shape for state space model cache
|
||||||
|
"""
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
hf_config = vllm_config.model_config.hf_config
|
||||||
|
|
||||||
|
intermediate_size = (int(hf_config.mamba_expand *
|
||||||
|
hf_config.hidden_size)
|
||||||
|
if hf_config.mamba_d_ssm is None else
|
||||||
|
hf_config.mamba_d_ssm)
|
||||||
|
|
||||||
|
return get_mamba_state_shape(
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
tp_world_size=parallel_config.tensor_parallel_size,
|
||||||
|
n_groups=hf_config.mamba_n_groups,
|
||||||
|
num_heads=hf_config.mamba_n_heads,
|
||||||
|
head_dim=hf_config.mamba_d_head,
|
||||||
|
state_size=hf_config.mamba_d_state,
|
||||||
|
conv_kernel=hf_config.mamba_d_conv,
|
||||||
|
use_v1=use_v1,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
@ -580,12 +616,15 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
mamba_cache_params = None
|
mamba_cache_params = None
|
||||||
if not envs.VLLM_USE_V1:
|
if not envs.VLLM_USE_V1:
|
||||||
if self.mamba_cache is None:
|
if self.mamba_cache is None:
|
||||||
|
mamba_state_shape = \
|
||||||
|
self.get_mamba_state_shape_from_config(
|
||||||
|
self.vllm_config, use_v1=False)
|
||||||
self.mamba_cache = MambaCacheManager(
|
self.mamba_cache = MambaCacheManager(
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
self.lm_head.weight.dtype if hasattr(
|
self.lm_head.weight.dtype if hasattr(
|
||||||
self.lm_head, 'weight') else torch.bfloat16,
|
self.lm_head, 'weight') else torch.bfloat16,
|
||||||
self.config.num_hidden_layers,
|
self.config.num_hidden_layers,
|
||||||
*self._get_mamba_cache_shape(),
|
*mamba_state_shape,
|
||||||
)
|
)
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
@ -606,39 +645,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||||
|
|
||||||
def _get_mamba_cache_shape(
|
|
||||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
hidden_size = self.config.hidden_size
|
|
||||||
|
|
||||||
conv_state_shape, temporal_state_shape = None, None
|
|
||||||
|
|
||||||
intermediate_size = (int(self.config.mamba_expand *
|
|
||||||
hidden_size) if self.config.mamba_d_ssm
|
|
||||||
is None else self.config.mamba_d_ssm)
|
|
||||||
|
|
||||||
# if n_groups is not divisible by world_size, need to extend the shards
|
|
||||||
# to ensure all groups needed by a head is sharded along with it
|
|
||||||
n_groups = self.config.mamba_n_groups + extra_groups_for_head_shards(
|
|
||||||
self.config.mamba_n_groups, world_size)
|
|
||||||
|
|
||||||
# - heads and n_groups are TP-ed
|
|
||||||
conv_dim = intermediate_size + 2 * n_groups * self.config.mamba_d_state
|
|
||||||
conv_state_shape = (
|
|
||||||
divide(conv_dim, world_size),
|
|
||||||
self.config.mamba_d_conv - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# These are not TP-ed as they depend on A, dt_bias, D
|
|
||||||
# - they are typically small
|
|
||||||
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
|
||||||
temporal_state_shape = (
|
|
||||||
divide(self.config.mamba_n_heads, world_size),
|
|
||||||
self.config.mamba_d_head,
|
|
||||||
self.config.mamba_d_state,
|
|
||||||
)
|
|
||||||
return conv_state_shape, temporal_state_shape
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from transformers import GraniteMoeHybridConfig
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -21,8 +21,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
MambaMixer2, extra_groups_for_head_shards)
|
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -524,6 +524,38 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
|||||||
}
|
}
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_shape_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
use_v1: bool = True,
|
||||||
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vllm_config: vLLM config
|
||||||
|
use_v1: Get shapes for V1 (or V0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- conv_state_shape: Shape for convolutional state cache
|
||||||
|
- temporal_state_shape: Shape for state space model cache
|
||||||
|
"""
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
hf_config = vllm_config.model_config.hf_config
|
||||||
|
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
|
||||||
|
|
||||||
|
return get_mamba_state_shape(
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
tp_world_size=parallel_config.tensor_parallel_size,
|
||||||
|
n_groups=hf_config.mamba_n_groups,
|
||||||
|
num_heads=hf_config.mamba_n_heads,
|
||||||
|
head_dim=hf_config.mamba_d_head,
|
||||||
|
state_size=hf_config.mamba_d_state,
|
||||||
|
conv_kernel=hf_config.mamba_d_conv,
|
||||||
|
use_v1=use_v1,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -587,9 +619,13 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
|||||||
self.model_config.get_num_layers_by_block_type(
|
self.model_config.get_num_layers_by_block_type(
|
||||||
self.vllm_config.parallel_config,
|
self.vllm_config.parallel_config,
|
||||||
LayerBlockType.mamba))
|
LayerBlockType.mamba))
|
||||||
self.mamba_cache = MambaCacheManager(
|
mamba_state_shape = \
|
||||||
self.vllm_config, self.model_config.dtype,
|
self.get_mamba_state_shape_from_config(
|
||||||
num_mamba_layers, *self._get_mamba_cache_shape())
|
self.vllm_config, use_v1=False)
|
||||||
|
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||||
|
self.model_config.dtype,
|
||||||
|
num_mamba_layers,
|
||||||
|
*mamba_state_shape)
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
@ -605,38 +641,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
|
|||||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||||
|
|
||||||
def _get_mamba_cache_shape(
|
|
||||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
hidden_size = self.config.hidden_size
|
|
||||||
|
|
||||||
conv_state_shape, temporal_state_shape = None, None
|
|
||||||
|
|
||||||
intermediate_size = self.config.mamba_expand * hidden_size
|
|
||||||
|
|
||||||
# if n_groups is not divisible by world_size, need to extend the shards
|
|
||||||
# to ensure all groups needed by a head is sharded along with it
|
|
||||||
n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards(
|
|
||||||
self.config.mamba_n_groups, world_size))
|
|
||||||
|
|
||||||
# - heads and n_groups are TP-ed
|
|
||||||
conv_dim = (intermediate_size +
|
|
||||||
2 * n_groups * self.config.mamba_d_state)
|
|
||||||
conv_state_shape = (
|
|
||||||
divide(conv_dim, world_size),
|
|
||||||
self.config.mamba_d_conv - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# These are not TP-ed as they depend on A, dt_bias, D
|
|
||||||
# - they are typically small
|
|
||||||
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
|
||||||
temporal_state_shape = (
|
|
||||||
divide(self.config.mamba_n_heads, world_size),
|
|
||||||
self.config.mamba_d_head,
|
|
||||||
self.config.mamba_d_state,
|
|
||||||
)
|
|
||||||
return conv_state_shape, temporal_state_shape
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from .interfaces_base import is_pooling_model
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.models.utils import WeightsMapper
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
@ -481,6 +482,25 @@ class IsHybrid(Protocol):
|
|||||||
, also indicates that the model's hf_config has
|
, also indicates that the model's hf_config has
|
||||||
'layers_block_type' """
|
'layers_block_type' """
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_shape_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
use_v1: bool = True,
|
||||||
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vllm_config: vLLM config
|
||||||
|
use_v1: Get shapes for V1 (or V0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- conv_state_shape: Shape for convolutional state cache
|
||||||
|
- temporal_state_shape: Shape for state space model cache
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class _IsHybridType(Protocol):
|
class _IsHybridType(Protocol):
|
||||||
|
|||||||
@ -11,15 +11,14 @@ from transformers import MambaConfig
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
MambaMixer2, extra_groups_for_head_shards)
|
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -198,6 +197,38 @@ class Mamba2Model(nn.Module):
|
|||||||
|
|
||||||
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_shape_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
use_v1: bool = True,
|
||||||
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vllm_config: vLLM config
|
||||||
|
use_v1: Get shapes for V1 (or V0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- conv_state_shape: Shape for convolutional state cache
|
||||||
|
- temporal_state_shape: Shape for state space model cache
|
||||||
|
"""
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
hf_config = vllm_config.model_config.hf_config
|
||||||
|
intermediate_size = hf_config.expand * hf_config.hidden_size
|
||||||
|
|
||||||
|
return get_mamba_state_shape(
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
tp_world_size=parallel_config.tensor_parallel_size,
|
||||||
|
n_groups=hf_config.n_groups,
|
||||||
|
num_heads=hf_config.num_heads,
|
||||||
|
head_dim=hf_config.head_dim,
|
||||||
|
state_size=hf_config.state_size,
|
||||||
|
conv_kernel=hf_config.conv_kernel,
|
||||||
|
use_v1=use_v1,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
@ -253,9 +284,13 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
|||||||
self.model_config.get_num_layers_by_block_type(
|
self.model_config.get_num_layers_by_block_type(
|
||||||
self.vllm_config.parallel_config,
|
self.vllm_config.parallel_config,
|
||||||
LayerBlockType.mamba))
|
LayerBlockType.mamba))
|
||||||
self.mamba_cache = MambaCacheManager(
|
mamba_state_shape = \
|
||||||
self.vllm_config, self.lm_head.weight.dtype,
|
self.get_mamba_state_shape_from_config(
|
||||||
num_mamba_layers, *self._get_mamba_cache_shape())
|
self.vllm_config, use_v1=False)
|
||||||
|
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||||
|
self.lm_head.weight.dtype,
|
||||||
|
num_mamba_layers,
|
||||||
|
*mamba_state_shape)
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
else:
|
else:
|
||||||
@ -274,39 +309,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
|||||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||||
|
|
||||||
def _get_mamba_cache_shape(
|
|
||||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
|
|
||||||
conv_state_shape, temporal_state_shape = None, None
|
|
||||||
|
|
||||||
intermediate_size = getattr(
|
|
||||||
self.config, "intermediate_size",
|
|
||||||
self.config.expand * self.config.hidden_size)
|
|
||||||
|
|
||||||
# if n_groups is not divisible by world_size, need to extend the shards
|
|
||||||
# to ensure all groups needed by a head is sharded along with it
|
|
||||||
n_groups = (
|
|
||||||
self.config.n_groups +
|
|
||||||
extra_groups_for_head_shards(self.config.n_groups, world_size))
|
|
||||||
|
|
||||||
# - heads and n_groups are TP-ed
|
|
||||||
conv_dim = (intermediate_size + 2 * n_groups * self.config.state_size)
|
|
||||||
conv_state_shape = (
|
|
||||||
divide(conv_dim, world_size),
|
|
||||||
self.config.conv_kernel - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# These are not TP-ed as they depend on A, dt_bias, D
|
|
||||||
# - they are typically small
|
|
||||||
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
|
||||||
temporal_state_shape = (
|
|
||||||
divide(self.config.num_heads, world_size),
|
|
||||||
self.config.head_dim,
|
|
||||||
self.config.state_size,
|
|
||||||
)
|
|
||||||
return conv_state_shape, temporal_state_shape
|
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: torch.Tensor,
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from torch import nn
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||||
@ -37,8 +37,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
MambaMixer2, extra_groups_for_head_shards)
|
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
@ -459,6 +459,38 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
}
|
}
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_shape_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
use_v1: bool = True,
|
||||||
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vllm_config: vLLM config
|
||||||
|
use_v1: Get shapes for V1 (or V0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- conv_state_shape: Shape for convolutional state cache
|
||||||
|
- temporal_state_shape: Shape for state space model cache
|
||||||
|
"""
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
hf_config = vllm_config.model_config.hf_config
|
||||||
|
intermediate_size = hf_config.expand * hf_config.hidden_size
|
||||||
|
|
||||||
|
return get_mamba_state_shape(
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
tp_world_size=parallel_config.tensor_parallel_size,
|
||||||
|
n_groups=hf_config.n_groups,
|
||||||
|
num_heads=hf_config.mamba_num_heads,
|
||||||
|
head_dim=hf_config.mamba_head_dim,
|
||||||
|
state_size=hf_config.ssm_state_size,
|
||||||
|
conv_kernel=hf_config.conv_kernel,
|
||||||
|
use_v1=use_v1,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
@ -515,10 +547,13 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
self.vllm_config.parallel_config,
|
self.vllm_config.parallel_config,
|
||||||
LayerBlockType.mamba
|
LayerBlockType.mamba
|
||||||
)
|
)
|
||||||
|
mamba_state_shape = \
|
||||||
self.mamba_cache = MambaCacheManager(
|
self.get_mamba_state_shape_from_config(
|
||||||
self.vllm_config, self.lm_head.weight.dtype,
|
self.vllm_config, use_v1=False)
|
||||||
num_mamba_layers, *self._get_mamba_cache_shape())
|
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||||
|
self.lm_head.weight.dtype,
|
||||||
|
num_mamba_layers,
|
||||||
|
*mamba_state_shape)
|
||||||
|
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
|
|
||||||
@ -534,39 +569,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||||
|
|
||||||
def _get_mamba_cache_shape(
|
|
||||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
hidden_size = self.config.hidden_size
|
|
||||||
|
|
||||||
conv_state_shape, temporal_state_shape = None, None
|
|
||||||
|
|
||||||
intermediate_size = self.config.expand * hidden_size
|
|
||||||
|
|
||||||
# if n_groups is not divisible by world_size, need to extend the shards
|
|
||||||
# to ensure all groups needed by a head is sharded along with it
|
|
||||||
n_groups = (
|
|
||||||
self.config.n_groups +
|
|
||||||
extra_groups_for_head_shards(self.config.n_groups, world_size))
|
|
||||||
|
|
||||||
# - heads and n_groups are TP-ed
|
|
||||||
conv_dim = (intermediate_size +
|
|
||||||
2 * n_groups * self.config.ssm_state_size)
|
|
||||||
conv_state_shape = (
|
|
||||||
divide(conv_dim, world_size),
|
|
||||||
self.config.conv_kernel - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# These are not TP-ed as they depend on A, dt_bias, D
|
|
||||||
# - they are typically small
|
|
||||||
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
|
||||||
temporal_state_shape = (
|
|
||||||
divide(self.config.mamba_num_heads, world_size),
|
|
||||||
self.config.mamba_head_dim,
|
|
||||||
self.config.ssm_state_size,
|
|
||||||
)
|
|
||||||
return conv_state_shape, temporal_state_shape
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from transformers import Zamba2Config
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.activation import GeluAndMul
|
from vllm.model_executor.layers.activation import GeluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -30,8 +30,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||||
Mamba2Metadata, prepare_mamba2_metadata)
|
Mamba2Metadata, prepare_mamba2_metadata)
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
MambaMixer2, extra_groups_for_head_shards)
|
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -843,6 +843,39 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
"1.weight": "B.weight",
|
"1.weight": "B.weight",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_shape_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
use_v1: bool = True,
|
||||||
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||||
|
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vllm_config: vLLM config
|
||||||
|
use_v1: Get shapes for V1 (or V0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- conv_state_shape: Shape for convolutional state cache
|
||||||
|
- temporal_state_shape: Shape for state space model cache
|
||||||
|
"""
|
||||||
|
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
hf_config = vllm_config.model_config.hf_config
|
||||||
|
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
|
||||||
|
|
||||||
|
return get_mamba_state_shape(
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
tp_world_size=parallel_config.tensor_parallel_size,
|
||||||
|
n_groups=hf_config.mamba_ngroups,
|
||||||
|
num_heads=hf_config.n_mamba_heads,
|
||||||
|
head_dim=hf_config.mamba_headdim,
|
||||||
|
state_size=hf_config.mamba_d_state,
|
||||||
|
conv_kernel=hf_config.mamba_d_conv,
|
||||||
|
use_v1=use_v1,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
"""Initialize the Zamba2 model for causal language modeling.
|
"""Initialize the Zamba2 model for causal language modeling.
|
||||||
|
|
||||||
@ -925,9 +958,13 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
if not envs.VLLM_USE_V1:
|
if not envs.VLLM_USE_V1:
|
||||||
if self.mamba_cache is None:
|
if self.mamba_cache is None:
|
||||||
num_mamba_layers = self.config.num_hidden_layers
|
num_mamba_layers = self.config.num_hidden_layers
|
||||||
self.mamba_cache = MambaCacheManager(
|
mamba_state_shape = \
|
||||||
self.vllm_config, self.lm_head.weight.dtype,
|
self.get_mamba_state_shape_from_config(
|
||||||
num_mamba_layers, *self._get_mamba_cache_shape())
|
self.vllm_config, use_v1=False)
|
||||||
|
self.mamba_cache = MambaCacheManager(self.vllm_config,
|
||||||
|
self.lm_head.weight.dtype,
|
||||||
|
num_mamba_layers,
|
||||||
|
*mamba_state_shape)
|
||||||
|
|
||||||
# Get cache parameters for current run
|
# Get cache parameters for current run
|
||||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||||
@ -968,49 +1005,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
"""
|
"""
|
||||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||||
|
|
||||||
def _get_mamba_cache_shape(
|
|
||||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
|
||||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple containing:
|
|
||||||
- conv_state_shape: Shape for convolutional state cache
|
|
||||||
- temporal_state_shape: Shape for state space model cache
|
|
||||||
"""
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
|
|
||||||
intermediate_size = self.config.mamba_expand * self.config.hidden_size
|
|
||||||
|
|
||||||
# Extend groups if needed to ensure all groups needed by a head
|
|
||||||
# are sharded together
|
|
||||||
|
|
||||||
# if n_groups is not divisible by world_size, need to extend the shards
|
|
||||||
# to ensure all groups needed by a head is sharded along with it
|
|
||||||
n_groups = (self.config.mamba_ngroups + extra_groups_for_head_shards(
|
|
||||||
self.config.mamba_ngroups, world_size))
|
|
||||||
|
|
||||||
# Calculate conv state shape (includes groups)
|
|
||||||
# - heads and n_groups are TP-ed
|
|
||||||
conv_dim = (intermediate_size +
|
|
||||||
2 * n_groups * self.config.mamba_d_state)
|
|
||||||
conv_state_shape = (
|
|
||||||
divide(conv_dim, world_size),
|
|
||||||
self.config.mamba_d_conv - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate temporal state shape (per-head states)
|
|
||||||
# These are not TP-ed as they depend on A, dt_bias, D
|
|
||||||
# - they are typically small
|
|
||||||
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
|
||||||
temporal_state_shape = (
|
|
||||||
divide(divide(intermediate_size, self.config.mamba_headdim),
|
|
||||||
world_size),
|
|
||||||
self.config.mamba_headdim,
|
|
||||||
self.config.mamba_d_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
return conv_state_shape, temporal_state_shape
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -42,7 +42,7 @@ from vllm.pooling_params import PoolingParams
|
|||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
|
GiB_bytes, LazyLoader, async_tensor_h2d,
|
||||||
check_use_alibi, get_dtype_size,
|
check_use_alibi, get_dtype_size,
|
||||||
is_pin_memory_available, round_up)
|
is_pin_memory_available, round_up)
|
||||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
||||||
@ -2648,9 +2648,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
"Prefix caching is not supported for Mamba yet.")
|
"Prefix caching is not supported for Mamba yet.")
|
||||||
max_model_len = self.vllm_config.model_config.max_model_len
|
max_model_len = self.vllm_config.model_config.max_model_len
|
||||||
|
|
||||||
page_size_padded = self._maybe_pad_mamba_page_size(
|
page_size_padded = (
|
||||||
attn_layers, mamba_layers, kv_cache_spec, max_model_len,
|
self.vllm_config.cache_config.mamba_page_size_padded)
|
||||||
block_size)
|
|
||||||
|
|
||||||
# Set block_size to max_model_len, so that mamba model will always
|
# Set block_size to max_model_len, so that mamba model will always
|
||||||
# have only one block in the KV cache.
|
# have only one block in the KV cache.
|
||||||
@ -2662,54 +2661,3 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
page_size_padded=page_size_padded)
|
page_size_padded=page_size_padded)
|
||||||
|
|
||||||
return kv_cache_spec
|
return kv_cache_spec
|
||||||
|
|
||||||
def _maybe_pad_mamba_page_size(
|
|
||||||
self,
|
|
||||||
attn_layers: dict[str, Attention],
|
|
||||||
mamba_layers: dict[str, MambaBase],
|
|
||||||
kv_cache_spec: dict[str, KVCacheSpec],
|
|
||||||
max_model_len: int,
|
|
||||||
block_size: int,
|
|
||||||
) -> Optional[int]:
|
|
||||||
"""
|
|
||||||
Ensure that page size of attention KV cache groups is greater than or
|
|
||||||
equal to the mamba KV cache groups. If not, we suggest to the user
|
|
||||||
how to set the attention block size to ensure that it is.
|
|
||||||
|
|
||||||
If the attention page size is strictly greater than the mamba page size,
|
|
||||||
we pad the mamba page size to make them equal.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attn_layers: Attention layers
|
|
||||||
mamba_layers: Mamba layers
|
|
||||||
kv_cache_spec: KV cache spec (populated with attention layers)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[int]: Mamba page size with padding (None if no padding).
|
|
||||||
"""
|
|
||||||
|
|
||||||
if len(attn_layers) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
attn_layer_name = next(iter(attn_layers))
|
|
||||||
attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes
|
|
||||||
mamba_layer_name = next(iter(mamba_layers))
|
|
||||||
mamba_page_size = MambaSpec(
|
|
||||||
shapes=mamba_layers[mamba_layer_name].get_state_shape(),
|
|
||||||
dtype=self.kv_cache_dtype,
|
|
||||||
block_size=max_model_len).page_size_bytes
|
|
||||||
if attn_page_size < mamba_page_size:
|
|
||||||
# attention page size (for 16 tokens)
|
|
||||||
attn_page_size_16 = 16 * attn_page_size // block_size
|
|
||||||
# some attention backends (e.g. FA) only support setting
|
|
||||||
# block size to multiple of 16, so let's suggest a value
|
|
||||||
# that would work (note: FA is currently not compatible
|
|
||||||
# with mamba layers, use FlashInfer instead).
|
|
||||||
suggest_attn_block_size = 16 * cdiv(mamba_page_size,
|
|
||||||
attn_page_size_16)
|
|
||||||
raise ValueError(
|
|
||||||
"Attention block size should be increased to at least "
|
|
||||||
f"{suggest_attn_block_size} in order to match "
|
|
||||||
"the mamba page size")
|
|
||||||
|
|
||||||
return attn_page_size
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user