[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)

Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Thomas Parnell 2025-08-15 14:57:06 +02:00 committed by GitHub
parent 22341b996e
commit 75531a6c13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 467 additions and 87 deletions

View File

@ -431,3 +431,65 @@ def test_full_cuda_graph(
name_0="hf" if hf_outputs is not None else "vllm-v0",
name_1="vllm-v1",
)
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_fp32_state(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
num_logprobs: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
with hf_runner(model) as hf_model:
if model not in HF_UNSUPPORTED_MODELS:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
else:
hf_outputs = None
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32") as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32",
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
if hf_outputs is not None:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs,
name_0="hf",
name_1="vllm-v0",
)
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
check_logprobs_close(
outputs_0_lst=ref_outputs,
outputs_1_lst=vllm_v1_outputs,
name_0="hf" if hf_outputs is not None else "vllm-v0",
name_1="vllm-v1",
)

View File

@ -772,6 +772,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
head_dim=hf_config.mamba_d_head,
rms_norm_eps=hf_config.rms_norm_eps,
activation=hf_config.hidden_act,
cache_config=cache_config,
model_config=model_config,
prefix=key,
)
# suppress var not used error

View File

@ -29,7 +29,7 @@ from typing_extensions import Self, assert_never, runtime_checkable
import vllm.envs as envs
from vllm import version
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType,
from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
PrefixCachingHashAlgo)
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
PassConfig)

View File

@ -23,6 +23,7 @@ logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
@ -93,6 +94,15 @@ class CacheConfig:
""" Optional override for mamba page size; used by hybrid mamba/attention
models to ensure exact alignment with attention page size."""
mamba_cache_dtype: MambaDType = "auto"
"""The data type to use for the Mamba cache (both the conv as well as the
ssm state). If set to 'auto', the data type will be inferred from the model
config."""
mamba_ssm_cache_dtype: MambaDType = "auto"
"""The data type to use for the Mamba cache (ssm state only, conv state will
still be controlled by mamba_cache_dtype). If set to 'auto', the data type
for the ssm state will be determined by mamba_cache_dtype."""
# Will be set after profiling.
num_gpu_blocks: Optional[int] = field(default=None, init=False)
"""The number of blocks to allocate for GPU memory."""
@ -123,6 +133,8 @@ class CacheConfig:
"""
factors: list[Any] = []
factors.append(self.cache_dtype)
factors.append(self.mamba_cache_dtype)
factors.append(self.mamba_ssm_cache_dtype)
# `cpu_offload_gb` does not use `torch.compile` yet.
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()

View File

@ -27,12 +27,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
PoolerConfig, PrefixCachingHashAlgo, RunnerOption,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
get_field)
LoRAConfig, MambaDType, ModelConfig, ModelDType,
ModelImpl, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
RunnerOption, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerMode,
VllmConfig, get_attr_docs, get_field)
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
@ -422,6 +422,8 @@ class EngineArgs:
override_attention_dtype: str = ModelConfig.override_attention_dtype
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
additional_config: dict[str, Any] = \
get_field(VllmConfig, "additional_config")
@ -694,6 +696,10 @@ class EngineArgs:
**cache_kwargs["calculate_kv_scales"])
cache_group.add_argument("--kv-sharing-fast-prefill",
**cache_kwargs["kv_sharing_fast_prefill"])
cache_group.add_argument("--mamba-cache-dtype",
**cache_kwargs["mamba_cache_dtype"])
cache_group.add_argument("--mamba-ssm-cache-dtype",
**cache_kwargs["mamba_ssm_cache_dtype"])
# Multimodal related configs
multimodal_kwargs = get_kwargs(MultiModalConfig)
@ -1105,6 +1111,8 @@ class EngineArgs:
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
)
ray_runtime_env = None

View File

@ -9,7 +9,7 @@ from torch.nn.parameter import Parameter
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import get_current_vllm_config
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context
@ -20,7 +20,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
@ -56,6 +56,8 @@ class MambaMixer(MambaBase, CustomOp):
rms_norm_eps: float = 1e-5,
activation="silu",
is_lora_enabled: bool = False,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = ""):
super().__init__()
self.time_step_rank = time_step_rank
@ -153,6 +155,8 @@ class MambaMixer(MambaBase, CustomOp):
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
def _ssm_transform(
@ -369,6 +373,15 @@ class MambaMixer(MambaBase, CustomOp):
return out
def get_state_dtype(self) -> tuple[torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.mamba1_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
self.cache_config.mamba_ssm_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba1_state_shape(
tp_world_size=get_tensor_model_parallel_world_size(),

View File

@ -8,7 +8,7 @@ from torch import nn
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import get_current_vllm_config
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
@ -21,7 +21,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
@ -218,23 +218,23 @@ class MambaMixer2(MambaBase, CustomOp):
**selective** state spaces)
"""
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# For TP, the sharding plan is as follows:
@ -417,6 +417,8 @@ class MambaMixer2(MambaBase, CustomOp):
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
def forward_native(
@ -670,7 +672,7 @@ class MambaMixer2(MambaBase, CustomOp):
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
self.head_dim),
)
state_dtype=ssm_state.dtype)
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
@ -732,6 +734,15 @@ class MambaMixer2(MambaBase, CustomOp):
# 5. Final linear projection
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.mamba2_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
self.cache_config.mamba_ssm_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=self.intermediate_size,

View File

@ -1,6 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Union
import torch
from vllm.config import MambaDType, ModelDType
from vllm.distributed import divide
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype
class MambaStateDtypeCalculator:
@classmethod
def linear_attention_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
# TODO (tdoublep) requires testing
if mamba_cache_dtype == "float32":
raise ValueError("fp32 state for minimax is not yet supported")
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, )
@classmethod
def mamba1_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
# TODO (tdoublep) requires kernel changes
if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32":
raise ValueError("fp32 state for mamba1 is not yet supported")
else:
return MambaStateDtypeCalculator.mamba2_state_dtype(
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype)
@classmethod
def mamba2_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
model_dtype)
if mamba_ssm_cache_dtype == "auto":
temporal_state_dtype = conv_state_dtype
else:
temporal_state_dtype = (
STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype])
return (conv_state_dtype, temporal_state_dtype)
class MambaStateShapeCalculator:

View File

@ -41,6 +41,7 @@ def _mamba_chunk_scan_combined_fwd(x,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
state_dtype=None,
out=None):
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
batch, seqlen, nheads, headdim = x.shape
@ -118,7 +119,7 @@ def _mamba_chunk_scan_combined_fwd(x,
if initial_states is not None else None,
seq_idx=seq_idx,
chunk_size=chunk_size,
out_dtype=C.dtype,
out_dtype=state_dtype if state_dtype is not None else C.dtype,
is_cont_batched=cu_seqlens is not None)
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
for t in [states, final_states])
@ -189,7 +190,8 @@ def mamba_chunk_scan_combined(x,
dt_limit=(0.0, float("inf")),
out=None,
return_final_states=False,
return_varlen_states=False):
return_varlen_states=False,
state_dtype=None):
"""
Argument:
x: (batch, seqlen, nheads, headdim)
@ -206,6 +208,7 @@ def mamba_chunk_scan_combined(x,
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
dt_softplus: Whether to apply softplus to dt
out: Preallocated output tensor
state_dtype: The data type of the ssm state
"""
if not return_varlen_states:
@ -229,7 +232,8 @@ def mamba_chunk_scan_combined(x,
cu_seqlens=cu_seqlens,
dt_softplus=dt_softplus,
dt_limit=dt_limit,
out=out)
out=out,
state_dtype=state_dtype)
if not return_varlen_states:
if not return_final_states:
return

View File

@ -12,7 +12,7 @@ from transformers import BambaConfig
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
@ -26,7 +26,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -83,6 +83,7 @@ class BambaMixerDecoderLayer(nn.Module):
def __init__(self,
config: BambaConfig,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
@ -100,6 +101,8 @@ class BambaMixerDecoderLayer(nn.Module):
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.mixer")
@ -138,6 +141,7 @@ class BambaAttentionDecoderLayer(nn.Module):
self,
config: BambaConfig,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@ -266,6 +270,7 @@ class BambaModel(nn.Module):
super().__init__()
config: BambaConfig = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -289,6 +294,7 @@ class BambaModel(nn.Module):
return layer_class(
config,
layer_idx,
model_config,
cache_config,
quant_config=quant_config,
prefix=prefix,
@ -437,6 +443,18 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.mamba2_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)
@classmethod
def get_mamba_state_shape_from_config(
cls,
@ -528,10 +546,13 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_mamba_layers,
*mamba_state_shape)
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)

View File

@ -318,7 +318,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# get mamba page size
mamba_page_size = MambaSpec(
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
dtype=kv_cache_dtype,
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
block_size=model_config.max_model_len,
).page_size_bytes

View File

@ -11,7 +11,7 @@ from transformers import FalconH1Config
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
@ -25,7 +25,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -85,6 +85,7 @@ class FalconH1SSMDecoderLayer(nn.Module):
def __init__(
self,
config: FalconH1Config,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@ -108,6 +109,8 @@ class FalconH1SSMDecoderLayer(nn.Module):
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
use_rms_norm=config.mamba_rms_norm,
prefix=f"{prefix}.mixer",
@ -317,6 +320,7 @@ class FalconH1ParallelHybrid(nn.Module):
self,
config: FalconH1Config,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@ -339,6 +343,7 @@ class FalconH1ParallelHybrid(nn.Module):
# Instantiate the SSM branch
self.mamba = FalconH1SSMDecoderLayer(
config=config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=ssm_prefix,
@ -408,6 +413,7 @@ class FalconH1Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: FalconH1Config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -435,6 +441,7 @@ class FalconH1Model(nn.Module):
return layer_class(
config,
layer_idx,
model_config,
cache_config,
quant_config=quant_config,
prefix=prefix,
@ -519,6 +526,18 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.mamba2_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)
@classmethod
def get_mamba_state_shape_from_config(
cls,
@ -624,12 +643,14 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(
self.vllm_config,
self.lm_head.weight.dtype if hasattr(
self.lm_head, 'weight') else torch.bfloat16,
self.config.num_hidden_layers,
*mamba_state_shape,
*mamba_state_dtype,
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)

View File

@ -12,7 +12,7 @@ from transformers import GraniteMoeHybridConfig
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
@ -24,7 +24,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -50,6 +50,7 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
def __init__(self,
config: GraniteMoeHybridConfig,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
@ -70,6 +71,8 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.mixer")
@ -137,6 +140,7 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module):
self,
config: GraniteMoeHybridConfig,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@ -217,6 +221,7 @@ class GraniteMoeHybridAttention(nn.Module):
def __init__(
self,
config: GraniteMoeHybridConfig,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@ -316,6 +321,7 @@ class GraniteMoeHybridModel(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -340,6 +346,7 @@ class GraniteMoeHybridModel(nn.Module):
return layer_class(
config,
layer_idx,
model_config,
cache_config,
quant_config=quant_config,
prefix=prefix,
@ -527,6 +534,18 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.mamba2_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)
@classmethod
def get_mamba_state_shape_from_config(
cls,
@ -625,10 +644,13 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.model_config.dtype,
num_mamba_layers,
*mamba_state_shape)
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)

View File

@ -10,7 +10,7 @@ from transformers import JambaConfig
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
@ -21,7 +21,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -94,6 +94,7 @@ class JambaMambaDecoderLayer(nn.Module):
def __init__(self,
config: JambaConfig,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
is_lora_enabled: Optional[bool] = False,
@ -114,6 +115,8 @@ class JambaMambaDecoderLayer(nn.Module):
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
is_lora_enabled = self.is_lora_enabled,
model_config=model_config,
cache_config=cache_config,
prefix=f"{prefix}.mixer",
)
@ -164,6 +167,7 @@ class JambaAttentionDecoderLayer(nn.Module):
def __init__(self,
config: JambaConfig,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@ -280,6 +284,7 @@ class JambaModel(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -304,6 +309,7 @@ class JambaModel(nn.Module):
config.layers_block_type[layer_idx]]
return layer_class(config,
layer_idx,
model_config,
cache_config,
quant_config=quant_config,
prefix=prefix,
@ -520,9 +526,11 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.vllm_config.parallel_config, LayerBlockType.mamba)
state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config)
state_dtype = self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_layers, *state_shape)
num_layers, *state_shape,
*state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
@ -537,6 +545,18 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.mamba1_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)
@classmethod
def get_mamba_state_shape_from_config(
cls,

View File

@ -9,13 +9,13 @@ from torch import nn
from transformers import MambaConfig
from vllm import envs
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -40,6 +40,7 @@ class MambaDecoderLayer(nn.Module):
def __init__(self,
config: MambaConfig,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
is_lora_enabled: Optional[bool] = False,
@ -61,6 +62,8 @@ class MambaDecoderLayer(nn.Module):
rms_norm_eps=mixer_rms_eps,
activation=config.hidden_act,
is_lora_enabled=self.is_lora_enabled,
model_config=model_config,
cache_config=cache_config,
prefix=f"{prefix}.mixer")
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@ -88,6 +91,7 @@ class MambaModel(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -108,6 +112,7 @@ class MambaModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MambaDecoderLayer(config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
is_lora_enabled=is_lora_enabled,
@ -243,9 +248,11 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
self.vllm_config.parallel_config, LayerBlockType.mamba)
state_shape = self.get_mamba_state_shape_from_config(
self.vllm_config)
state_dtype = self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_layers, *state_shape)
num_layers, *state_shape,
*state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
@ -254,6 +261,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP):
return hidden_states
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.mamba1_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)
@classmethod
def get_mamba_state_shape_from_config(
cls,

View File

@ -11,7 +11,7 @@ from transformers import MambaConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
@ -20,7 +20,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -45,6 +45,8 @@ class Mamba2DecoderLayer(nn.Module):
def __init__(self,
config: MambaConfig,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__()
@ -62,6 +64,8 @@ class Mamba2DecoderLayer(nn.Module):
head_dim=config.head_dim,
rms_norm_eps=config.layer_norm_epsilon,
activation=config.hidden_act,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.mixer")
@ -93,6 +97,8 @@ class Mamba2Model(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
is_lora_enabled = bool(lora_config)
@ -112,8 +118,11 @@ class Mamba2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Mamba2DecoderLayer(
config, quant_config=quant_config, prefix=prefix),
lambda prefix: Mamba2DecoderLayer(config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
self.norm_f = RMSNorm(config.hidden_size,
@ -200,6 +209,18 @@ class Mamba2Model(nn.Module):
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.mamba2_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)
@classmethod
def get_mamba_state_shape_from_config(
cls,
@ -290,10 +311,13 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_mamba_layers,
*mamba_state_shape)
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
else:

View File

@ -24,9 +24,14 @@ class MambaCacheParams:
class MambaCacheManager(ConstantSizeCache):
def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
num_mamba_layers: int, conv_state_shape: tuple[int, int],
temporal_state_shape: tuple[int, int]):
def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int,
conv_state_shape: tuple[int, int],
temporal_state_shape: tuple[int, int],
conv_state_dtype: torch.dtype,
temporal_state_dtype: torch.dtype):
self.conv_state_dtype = conv_state_dtype
self.temporal_state_dtype = temporal_state_dtype
# Determine max batch size to set size of MambaCache
max_batch_size = vllm_config.scheduler_config.max_num_seqs
@ -40,11 +45,11 @@ class MambaCacheManager(ConstantSizeCache):
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
(conv_state_shape[1], conv_state_shape[0]),
dtype=dtype,
dtype=self.conv_state_dtype,
device="cuda").transpose(-1, -2)
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=dtype,
dtype=self.temporal_state_dtype,
device="cuda")
self._mamba_cache = (conv_state, temporal_state)

View File

@ -16,7 +16,8 @@ from transformers import MiniMaxConfig
from vllm import envs
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_rank,
@ -36,7 +37,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -338,6 +339,12 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def mamba_type(self) -> str:
return "linear_attention"
def get_state_dtype(self) -> tuple[torch.dtype]:
return MambaStateDtypeCalculator.linear_attention_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.linear_attention_state_shape(
num_heads=self.num_heads,
@ -353,6 +360,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
max_position: int,
block_size: int,
num_hidden_layer: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
layer_idx: int = 0,
linear_layer_idx: int = 0,
@ -374,6 +383,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
self.tp_heads = self.total_num_heads // self.tp_size
self.qkv_size = self.num_heads * self.head_dim
self.tp_hidden = self.head_dim * self.tp_heads
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
self.qkv_proj = ColumnParallelLinear(
@ -657,6 +668,7 @@ class MiniMaxText01DecoderLayer(nn.Module):
def __init__(
self,
config: MiniMaxConfig,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
expert_num: int = 1,
@ -693,6 +705,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
max_position=max_position_embeddings,
block_size=config.block if hasattr(config, "block") else 256,
num_hidden_layer=config.num_hidden_layers,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
layer_idx=self._ilayer,
linear_layer_idx=linear_layer_id,
@ -861,6 +875,7 @@ class MiniMaxText01Model(nn.Module):
def __init__(
self,
config: MiniMaxConfig,
model_config: Optional[ModelConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
scheduler_config=None,
@ -910,6 +925,7 @@ class MiniMaxText01Model(nn.Module):
decoder_kwargs = {
"quant_config": quant_config,
"layer_id": layer_idx,
"model_config": model_config,
"cache_config": cache_config
}
@ -1111,8 +1127,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
self.config.max_model_len = vllm_config.model_config.max_model_len
self.model = MiniMaxText01Model(
self.config,
quant_config,
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=quant_config,
scheduler_config=vllm_config.scheduler_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
@ -1409,6 +1426,17 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
load_basic_weight(name, loaded_weight, self)
return loaded_params
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.linear_attention_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
)
@classmethod
def get_mamba_state_shape_from_config(
cls,

View File

@ -26,7 +26,7 @@ from torch import nn
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
@ -40,7 +40,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@ -110,6 +110,7 @@ class NemotronHMLPDecoderLayer(nn.Module):
self,
config: NemotronHConfig,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@ -149,6 +150,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
self,
config: NemotronHConfig,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@ -167,6 +169,8 @@ class NemotronHMambaDecoderLayer(nn.Module):
head_dim=config.mamba_head_dim,
rms_norm_eps=config.rms_norm_eps,
activation=config.mamba_hidden_act,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.mixer",
)
@ -198,6 +202,7 @@ class NemotronHAttention(nn.Module):
self,
config: NemotronHConfig,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@ -270,6 +275,7 @@ class NemotronHAttentionDecoderLayer(nn.Module):
self,
config: NemotronHConfig,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@ -279,6 +285,7 @@ class NemotronHAttentionDecoderLayer(nn.Module):
self.mixer = NemotronHAttention(
config,
layer_idx,
model_config,
cache_config,
quant_config,
prefix=f"{prefix}.mixer",
@ -317,6 +324,7 @@ class NemotronHModel(nn.Module):
super().__init__()
config: NemotronHConfig = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -340,6 +348,7 @@ class NemotronHModel(nn.Module):
return layer_class(
config,
layer_idx,
model_config,
cache_config,
quant_config=quant_config,
prefix=prefix,
@ -478,6 +487,18 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.mamba2_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)
@classmethod
def get_mamba_state_shape_from_config(
cls,
@ -569,10 +590,13 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_mamba_layers,
*mamba_state_shape)
*mamba_state_shape,
*mamba_state_dtype)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)

View File

@ -18,7 +18,7 @@ from transformers import Zamba2Config
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import GeluAndMul
@ -33,7 +33,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -478,6 +478,8 @@ class Zamba2MambaDecoderLayer(nn.Module):
def __init__(self,
config: Zamba2Config,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
"""Initialize the Mamba decoder layer.
@ -502,6 +504,8 @@ class Zamba2MambaDecoderLayer(nn.Module):
config.n_mamba_heads,
rms_norm_eps=config.rms_norm_eps,
activation="silu",
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.mixer")
@ -578,6 +582,8 @@ class Zamba2HybridLayer(nn.Module):
shared_transformer: Zamba2AttentionDecoderLayer,
config: Zamba2Config,
block_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
@ -596,6 +602,8 @@ class Zamba2HybridLayer(nn.Module):
bias=False,
quant_config=quant_config)
self.mamba_decoder = Zamba2MambaDecoderLayer(config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix)
@ -669,6 +677,7 @@ class Zamba2Model(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -718,11 +727,15 @@ class Zamba2Model(nn.Module):
Zamba2HybridLayer(block,
config,
block_idx,
quant_config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix))
else:
layers.append(
Zamba2MambaDecoderLayer(config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix))
self.layers = nn.ModuleList(layers)
@ -848,6 +861,18 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
"1.weight": "B.weight",
})
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.mamba2_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)
@classmethod
def get_mamba_state_shape_from_config(
cls,
@ -966,10 +991,13 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
mamba_state_dtype = \
self.get_mamba_state_dtype_from_config(
self.vllm_config)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_mamba_layers,
*mamba_state_shape)
*mamba_state_shape,
*mamba_state_dtype)
# Get cache parameters for current run
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)

View File

@ -173,6 +173,7 @@ CYAN = '\033[1;36m'
RESET = '\033[0;0m'
STR_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32,
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,

View File

@ -182,14 +182,15 @@ class SlidingWindowSpec(AttentionSpec):
@dataclass(frozen=True)
class MambaSpec(KVCacheSpec):
shapes: tuple[tuple[int, ...], ...]
dtype: torch.dtype
dtypes: tuple[torch.dtype]
page_size_padded: Optional[int] = None
mamba_type: str = "mamba2"
@property
def page_size_bytes(self) -> int:
num_elements = sum(prod(shape) for shape in self.shapes)
page_size = num_elements * get_dtype_size(self.dtype)
page_size = sum(
prod(shape) * get_dtype_size(dtype)
for (shape, dtype) in zip(self.shapes, self.dtypes))
if self.page_size_padded is not None:
assert self.page_size_padded >= page_size
return self.page_size_padded

View File

@ -2884,23 +2884,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
elif isinstance(kv_cache_spec, MambaSpec):
has_mamba = True
raw_tensor = kv_cache_raw_tensors[layer_name]
dtype = kv_cache_spec.dtype
num_element_per_page = (kv_cache_spec.page_size_bytes //
get_dtype_size(dtype))
state_tensors = []
storage_offset = 0
for shape in kv_cache_spec.shapes:
storage_offset_bytes = 0
for (shape, dtype) in zip(kv_cache_spec.shapes,
kv_cache_spec.dtypes):
dtype_size = get_dtype_size(dtype)
num_element_per_page = (
kv_cache_spec.page_size_bytes // dtype_size)
target_shape = (num_blocks, *shape)
stride = torch.empty(target_shape).stride()
target_stride = (num_element_per_page, *stride[1:])
assert storage_offset_bytes % dtype_size == 0
tensor = torch.as_strided(
raw_tensor.view(dtype),
size=target_shape,
stride=target_stride,
storage_offset=storage_offset,
storage_offset=storage_offset_bytes // dtype_size,
)
state_tensors.append(tensor)
storage_offset += stride[0]
storage_offset_bytes += stride[0] * dtype_size
kv_caches[layer_name] = state_tensors
else:
@ -3087,7 +3089,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for layer_name, mamba_module in mamba_layers.items():
kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(),
dtype=self.kv_cache_dtype,
dtypes=mamba_module.get_state_dtype(),
block_size=max_model_len,
page_size_padded=page_size_padded,
mamba_type=mamba_module.mamba_type)