vllm/vllm/model_executor/models/minimax_cache.py
Simon Mo 02f0c7b220
[Misc] Add SPDX-FileCopyrightText (#19100)
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-06-03 11:20:17 -07:00

37 lines
1.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.model_executor.models.constant_size_cache import ConstantSizeCache
@dataclass
class MinimaxCacheParams:
minimax_cache: torch.Tensor = torch.Tensor()
state_indices_tensor: torch.Tensor = torch.Tensor()
def at_layer_idx(self, layer_idx):
return MinimaxCacheParams(self.minimax_cache[layer_idx, ...],
self.state_indices_tensor)
class MinimaxCacheManager(ConstantSizeCache):
def __init__(self, dtype, cache_shape):
super().__init__(cache_shape[1]) # max_batch_size is cache_shape[1]
self._minimax_cache = torch.empty(size=cache_shape,
dtype=dtype,
device="cuda")
@property
def cache(self):
return self._minimax_cache
def _copy_cache(self, from_index: int, to_index: int):
assert len(self.cache) > 0
for cache_t in self.cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)