diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py new file mode 100644 index 0000000000000..d93cef1a27ad4 --- /dev/null +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -0,0 +1,442 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +from typing import TYPE_CHECKING + +import torch +import torch.distributed +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +from vllm import envs +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.lightning_attn import ( + lightning_attention, linear_decode_forward_triton) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +import torch +import torch.distributed + +from vllm.model_executor.models.minimax_cache import MinimaxCacheParams + + +class MiniMaxText01RMSNormTP(CustomOp): + name = "MiniMaxText01RMSNormTP" + + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.tp_world = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.weight = nn.Parameter(torch.ones(int(hidden_size / + self.tp_world))) + + self.weight.weight_loader = self.weight_loader + self.variance_epsilon = eps + return + + @staticmethod + def weight_loader( + param: nn.Parameter, + loaded_weight: torch.Tensor, + ) -> None: + tp_world = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + shard_size = loaded_weight.shape[0] // tp_world + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + param.data.copy_(loaded_weight[shard]) + return + + def _forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + orig_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32) + if self.tp_world > 1: + variance = tensor_model_parallel_all_reduce( + variance) / self.tp_world + x = x * torch.rsqrt(variance + self.variance_epsilon) + + weight = self.weight + if x.size(-1) != self.weight.size(0): + if self.weight.size(0) < x.size(-1): + repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1) + full_weight = self.weight.repeat(repeat_count) + weight = full_weight[:x.size(-1)] + else: + weight = self.weight[:x.size(-1)] + + x = x.to(orig_dtype) * weight + return x + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + assert residual is None, "RMSNorm does not support residual connection." + return self._forward(x) + + +class MiniMaxText01LinearKernel: + + @staticmethod + def jit_linear_forward_prefix(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_caches: torch.Tensor, + slope_rate: torch.Tensor, + block_size: int, + layer_idx: Optional[int] = None, + **kwargs) -> torch.Tensor: + + slope_rate = slope_rate.to(torch.float32) + should_pad_dim = q.dim() == 3 + if should_pad_dim: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + b, h, n, d = q.shape + e = d + kv_history = kv_caches.reshape(1, h, d, e).contiguous() + output, kv_history = lightning_attention(q, + k, + v, + slope_rate, + block_size=block_size, + kv_history=kv_history) + kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e)) + assert output.shape[0] == 1, "batch size must be 1" + return rearrange(output.squeeze(0), "h n d -> n (h d)") + + +class MiniMaxText01LinearAttention(nn.Module, MambaBase): + + @property + def mamba_type(self) -> str: + return "linear_attention" + + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.linear_attn import ( + LinearAttentionBackend) + return LinearAttentionBackend + + def get_state_dtype(self) -> tuple[torch.dtype]: + assert self.model_config is not None + assert self.cache_config is not None + return MambaStateDtypeCalculator.linear_attention_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + ) + + def get_state_shape(self) -> tuple[tuple[int, int, int], ...]: + return MambaStateShapeCalculator.linear_attention_state_shape( + num_heads=self.num_heads, + tp_size=self.tp_size, + head_dim=self.head_dim) + + def __init__( + self, + hidden_size: int, + hidden_inner_size: int, + num_heads: int, + head_dim: int, + max_position: int, + block_size: int, + num_hidden_layer: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_idx: int = 0, + linear_layer_idx: int = 0, + prefix: str = "linear_attn", + ) -> None: + super().__init__() + + self.layer_idx = layer_idx + self.BLOCK = block_size + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = head_dim + self.total_num_heads = num_heads + self.hidden_inner_size = hidden_inner_size + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + assert self.total_num_heads % self.tp_size == 0 + self.tp_heads = self.total_num_heads // self.tp_size + self.qkv_size = self.num_heads * self.head_dim + self.tp_hidden = self.head_dim * self.tp_heads + self.model_config = model_config + self.cache_config = cache_config + self.prefix = prefix + + self.qkv_proj = ColumnParallelLinear( + hidden_size, + self.hidden_inner_size * 3, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.output_gate = ColumnParallelLinear( + hidden_size, + self.hidden_inner_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.output_gate", + ) + self.out_proj = RowParallelLinear( + self.hidden_inner_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.norm = MiniMaxText01RMSNormTP( + self.hidden_inner_size, + eps=1e-5, + ) + + slope_rate = MiniMaxText01LinearAttention._build_slope_tensor( + self.num_heads) + if num_hidden_layer <= 1: + self.slope_rate = slope_rate * (1 + 1e-5) + else: + self.slope_rate = slope_rate * (1 - layer_idx / + (num_hidden_layer - 1) + 1e-5) + self.tp_slope = self.slope_rate[self.tp_rank * + self.tp_heads:(self.tp_rank + 1) * + self.tp_heads].contiguous() + + if envs.VLLM_USE_V1: + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + @staticmethod + def weight_direct_load(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + return + + @staticmethod + def _build_slope_tensor(n_attention_heads: int): + + def get_slopes(n): + + def get_slopes_power_of_2(n): + start = 2**(-(2**-(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + get_slopes( + 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + + slopes = torch.tensor(get_slopes(n_attention_heads), + dtype=torch.float32).reshape( + n_attention_heads, 1, 1) + return slopes + + def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, + attn_metadata): + hidden = [] + for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): + if _prefill_idx >= len(attn_metadata.query_start_loc): + break + if _prefill_idx >= len(state_indices_tensor): + break + # prefills are packed at end of batch in V1 + offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0 + _start = attn_metadata.query_start_loc[offset + _prefill_idx] + _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] + slot_id = state_indices_tensor[offset + _prefill_idx] + qs = q[_start:_end].transpose(0, 1).contiguous() + ks = k[_start:_end].transpose(0, 1).contiguous() + vs = v[_start:_end].transpose(0, 1).contiguous() + slice_layer_cache = kv_cache[slot_id, ...] + + out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( + qs, + ks, + vs, + slice_layer_cache, + self.tp_slope, + self.BLOCK, + layer_idx=self.layer_idx) + hidden.append(out_slice.contiguous()) + if attn_metadata.num_decode_tokens > 0: + hidden_decode = self._decode_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) + if envs.VLLM_USE_V1: + hidden.insert(0, hidden_decode) + else: + hidden.append(hidden_decode) + + if not hidden: + return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) + + hidden = torch.concat(hidden, dim=0).contiguous() + return hidden + + def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, + attn_metadata): + if not envs.VLLM_USE_V1: + q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() + num_prefills = getattr(attn_metadata, "num_prefills", 0) + slot_id = state_indices_tensor[num_prefills:] + else: + q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + slot_id = state_indices_tensor[:attn_metadata.num_decodes] + hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, + slot_id, 32) + return hidden + + def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, + kv_caches: MinimaxCacheParams) -> None: + if not envs.VLLM_USE_V1: + self._forward(hidden_states, output, positions, kv_caches) + else: + torch.ops.vllm.linear_attention( + hidden_states, + output, + positions, + self.prefix, + ) + + def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[MinimaxCacheParams]) -> None: + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if envs.VLLM_USE_V1 and attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, LinearAttentionMetadata) + num_actual_tokens = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens + else: + num_actual_tokens = hidden_states.shape[0] + + qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens]) + qkv32 = qkv.to(torch.float32) + qkvact = torch.nn.functional.silu(qkv32) + qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) + q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) + if envs.VLLM_USE_V1: + if attn_metadata is not None: + kv_cache = self.kv_cache[forward_context.virtual_engine][0] + state_indices_tensor = attn_metadata.state_indices_tensor + + num_prefills = getattr(attn_metadata, "num_prefills", 0) + if num_prefills > 0: + num_decode_tokens = getattr(attn_metadata, + "num_decode_tokens", 0) + for prefill_idx in range(num_prefills): + q_start = attn_metadata.query_start_loc[ + num_decode_tokens + prefill_idx] + q_end = attn_metadata.query_start_loc[num_decode_tokens + + prefill_idx + + 1] + query_len = q_end - q_start + context_len = attn_metadata.seq_lens[ + num_decode_tokens + prefill_idx] - query_len + if context_len == 0: + block_to_clear = state_indices_tensor[ + num_decode_tokens + prefill_idx] + kv_cache[block_to_clear, ...] = 0 + else: + assert kv_caches is not None + kv_cache = kv_caches.minimax_cache + state_indices_tensor = kv_caches.state_indices_tensor + + decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 + if attn_metadata is None: + hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]), + device=q.device, + dtype=q.dtype) + else: + if not decode_only: + hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) + else: + hidden = self._decode_infer(q, k, v, kv_cache, + state_indices_tensor, + attn_metadata) + hidden = self.norm._forward(hidden) + gate, _ = self.output_gate(hidden_states[:num_actual_tokens]) + hidden = F.sigmoid(gate) * hidden + hidden = hidden.to(hidden_states.dtype) + + output[:num_actual_tokens], _ = self.out_proj(hidden) + + +def linear_attention( + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self._forward(hidden_states=hidden_states, + output=output, + positions=positions, + kv_caches=None) + + +def linear_attention_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="linear_attention", + op_func=linear_attention, + mutates_args=["output"], + fake_impl=linear_attention_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 93ef13d5d16a0..ef1fe86c5b5c0 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -1,45 +1,37 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only MiniMaxText01 model.""" -import math from collections.abc import Iterable from itertools import islice from typing import TYPE_CHECKING, Optional, Union if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend + pass import regex as re import torch import torch.distributed -import torch.nn.functional as F -from einops import rearrange from torch import nn from transformers import MiniMaxConfig from vllm import envs from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CacheConfig, ModelConfig, VllmConfig, - get_current_vllm_config) -from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.forward_context import ForwardContext, get_forward_context -from vllm.model_executor.custom_op import CustomOp +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.lightning_attn import ( - lightning_attention, linear_decode_forward_triton) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.linear_attn import ( + MiniMaxText01LinearAttention) from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization.base_config import ( @@ -50,10 +42,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import direct_register_custom_op -from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata from .interfaces import HasInnerState, IsHybrid from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams @@ -87,66 +76,6 @@ def weight_loader_with_alias(alias: str): return wrapper -class MiniMaxText01RMSNormTP(CustomOp): - name = "MiniMaxText01RMSNormTP" - - def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: - super().__init__() - self.tp_world = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.weight = nn.Parameter(torch.ones(int(hidden_size / - self.tp_world))) - - self.weight.weight_loader = self.weight_loader - self.variance_epsilon = eps - return - - @staticmethod - def weight_loader( - param: nn.Parameter, - loaded_weight: torch.Tensor, - ) -> None: - tp_world = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - - shard_size = loaded_weight.shape[0] // tp_world - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - param.data.copy_(loaded_weight[shard]) - return - - def _forward( - self, - x: torch.Tensor, - ) -> torch.Tensor: - orig_dtype = x.dtype - x = x.to(torch.float32) - variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32) - if self.tp_world > 1: - variance = tensor_model_parallel_all_reduce( - variance) / self.tp_world - x = x * torch.rsqrt(variance + self.variance_epsilon) - - weight = self.weight - if x.size(-1) != self.weight.size(0): - if self.weight.size(0) < x.size(-1): - repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1) - full_weight = self.weight.repeat(repeat_count) - weight = full_weight[:x.size(-1)] - else: - weight = self.weight[:x.size(-1)] - - x = x.to(orig_dtype) * weight - return x - - def forward( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - assert residual is None, "RMSNorm does not support residual connection." - return self._forward(x) - - class MiniMaxText01MLP(nn.Module): def __init__( @@ -253,307 +182,6 @@ class MiniMaxText01MoE(nn.Module): return final_hidden -class MiniMaxText01LinearKernel: - - @staticmethod - def jit_linear_forward_prefix(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - kv_caches: torch.Tensor, - slope_rate: torch.Tensor, - block_size: int, - layer_idx: int = None, - **kwargs) -> torch.Tensor: - - slope_rate = slope_rate.to(torch.float32) - should_pad_dim = q.dim() == 3 - if should_pad_dim: - q = q.unsqueeze(0) - k = k.unsqueeze(0) - v = v.unsqueeze(0) - b, h, n, d = q.shape - e = d - kv_history = kv_caches.reshape(1, h, d, e).contiguous() - output, kv_history = lightning_attention(q, - k, - v, - slope_rate, - block_size=block_size, - kv_history=kv_history) - kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e)) - assert output.shape[0] == 1, "batch size must be 1" - return rearrange(output.squeeze(0), "h n d -> n (h d)") - - -class MiniMaxText01LinearAttention(nn.Module, MambaBase): - - @property - def mamba_type(self) -> str: - return "linear_attention" - - def get_attn_backend(self) -> type["AttentionBackend"]: - from vllm.v1.attention.backends.linear_attn import ( - LinearAttentionBackend) - return LinearAttentionBackend - - def get_state_dtype(self) -> tuple[torch.dtype]: - return MambaStateDtypeCalculator.linear_attention_state_dtype( - self.model_config.dtype, - self.cache_config.mamba_cache_dtype, - ) - - def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: - return MambaStateShapeCalculator.linear_attention_state_shape( - num_heads=self.num_heads, - tp_size=self.tp_size, - head_dim=self.head_dim) - - def __init__( - self, - hidden_size: int, - hidden_inner_size: int, - num_heads: int, - head_dim: int, - max_position: int, - block_size: int, - num_hidden_layer: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - layer_idx: int = 0, - linear_layer_idx: int = 0, - prefix: str = "linear_attn", - ) -> None: - super().__init__() - - self.layer_idx = layer_idx - self.BLOCK = block_size - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = head_dim - self.total_num_heads = num_heads - self.hidden_inner_size = hidden_inner_size - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - assert self.total_num_heads % self.tp_size == 0 - self.tp_heads = self.total_num_heads // self.tp_size - self.qkv_size = self.num_heads * self.head_dim - self.tp_hidden = self.head_dim * self.tp_heads - self.model_config = model_config - self.cache_config = cache_config - self.prefix = prefix - - self.qkv_proj = ColumnParallelLinear( - hidden_size, - self.hidden_inner_size * 3, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.output_gate = ColumnParallelLinear( - hidden_size, - self.hidden_inner_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.output_gate", - ) - self.out_proj = RowParallelLinear( - self.hidden_inner_size, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - self.norm = MiniMaxText01RMSNormTP( - self.hidden_inner_size, - eps=1e-5, - ) - - slope_rate = MiniMaxText01LinearAttention._build_slope_tensor( - self.num_heads) - if num_hidden_layer <= 1: - self.slope_rate = slope_rate * (1 + 1e-5) - else: - self.slope_rate = slope_rate * (1 - layer_idx / - (num_hidden_layer - 1) + 1e-5) - self.tp_slope = self.slope_rate[self.tp_rank * - self.tp_heads:(self.tp_rank + 1) * - self.tp_heads].contiguous() - - if envs.VLLM_USE_V1: - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - - @staticmethod - def weight_direct_load(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: - assert param.size() == loaded_weight.size() - param.data.copy_(loaded_weight) - return - - @staticmethod - def _build_slope_tensor(n_attention_heads: int): - - def get_slopes(n): - - def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return (get_slopes_power_of_2(closest_power_of_2) + get_slopes( - 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) - - slopes = torch.tensor(get_slopes(n_attention_heads), - dtype=torch.float32).reshape( - n_attention_heads, 1, 1) - return slopes - - def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, - attn_metadata): - hidden = [] - for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): - if _prefill_idx >= len(attn_metadata.query_start_loc): - break - if _prefill_idx >= len(state_indices_tensor): - break - # prefills are packed at end of batch in V1 - offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0 - _start = attn_metadata.query_start_loc[offset + _prefill_idx] - _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] - slot_id = state_indices_tensor[offset + _prefill_idx] - qs = q[_start:_end].transpose(0, 1).contiguous() - ks = k[_start:_end].transpose(0, 1).contiguous() - vs = v[_start:_end].transpose(0, 1).contiguous() - slice_layer_cache = kv_cache[slot_id, ...] - - out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( - qs, - ks, - vs, - slice_layer_cache, - self.tp_slope, - self.BLOCK, - layer_idx=self.layer_idx) - hidden.append(out_slice.contiguous()) - if attn_metadata.num_decode_tokens > 0: - hidden_decode = self._decode_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) - if envs.VLLM_USE_V1: - hidden.insert(0, hidden_decode) - else: - hidden.append(hidden_decode) - - if not hidden: - return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) - - hidden = torch.concat(hidden, dim=0).contiguous() - return hidden - - def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, - attn_metadata): - if not envs.VLLM_USE_V1: - q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous() - num_prefills = getattr(attn_metadata, "num_prefills", 0) - slot_id = state_indices_tensor[num_prefills:] - else: - q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - slot_id = state_indices_tensor[:attn_metadata.num_decodes] - hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, - slot_id, 32) - return hidden - - def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, - positions: torch.Tensor, - kv_caches: MinimaxCacheParams) -> None: - if not envs.VLLM_USE_V1: - self._forward(hidden_states, output, positions, kv_caches) - else: - torch.ops.vllm.linear_attention( - hidden_states, - output, - positions, - self.prefix, - ) - - def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[MinimaxCacheParams]) -> None: - forward_context = get_forward_context() - attn_metadata: AttentionMetadata = forward_context.attn_metadata - if envs.VLLM_USE_V1 and attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, LinearAttentionMetadata) - num_actual_tokens = attn_metadata.num_prefill_tokens + \ - attn_metadata.num_decode_tokens - else: - num_actual_tokens = hidden_states.shape[0] - - qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens]) - qkv32 = qkv.to(torch.float32) - qkvact = torch.nn.functional.silu(qkv32) - qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) - q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) - if envs.VLLM_USE_V1: - if attn_metadata is not None: - kv_cache = self.kv_cache[forward_context.virtual_engine][0] - state_indices_tensor = attn_metadata.state_indices_tensor - - num_prefills = getattr(attn_metadata, "num_prefills", 0) - if num_prefills > 0: - num_decode_tokens = getattr(attn_metadata, - "num_decode_tokens", 0) - for prefill_idx in range(num_prefills): - q_start = attn_metadata.query_start_loc[ - num_decode_tokens + prefill_idx] - q_end = attn_metadata.query_start_loc[num_decode_tokens - + prefill_idx + - 1] - query_len = q_end - q_start - context_len = attn_metadata.seq_lens[ - num_decode_tokens + prefill_idx] - query_len - if context_len == 0: - block_to_clear = state_indices_tensor[ - num_decode_tokens + prefill_idx] - kv_cache[block_to_clear, ...] = 0 - else: - kv_cache = kv_caches.minimax_cache - state_indices_tensor = kv_caches.state_indices_tensor - - decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 - if attn_metadata is None: - hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]), - device=q.device, - dtype=q.dtype) - else: - if not decode_only: - hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) - else: - hidden = self._decode_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) - hidden = self.norm._forward(hidden) - gate, _ = self.output_gate(hidden_states[:num_actual_tokens]) - hidden = F.sigmoid(gate) * hidden - hidden = hidden.to(hidden_states.dtype) - output[:num_actual_tokens], _ = self.out_proj(hidden) - - class MiniMaxText01Attention(nn.Module): def __init__( @@ -1397,35 +1025,3 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): tp_size=parallel_config.tensor_parallel_size, head_dim=hf_config.head_dim, ) - - -def linear_attention( - hidden_states: torch.Tensor, - output: torch.Tensor, - positions: torch.Tensor, - layer_name: str, -) -> None: - forward_context: ForwardContext = get_forward_context() - self = forward_context.no_compile_layers[layer_name] - self._forward(hidden_states=hidden_states, - output=output, - positions=positions, - kv_caches=None) - - -def linear_attention_fake( - hidden_states: torch.Tensor, - output: torch.Tensor, - positions: torch.Tensor, - layer_name: str, -) -> None: - return - - -direct_register_custom_op( - op_name="linear_attention", - op_func=linear_attention, - mutates_args=["output"], - fake_impl=linear_attention_fake, - dispatch_key=current_platform.dispatch_key, -)