vllm/vllm/model_executor/models/mamba_cache.py
Thomas Parnell 75531a6c13
[V1] [Hybrid] Support using float32 for state in Hybrid Models (Mamba2, Mamba1, Minimax) (#22928)
Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Daniel Afrimi <danielafrimi8@gmail.com>
Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
2025-08-15 12:57:06 +00:00

84 lines
3.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass
class MambaCacheParams:
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MambaCacheParams(self.conv_state[layer_idx],
self.ssm_state[layer_idx],
self.state_indices_tensor)
class MambaCacheManager(ConstantSizeCache):
def __init__(self, vllm_config: VllmConfig, num_mamba_layers: int,
conv_state_shape: tuple[int, int],
temporal_state_shape: tuple[int, int],
conv_state_dtype: torch.dtype,
temporal_state_dtype: torch.dtype):
self.conv_state_dtype = conv_state_dtype
self.temporal_state_dtype = temporal_state_dtype
# Determine max batch size to set size of MambaCache
max_batch_size = vllm_config.scheduler_config.max_num_seqs
if not vllm_config.model_config.enforce_eager:
max_batch_size = vllm_config.pad_for_cudagraph(max_batch_size)
# Initialize parent class
super().__init__(max_batch_size)
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
(conv_state_shape[1], conv_state_shape[0]),
dtype=self.conv_state_dtype,
device="cuda").transpose(-1, -2)
temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=self.temporal_state_dtype,
device="cuda")
self._mamba_cache = (conv_state, temporal_state)
@property
def cache(self):
return self._mamba_cache
def _copy_cache(self, from_index: int, to_index: int):
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)
def current_run_tensors(self, **kwargs) -> MambaCacheParams:
"""
Return the tensors for the current run's conv and ssm state.
"""
cache_tensors, state_indices_tensor = super().current_run_tensors(
**kwargs)
return MambaCacheParams(cache_tensors[0], cache_tensors[1],
state_indices_tensor)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return self._mamba_cache, torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device="cuda")