mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 13:55:38 +08:00
[V1] [Hybrid] Support Minimax-Text-01 in V1 (#22151)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
3157aebb63
commit
6ade99eafa
@ -532,7 +532,7 @@ def _linear_attn_decode_kernel(
|
||||
pid_d = tl.program_id(2) # dimension block index
|
||||
|
||||
# Load slot index for the current batch
|
||||
slot_id = tl.load(slot_idx + pid_b)
|
||||
slot_id = tl.load(slot_idx + pid_b).to(tl.int64)
|
||||
|
||||
# Skip if slot_id is -1 (padding)
|
||||
if slot_id == -1:
|
||||
|
||||
@ -5,6 +5,17 @@ from vllm.distributed import divide
|
||||
|
||||
class MambaStateShapeCalculator:
|
||||
|
||||
@classmethod
|
||||
def linear_attention_state_shape(
|
||||
cls,
|
||||
num_heads: int,
|
||||
tp_size: int,
|
||||
head_dim: int,
|
||||
) -> tuple[tuple[int, int, int], ...]:
|
||||
|
||||
state_shape = (num_heads // tp_size, head_dim, head_dim)
|
||||
return (state_shape, )
|
||||
|
||||
@classmethod
|
||||
def mamba1_state_shape(
|
||||
cls,
|
||||
|
||||
@ -14,8 +14,9 @@ from einops import rearrange
|
||||
from torch import nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tensor_model_parallel_rank,
|
||||
@ -33,6 +34,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -41,8 +45,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
|
||||
from .interfaces import HasInnerState, IsHybrid
|
||||
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||
|
||||
@ -327,7 +332,17 @@ class MiniMaxText01LinearKernel:
|
||||
return rearrange(output.squeeze(0), "h n d -> n (h d)")
|
||||
|
||||
|
||||
class MiniMaxText01LinearAttention(nn.Module):
|
||||
class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "linear_attention"
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.linear_attention_state_shape(
|
||||
num_heads=self.num_heads,
|
||||
tp_size=self.tp_size,
|
||||
head_dim=self.head_dim)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -359,6 +374,7 @@ class MiniMaxText01LinearAttention(nn.Module):
|
||||
self.tp_heads = self.total_num_heads // self.tp_size
|
||||
self.qkv_size = self.num_heads * self.head_dim
|
||||
self.tp_hidden = self.head_dim * self.tp_heads
|
||||
self.prefix = prefix
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
@ -397,6 +413,12 @@ class MiniMaxText01LinearAttention(nn.Module):
|
||||
self.tp_heads:(self.tp_rank + 1) *
|
||||
self.tp_heads].contiguous()
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
@staticmethod
|
||||
def weight_direct_load(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
@ -434,13 +456,14 @@ class MiniMaxText01LinearAttention(nn.Module):
|
||||
break
|
||||
if _prefill_idx >= len(state_indices_tensor):
|
||||
break
|
||||
_start = attn_metadata.query_start_loc[_prefill_idx]
|
||||
_end = attn_metadata.query_start_loc[_prefill_idx + 1]
|
||||
slot_id = state_indices_tensor[_prefill_idx]
|
||||
# prefills are packed at end of batch in V1
|
||||
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
|
||||
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
||||
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
||||
slot_id = state_indices_tensor[offset + _prefill_idx]
|
||||
qs = q[_start:_end].transpose(0, 1).contiguous()
|
||||
ks = k[_start:_end].transpose(0, 1).contiguous()
|
||||
vs = v[_start:_end].transpose(0, 1).contiguous()
|
||||
slot_id = state_indices_tensor[_prefill_idx]
|
||||
slice_layer_cache = kv_cache[slot_id, ...]
|
||||
|
||||
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
|
||||
@ -453,9 +476,13 @@ class MiniMaxText01LinearAttention(nn.Module):
|
||||
layer_idx=self.layer_idx)
|
||||
hidden.append(out_slice.contiguous())
|
||||
if attn_metadata.num_decode_tokens > 0:
|
||||
hidden.append(
|
||||
self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
|
||||
attn_metadata))
|
||||
hidden_decode = self._decode_infer(q, k, v, kv_cache,
|
||||
state_indices_tensor,
|
||||
attn_metadata)
|
||||
if envs.VLLM_USE_V1:
|
||||
hidden.insert(0, hidden_decode)
|
||||
else:
|
||||
hidden.append(hidden_decode)
|
||||
|
||||
if not hidden:
|
||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||
@ -465,11 +492,17 @@ class MiniMaxText01LinearAttention(nn.Module):
|
||||
|
||||
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
|
||||
attn_metadata):
|
||||
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0
|
||||
):]
|
||||
if not envs.VLLM_USE_V1:
|
||||
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
slot_id = state_indices_tensor[num_prefills:]
|
||||
else:
|
||||
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
|
||||
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
|
||||
slot_id, 32)
|
||||
return hidden
|
||||
@ -483,17 +516,49 @@ class MiniMaxText01LinearAttention(nn.Module):
|
||||
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
kv_cache = kv_caches.minimax_cache
|
||||
state_indices_tensor = kv_caches.state_indices_tensor
|
||||
if envs.VLLM_USE_V1:
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
if num_prefills > 0:
|
||||
num_decode_tokens = getattr(attn_metadata,
|
||||
"num_decode_tokens", 0)
|
||||
for prefill_idx in range(num_prefills):
|
||||
q_start = attn_metadata.query_start_loc[
|
||||
num_decode_tokens + prefill_idx]
|
||||
q_end = attn_metadata.query_start_loc[num_decode_tokens
|
||||
+ prefill_idx +
|
||||
1]
|
||||
query_len = q_end - q_start
|
||||
context_len = attn_metadata.seq_lens[
|
||||
num_decode_tokens + prefill_idx] - query_len
|
||||
if context_len == 0:
|
||||
block_to_clear = state_indices_tensor[
|
||||
num_decode_tokens + prefill_idx]
|
||||
kv_cache[block_to_clear, ...] = 0
|
||||
else:
|
||||
kv_cache = kv_caches.minimax_cache
|
||||
state_indices_tensor = kv_caches.state_indices_tensor
|
||||
|
||||
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
||||
if not decode_only:
|
||||
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
|
||||
state_indices_tensor,
|
||||
attn_metadata)
|
||||
if attn_metadata is None:
|
||||
hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
|
||||
device=q.device,
|
||||
dtype=q.dtype)
|
||||
else:
|
||||
hidden = self._decode_infer(q, k, v, kv_cache,
|
||||
state_indices_tensor, attn_metadata)
|
||||
if not decode_only:
|
||||
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
|
||||
state_indices_tensor,
|
||||
attn_metadata)
|
||||
else:
|
||||
hidden = self._decode_infer(q, k, v, kv_cache,
|
||||
state_indices_tensor,
|
||||
attn_metadata)
|
||||
|
||||
hidden = self.norm._forward(hidden)
|
||||
gate, _ = self.output_gate(hidden_states)
|
||||
@ -541,6 +606,7 @@ class MiniMaxText01Attention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
self.prefix = prefix
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@ -575,7 +641,12 @@ class MiniMaxText01Attention(nn.Module):
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = attn_metadata.rotary_emb(positions, q, k)
|
||||
if envs.VLLM_USE_V1:
|
||||
if attn_metadata is not None:
|
||||
q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb(
|
||||
positions, q, k)
|
||||
else:
|
||||
q, k = attn_metadata.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
@ -595,6 +666,7 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
) -> None:
|
||||
self._ilayer = layer_id
|
||||
self._irank = get_tensor_model_parallel_rank()
|
||||
self.prefix = prefix
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -876,8 +948,9 @@ class MiniMaxText01Model(nn.Module):
|
||||
self._dtype = _dummy.dtype
|
||||
del _dummy
|
||||
|
||||
self.minimax_cache = MinimaxCacheManager(dtype=torch.float32,
|
||||
cache_shape=self.cache_shape)
|
||||
if not envs.VLLM_USE_V1:
|
||||
self.minimax_cache = MinimaxCacheManager(
|
||||
dtype=torch.float32, cache_shape=self.cache_shape)
|
||||
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
head_dim = getattr(config, "head_dim", None)
|
||||
@ -944,23 +1017,27 @@ class MiniMaxText01Model(nn.Module):
|
||||
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
if not envs.VLLM_USE_V1 and attn_metadata is None:
|
||||
return None
|
||||
if "request_ids_to_seq_ids" not in kwargs:
|
||||
kwargs["request_ids_to_seq_ids"] = {}
|
||||
if "finished_requests_ids" not in kwargs:
|
||||
kwargs["finished_requests_ids"] = []
|
||||
|
||||
(
|
||||
minimax_cache_tensors,
|
||||
state_indices_tensor,
|
||||
) = self.minimax_cache.current_run_tensors(**kwargs)
|
||||
if getattr(attn_metadata, "num_prefills", 0) > 0:
|
||||
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
|
||||
**kwargs)
|
||||
if not envs.VLLM_USE_V1:
|
||||
(
|
||||
minimax_cache_tensors,
|
||||
state_indices_tensor,
|
||||
) = self.minimax_cache.current_run_tensors(**kwargs)
|
||||
if getattr(attn_metadata, "num_prefills", 0) > 0:
|
||||
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
|
||||
**kwargs)
|
||||
|
||||
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
|
||||
state_indices_tensor)
|
||||
else:
|
||||
minimax_cache_params = None
|
||||
|
||||
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
|
||||
state_indices_tensor)
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is None:
|
||||
hidden_states = self.embed_scale * self.embed_tokens(input_ids)
|
||||
@ -973,11 +1050,22 @@ class MiniMaxText01Model(nn.Module):
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
minimax_cache_index = 0
|
||||
attn_metadata.rotary_emb = self.rotary_emb
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
if attn_metadata is not None:
|
||||
# TODO (tdoublep): this whole thing with the rotary_emb is
|
||||
# weird. we shouldn't be passing it via attn_metadata imo.
|
||||
if envs.VLLM_USE_V1:
|
||||
if isinstance(layer.self_attn, MiniMaxText01Attention):
|
||||
attn_metadata[layer.prefix +
|
||||
".attn"].rotary_emb = self.rotary_emb
|
||||
else:
|
||||
attn_metadata.rotary_emb = self.rotary_emb
|
||||
|
||||
_caches = None
|
||||
if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
|
||||
if not envs.VLLM_USE_V1 and isinstance(
|
||||
layer.self_attn, MiniMaxText01LinearAttention):
|
||||
current_state_layer = minimax_cache_index
|
||||
_caches = minimax_cache_params.at_layer_idx(
|
||||
current_state_layer)
|
||||
@ -1002,8 +1090,7 @@ class MiniMaxText01Model(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
SupportsV0Only):
|
||||
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
|
||||
@ -1321,3 +1408,28 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
|
||||
load_basic_weight(name, loaded_weight, self)
|
||||
return loaded_params
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
use_v1: bool = True,
|
||||
) -> tuple[tuple[int, ...], ...]:
|
||||
"""Calculate shape for MiniMaxText01LinearAttention cache.
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM config
|
||||
use_v1: Get shapes for V1 (or V0)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- state_shape: Shape of the cache
|
||||
"""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
|
||||
return MambaStateShapeCalculator.linear_attention_state_shape(
|
||||
num_heads=hf_config.num_attention_heads,
|
||||
tp_size=parallel_config.tensor_parallel_size,
|
||||
head_dim=hf_config.head_dim,
|
||||
)
|
||||
|
||||
67
vllm/v1/attention/backends/linear_attn.py
Normal file
67
vllm/v1/attention/backends/linear_attn.py
Normal file
@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class LinearAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
|
||||
return LinearAttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinearAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
query_start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||
|
||||
|
||||
class LinearAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[LinearAttentionMetadata]):
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> LinearAttentionMetadata:
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=1))
|
||||
|
||||
attn_metadata = LinearAttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
)
|
||||
return attn_metadata
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
||||
|
||||
@ -8,9 +9,10 @@ from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
||||
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
|
||||
if mamba_type == "mamba1":
|
||||
return Mamba1AttentionBackend
|
||||
|
||||
if mamba_type == "mamba2":
|
||||
return Mamba2AttentionBackend
|
||||
if mamba_type == "linear_attention":
|
||||
return LinearAttentionBackend
|
||||
|
||||
raise NotImplementedError(f"Mamba Attention type {mamba_type} is not "
|
||||
"supported yet.")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user