[V1] [Hybrid] Support Minimax-Text-01 in V1 (#22151)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2025-08-09 08:08:48 +02:00 committed by GitHub
parent 3157aebb63
commit 6ade99eafa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 234 additions and 42 deletions

View File

@ -532,7 +532,7 @@ def _linear_attn_decode_kernel(
pid_d = tl.program_id(2) # dimension block index
# Load slot index for the current batch
slot_id = tl.load(slot_idx + pid_b)
slot_id = tl.load(slot_idx + pid_b).to(tl.int64)
# Skip if slot_id is -1 (padding)
if slot_id == -1:

View File

@ -5,6 +5,17 @@ from vllm.distributed import divide
class MambaStateShapeCalculator:
@classmethod
def linear_attention_state_shape(
cls,
num_heads: int,
tp_size: int,
head_dim: int,
) -> tuple[tuple[int, int, int], ...]:
state_shape = (num_heads // tp_size, head_dim, head_dim)
return (state_shape, )
@classmethod
def mamba1_state_shape(
cls,

View File

@ -14,8 +14,9 @@ from einops import rearrange
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from vllm import envs
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_rank,
@ -33,6 +34,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -41,8 +45,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
from .interfaces import HasInnerState, IsHybrid
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
@ -327,7 +332,17 @@ class MiniMaxText01LinearKernel:
return rearrange(output.squeeze(0), "h n d -> n (h d)")
class MiniMaxText01LinearAttention(nn.Module):
class MiniMaxText01LinearAttention(nn.Module, MambaBase):
@property
def mamba_type(self) -> str:
return "linear_attention"
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.linear_attention_state_shape(
num_heads=self.num_heads,
tp_size=self.tp_size,
head_dim=self.head_dim)
def __init__(
self,
@ -359,6 +374,7 @@ class MiniMaxText01LinearAttention(nn.Module):
self.tp_heads = self.total_num_heads // self.tp_size
self.qkv_size = self.num_heads * self.head_dim
self.tp_hidden = self.head_dim * self.tp_heads
self.prefix = prefix
self.qkv_proj = ColumnParallelLinear(
hidden_size,
@ -397,6 +413,12 @@ class MiniMaxText01LinearAttention(nn.Module):
self.tp_heads:(self.tp_rank + 1) *
self.tp_heads].contiguous()
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
@staticmethod
def weight_direct_load(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
@ -434,13 +456,14 @@ class MiniMaxText01LinearAttention(nn.Module):
break
if _prefill_idx >= len(state_indices_tensor):
break
_start = attn_metadata.query_start_loc[_prefill_idx]
_end = attn_metadata.query_start_loc[_prefill_idx + 1]
slot_id = state_indices_tensor[_prefill_idx]
# prefills are packed at end of batch in V1
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slot_id = state_indices_tensor[_prefill_idx]
slice_layer_cache = kv_cache[slot_id, ...]
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
@ -453,9 +476,13 @@ class MiniMaxText01LinearAttention(nn.Module):
layer_idx=self.layer_idx)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden.append(
self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
attn_metadata))
hidden_decode = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
if envs.VLLM_USE_V1:
hidden.insert(0, hidden_decode)
else:
hidden.append(hidden_decode)
if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
@ -465,11 +492,17 @@ class MiniMaxText01LinearAttention(nn.Module):
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0
):]
if not envs.VLLM_USE_V1:
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
num_prefills = getattr(attn_metadata, "num_prefills", 0)
slot_id = state_indices_tensor[num_prefills:]
else:
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
slot_id, 32)
return hidden
@ -483,17 +516,49 @@ class MiniMaxText01LinearAttention(nn.Module):
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
kv_cache = kv_caches.minimax_cache
state_indices_tensor = kv_caches.state_indices_tensor
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor
num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills > 0:
num_decode_tokens = getattr(attn_metadata,
"num_decode_tokens", 0)
for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx]
q_end = attn_metadata.query_start_loc[num_decode_tokens
+ prefill_idx +
1]
query_len = q_end - q_start
context_len = attn_metadata.seq_lens[
num_decode_tokens + prefill_idx] - query_len
if context_len == 0:
block_to_clear = state_indices_tensor[
num_decode_tokens + prefill_idx]
kv_cache[block_to_clear, ...] = 0
else:
kv_cache = kv_caches.minimax_cache
state_indices_tensor = kv_caches.state_indices_tensor
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if not decode_only:
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
if attn_metadata is None:
hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
device=q.device,
dtype=q.dtype)
else:
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor, attn_metadata)
if not decode_only:
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
else:
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
hidden = self.norm._forward(hidden)
gate, _ = self.output_gate(hidden_states)
@ -541,6 +606,7 @@ class MiniMaxText01Attention(nn.Module):
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.prefix = prefix
self.qkv_proj = QKVParallelLinear(
hidden_size,
@ -575,7 +641,12 @@ class MiniMaxText01Attention(nn.Module):
attn_metadata = forward_context.attn_metadata
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = attn_metadata.rotary_emb(positions, q, k)
if envs.VLLM_USE_V1:
if attn_metadata is not None:
q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb(
positions, q, k)
else:
q, k = attn_metadata.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
@ -595,6 +666,7 @@ class MiniMaxText01DecoderLayer(nn.Module):
) -> None:
self._ilayer = layer_id
self._irank = get_tensor_model_parallel_rank()
self.prefix = prefix
super().__init__()
self.hidden_size = config.hidden_size
@ -876,8 +948,9 @@ class MiniMaxText01Model(nn.Module):
self._dtype = _dummy.dtype
del _dummy
self.minimax_cache = MinimaxCacheManager(dtype=torch.float32,
cache_shape=self.cache_shape)
if not envs.VLLM_USE_V1:
self.minimax_cache = MinimaxCacheManager(
dtype=torch.float32, cache_shape=self.cache_shape)
rope_theta = getattr(config, "rope_theta", 10000)
head_dim = getattr(config, "head_dim", None)
@ -944,23 +1017,27 @@ class MiniMaxText01Model(nn.Module):
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
if not envs.VLLM_USE_V1 and attn_metadata is None:
return None
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []
(
minimax_cache_tensors,
state_indices_tensor,
) = self.minimax_cache.current_run_tensors(**kwargs)
if getattr(attn_metadata, "num_prefills", 0) > 0:
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
**kwargs)
if not envs.VLLM_USE_V1:
(
minimax_cache_tensors,
state_indices_tensor,
) = self.minimax_cache.current_run_tensors(**kwargs)
if getattr(attn_metadata, "num_prefills", 0) > 0:
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
**kwargs)
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
state_indices_tensor)
else:
minimax_cache_params = None
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
state_indices_tensor)
if get_pp_group().is_first_rank:
if inputs_embeds is None:
hidden_states = self.embed_scale * self.embed_tokens(input_ids)
@ -973,11 +1050,22 @@ class MiniMaxText01Model(nn.Module):
residual = intermediate_tensors["residual"]
minimax_cache_index = 0
attn_metadata.rotary_emb = self.rotary_emb
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if attn_metadata is not None:
# TODO (tdoublep): this whole thing with the rotary_emb is
# weird. we shouldn't be passing it via attn_metadata imo.
if envs.VLLM_USE_V1:
if isinstance(layer.self_attn, MiniMaxText01Attention):
attn_metadata[layer.prefix +
".attn"].rotary_emb = self.rotary_emb
else:
attn_metadata.rotary_emb = self.rotary_emb
_caches = None
if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
if not envs.VLLM_USE_V1 and isinstance(
layer.self_attn, MiniMaxText01LinearAttention):
current_state_layer = minimax_cache_index
_caches = minimax_cache_params.at_layer_idx(
current_state_layer)
@ -1002,8 +1090,7 @@ class MiniMaxText01Model(nn.Module):
return hidden_states
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
SupportsV0Only):
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
@ -1321,3 +1408,28 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
load_basic_weight(name, loaded_weight, self)
return loaded_params
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, ...], ...]:
"""Calculate shape for MiniMaxText01LinearAttention cache.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- state_shape: Shape of the cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
return MambaStateShapeCalculator.linear_attention_state_shape(
num_heads=hf_config.num_attention_heads,
tp_size=parallel_config.tensor_parallel_size,
head_dim=hf_config.head_dim,
)

View File

@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class LinearAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
return LinearAttentionMetadataBuilder
@dataclass
class LinearAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
state_indices_tensor: torch.Tensor # shape: [batch,]
class LinearAttentionMetadataBuilder(
AttentionMetadataBuilder[LinearAttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> LinearAttentionMetadata:
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
attn_metadata = LinearAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
state_indices_tensor=state_indices_tensor,
)
return attn_metadata

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.abstract import AttentionBackend
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
@ -8,9 +9,10 @@ from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
if mamba_type == "mamba1":
return Mamba1AttentionBackend
if mamba_type == "mamba2":
return Mamba2AttentionBackend
if mamba_type == "linear_attention":
return LinearAttentionBackend
raise NotImplementedError(f"Mamba Attention type {mamba_type} is not "
"supported yet.")