mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 10:44:25 +08:00
[V1] [Hybrid] Enable compile and piecewise CUDA graph for MiniMax-Text models (#22589)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
52883ed084
commit
dd58932280
@ -339,6 +339,7 @@ class CompilationConfig:
|
|||||||
"vllm.mamba_mixer2",
|
"vllm.mamba_mixer2",
|
||||||
"vllm.mamba_mixer",
|
"vllm.mamba_mixer",
|
||||||
"vllm.short_conv",
|
"vllm.short_conv",
|
||||||
|
"vllm.linear_attention",
|
||||||
]
|
]
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Inference-only MiniMaxText01 model."""
|
"""Inference-only MiniMaxText01 model."""
|
||||||
import copy
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
@ -19,13 +18,14 @@ from transformers import MiniMaxConfig
|
|||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||||
get_current_vllm_config)
|
get_current_vllm_config)
|
||||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_pp_group, get_tensor_model_parallel_rank,
|
get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
@ -43,12 +43,15 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
|||||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.utils import maybe_prefix
|
from vllm.model_executor.models.utils import maybe_prefix
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
||||||
|
|
||||||
from .interfaces import HasInnerState, IsHybrid
|
from .interfaces import HasInnerState, IsHybrid
|
||||||
@ -143,61 +146,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
|||||||
return self._forward(x)
|
return self._forward(x)
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxText01RotaryEmbedding(CustomOp):
|
|
||||||
name = "MiniMaxText01RotaryEmbedding"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
head_size: int,
|
|
||||||
rotary_dim: int,
|
|
||||||
max_position: int,
|
|
||||||
base: float,
|
|
||||||
is_neox_style: bool,
|
|
||||||
cache_dtype: torch.dtype,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.head_size = head_size
|
|
||||||
self.rotary_dim = rotary_dim
|
|
||||||
self.max_position_embeddings = max_position
|
|
||||||
self.base = base
|
|
||||||
self.is_neox_style = is_neox_style
|
|
||||||
self.cache_dtype = cache_dtype
|
|
||||||
cache = self._compute_cos_sin_cache().to(cache_dtype)
|
|
||||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
|
||||||
"""Compute the inverse frequency."""
|
|
||||||
inv_freq = 1.0 / (base**(torch.arange(
|
|
||||||
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
|
||||||
return inv_freq
|
|
||||||
|
|
||||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
|
||||||
"""Compute the cos and sin cache."""
|
|
||||||
inv_freq = self._compute_inv_freq(self.base)
|
|
||||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
|
||||||
cos = freqs.cos()
|
|
||||||
sin = freqs.sin()
|
|
||||||
cache = torch.cat((cos, sin), dim=-1)
|
|
||||||
return cache
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
|
|
||||||
query_cast = query.to(self.cache_dtype)
|
|
||||||
key_cast = key.to(self.cache_dtype)
|
|
||||||
ops.rotary_embedding(positions, query_cast, key_cast, self.head_size,
|
|
||||||
self.cos_sin_cache, self.is_neox_style)
|
|
||||||
query = query_cast.to(query.dtype)
|
|
||||||
key = key_cast.to(key.dtype)
|
|
||||||
return query, key
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxText01MLP(nn.Module):
|
class MiniMaxText01MLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -526,20 +474,40 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
|||||||
slot_id, 32)
|
slot_id, 32)
|
||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||||
kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor:
|
positions: torch.Tensor,
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
kv_caches: MinimaxCacheParams) -> None:
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
self._forward(hidden_states, output, positions, kv_caches)
|
||||||
|
else:
|
||||||
|
torch.ops.vllm.linear_attention(
|
||||||
|
hidden_states,
|
||||||
|
output,
|
||||||
|
positions,
|
||||||
|
self.prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: Optional[MinimaxCacheParams]) -> None:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||||
|
if envs.VLLM_USE_V1 and attn_metadata is not None:
|
||||||
|
assert isinstance(attn_metadata, dict)
|
||||||
|
attn_metadata = attn_metadata[self.prefix]
|
||||||
|
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
||||||
|
num_actual_tokens = attn_metadata.num_prefill_tokens + \
|
||||||
|
attn_metadata.num_decode_tokens
|
||||||
|
else:
|
||||||
|
num_actual_tokens = hidden_states.shape[0]
|
||||||
|
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
|
||||||
qkv32 = qkv.to(torch.float32)
|
qkv32 = qkv.to(torch.float32)
|
||||||
qkvact = torch.nn.functional.silu(qkv32)
|
qkvact = torch.nn.functional.silu(qkv32)
|
||||||
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
|
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
|
||||||
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
||||||
forward_context = get_forward_context()
|
|
||||||
attn_metadata = forward_context.attn_metadata
|
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
if attn_metadata is not None:
|
if attn_metadata is not None:
|
||||||
assert isinstance(attn_metadata, dict)
|
|
||||||
attn_metadata = attn_metadata[self.prefix]
|
|
||||||
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
|
|
||||||
@ -578,13 +546,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
|||||||
hidden = self._decode_infer(q, k, v, kv_cache,
|
hidden = self._decode_infer(q, k, v, kv_cache,
|
||||||
state_indices_tensor,
|
state_indices_tensor,
|
||||||
attn_metadata)
|
attn_metadata)
|
||||||
|
|
||||||
hidden = self.norm._forward(hidden)
|
hidden = self.norm._forward(hidden)
|
||||||
gate, _ = self.output_gate(hidden_states)
|
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
|
||||||
hidden = F.sigmoid(gate) * hidden
|
hidden = F.sigmoid(gate) * hidden
|
||||||
hidden = hidden.to(hidden_states.dtype)
|
hidden = hidden.to(hidden_states.dtype)
|
||||||
hidden, _ = self.out_proj(hidden)
|
output[:num_actual_tokens], _ = self.out_proj(hidden)
|
||||||
return hidden
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxText01Attention(nn.Module):
|
class MiniMaxText01Attention(nn.Module):
|
||||||
@ -652,23 +618,23 @@ class MiniMaxText01Attention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
head_size=self.head_dim,
|
||||||
|
rotary_dim=rotary_dim,
|
||||||
|
max_position=max_position,
|
||||||
|
base=int(rope_theta),
|
||||||
|
is_neox_style=True,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||||
**kwargs) -> torch.Tensor:
|
positions: torch.Tensor, **kwargs) -> None:
|
||||||
forward_context = get_forward_context()
|
|
||||||
attn_metadata = forward_context.attn_metadata
|
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
if envs.VLLM_USE_V1:
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
if attn_metadata is not None:
|
|
||||||
q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb(
|
|
||||||
positions, q, k)
|
|
||||||
else:
|
|
||||||
q, k = attn_metadata.rotary_emb(positions, q, k)
|
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output[:], _ = self.o_proj(attn_output)
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxText01DecoderLayer(nn.Module):
|
class MiniMaxText01DecoderLayer(nn.Module):
|
||||||
@ -816,16 +782,15 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
|||||||
is_warmup: bool = False,
|
is_warmup: bool = False,
|
||||||
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
|
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
|
||||||
attn_metadata = forward_context.attn_metadata
|
|
||||||
layernorm_input = hidden_states
|
layernorm_input = hidden_states
|
||||||
layernorm_output = self.input_layernorm(layernorm_input)
|
layernorm_output = self.input_layernorm(layernorm_input)
|
||||||
residual = layernorm_output if self.postnorm else layernorm_input
|
residual = layernorm_output if self.postnorm else layernorm_input
|
||||||
self_attention_output = self.self_attn(
|
self_attention_output = torch.empty_like(layernorm_output)
|
||||||
|
self.self_attn(
|
||||||
hidden_states=layernorm_output,
|
hidden_states=layernorm_output,
|
||||||
|
output=self_attention_output,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
kv_caches=kv_caches,
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
residual = residual * self.layernorm_attention_alpha
|
residual = residual * self.layernorm_attention_alpha
|
||||||
@ -839,8 +804,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
|||||||
if self.expert_num == 1:
|
if self.expert_num == 1:
|
||||||
hidden_states = self.mlp(layernorm_output)
|
hidden_states = self.mlp(layernorm_output)
|
||||||
else:
|
else:
|
||||||
moe_hidden_states = self.block_sparse_moe(
|
moe_layernorm_output = layernorm_output.clone()
|
||||||
copy.deepcopy(layernorm_output))
|
moe_hidden_states = self.block_sparse_moe(moe_layernorm_output)
|
||||||
if self.shared_moe:
|
if self.shared_moe:
|
||||||
before_moe_dtype = layernorm_output.dtype
|
before_moe_dtype = layernorm_output.dtype
|
||||||
moe_hidden_fp32 = moe_hidden_states.to(torch.float32)
|
moe_hidden_fp32 = moe_hidden_states.to(torch.float32)
|
||||||
@ -878,18 +843,16 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class MiniMaxText01Model(nn.Module):
|
class MiniMaxText01Model(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
self,
|
|
||||||
config: MiniMaxConfig,
|
|
||||||
model_config: Optional[ModelConfig] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
cache_config: Optional[CacheConfig] = None,
|
|
||||||
scheduler_config=None,
|
|
||||||
prefix: str = "",
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
config: MiniMaxConfig = vllm_config.model_config.hf_config
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
scheduler_config = vllm_config.scheduler_config
|
||||||
|
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
@ -976,24 +939,6 @@ class MiniMaxText01Model(nn.Module):
|
|||||||
self.minimax_cache = MinimaxCacheManager(
|
self.minimax_cache = MinimaxCacheManager(
|
||||||
dtype=torch.float32, cache_shape=self.cache_shape)
|
dtype=torch.float32, cache_shape=self.cache_shape)
|
||||||
|
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
|
||||||
head_dim = getattr(config, "head_dim", None)
|
|
||||||
if head_dim is None:
|
|
||||||
head_dim = config.hidden_size // config.num_attention_heads
|
|
||||||
if hasattr(config, "max_model_len") and isinstance(
|
|
||||||
config.max_model_len, int):
|
|
||||||
max_position_embeddings = min(config.max_position_embeddings,
|
|
||||||
config.max_model_len)
|
|
||||||
self.rotary_emb = MiniMaxText01RotaryEmbedding(
|
|
||||||
head_dim,
|
|
||||||
rotary_dim=config.rotary_dim
|
|
||||||
if hasattr(config, "rotary_dim") else head_dim,
|
|
||||||
max_position=max_position_embeddings,
|
|
||||||
base=int(rope_theta),
|
|
||||||
is_neox_style=True,
|
|
||||||
cache_dtype=torch.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
norm_kwargs = {}
|
norm_kwargs = {}
|
||||||
if hasattr(config, "rms_norm_eps"):
|
if hasattr(config, "rms_norm_eps"):
|
||||||
norm_kwargs["eps"] = config.rms_norm_eps
|
norm_kwargs["eps"] = config.rms_norm_eps
|
||||||
@ -1043,12 +988,11 @@ class MiniMaxText01Model(nn.Module):
|
|||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
if not envs.VLLM_USE_V1 and attn_metadata is None:
|
if not envs.VLLM_USE_V1 and attn_metadata is None:
|
||||||
return None
|
return None
|
||||||
if "request_ids_to_seq_ids" not in kwargs:
|
|
||||||
kwargs["request_ids_to_seq_ids"] = {}
|
|
||||||
if "finished_requests_ids" not in kwargs:
|
|
||||||
kwargs["finished_requests_ids"] = []
|
|
||||||
|
|
||||||
if not envs.VLLM_USE_V1:
|
if not envs.VLLM_USE_V1:
|
||||||
|
if "request_ids_to_seq_ids" not in kwargs:
|
||||||
|
kwargs["request_ids_to_seq_ids"] = {}
|
||||||
|
if "finished_requests_ids" not in kwargs:
|
||||||
|
kwargs["finished_requests_ids"] = []
|
||||||
(
|
(
|
||||||
minimax_cache_tensors,
|
minimax_cache_tensors,
|
||||||
state_indices_tensor,
|
state_indices_tensor,
|
||||||
@ -1077,16 +1021,6 @@ class MiniMaxText01Model(nn.Module):
|
|||||||
|
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for i in range(self.start_layer, self.end_layer):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
if attn_metadata is not None:
|
|
||||||
# TODO (tdoublep): this whole thing with the rotary_emb is
|
|
||||||
# weird. we shouldn't be passing it via attn_metadata imo.
|
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
if isinstance(layer.self_attn, MiniMaxText01Attention):
|
|
||||||
attn_metadata[layer.prefix +
|
|
||||||
".attn"].rotary_emb = self.rotary_emb
|
|
||||||
else:
|
|
||||||
attn_metadata.rotary_emb = self.rotary_emb
|
|
||||||
|
|
||||||
_caches = None
|
_caches = None
|
||||||
if not envs.VLLM_USE_V1 and isinstance(
|
if not envs.VLLM_USE_V1 and isinstance(
|
||||||
layer.self_attn, MiniMaxText01LinearAttention):
|
layer.self_attn, MiniMaxText01LinearAttention):
|
||||||
@ -1120,7 +1054,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
@ -1133,13 +1066,8 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
self.unpadded_vocab_size = self.config.vocab_size
|
self.unpadded_vocab_size = self.config.vocab_size
|
||||||
if hasattr(vllm_config.model_config, "max_model_len"):
|
if hasattr(vllm_config.model_config, "max_model_len"):
|
||||||
self.config.max_model_len = vllm_config.model_config.max_model_len
|
self.config.max_model_len = vllm_config.model_config.max_model_len
|
||||||
self.model = MiniMaxText01Model(
|
self.model = MiniMaxText01Model(vllm_config=vllm_config,
|
||||||
self.config,
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
model_config=vllm_config.model_config,
|
|
||||||
cache_config=vllm_config.cache_config,
|
|
||||||
quant_config=quant_config,
|
|
||||||
scheduler_config=vllm_config.scheduler_config,
|
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
self.unpadded_vocab_size,
|
self.unpadded_vocab_size,
|
||||||
@ -1469,3 +1397,35 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
|||||||
tp_size=parallel_config.tensor_parallel_size,
|
tp_size=parallel_config.tensor_parallel_size,
|
||||||
head_dim=hf_config.head_dim,
|
head_dim=hf_config.head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_attention(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
|
self._forward(hidden_states=hidden_states,
|
||||||
|
output=output,
|
||||||
|
positions=positions,
|
||||||
|
kv_caches=None)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_attention_fake(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="linear_attention",
|
||||||
|
op_func=linear_attention,
|
||||||
|
mutates_args=["output"],
|
||||||
|
fake_impl=linear_attention_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user