diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index e75677347f039..aee0a50336c09 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -431,3 +431,65 @@ def test_full_cuda_graph( name_0="hf" if hf_outputs is not None else "vllm-v0", name_1="vllm-v1", ) + + +@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_fp32_state( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + num_logprobs: int, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + with hf_runner(model) as hf_model: + if model not in HF_UNSUPPORTED_MODELS: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + else: + hf_outputs = None + + with vllm_runner(model, + max_num_seqs=MAX_NUM_SEQS, + mamba_ssm_cache_dtype="float32") as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + if model in HYBRID_MODELS: + # required due to reorder_batch behaviour + m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + with vllm_runner(model, + max_num_seqs=MAX_NUM_SEQS, + mamba_ssm_cache_dtype="float32", + enable_prefix_caching=False) as vllm_model: + vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + if hf_outputs is not None: + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_v0_outputs, + name_0="hf", + name_1="vllm-v0", + ) + + ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs + check_logprobs_close( + outputs_0_lst=ref_outputs, + outputs_1_lst=vllm_v1_outputs, + name_0="hf" if hf_outputs is not None else "vllm-v0", + name_1="vllm-v1", + ) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index e97cdf482710a..4bcc63f293e03 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -772,6 +772,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): head_dim=hf_config.mamba_d_head, rms_norm_eps=hf_config.rms_norm_eps, activation=hf_config.hidden_act, + cache_config=cache_config, + model_config=model_config, prefix=key, ) # suppress var not used error diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index a2e93c344b3f3..82ef8db673fec 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -29,7 +29,7 @@ from typing_extensions import Self, assert_never, runtime_checkable import vllm.envs as envs from vllm import version -from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, +from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, PrefixCachingHashAlgo) from vllm.config.compilation import (CompilationConfig, CompilationLevel, PassConfig) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 69cb0d9732fac..ae11dec3ca5e2 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -23,6 +23,7 @@ logger = init_logger(__name__) BlockSize = Literal[1, 8, 16, 32, 64, 128] CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] +MambaDType = Literal["auto", "float32"] PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"] @@ -93,6 +94,15 @@ class CacheConfig: """ Optional override for mamba page size; used by hybrid mamba/attention models to ensure exact alignment with attention page size.""" + mamba_cache_dtype: MambaDType = "auto" + """The data type to use for the Mamba cache (both the conv as well as the + ssm state). If set to 'auto', the data type will be inferred from the model + config.""" + mamba_ssm_cache_dtype: MambaDType = "auto" + """The data type to use for the Mamba cache (ssm state only, conv state will + still be controlled by mamba_cache_dtype). If set to 'auto', the data type + for the ssm state will be determined by mamba_cache_dtype.""" + # Will be set after profiling. num_gpu_blocks: Optional[int] = field(default=None, init=False) """The number of blocks to allocate for GPU memory.""" @@ -123,6 +133,8 @@ class CacheConfig: """ factors: list[Any] = [] factors.append(self.cache_dtype) + factors.append(self.mamba_cache_dtype) + factors.append(self.mamba_ssm_cache_dtype) # `cpu_offload_gb` does not use `torch.compile` yet. hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 31de2ede7a380..f8af6d36e0c06 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -27,12 +27,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, DeviceConfig, DistributedExecutorBackend, GuidedDecodingBackend, HfOverrides, KVEventsConfig, KVTransferConfig, LoadConfig, LogprobsMode, - LoRAConfig, ModelConfig, ModelDType, ModelImpl, - MultiModalConfig, ObservabilityConfig, ParallelConfig, - PoolerConfig, PrefixCachingHashAlgo, RunnerOption, - SchedulerConfig, SchedulerPolicy, SpeculativeConfig, - TaskOption, TokenizerMode, VllmConfig, get_attr_docs, - get_field) + LoRAConfig, MambaDType, ModelConfig, ModelDType, + ModelImpl, MultiModalConfig, ObservabilityConfig, + ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, + RunnerOption, SchedulerConfig, SchedulerPolicy, + SpeculativeConfig, TaskOption, TokenizerMode, + VllmConfig, get_attr_docs, get_field) from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins @@ -422,6 +422,8 @@ class EngineArgs: override_attention_dtype: str = ModelConfig.override_attention_dtype calculate_kv_scales: bool = CacheConfig.calculate_kv_scales + mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype + mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype additional_config: dict[str, Any] = \ get_field(VllmConfig, "additional_config") @@ -694,6 +696,10 @@ class EngineArgs: **cache_kwargs["calculate_kv_scales"]) cache_group.add_argument("--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"]) + cache_group.add_argument("--mamba-cache-dtype", + **cache_kwargs["mamba_cache_dtype"]) + cache_group.add_argument("--mamba-ssm-cache-dtype", + **cache_kwargs["mamba_ssm_cache_dtype"]) # Multimodal related configs multimodal_kwargs = get_kwargs(MultiModalConfig) @@ -1105,6 +1111,8 @@ class EngineArgs: cpu_offload_gb=self.cpu_offload_gb, calculate_kv_scales=self.calculate_kv_scales, kv_sharing_fast_prefill=self.kv_sharing_fast_prefill, + mamba_cache_dtype=self.mamba_cache_dtype, + mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, ) ray_runtime_env = None diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 3b17fb0ca8c79..3c7322260df43 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -9,7 +9,7 @@ from torch.nn.parameter import Parameter from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import get_current_vllm_config +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.forward_context import ForwardContext, get_forward_context @@ -20,7 +20,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( @@ -56,6 +56,8 @@ class MambaMixer(MambaBase, CustomOp): rms_norm_eps: float = 1e-5, activation="silu", is_lora_enabled: bool = False, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, prefix: str = ""): super().__init__() self.time_step_rank = time_step_rank @@ -153,6 +155,8 @@ class MambaMixer(MambaBase, CustomOp): # The inner tuple is (conv_state, ssm_state) self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + self.model_config = model_config + self.cache_config = cache_config self.prefix = prefix def _ssm_transform( @@ -369,6 +373,15 @@ class MambaMixer(MambaBase, CustomOp): return out + def get_state_dtype(self) -> tuple[torch.dtype]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.mamba1_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + self.cache_config.mamba_ssm_cache_dtype, + ) + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.mamba1_state_shape( tp_world_size=get_tensor_model_parallel_world_size(), diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 6bf0c18ebdb47..743e520ec8ee1 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -8,7 +8,7 @@ from torch import nn from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import get_current_vllm_config +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, @@ -21,7 +21,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, update_metadata) from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated @@ -218,23 +218,23 @@ class MambaMixer2(MambaBase, CustomOp): **selective** state spaces) """ - def __init__( - self, - hidden_size: int, - ssm_state_size: int, - conv_kernel_size: int, - intermediate_size: int, - use_conv_bias: bool, - use_bias: bool, - n_groups: int = 1, - num_heads: int = 128, - head_dim: int = 64, - rms_norm_eps: float = 1e-5, - activation: str = "silu", - use_rms_norm: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): + def __init__(self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation: str = "silu", + use_rms_norm: bool = True, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() # For TP, the sharding plan is as follows: @@ -417,6 +417,8 @@ class MambaMixer2(MambaBase, CustomOp): # The inner tuple is (conv_state, ssm_state) self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + self.model_config = model_config + self.cache_config = cache_config self.prefix = prefix def forward_native( @@ -670,7 +672,7 @@ class MambaMixer2(MambaBase, CustomOp): dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, self.head_dim), - ) + state_dtype=ssm_state.dtype) # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor @@ -732,6 +734,15 @@ class MambaMixer2(MambaBase, CustomOp): # 5. Final linear projection output[:num_actual_tokens], _ = self.out_proj(hidden_states) + def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.mamba2_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + self.cache_config.mamba_ssm_cache_dtype, + ) + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=self.intermediate_size, diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index ad14017912381..66674d1a6f251 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -1,6 +1,58 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Union + +import torch + +from vllm.config import MambaDType, ModelDType from vllm.distributed import divide +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype + + +class MambaStateDtypeCalculator: + + @classmethod + def linear_attention_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, ...]: + # TODO (tdoublep) requires testing + if mamba_cache_dtype == "float32": + raise ValueError("fp32 state for minimax is not yet supported") + state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) + return (state_dtype, ) + + @classmethod + def mamba1_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + mamba_ssm_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, ...]: + # TODO (tdoublep) requires kernel changes + if mamba_cache_dtype == "float32" or mamba_ssm_cache_dtype == "float32": + raise ValueError("fp32 state for mamba1 is not yet supported") + else: + return MambaStateDtypeCalculator.mamba2_state_dtype( + model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype) + + @classmethod + def mamba2_state_dtype( + cls, + model_dtype: Union[ModelDType, torch.dtype], + mamba_cache_dtype: MambaDType, + mamba_ssm_cache_dtype: MambaDType, + ) -> tuple[torch.dtype, ...]: + conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, + model_dtype) + if mamba_ssm_cache_dtype == "auto": + temporal_state_dtype = conv_state_dtype + else: + temporal_state_dtype = ( + STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype]) + + return (conv_state_dtype, temporal_state_dtype) class MambaStateShapeCalculator: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index fd74cb837290b..d0b3e9e5235bf 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -41,6 +41,7 @@ def _mamba_chunk_scan_combined_fwd(x, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), + state_dtype=None, out=None): assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" batch, seqlen, nheads, headdim = x.shape @@ -118,7 +119,7 @@ def _mamba_chunk_scan_combined_fwd(x, if initial_states is not None else None, seq_idx=seq_idx, chunk_size=chunk_size, - out_dtype=C.dtype, + out_dtype=state_dtype if state_dtype is not None else C.dtype, is_cont_batched=cu_seqlens is not None) states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]) @@ -189,7 +190,8 @@ def mamba_chunk_scan_combined(x, dt_limit=(0.0, float("inf")), out=None, return_final_states=False, - return_varlen_states=False): + return_varlen_states=False, + state_dtype=None): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -206,6 +208,7 @@ def mamba_chunk_scan_combined(x, cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True dt_softplus: Whether to apply softplus to dt out: Preallocated output tensor + state_dtype: The data type of the ssm state """ if not return_varlen_states: @@ -229,7 +232,8 @@ def mamba_chunk_scan_combined(x, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit, - out=out) + out=out, + state_dtype=state_dtype) if not return_varlen_states: if not return_final_states: return diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 4a2ae07581f3e..e2cd31af5390a 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -12,7 +12,7 @@ from transformers import BambaConfig from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context @@ -26,7 +26,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -83,6 +83,7 @@ class BambaMixerDecoderLayer(nn.Module): def __init__(self, config: BambaConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> None: @@ -100,6 +101,8 @@ class BambaMixerDecoderLayer(nn.Module): head_dim=config.mamba_d_head, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer") @@ -138,6 +141,7 @@ class BambaAttentionDecoderLayer(nn.Module): self, config: BambaConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -266,6 +270,7 @@ class BambaModel(nn.Module): super().__init__() config: BambaConfig = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -289,6 +294,7 @@ class BambaModel(nn.Module): return layer_class( config, layer_idx, + model_config, cache_config, quant_config=quant_config, prefix=prefix, @@ -437,6 +443,18 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -528,10 +546,13 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, num_mamba_layers, - *mamba_state_shape) + *mamba_state_shape, + *mamba_state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 6f21cd267b0e6..882df7e8162c5 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -318,7 +318,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): # get mamba page size mamba_page_size = MambaSpec( shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), - dtype=kv_cache_dtype, + dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), block_size=model_config.max_model_len, ).page_size_bytes diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 85d64af5bd281..5e2b6d69124c8 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -11,7 +11,7 @@ from transformers import FalconH1Config from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context @@ -25,7 +25,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -85,6 +85,7 @@ class FalconH1SSMDecoderLayer(nn.Module): def __init__( self, config: FalconH1Config, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -108,6 +109,8 @@ class FalconH1SSMDecoderLayer(nn.Module): head_dim=config.mamba_d_head, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, use_rms_norm=config.mamba_rms_norm, prefix=f"{prefix}.mixer", @@ -317,6 +320,7 @@ class FalconH1ParallelHybrid(nn.Module): self, config: FalconH1Config, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -339,6 +343,7 @@ class FalconH1ParallelHybrid(nn.Module): # Instantiate the SSM branch self.mamba = FalconH1SSMDecoderLayer( config=config, + model_config=model_config, cache_config=cache_config, quant_config=quant_config, prefix=ssm_prefix, @@ -408,6 +413,7 @@ class FalconH1Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: FalconH1Config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -435,6 +441,7 @@ class FalconH1Model(nn.Module): return layer_class( config, layer_idx, + model_config, cache_config, quant_config=quant_config, prefix=prefix, @@ -519,6 +526,18 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -624,12 +643,14 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager( self.vllm_config, - self.lm_head.weight.dtype if hasattr( - self.lm_head, 'weight') else torch.bfloat16, self.config.num_hidden_layers, *mamba_state_shape, + *mamba_state_dtype, ) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index e59502f12a1cc..5704496b9a5d4 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -12,7 +12,7 @@ from transformers import GraniteMoeHybridConfig from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context @@ -24,7 +24,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -50,6 +50,7 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> None: @@ -70,6 +71,8 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module): head_dim=config.mamba_d_head, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer") @@ -137,6 +140,7 @@ class GraniteMoeHybridAttentionDecoderLayer(nn.Module): self, config: GraniteMoeHybridConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -217,6 +221,7 @@ class GraniteMoeHybridAttention(nn.Module): def __init__( self, config: GraniteMoeHybridConfig, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -316,6 +321,7 @@ class GraniteMoeHybridModel(nn.Module): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -340,6 +346,7 @@ class GraniteMoeHybridModel(nn.Module): return layer_class( config, layer_idx, + model_config, cache_config, quant_config=quant_config, prefix=prefix, @@ -527,6 +534,18 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -625,10 +644,13 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.model_config.dtype, num_mamba_layers, - *mamba_state_shape) + *mamba_state_shape, + *mamba_state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index fbd310121ad47..0b32d6f256590 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -10,7 +10,7 @@ from transformers import JambaConfig from vllm import envs from vllm.attention.layer import Attention -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE @@ -21,7 +21,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -94,6 +94,7 @@ class JambaMambaDecoderLayer(nn.Module): def __init__(self, config: JambaConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, is_lora_enabled: Optional[bool] = False, @@ -114,6 +115,8 @@ class JambaMambaDecoderLayer(nn.Module): rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, is_lora_enabled = self.is_lora_enabled, + model_config=model_config, + cache_config=cache_config, prefix=f"{prefix}.mixer", ) @@ -164,6 +167,7 @@ class JambaAttentionDecoderLayer(nn.Module): def __init__(self, config: JambaConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -280,6 +284,7 @@ class JambaModel(nn.Module): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -304,6 +309,7 @@ class JambaModel(nn.Module): config.layers_block_type[layer_idx]] return layer_class(config, layer_idx, + model_config, cache_config, quant_config=quant_config, prefix=prefix, @@ -520,9 +526,11 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, self.vllm_config.parallel_config, LayerBlockType.mamba) state_shape = self.get_mamba_state_shape_from_config( self.vllm_config) + state_dtype = self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, - num_layers, *state_shape) + num_layers, *state_shape, + *state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -537,6 +545,18 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba1_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 80b63e15377a2..f4aaf0c6f467c 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -9,13 +9,13 @@ from torch import nn from transformers import MambaConfig from vllm import envs -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -40,6 +40,7 @@ class MambaDecoderLayer(nn.Module): def __init__(self, config: MambaConfig, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, is_lora_enabled: Optional[bool] = False, @@ -61,6 +62,8 @@ class MambaDecoderLayer(nn.Module): rms_norm_eps=mixer_rms_eps, activation=config.hidden_act, is_lora_enabled=self.is_lora_enabled, + model_config=model_config, + cache_config=cache_config, prefix=f"{prefix}.mixer") self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -88,6 +91,7 @@ class MambaModel(nn.Module): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -108,6 +112,7 @@ class MambaModel(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MambaDecoderLayer(config, + model_config=model_config, cache_config=cache_config, quant_config=quant_config, is_lora_enabled=is_lora_enabled, @@ -243,9 +248,11 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): self.vllm_config.parallel_config, LayerBlockType.mamba) state_shape = self.get_mamba_state_shape_from_config( self.vllm_config) + state_dtype = self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, - num_layers, *state_shape) + num_layers, *state_shape, + *state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -254,6 +261,18 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): return hidden_states + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba1_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 75e92b01762da..3432cf29feac6 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -11,7 +11,7 @@ from transformers import MambaConfig from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm @@ -20,7 +20,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -45,6 +45,8 @@ class Mamba2DecoderLayer(nn.Module): def __init__(self, config: MambaConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> None: super().__init__() @@ -62,6 +64,8 @@ class Mamba2DecoderLayer(nn.Module): head_dim=config.head_dim, rms_norm_eps=config.layer_norm_epsilon, activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer") @@ -93,6 +97,8 @@ class Mamba2Model(nn.Module): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config is_lora_enabled = bool(lora_config) @@ -112,8 +118,11 @@ class Mamba2Model(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Mamba2DecoderLayer( - config, quant_config=quant_config, prefix=prefix), + lambda prefix: Mamba2DecoderLayer(config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.layers") self.norm_f = RMSNorm(config.hidden_size, @@ -200,6 +209,18 @@ class Mamba2Model(nn.Module): class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -290,10 +311,13 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, num_mamba_layers, - *mamba_state_shape) + *mamba_state_shape, + *mamba_state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) else: diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 27685c59a3eac..6b16e3ce7d984 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -24,9 +24,14 @@ class MambaCacheParams: class MambaCacheManager(ConstantSizeCache): - def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype, - num_mamba_layers: int, conv_state_shape: tuple[int, int], - temporal_state_shape: tuple[int, int]): + def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int, + conv_state_shape: tuple[int, int], + temporal_state_shape: tuple[int, int], + conv_state_dtype: torch.dtype, + temporal_state_dtype: torch.dtype): + + self.conv_state_dtype = conv_state_dtype + self.temporal_state_dtype = temporal_state_dtype # Determine max batch size to set size of MambaCache max_batch_size = vllm_config.scheduler_config.max_num_seqs @@ -40,11 +45,11 @@ class MambaCacheManager(ConstantSizeCache): assert conv_state_shape[0] > conv_state_shape[1] conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + (conv_state_shape[1], conv_state_shape[0]), - dtype=dtype, + dtype=self.conv_state_dtype, device="cuda").transpose(-1, -2) temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) + temporal_state_shape, - dtype=dtype, + dtype=self.temporal_state_dtype, device="cuda") self._mamba_cache = (conv_state, temporal_state) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 3d14a6ad5c3a4..82e96844cd5f6 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -16,7 +16,8 @@ from transformers import MiniMaxConfig from vllm import envs from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.config import (CacheConfig, ModelConfig, VllmConfig, + get_current_vllm_config) from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_rank, @@ -36,7 +37,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -338,6 +339,12 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): def mamba_type(self) -> str: return "linear_attention" + def get_state_dtype(self) -> tuple[torch.dtype]: + return MambaStateDtypeCalculator.linear_attention_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + ) + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.linear_attention_state_shape( num_heads=self.num_heads, @@ -353,6 +360,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): max_position: int, block_size: int, num_hidden_layer: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, layer_idx: int = 0, linear_layer_idx: int = 0, @@ -374,6 +383,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): self.tp_heads = self.total_num_heads // self.tp_size self.qkv_size = self.num_heads * self.head_dim self.tp_hidden = self.head_dim * self.tp_heads + self.model_config = model_config + self.cache_config = cache_config self.prefix = prefix self.qkv_proj = ColumnParallelLinear( @@ -657,6 +668,7 @@ class MiniMaxText01DecoderLayer(nn.Module): def __init__( self, config: MiniMaxConfig, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, expert_num: int = 1, @@ -693,6 +705,8 @@ class MiniMaxText01DecoderLayer(nn.Module): max_position=max_position_embeddings, block_size=config.block if hasattr(config, "block") else 256, num_hidden_layer=config.num_hidden_layers, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, layer_idx=self._ilayer, linear_layer_idx=linear_layer_id, @@ -861,6 +875,7 @@ class MiniMaxText01Model(nn.Module): def __init__( self, config: MiniMaxConfig, + model_config: Optional[ModelConfig] = None, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, scheduler_config=None, @@ -910,6 +925,7 @@ class MiniMaxText01Model(nn.Module): decoder_kwargs = { "quant_config": quant_config, "layer_id": layer_idx, + "model_config": model_config, "cache_config": cache_config } @@ -1111,8 +1127,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): self.config.max_model_len = vllm_config.model_config.max_model_len self.model = MiniMaxText01Model( self.config, - quant_config, + model_config=vllm_config.model_config, cache_config=vllm_config.cache_config, + quant_config=quant_config, scheduler_config=vllm_config.scheduler_config, prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: @@ -1409,6 +1426,17 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): load_basic_weight(name, loaded_weight, self) return loaded_params + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.linear_attention_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 08315a13853c0..07cd5a4c6e24f 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -26,7 +26,7 @@ from torch import nn from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context @@ -40,7 +40,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -110,6 +110,7 @@ class NemotronHMLPDecoderLayer(nn.Module): self, config: NemotronHConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -149,6 +150,7 @@ class NemotronHMambaDecoderLayer(nn.Module): self, config: NemotronHConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -167,6 +169,8 @@ class NemotronHMambaDecoderLayer(nn.Module): head_dim=config.mamba_head_dim, rms_norm_eps=config.rms_norm_eps, activation=config.mamba_hidden_act, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer", ) @@ -198,6 +202,7 @@ class NemotronHAttention(nn.Module): self, config: NemotronHConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -270,6 +275,7 @@ class NemotronHAttentionDecoderLayer(nn.Module): self, config: NemotronHConfig, layer_idx: int, + model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -279,6 +285,7 @@ class NemotronHAttentionDecoderLayer(nn.Module): self.mixer = NemotronHAttention( config, layer_idx, + model_config, cache_config, quant_config, prefix=f"{prefix}.mixer", @@ -317,6 +324,7 @@ class NemotronHModel(nn.Module): super().__init__() config: NemotronHConfig = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -340,6 +348,7 @@ class NemotronHModel(nn.Module): return layer_class( config, layer_idx, + model_config, cache_config, quant_config=quant_config, prefix=prefix, @@ -478,6 +487,18 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -569,10 +590,13 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, num_mamba_layers, - *mamba_state_shape) + *mamba_state_shape, + *mamba_state_dtype) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 4cb0becf302f1..ed65944c109bd 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -18,7 +18,7 @@ from transformers import Zamba2Config from vllm import envs from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import GeluAndMul @@ -33,7 +33,7 @@ from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateShapeCalculator) + MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -478,6 +478,8 @@ class Zamba2MambaDecoderLayer(nn.Module): def __init__(self, config: Zamba2Config, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "") -> None: """Initialize the Mamba decoder layer. @@ -502,6 +504,8 @@ class Zamba2MambaDecoderLayer(nn.Module): config.n_mamba_heads, rms_norm_eps=config.rms_norm_eps, activation="silu", + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer") @@ -578,6 +582,8 @@ class Zamba2HybridLayer(nn.Module): shared_transformer: Zamba2AttentionDecoderLayer, config: Zamba2Config, block_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: @@ -596,6 +602,8 @@ class Zamba2HybridLayer(nn.Module): bias=False, quant_config=quant_config) self.mamba_decoder = Zamba2MambaDecoderLayer(config, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=prefix) @@ -669,6 +677,7 @@ class Zamba2Model(nn.Module): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config @@ -718,11 +727,15 @@ class Zamba2Model(nn.Module): Zamba2HybridLayer(block, config, block_idx, - quant_config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, prefix=prefix)) else: layers.append( Zamba2MambaDecoderLayer(config, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, prefix=prefix)) self.layers = nn.ModuleList(layers) @@ -848,6 +861,18 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): "1.weight": "B.weight", }) + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype]: + + return MambaStateDtypeCalculator.mamba2_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + vllm_config.cache_config.mamba_ssm_cache_dtype, + ) + @classmethod def get_mamba_state_shape_from_config( cls, @@ -966,10 +991,13 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): mamba_state_shape = \ self.get_mamba_state_shape_from_config( self.vllm_config, use_v1=False) + mamba_state_dtype = \ + self.get_mamba_state_dtype_from_config( + self.vllm_config) self.mamba_cache = MambaCacheManager(self.vllm_config, - self.lm_head.weight.dtype, num_mamba_layers, - *mamba_state_shape) + *mamba_state_shape, + *mamba_state_dtype) # Get cache parameters for current run mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index cae4eecc0deeb..a1f8ad164762d 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -173,6 +173,7 @@ CYAN = '\033[1;36m' RESET = '\033[0;0m' STR_DTYPE_TO_TORCH_DTYPE = { + "float32": torch.float32, "half": torch.half, "bfloat16": torch.bfloat16, "float": torch.float, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 4ff96f9786b88..429416afa2483 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -182,14 +182,15 @@ class SlidingWindowSpec(AttentionSpec): @dataclass(frozen=True) class MambaSpec(KVCacheSpec): shapes: tuple[tuple[int, ...], ...] - dtype: torch.dtype + dtypes: tuple[torch.dtype] page_size_padded: Optional[int] = None mamba_type: str = "mamba2" @property def page_size_bytes(self) -> int: - num_elements = sum(prod(shape) for shape in self.shapes) - page_size = num_elements * get_dtype_size(self.dtype) + page_size = sum( + prod(shape) * get_dtype_size(dtype) + for (shape, dtype) in zip(self.shapes, self.dtypes)) if self.page_size_padded is not None: assert self.page_size_padded >= page_size return self.page_size_padded diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 703092ca9feeb..d5325287889fd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2884,23 +2884,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): elif isinstance(kv_cache_spec, MambaSpec): has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] - dtype = kv_cache_spec.dtype - num_element_per_page = (kv_cache_spec.page_size_bytes // - get_dtype_size(dtype)) state_tensors = [] - storage_offset = 0 - for shape in kv_cache_spec.shapes: + storage_offset_bytes = 0 + for (shape, dtype) in zip(kv_cache_spec.shapes, + kv_cache_spec.dtypes): + dtype_size = get_dtype_size(dtype) + num_element_per_page = ( + kv_cache_spec.page_size_bytes // dtype_size) target_shape = (num_blocks, *shape) stride = torch.empty(target_shape).stride() target_stride = (num_element_per_page, *stride[1:]) + assert storage_offset_bytes % dtype_size == 0 tensor = torch.as_strided( raw_tensor.view(dtype), size=target_shape, stride=target_stride, - storage_offset=storage_offset, + storage_offset=storage_offset_bytes // dtype_size, ) state_tensors.append(tensor) - storage_offset += stride[0] + storage_offset_bytes += stride[0] * dtype_size kv_caches[layer_name] = state_tensors else: @@ -3087,7 +3089,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for layer_name, mamba_module in mamba_layers.items(): kv_cache_spec[layer_name] = MambaSpec( shapes=mamba_module.get_state_shape(), - dtype=self.kv_cache_dtype, + dtypes=mamba_module.get_state_dtype(), block_size=max_model_len, page_size_padded=page_size_padded, mamba_type=mamba_module.mamba_type)