mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 07:25:01 +08:00
[V1] [Hybrid] Move MiniMaxLinearAttention into layers/mamba (#23831)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
f1bddbd852
commit
4071c76cf3
442
vllm/model_executor/layers/mamba/linear_attn.py
Normal file
442
vllm/model_executor/layers/mamba/linear_attn.py
Normal file
@ -0,0 +1,442 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||||
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||||
|
from vllm.distributed.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
from vllm.model_executor.layers.lightning_attn import (
|
||||||
|
lightning_attention, linear_decode_forward_triton)
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||||
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from vllm.model_executor.models.minimax_cache import MinimaxCacheParams
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxText01RMSNormTP(CustomOp):
|
||||||
|
name = "MiniMaxText01RMSNormTP"
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.tp_world = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
self.weight = nn.Parameter(torch.ones(int(hidden_size /
|
||||||
|
self.tp_world)))
|
||||||
|
|
||||||
|
self.weight.weight_loader = self.weight_loader
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def weight_loader(
|
||||||
|
param: nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
tp_world = get_tensor_model_parallel_world_size()
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
|
shard_size = loaded_weight.shape[0] // tp_world
|
||||||
|
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||||
|
param.data.copy_(loaded_weight[shard])
|
||||||
|
return
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
|
||||||
|
if self.tp_world > 1:
|
||||||
|
variance = tensor_model_parallel_all_reduce(
|
||||||
|
variance) / self.tp_world
|
||||||
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
|
||||||
|
weight = self.weight
|
||||||
|
if x.size(-1) != self.weight.size(0):
|
||||||
|
if self.weight.size(0) < x.size(-1):
|
||||||
|
repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
|
||||||
|
full_weight = self.weight.repeat(repeat_count)
|
||||||
|
weight = full_weight[:x.size(-1)]
|
||||||
|
else:
|
||||||
|
weight = self.weight[:x.size(-1)]
|
||||||
|
|
||||||
|
x = x.to(orig_dtype) * weight
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
assert residual is None, "RMSNorm does not support residual connection."
|
||||||
|
return self._forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxText01LinearKernel:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def jit_linear_forward_prefix(q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
kv_caches: torch.Tensor,
|
||||||
|
slope_rate: torch.Tensor,
|
||||||
|
block_size: int,
|
||||||
|
layer_idx: Optional[int] = None,
|
||||||
|
**kwargs) -> torch.Tensor:
|
||||||
|
|
||||||
|
slope_rate = slope_rate.to(torch.float32)
|
||||||
|
should_pad_dim = q.dim() == 3
|
||||||
|
if should_pad_dim:
|
||||||
|
q = q.unsqueeze(0)
|
||||||
|
k = k.unsqueeze(0)
|
||||||
|
v = v.unsqueeze(0)
|
||||||
|
b, h, n, d = q.shape
|
||||||
|
e = d
|
||||||
|
kv_history = kv_caches.reshape(1, h, d, e).contiguous()
|
||||||
|
output, kv_history = lightning_attention(q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
slope_rate,
|
||||||
|
block_size=block_size,
|
||||||
|
kv_history=kv_history)
|
||||||
|
kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
|
||||||
|
assert output.shape[0] == 1, "batch size must be 1"
|
||||||
|
return rearrange(output.squeeze(0), "h n d -> n (h d)")
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mamba_type(self) -> str:
|
||||||
|
return "linear_attention"
|
||||||
|
|
||||||
|
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||||
|
from vllm.v1.attention.backends.linear_attn import (
|
||||||
|
LinearAttentionBackend)
|
||||||
|
return LinearAttentionBackend
|
||||||
|
|
||||||
|
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||||
|
assert self.model_config is not None
|
||||||
|
assert self.cache_config is not None
|
||||||
|
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
||||||
|
self.model_config.dtype,
|
||||||
|
self.cache_config.mamba_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_state_shape(self) -> tuple[tuple[int, int, int], ...]:
|
||||||
|
return MambaStateShapeCalculator.linear_attention_state_shape(
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
tp_size=self.tp_size,
|
||||||
|
head_dim=self.head_dim)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
hidden_inner_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
max_position: int,
|
||||||
|
block_size: int,
|
||||||
|
num_hidden_layer: int,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
layer_idx: int = 0,
|
||||||
|
linear_layer_idx: int = 0,
|
||||||
|
prefix: str = "linear_attn",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.BLOCK = block_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
self.hidden_inner_size = hidden_inner_size
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
|
assert self.total_num_heads % self.tp_size == 0
|
||||||
|
self.tp_heads = self.total_num_heads // self.tp_size
|
||||||
|
self.qkv_size = self.num_heads * self.head_dim
|
||||||
|
self.tp_hidden = self.head_dim * self.tp_heads
|
||||||
|
self.model_config = model_config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.prefix = prefix
|
||||||
|
|
||||||
|
self.qkv_proj = ColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
self.hidden_inner_size * 3,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
)
|
||||||
|
self.output_gate = ColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
self.hidden_inner_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.output_gate",
|
||||||
|
)
|
||||||
|
self.out_proj = RowParallelLinear(
|
||||||
|
self.hidden_inner_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.out_proj",
|
||||||
|
)
|
||||||
|
self.norm = MiniMaxText01RMSNormTP(
|
||||||
|
self.hidden_inner_size,
|
||||||
|
eps=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
|
||||||
|
self.num_heads)
|
||||||
|
if num_hidden_layer <= 1:
|
||||||
|
self.slope_rate = slope_rate * (1 + 1e-5)
|
||||||
|
else:
|
||||||
|
self.slope_rate = slope_rate * (1 - layer_idx /
|
||||||
|
(num_hidden_layer - 1) + 1e-5)
|
||||||
|
self.tp_slope = self.slope_rate[self.tp_rank *
|
||||||
|
self.tp_heads:(self.tp_rank + 1) *
|
||||||
|
self.tp_heads].contiguous()
|
||||||
|
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
if prefix in compilation_config.static_forward_context:
|
||||||
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
|
compilation_config.static_forward_context[prefix] = self
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def weight_direct_load(param: torch.Tensor,
|
||||||
|
loaded_weight: torch.Tensor) -> None:
|
||||||
|
assert param.size() == loaded_weight.size()
|
||||||
|
param.data.copy_(loaded_weight)
|
||||||
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_slope_tensor(n_attention_heads: int):
|
||||||
|
|
||||||
|
def get_slopes(n):
|
||||||
|
|
||||||
|
def get_slopes_power_of_2(n):
|
||||||
|
start = 2**(-(2**-(math.log2(n) - 3)))
|
||||||
|
ratio = start
|
||||||
|
return [start * ratio**i for i in range(n)]
|
||||||
|
|
||||||
|
if math.log2(n).is_integer():
|
||||||
|
return get_slopes_power_of_2(n)
|
||||||
|
else:
|
||||||
|
closest_power_of_2 = 2**math.floor(math.log2(n))
|
||||||
|
return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
|
||||||
|
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
|
||||||
|
|
||||||
|
slopes = torch.tensor(get_slopes(n_attention_heads),
|
||||||
|
dtype=torch.float32).reshape(
|
||||||
|
n_attention_heads, 1, 1)
|
||||||
|
return slopes
|
||||||
|
|
||||||
|
def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
|
||||||
|
attn_metadata):
|
||||||
|
hidden = []
|
||||||
|
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
|
||||||
|
if _prefill_idx >= len(attn_metadata.query_start_loc):
|
||||||
|
break
|
||||||
|
if _prefill_idx >= len(state_indices_tensor):
|
||||||
|
break
|
||||||
|
# prefills are packed at end of batch in V1
|
||||||
|
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
|
||||||
|
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
||||||
|
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
||||||
|
slot_id = state_indices_tensor[offset + _prefill_idx]
|
||||||
|
qs = q[_start:_end].transpose(0, 1).contiguous()
|
||||||
|
ks = k[_start:_end].transpose(0, 1).contiguous()
|
||||||
|
vs = v[_start:_end].transpose(0, 1).contiguous()
|
||||||
|
slice_layer_cache = kv_cache[slot_id, ...]
|
||||||
|
|
||||||
|
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
|
||||||
|
qs,
|
||||||
|
ks,
|
||||||
|
vs,
|
||||||
|
slice_layer_cache,
|
||||||
|
self.tp_slope,
|
||||||
|
self.BLOCK,
|
||||||
|
layer_idx=self.layer_idx)
|
||||||
|
hidden.append(out_slice.contiguous())
|
||||||
|
if attn_metadata.num_decode_tokens > 0:
|
||||||
|
hidden_decode = self._decode_infer(q, k, v, kv_cache,
|
||||||
|
state_indices_tensor,
|
||||||
|
attn_metadata)
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
hidden.insert(0, hidden_decode)
|
||||||
|
else:
|
||||||
|
hidden.append(hidden_decode)
|
||||||
|
|
||||||
|
if not hidden:
|
||||||
|
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
|
hidden = torch.concat(hidden, dim=0).contiguous()
|
||||||
|
return hidden
|
||||||
|
|
||||||
|
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
|
||||||
|
attn_metadata):
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||||
|
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||||
|
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
||||||
|
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||||
|
slot_id = state_indices_tensor[num_prefills:]
|
||||||
|
else:
|
||||||
|
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||||
|
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||||
|
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||||
|
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
|
||||||
|
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
|
||||||
|
slot_id, 32)
|
||||||
|
return hidden
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: MinimaxCacheParams) -> None:
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
self._forward(hidden_states, output, positions, kv_caches)
|
||||||
|
else:
|
||||||
|
torch.ops.vllm.linear_attention(
|
||||||
|
hidden_states,
|
||||||
|
output,
|
||||||
|
positions,
|
||||||
|
self.prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: Optional[MinimaxCacheParams]) -> None:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||||
|
if envs.VLLM_USE_V1 and attn_metadata is not None:
|
||||||
|
assert isinstance(attn_metadata, dict)
|
||||||
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
|
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
||||||
|
num_actual_tokens = attn_metadata.num_prefill_tokens + \
|
||||||
|
attn_metadata.num_decode_tokens
|
||||||
|
else:
|
||||||
|
num_actual_tokens = hidden_states.shape[0]
|
||||||
|
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
|
||||||
|
qkv32 = qkv.to(torch.float32)
|
||||||
|
qkvact = torch.nn.functional.silu(qkv32)
|
||||||
|
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
|
||||||
|
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
if attn_metadata is not None:
|
||||||
|
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||||
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
|
|
||||||
|
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||||
|
if num_prefills > 0:
|
||||||
|
num_decode_tokens = getattr(attn_metadata,
|
||||||
|
"num_decode_tokens", 0)
|
||||||
|
for prefill_idx in range(num_prefills):
|
||||||
|
q_start = attn_metadata.query_start_loc[
|
||||||
|
num_decode_tokens + prefill_idx]
|
||||||
|
q_end = attn_metadata.query_start_loc[num_decode_tokens
|
||||||
|
+ prefill_idx +
|
||||||
|
1]
|
||||||
|
query_len = q_end - q_start
|
||||||
|
context_len = attn_metadata.seq_lens[
|
||||||
|
num_decode_tokens + prefill_idx] - query_len
|
||||||
|
if context_len == 0:
|
||||||
|
block_to_clear = state_indices_tensor[
|
||||||
|
num_decode_tokens + prefill_idx]
|
||||||
|
kv_cache[block_to_clear, ...] = 0
|
||||||
|
else:
|
||||||
|
assert kv_caches is not None
|
||||||
|
kv_cache = kv_caches.minimax_cache
|
||||||
|
state_indices_tensor = kv_caches.state_indices_tensor
|
||||||
|
|
||||||
|
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
||||||
|
if attn_metadata is None:
|
||||||
|
hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
|
||||||
|
device=q.device,
|
||||||
|
dtype=q.dtype)
|
||||||
|
else:
|
||||||
|
if not decode_only:
|
||||||
|
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
|
||||||
|
state_indices_tensor,
|
||||||
|
attn_metadata)
|
||||||
|
else:
|
||||||
|
hidden = self._decode_infer(q, k, v, kv_cache,
|
||||||
|
state_indices_tensor,
|
||||||
|
attn_metadata)
|
||||||
|
hidden = self.norm._forward(hidden)
|
||||||
|
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
|
||||||
|
hidden = F.sigmoid(gate) * hidden
|
||||||
|
hidden = hidden.to(hidden_states.dtype)
|
||||||
|
|
||||||
|
output[:num_actual_tokens], _ = self.out_proj(hidden)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_attention(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
|
self._forward(hidden_states=hidden_states,
|
||||||
|
output=output,
|
||||||
|
positions=positions,
|
||||||
|
kv_caches=None)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_attention_fake(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="linear_attention",
|
||||||
|
op_func=linear_attention,
|
||||||
|
mutates_args=["output"],
|
||||||
|
fake_impl=linear_attention_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
@ -1,45 +1,37 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Inference-only MiniMaxText01 model."""
|
"""Inference-only MiniMaxText01 model."""
|
||||||
import math
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
pass
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.nn.functional as F
|
|
||||||
from einops import rearrange
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import MiniMaxConfig
|
from transformers import MiniMaxConfig
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||||
get_current_vllm_config)
|
|
||||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_pp_group, get_tensor_model_parallel_rank,
|
get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.lightning_attn import (
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
lightning_attention, linear_decode_forward_triton)
|
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
||||||
MergedColumnParallelLinear,
|
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
from vllm.model_executor.layers.mamba.linear_attn import (
|
||||||
|
MiniMaxText01LinearAttention)
|
||||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -50,10 +42,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.utils import maybe_prefix
|
from vllm.model_executor.models.utils import maybe_prefix
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import direct_register_custom_op
|
|
||||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
|
||||||
|
|
||||||
from .interfaces import HasInnerState, IsHybrid
|
from .interfaces import HasInnerState, IsHybrid
|
||||||
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
|
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
|
||||||
@ -87,66 +76,6 @@ def weight_loader_with_alias(alias: str):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxText01RMSNormTP(CustomOp):
|
|
||||||
name = "MiniMaxText01RMSNormTP"
|
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.tp_world = get_tensor_model_parallel_world_size()
|
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
self.weight = nn.Parameter(torch.ones(int(hidden_size /
|
|
||||||
self.tp_world)))
|
|
||||||
|
|
||||||
self.weight.weight_loader = self.weight_loader
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
return
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def weight_loader(
|
|
||||||
param: nn.Parameter,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
tp_world = get_tensor_model_parallel_world_size()
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
shard_size = loaded_weight.shape[0] // tp_world
|
|
||||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
|
||||||
param.data.copy_(loaded_weight[shard])
|
|
||||||
return
|
|
||||||
|
|
||||||
def _forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
orig_dtype = x.dtype
|
|
||||||
x = x.to(torch.float32)
|
|
||||||
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
|
|
||||||
if self.tp_world > 1:
|
|
||||||
variance = tensor_model_parallel_all_reduce(
|
|
||||||
variance) / self.tp_world
|
|
||||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
|
|
||||||
weight = self.weight
|
|
||||||
if x.size(-1) != self.weight.size(0):
|
|
||||||
if self.weight.size(0) < x.size(-1):
|
|
||||||
repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
|
|
||||||
full_weight = self.weight.repeat(repeat_count)
|
|
||||||
weight = full_weight[:x.size(-1)]
|
|
||||||
else:
|
|
||||||
weight = self.weight[:x.size(-1)]
|
|
||||||
|
|
||||||
x = x.to(orig_dtype) * weight
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
residual: Optional[torch.Tensor] = None,
|
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
||||||
assert residual is None, "RMSNorm does not support residual connection."
|
|
||||||
return self._forward(x)
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxText01MLP(nn.Module):
|
class MiniMaxText01MLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -253,307 +182,6 @@ class MiniMaxText01MoE(nn.Module):
|
|||||||
return final_hidden
|
return final_hidden
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxText01LinearKernel:
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def jit_linear_forward_prefix(q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
kv_caches: torch.Tensor,
|
|
||||||
slope_rate: torch.Tensor,
|
|
||||||
block_size: int,
|
|
||||||
layer_idx: int = None,
|
|
||||||
**kwargs) -> torch.Tensor:
|
|
||||||
|
|
||||||
slope_rate = slope_rate.to(torch.float32)
|
|
||||||
should_pad_dim = q.dim() == 3
|
|
||||||
if should_pad_dim:
|
|
||||||
q = q.unsqueeze(0)
|
|
||||||
k = k.unsqueeze(0)
|
|
||||||
v = v.unsqueeze(0)
|
|
||||||
b, h, n, d = q.shape
|
|
||||||
e = d
|
|
||||||
kv_history = kv_caches.reshape(1, h, d, e).contiguous()
|
|
||||||
output, kv_history = lightning_attention(q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
slope_rate,
|
|
||||||
block_size=block_size,
|
|
||||||
kv_history=kv_history)
|
|
||||||
kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
|
|
||||||
assert output.shape[0] == 1, "batch size must be 1"
|
|
||||||
return rearrange(output.squeeze(0), "h n d -> n (h d)")
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def mamba_type(self) -> str:
|
|
||||||
return "linear_attention"
|
|
||||||
|
|
||||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
|
||||||
from vllm.v1.attention.backends.linear_attn import (
|
|
||||||
LinearAttentionBackend)
|
|
||||||
return LinearAttentionBackend
|
|
||||||
|
|
||||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
|
||||||
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
|
||||||
self.model_config.dtype,
|
|
||||||
self.cache_config.mamba_cache_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
|
||||||
return MambaStateShapeCalculator.linear_attention_state_shape(
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
tp_size=self.tp_size,
|
|
||||||
head_dim=self.head_dim)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
hidden_inner_size: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
max_position: int,
|
|
||||||
block_size: int,
|
|
||||||
num_hidden_layer: int,
|
|
||||||
model_config: Optional[ModelConfig] = None,
|
|
||||||
cache_config: Optional[CacheConfig] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
layer_idx: int = 0,
|
|
||||||
linear_layer_idx: int = 0,
|
|
||||||
prefix: str = "linear_attn",
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.layer_idx = layer_idx
|
|
||||||
self.BLOCK = block_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.total_num_heads = num_heads
|
|
||||||
self.hidden_inner_size = hidden_inner_size
|
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
assert self.total_num_heads % self.tp_size == 0
|
|
||||||
self.tp_heads = self.total_num_heads // self.tp_size
|
|
||||||
self.qkv_size = self.num_heads * self.head_dim
|
|
||||||
self.tp_hidden = self.head_dim * self.tp_heads
|
|
||||||
self.model_config = model_config
|
|
||||||
self.cache_config = cache_config
|
|
||||||
self.prefix = prefix
|
|
||||||
|
|
||||||
self.qkv_proj = ColumnParallelLinear(
|
|
||||||
hidden_size,
|
|
||||||
self.hidden_inner_size * 3,
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.qkv_proj",
|
|
||||||
)
|
|
||||||
self.output_gate = ColumnParallelLinear(
|
|
||||||
hidden_size,
|
|
||||||
self.hidden_inner_size,
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.output_gate",
|
|
||||||
)
|
|
||||||
self.out_proj = RowParallelLinear(
|
|
||||||
self.hidden_inner_size,
|
|
||||||
hidden_size,
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.out_proj",
|
|
||||||
)
|
|
||||||
self.norm = MiniMaxText01RMSNormTP(
|
|
||||||
self.hidden_inner_size,
|
|
||||||
eps=1e-5,
|
|
||||||
)
|
|
||||||
|
|
||||||
slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
|
|
||||||
self.num_heads)
|
|
||||||
if num_hidden_layer <= 1:
|
|
||||||
self.slope_rate = slope_rate * (1 + 1e-5)
|
|
||||||
else:
|
|
||||||
self.slope_rate = slope_rate * (1 - layer_idx /
|
|
||||||
(num_hidden_layer - 1) + 1e-5)
|
|
||||||
self.tp_slope = self.slope_rate[self.tp_rank *
|
|
||||||
self.tp_heads:(self.tp_rank + 1) *
|
|
||||||
self.tp_heads].contiguous()
|
|
||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
compilation_config = get_current_vllm_config().compilation_config
|
|
||||||
if prefix in compilation_config.static_forward_context:
|
|
||||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
|
||||||
compilation_config.static_forward_context[prefix] = self
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def weight_direct_load(param: torch.Tensor,
|
|
||||||
loaded_weight: torch.Tensor) -> None:
|
|
||||||
assert param.size() == loaded_weight.size()
|
|
||||||
param.data.copy_(loaded_weight)
|
|
||||||
return
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_slope_tensor(n_attention_heads: int):
|
|
||||||
|
|
||||||
def get_slopes(n):
|
|
||||||
|
|
||||||
def get_slopes_power_of_2(n):
|
|
||||||
start = 2**(-(2**-(math.log2(n) - 3)))
|
|
||||||
ratio = start
|
|
||||||
return [start * ratio**i for i in range(n)]
|
|
||||||
|
|
||||||
if math.log2(n).is_integer():
|
|
||||||
return get_slopes_power_of_2(n)
|
|
||||||
else:
|
|
||||||
closest_power_of_2 = 2**math.floor(math.log2(n))
|
|
||||||
return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
|
|
||||||
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
|
|
||||||
|
|
||||||
slopes = torch.tensor(get_slopes(n_attention_heads),
|
|
||||||
dtype=torch.float32).reshape(
|
|
||||||
n_attention_heads, 1, 1)
|
|
||||||
return slopes
|
|
||||||
|
|
||||||
def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
|
|
||||||
attn_metadata):
|
|
||||||
hidden = []
|
|
||||||
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
|
|
||||||
if _prefill_idx >= len(attn_metadata.query_start_loc):
|
|
||||||
break
|
|
||||||
if _prefill_idx >= len(state_indices_tensor):
|
|
||||||
break
|
|
||||||
# prefills are packed at end of batch in V1
|
|
||||||
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
|
|
||||||
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
|
||||||
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
|
||||||
slot_id = state_indices_tensor[offset + _prefill_idx]
|
|
||||||
qs = q[_start:_end].transpose(0, 1).contiguous()
|
|
||||||
ks = k[_start:_end].transpose(0, 1).contiguous()
|
|
||||||
vs = v[_start:_end].transpose(0, 1).contiguous()
|
|
||||||
slice_layer_cache = kv_cache[slot_id, ...]
|
|
||||||
|
|
||||||
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
|
|
||||||
qs,
|
|
||||||
ks,
|
|
||||||
vs,
|
|
||||||
slice_layer_cache,
|
|
||||||
self.tp_slope,
|
|
||||||
self.BLOCK,
|
|
||||||
layer_idx=self.layer_idx)
|
|
||||||
hidden.append(out_slice.contiguous())
|
|
||||||
if attn_metadata.num_decode_tokens > 0:
|
|
||||||
hidden_decode = self._decode_infer(q, k, v, kv_cache,
|
|
||||||
state_indices_tensor,
|
|
||||||
attn_metadata)
|
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
hidden.insert(0, hidden_decode)
|
|
||||||
else:
|
|
||||||
hidden.append(hidden_decode)
|
|
||||||
|
|
||||||
if not hidden:
|
|
||||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
|
||||||
|
|
||||||
hidden = torch.concat(hidden, dim=0).contiguous()
|
|
||||||
return hidden
|
|
||||||
|
|
||||||
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
|
|
||||||
attn_metadata):
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
|
||||||
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
|
||||||
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
|
|
||||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
|
||||||
slot_id = state_indices_tensor[num_prefills:]
|
|
||||||
else:
|
|
||||||
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
|
||||||
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
|
||||||
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
|
||||||
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
|
|
||||||
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
|
|
||||||
slot_id, 32)
|
|
||||||
return hidden
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
kv_caches: MinimaxCacheParams) -> None:
|
|
||||||
if not envs.VLLM_USE_V1:
|
|
||||||
self._forward(hidden_states, output, positions, kv_caches)
|
|
||||||
else:
|
|
||||||
torch.ops.vllm.linear_attention(
|
|
||||||
hidden_states,
|
|
||||||
output,
|
|
||||||
positions,
|
|
||||||
self.prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
kv_caches: Optional[MinimaxCacheParams]) -> None:
|
|
||||||
forward_context = get_forward_context()
|
|
||||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
|
||||||
if envs.VLLM_USE_V1 and attn_metadata is not None:
|
|
||||||
assert isinstance(attn_metadata, dict)
|
|
||||||
attn_metadata = attn_metadata[self.prefix]
|
|
||||||
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
|
||||||
num_actual_tokens = attn_metadata.num_prefill_tokens + \
|
|
||||||
attn_metadata.num_decode_tokens
|
|
||||||
else:
|
|
||||||
num_actual_tokens = hidden_states.shape[0]
|
|
||||||
|
|
||||||
qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
|
|
||||||
qkv32 = qkv.to(torch.float32)
|
|
||||||
qkvact = torch.nn.functional.silu(qkv32)
|
|
||||||
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
|
|
||||||
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
if attn_metadata is not None:
|
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
|
||||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
|
||||||
|
|
||||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
|
||||||
if num_prefills > 0:
|
|
||||||
num_decode_tokens = getattr(attn_metadata,
|
|
||||||
"num_decode_tokens", 0)
|
|
||||||
for prefill_idx in range(num_prefills):
|
|
||||||
q_start = attn_metadata.query_start_loc[
|
|
||||||
num_decode_tokens + prefill_idx]
|
|
||||||
q_end = attn_metadata.query_start_loc[num_decode_tokens
|
|
||||||
+ prefill_idx +
|
|
||||||
1]
|
|
||||||
query_len = q_end - q_start
|
|
||||||
context_len = attn_metadata.seq_lens[
|
|
||||||
num_decode_tokens + prefill_idx] - query_len
|
|
||||||
if context_len == 0:
|
|
||||||
block_to_clear = state_indices_tensor[
|
|
||||||
num_decode_tokens + prefill_idx]
|
|
||||||
kv_cache[block_to_clear, ...] = 0
|
|
||||||
else:
|
|
||||||
kv_cache = kv_caches.minimax_cache
|
|
||||||
state_indices_tensor = kv_caches.state_indices_tensor
|
|
||||||
|
|
||||||
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
|
||||||
if attn_metadata is None:
|
|
||||||
hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
|
|
||||||
device=q.device,
|
|
||||||
dtype=q.dtype)
|
|
||||||
else:
|
|
||||||
if not decode_only:
|
|
||||||
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
|
|
||||||
state_indices_tensor,
|
|
||||||
attn_metadata)
|
|
||||||
else:
|
|
||||||
hidden = self._decode_infer(q, k, v, kv_cache,
|
|
||||||
state_indices_tensor,
|
|
||||||
attn_metadata)
|
|
||||||
hidden = self.norm._forward(hidden)
|
|
||||||
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
|
|
||||||
hidden = F.sigmoid(gate) * hidden
|
|
||||||
hidden = hidden.to(hidden_states.dtype)
|
|
||||||
output[:num_actual_tokens], _ = self.out_proj(hidden)
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxText01Attention(nn.Module):
|
class MiniMaxText01Attention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -1397,35 +1025,3 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
tp_size=parallel_config.tensor_parallel_size,
|
tp_size=parallel_config.tensor_parallel_size,
|
||||||
head_dim=hf_config.head_dim,
|
head_dim=hf_config.head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def linear_attention(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
output: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
layer_name: str,
|
|
||||||
) -> None:
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
|
||||||
self = forward_context.no_compile_layers[layer_name]
|
|
||||||
self._forward(hidden_states=hidden_states,
|
|
||||||
output=output,
|
|
||||||
positions=positions,
|
|
||||||
kv_caches=None)
|
|
||||||
|
|
||||||
|
|
||||||
def linear_attention_fake(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
output: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
layer_name: str,
|
|
||||||
) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="linear_attention",
|
|
||||||
op_func=linear_attention,
|
|
||||||
mutates_args=["output"],
|
|
||||||
fake_impl=linear_attention_fake,
|
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user