mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 15:56:07 +08:00
[V1] [Hybrid] Support Minimax-Text-01 in V1 (#22151)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
3157aebb63
commit
6ade99eafa
@ -532,7 +532,7 @@ def _linear_attn_decode_kernel(
|
|||||||
pid_d = tl.program_id(2) # dimension block index
|
pid_d = tl.program_id(2) # dimension block index
|
||||||
|
|
||||||
# Load slot index for the current batch
|
# Load slot index for the current batch
|
||||||
slot_id = tl.load(slot_idx + pid_b)
|
slot_id = tl.load(slot_idx + pid_b).to(tl.int64)
|
||||||
|
|
||||||
# Skip if slot_id is -1 (padding)
|
# Skip if slot_id is -1 (padding)
|
||||||
if slot_id == -1:
|
if slot_id == -1:
|
||||||
|
|||||||
@ -5,6 +5,17 @@ from vllm.distributed import divide
|
|||||||
|
|
||||||
class MambaStateShapeCalculator:
|
class MambaStateShapeCalculator:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def linear_attention_state_shape(
|
||||||
|
cls,
|
||||||
|
num_heads: int,
|
||||||
|
tp_size: int,
|
||||||
|
head_dim: int,
|
||||||
|
) -> tuple[tuple[int, int, int], ...]:
|
||||||
|
|
||||||
|
state_shape = (num_heads // tp_size, head_dim, head_dim)
|
||||||
|
return (state_shape, )
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def mamba1_state_shape(
|
def mamba1_state_shape(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -14,8 +14,9 @@ from einops import rearrange
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_pp_group, get_tensor_model_parallel_rank,
|
get_pp_group, get_tensor_model_parallel_rank,
|
||||||
@ -33,6 +34,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
|
MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -41,8 +45,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.models.utils import maybe_prefix
|
from vllm.model_executor.models.utils import maybe_prefix
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
||||||
|
|
||||||
from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
|
from .interfaces import HasInnerState, IsHybrid
|
||||||
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
|
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
|
||||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||||
|
|
||||||
@ -327,7 +332,17 @@ class MiniMaxText01LinearKernel:
|
|||||||
return rearrange(output.squeeze(0), "h n d -> n (h d)")
|
return rearrange(output.squeeze(0), "h n d -> n (h d)")
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxText01LinearAttention(nn.Module):
|
class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mamba_type(self) -> str:
|
||||||
|
return "linear_attention"
|
||||||
|
|
||||||
|
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||||
|
return MambaStateShapeCalculator.linear_attention_state_shape(
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
tp_size=self.tp_size,
|
||||||
|
head_dim=self.head_dim)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -359,6 +374,7 @@ class MiniMaxText01LinearAttention(nn.Module):
|
|||||||
self.tp_heads = self.total_num_heads // self.tp_size
|
self.tp_heads = self.total_num_heads // self.tp_size
|
||||||
self.qkv_size = self.num_heads * self.head_dim
|
self.qkv_size = self.num_heads * self.head_dim
|
||||||
self.tp_hidden = self.head_dim * self.tp_heads
|
self.tp_hidden = self.head_dim * self.tp_heads
|
||||||
|
self.prefix = prefix
|
||||||
|
|
||||||
self.qkv_proj = ColumnParallelLinear(
|
self.qkv_proj = ColumnParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@ -397,6 +413,12 @@ class MiniMaxText01LinearAttention(nn.Module):
|
|||||||
self.tp_heads:(self.tp_rank + 1) *
|
self.tp_heads:(self.tp_rank + 1) *
|
||||||
self.tp_heads].contiguous()
|
self.tp_heads].contiguous()
|
||||||
|
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
if prefix in compilation_config.static_forward_context:
|
||||||
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
|
compilation_config.static_forward_context[prefix] = self
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def weight_direct_load(param: torch.Tensor,
|
def weight_direct_load(param: torch.Tensor,
|
||||||
loaded_weight: torch.Tensor) -> None:
|
loaded_weight: torch.Tensor) -> None:
|
||||||
@ -434,13 +456,14 @@ class MiniMaxText01LinearAttention(nn.Module):
|
|||||||
break
|
break
|
||||||
if _prefill_idx >= len(state_indices_tensor):
|
if _prefill_idx >= len(state_indices_tensor):
|
||||||
break
|
break
|
||||||
_start = attn_metadata.query_start_loc[_prefill_idx]
|
# prefills are packed at end of batch in V1
|
||||||
_end = attn_metadata.query_start_loc[_prefill_idx + 1]
|
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
|
||||||
slot_id = state_indices_tensor[_prefill_idx]
|
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
||||||
|
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
||||||
|
slot_id = state_indices_tensor[offset + _prefill_idx]
|
||||||
qs = q[_start:_end].transpose(0, 1).contiguous()
|
qs = q[_start:_end].transpose(0, 1).contiguous()
|
||||||
ks = k[_start:_end].transpose(0, 1).contiguous()
|
ks = k[_start:_end].transpose(0, 1).contiguous()
|
||||||
vs = v[_start:_end].transpose(0, 1).contiguous()
|
vs = v[_start:_end].transpose(0, 1).contiguous()
|
||||||
slot_id = state_indices_tensor[_prefill_idx]
|
|
||||||
slice_layer_cache = kv_cache[slot_id, ...]
|
slice_layer_cache = kv_cache[slot_id, ...]
|
||||||
|
|
||||||
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
|
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
|
||||||
@ -453,9 +476,13 @@ class MiniMaxText01LinearAttention(nn.Module):
|
|||||||
layer_idx=self.layer_idx)
|
layer_idx=self.layer_idx)
|
||||||
hidden.append(out_slice.contiguous())
|
hidden.append(out_slice.contiguous())
|
||||||
if attn_metadata.num_decode_tokens > 0:
|
if attn_metadata.num_decode_tokens > 0:
|
||||||
hidden.append(
|
hidden_decode = self._decode_infer(q, k, v, kv_cache,
|
||||||
self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
|
state_indices_tensor,
|
||||||
attn_metadata))
|
attn_metadata)
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
hidden.insert(0, hidden_decode)
|
||||||
|
else:
|
||||||
|
hidden.append(hidden_decode)
|
||||||
|
|
||||||
if not hidden:
|
if not hidden:
|
||||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||||
@ -465,11 +492,17 @@ class MiniMaxText01LinearAttention(nn.Module):
|
|||||||
|
|
||||||
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
|
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
|
||||||
attn_metadata):
|
attn_metadata):
|
||||||
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
if not envs.VLLM_USE_V1:
|
||||||
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||||
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||||
slot_id = state_indices_tensor[getattr(attn_metadata, "num_prefills", 0
|
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||||
):]
|
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||||
|
slot_id = state_indices_tensor[num_prefills:]
|
||||||
|
else:
|
||||||
|
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||||
|
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||||
|
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||||
|
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
|
||||||
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
|
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
|
||||||
slot_id, 32)
|
slot_id, 32)
|
||||||
return hidden
|
return hidden
|
||||||
@ -483,17 +516,49 @@ class MiniMaxText01LinearAttention(nn.Module):
|
|||||||
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
kv_cache = kv_caches.minimax_cache
|
if envs.VLLM_USE_V1:
|
||||||
state_indices_tensor = kv_caches.state_indices_tensor
|
if attn_metadata is not None:
|
||||||
|
assert isinstance(attn_metadata, dict)
|
||||||
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
|
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
||||||
|
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||||
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
|
|
||||||
|
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||||
|
if num_prefills > 0:
|
||||||
|
num_decode_tokens = getattr(attn_metadata,
|
||||||
|
"num_decode_tokens", 0)
|
||||||
|
for prefill_idx in range(num_prefills):
|
||||||
|
q_start = attn_metadata.query_start_loc[
|
||||||
|
num_decode_tokens + prefill_idx]
|
||||||
|
q_end = attn_metadata.query_start_loc[num_decode_tokens
|
||||||
|
+ prefill_idx +
|
||||||
|
1]
|
||||||
|
query_len = q_end - q_start
|
||||||
|
context_len = attn_metadata.seq_lens[
|
||||||
|
num_decode_tokens + prefill_idx] - query_len
|
||||||
|
if context_len == 0:
|
||||||
|
block_to_clear = state_indices_tensor[
|
||||||
|
num_decode_tokens + prefill_idx]
|
||||||
|
kv_cache[block_to_clear, ...] = 0
|
||||||
|
else:
|
||||||
|
kv_cache = kv_caches.minimax_cache
|
||||||
|
state_indices_tensor = kv_caches.state_indices_tensor
|
||||||
|
|
||||||
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
||||||
if not decode_only:
|
if attn_metadata is None:
|
||||||
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
|
hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
|
||||||
state_indices_tensor,
|
device=q.device,
|
||||||
attn_metadata)
|
dtype=q.dtype)
|
||||||
else:
|
else:
|
||||||
hidden = self._decode_infer(q, k, v, kv_cache,
|
if not decode_only:
|
||||||
state_indices_tensor, attn_metadata)
|
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
|
||||||
|
state_indices_tensor,
|
||||||
|
attn_metadata)
|
||||||
|
else:
|
||||||
|
hidden = self._decode_infer(q, k, v, kv_cache,
|
||||||
|
state_indices_tensor,
|
||||||
|
attn_metadata)
|
||||||
|
|
||||||
hidden = self.norm._forward(hidden)
|
hidden = self.norm._forward(hidden)
|
||||||
gate, _ = self.output_gate(hidden_states)
|
gate, _ = self.output_gate(hidden_states)
|
||||||
@ -541,6 +606,7 @@ class MiniMaxText01Attention(nn.Module):
|
|||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
|
self.prefix = prefix
|
||||||
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@ -575,7 +641,12 @@ class MiniMaxText01Attention(nn.Module):
|
|||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = attn_metadata.rotary_emb(positions, q, k)
|
if envs.VLLM_USE_V1:
|
||||||
|
if attn_metadata is not None:
|
||||||
|
q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb(
|
||||||
|
positions, q, k)
|
||||||
|
else:
|
||||||
|
q, k = attn_metadata.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
@ -595,6 +666,7 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self._ilayer = layer_id
|
self._ilayer = layer_id
|
||||||
self._irank = get_tensor_model_parallel_rank()
|
self._irank = get_tensor_model_parallel_rank()
|
||||||
|
self.prefix = prefix
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -876,8 +948,9 @@ class MiniMaxText01Model(nn.Module):
|
|||||||
self._dtype = _dummy.dtype
|
self._dtype = _dummy.dtype
|
||||||
del _dummy
|
del _dummy
|
||||||
|
|
||||||
self.minimax_cache = MinimaxCacheManager(dtype=torch.float32,
|
if not envs.VLLM_USE_V1:
|
||||||
cache_shape=self.cache_shape)
|
self.minimax_cache = MinimaxCacheManager(
|
||||||
|
dtype=torch.float32, cache_shape=self.cache_shape)
|
||||||
|
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
head_dim = getattr(config, "head_dim", None)
|
head_dim = getattr(config, "head_dim", None)
|
||||||
@ -944,23 +1017,27 @@ class MiniMaxText01Model(nn.Module):
|
|||||||
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
|
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
if attn_metadata is None:
|
if not envs.VLLM_USE_V1 and attn_metadata is None:
|
||||||
return None
|
return None
|
||||||
if "request_ids_to_seq_ids" not in kwargs:
|
if "request_ids_to_seq_ids" not in kwargs:
|
||||||
kwargs["request_ids_to_seq_ids"] = {}
|
kwargs["request_ids_to_seq_ids"] = {}
|
||||||
if "finished_requests_ids" not in kwargs:
|
if "finished_requests_ids" not in kwargs:
|
||||||
kwargs["finished_requests_ids"] = []
|
kwargs["finished_requests_ids"] = []
|
||||||
|
|
||||||
(
|
if not envs.VLLM_USE_V1:
|
||||||
minimax_cache_tensors,
|
(
|
||||||
state_indices_tensor,
|
minimax_cache_tensors,
|
||||||
) = self.minimax_cache.current_run_tensors(**kwargs)
|
state_indices_tensor,
|
||||||
if getattr(attn_metadata, "num_prefills", 0) > 0:
|
) = self.minimax_cache.current_run_tensors(**kwargs)
|
||||||
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
|
if getattr(attn_metadata, "num_prefills", 0) > 0:
|
||||||
**kwargs)
|
self._clear_prefill_cache(attn_metadata, minimax_cache_tensors,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
|
||||||
|
state_indices_tensor)
|
||||||
|
else:
|
||||||
|
minimax_cache_params = None
|
||||||
|
|
||||||
minimax_cache_params = MinimaxCacheParams(minimax_cache_tensors,
|
|
||||||
state_indices_tensor)
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
hidden_states = self.embed_scale * self.embed_tokens(input_ids)
|
hidden_states = self.embed_scale * self.embed_tokens(input_ids)
|
||||||
@ -973,11 +1050,22 @@ class MiniMaxText01Model(nn.Module):
|
|||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
minimax_cache_index = 0
|
minimax_cache_index = 0
|
||||||
attn_metadata.rotary_emb = self.rotary_emb
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for i in range(self.start_layer, self.end_layer):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
|
if attn_metadata is not None:
|
||||||
|
# TODO (tdoublep): this whole thing with the rotary_emb is
|
||||||
|
# weird. we shouldn't be passing it via attn_metadata imo.
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
if isinstance(layer.self_attn, MiniMaxText01Attention):
|
||||||
|
attn_metadata[layer.prefix +
|
||||||
|
".attn"].rotary_emb = self.rotary_emb
|
||||||
|
else:
|
||||||
|
attn_metadata.rotary_emb = self.rotary_emb
|
||||||
|
|
||||||
_caches = None
|
_caches = None
|
||||||
if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
|
if not envs.VLLM_USE_V1 and isinstance(
|
||||||
|
layer.self_attn, MiniMaxText01LinearAttention):
|
||||||
current_state_layer = minimax_cache_index
|
current_state_layer = minimax_cache_index
|
||||||
_caches = minimax_cache_params.at_layer_idx(
|
_caches = minimax_cache_params.at_layer_idx(
|
||||||
current_state_layer)
|
current_state_layer)
|
||||||
@ -1002,8 +1090,7 @@ class MiniMaxText01Model(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||||
SupportsV0Only):
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
|
|
||||||
@ -1321,3 +1408,28 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
|||||||
|
|
||||||
load_basic_weight(name, loaded_weight, self)
|
load_basic_weight(name, loaded_weight, self)
|
||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mamba_state_shape_from_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
use_v1: bool = True,
|
||||||
|
) -> tuple[tuple[int, ...], ...]:
|
||||||
|
"""Calculate shape for MiniMaxText01LinearAttention cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vllm_config: vLLM config
|
||||||
|
use_v1: Get shapes for V1 (or V0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- state_shape: Shape of the cache
|
||||||
|
"""
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
hf_config = vllm_config.model_config.hf_config
|
||||||
|
|
||||||
|
return MambaStateShapeCalculator.linear_attention_state_shape(
|
||||||
|
num_heads=hf_config.num_attention_heads,
|
||||||
|
tp_size=parallel_config.tensor_parallel_size,
|
||||||
|
head_dim=hf_config.head_dim,
|
||||||
|
)
|
||||||
|
|||||||
67
vllm/v1/attention/backends/linear_attn.py
Normal file
67
vllm/v1/attention/backends/linear_attn.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
|
CommonAttentionMetadata,
|
||||||
|
split_decodes_and_prefills)
|
||||||
|
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
|
||||||
|
return LinearAttentionMetadataBuilder
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LinearAttentionMetadata:
|
||||||
|
num_prefills: int
|
||||||
|
num_prefill_tokens: int
|
||||||
|
num_decodes: int
|
||||||
|
num_decode_tokens: int
|
||||||
|
query_start_loc: torch.Tensor
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
|
||||||
|
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttentionMetadataBuilder(
|
||||||
|
AttentionMetadataBuilder[LinearAttentionMetadata]):
|
||||||
|
|
||||||
|
reorder_batch_threshold: ClassVar[int] = 1
|
||||||
|
|
||||||
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
|
assert isinstance(kv_cache_spec, MambaSpec)
|
||||||
|
self.kv_cache_spec = kv_cache_spec
|
||||||
|
|
||||||
|
def build(self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False) -> LinearAttentionMetadata:
|
||||||
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
|
|
||||||
|
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||||
|
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
split_decodes_and_prefills(common_attn_metadata,
|
||||||
|
decode_threshold=1))
|
||||||
|
|
||||||
|
attn_metadata = LinearAttentionMetadata(
|
||||||
|
num_prefills=num_prefills,
|
||||||
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
|
num_decodes=num_decodes,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
query_start_loc=query_start_loc,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
state_indices_tensor=state_indices_tensor,
|
||||||
|
)
|
||||||
|
return attn_metadata
|
||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
|
||||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
||||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
||||||
|
|
||||||
@ -8,9 +9,10 @@ from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
|||||||
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
|
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
|
||||||
if mamba_type == "mamba1":
|
if mamba_type == "mamba1":
|
||||||
return Mamba1AttentionBackend
|
return Mamba1AttentionBackend
|
||||||
|
|
||||||
if mamba_type == "mamba2":
|
if mamba_type == "mamba2":
|
||||||
return Mamba2AttentionBackend
|
return Mamba2AttentionBackend
|
||||||
|
if mamba_type == "linear_attention":
|
||||||
|
return LinearAttentionBackend
|
||||||
|
|
||||||
raise NotImplementedError(f"Mamba Attention type {mamba_type} is not "
|
raise NotImplementedError(f"Mamba Attention type {mamba_type} is not "
|
||||||
"supported yet.")
|
"supported yet.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user