use snippest

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
shen-shanshan 2025-12-24 03:21:23 +00:00
parent e1c9d6e7e0
commit a2c7852e4c
22 changed files with 185 additions and 41 deletions

View File

@ -57,46 +57,76 @@ For example:
## Types of Supported CustomOp in vLLM
| Category | OP Name | OP Class |
|----------|---------|----------|
| Attention | `mm_encoder_attn` | `MMEncoderAttention` |
| Attention | `multi_head_latent_attention` | `MultiHeadLatentAttentionWrapper` |
| Activation | `fatrelu_and_mul` | `FatreluAndMul` |
| Activation | `silu_and_mul` | `SiluAndMul` |
| Activation | `mul_and_silu` | `MulAndSilu` |
| Activation | `gelu_and_mul_sparse` | `GeluAndMulSparse` |
| Activation | `gelu_and_mul` | `GeluAndMul` |
| Activation | `swigluoai_and_mul` | `SwigluOAIAndMul` |
| Activation | `gelu_new` | `NewGELU` |
| Activation | `gelu_fast` | `FastGELU` |
| Activation | `quick_gelu` | `QuickGELU` |
| Activation | `relu2` | `ReLUSquaredActivation` |
| Activation | `xielu` | `XIELU` |
| Conv | `conv2d` | `Conv2dLayer` |
| Conv | `conv3d` | `Conv3dLayer` |
| Conv | `short_conv` | `ShortConv` |
| Embedding | `vocab_parallel_embedding` | `VocabParallelEmbedding` |
| Embedding | `parallel_lm_head` | `ParallelLMHead` |
| Linear | `row_parallel_linear` | `RowParallelLinear` |
| Linear | `column_parallel_linear` | `ColumnParallelLinear` |
| Linear | `replicated_linear` | `ReplicatedLinear` |
| Logits Processor | `logits_processor` | `LogitsProcessor` |
| Mamba | `mamba_mixer` | `MambaMixer` |
| Mamba | `mamba_mixer2` | `MambaMixer2` |
| Mamba | `plamo2_mamba_mixer` | `Plamo2MambaMixer` |
| Mamba | `mixer2_gated_rms_norm` | `Mixer2RMSNormGated` |
| MoE | `fused_moe` | `FusedMoE` |
| MoE | `modular_fused_moe` | `FusedMoEModularMethod` |
| MoE | `unquantized_fused_moe` | `UnquantizedFusedMoEMethod` |
| MoE | `transformers_fused_moe` | `TransformersFusedMoE` |
| MoE | `grouped_topk` | `GroupedTopk` |
| Norm | `rms_norm` | `RMSNorm` |
| Norm | `gemma_rms_norm` | `GemmaRMSNorm` |
| Norm | `rms_norm_gated` | `RMSNormGated` |
| Quantization | `quant_fp8` | `QuantFP8` |
| Rope | `rotary_embedding` | `RotaryEmbeddingBase` |
| Rope | `dual_chunk_rotary_embedding` | `DualChunkRotaryEmbedding` |
| Rope | `apply_rotary_emb` | `ApplyRotaryEmb` |
**1. Attention:**
--8<-- "../../vllm/attention/layers/mm_encoder_attention.py:mm_encoder_attn"
--8<-- "../../vllm/model_executor/layers/mla.py:multi_head_latent_attention"
**2. Activation:**
--8<-- "../../vllm/model_executor/layers/activation.py:silu_and_mul"
--8<-- "../../vllm/model_executor/layers/activation.py:mul_and_silu"
--8<-- "../../vllm/model_executor/layers/activation.py:gelu_new"
--8<-- "../../vllm/model_executor/layers/activation.py:gelu_fast"
--8<-- "../../vllm/model_executor/layers/activation.py:quick_gelu"
--8<-- "../../vllm/model_executor/layers/activation.py:gelu_and_mul"
--8<-- "../../vllm/model_executor/layers/activation.py:gelu_and_mul_sparse"
--8<-- "../../vllm/model_executor/layers/activation.py:relu2"
--8<-- "../../vllm/model_executor/layers/activation.py:xielu"
--8<-- "../../vllm/model_executor/layers/activation.py:swigluoai_and_mul"
--8<-- "../../vllm/model_executor/layers/activation.py:fatrelu_and_mul"
**3. MM-Conv:**
--8<-- "../../vllm/model_executor/layers/conv.py:conv2d"
--8<-- "../../vllm/model_executor/layers/conv.py:conv3d"
**4. Embedding:**
--8<-- "../../vllm/model_executor/layers/vocab_parallel_embedding.py:vocab_parallel_embedding"
--8<-- "../../vllm/model_executor/layers/vocab_parallel_embedding.py:parallel_lm_head"
**5. Linear:**
--8<-- "../../vllm/model_executor/layers/linear.py:row_parallel_linear"
--8<-- "../../vllm/model_executor/layers/linear.py:row_parallel_linear:column_parallel_linear"
--8<-- "../../vllm/model_executor/layers/linear.py:row_parallel_linear:replicated_linear"
**6. Logits Processor:**
--8<-- "../../vllm/model_executor/layers/logits_processor.py:logits_processor"
**7. Mamba:**
--8<-- "../../vllm/model_executor/layers/mamba/mamba_mixer.py:mamba_mixer"
--8<-- "../../vllm/model_executor/layers/mamba/mamba_mixer2.py:mamba_mixer2"
--8<-- "../../vllm/model_executor/layers/mamba/mamba_mixer2.py:mixer2_gated_rms_norm"
--8<-- "../../vllm/model_executor/models/plamo2.py:plamo2_mamba_mixer"
--8<-- "../../vllm/model_executor/layers/mamba/short_conv.py:short_conv"
**8. MoE:**
--8<-- "../../vllm/model_executor/layers/fused_moe/layer.py:fused_moe"
--8<-- "../../vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py:modular_fused_moe"
--8<-- "../../vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py:unquantized_fused_moe"
--8<-- "../../vllm/model_executor/models/transformers/moe.py:transformers_fused_moe"
--8<-- "../../vllm/model_executor/layers/fused_moe/fused_moe.py:grouped_topk"
**9. Norm:**
--8<-- "../../vllm/model_executor/layers/layernorm.py:rms_norm"
--8<-- "../../vllm/model_executor/layers/layernorm.py:rms_norm_gated"
--8<-- "../../vllm/model_executor/layers/layernorm.py:gemma_rms_norm"
**10. Quantization:**
--8<-- "../../vllm/model_executor/layers/quantization/input_quant_fp8.py:quant_fp8"
**11. Rope:**
--8<-- "../../vllm/model_executor/layers/rotary_embedding/base.py:rotary_embedding"
--8<-- "../../vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py:dual_chunk_rotary_embedding"
--8<-- "../../vllm/model_executor/layers/rotary_embedding/common.py:apply_rotary_emb"
## Guidelines for Implementing a New CustomOp

View File

@ -18,10 +18,13 @@ from vllm.model_executor.models.vision import get_vit_attn_backend
logger = init_logger(__name__)
# --8<-- [start:mm_encoder_attn]
@CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp):
"""Multi-headed attention without any cache, used for multimodal encoder."""
# --8<-- [end:mm_encoder_attn]
def __init__(
self,
num_heads: int,

View File

@ -22,6 +22,7 @@ from vllm.utils.collection_utils import LazyDict
logger = init_logger(__name__)
# --8<-- [start:fatrelu_and_mul]
@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
"""An activation function for FATReLU.
@ -35,6 +36,8 @@ class FatreluAndMul(CustomOp):
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
# --8<-- [end:fatrelu_and_mul]
def __init__(self, threshold: float = 0.0):
super().__init__()
self.threshold = threshold
@ -58,6 +61,7 @@ class FatreluAndMul(CustomOp):
return out
# --8<-- [start:silu_and_mul]
@CustomOp.register("silu_and_mul")
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
@ -69,6 +73,8 @@ class SiluAndMul(CustomOp):
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
# --8<-- [end:silu_and_mul]
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike():
@ -101,6 +107,7 @@ class SiluAndMul(CustomOp):
return out
# --8<-- [start:mul_and_silu]
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
"""An activation function for SwiGLU.
@ -112,6 +119,8 @@ class MulAndSilu(CustomOp):
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
# --8<-- [end:mul_and_silu]
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike():
@ -139,6 +148,7 @@ class MulAndSilu(CustomOp):
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
# --8<-- [start:gelu_and_mul_sparse]
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
"""An activation function for GeluAndMulSparse.
@ -153,6 +163,8 @@ class GeluAndMulSparse(CustomOp):
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
# --8<-- [end:gelu_and_mul_sparse]
def __init__(self, activation_sparsity: float, approximate: str = "none"):
super().__init__()
# Gelu.
@ -195,6 +207,7 @@ class GeluAndMulSparse(CustomOp):
return self.forward_native(x)
# --8<-- [start:gelu_and_mul]
@CustomOp.register("gelu_and_mul")
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
@ -206,6 +219,8 @@ class GeluAndMul(CustomOp):
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
# --8<-- [end:gelu_and_mul]
def __init__(self, approximate: str = "none"):
super().__init__()
self.approximate = approximate
@ -257,9 +272,12 @@ class GeluAndMul(CustomOp):
return f"approximate={repr(self.approximate)}"
# --8<-- [start:swigluoai_and_mul]
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
# https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
# --8<-- [end:swigluoai_and_mul]
def __init__(self, alpha: float = 1.702, limit: float = 7.0):
super().__init__()
self.alpha = alpha
@ -286,8 +304,11 @@ class SwigluOAIAndMul(CustomOp):
return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"
# --8<-- [start:gelu_new]
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):
# --8<-- [end:gelu_new]
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
@ -311,8 +332,11 @@ class NewGELU(CustomOp):
return self.op(x)
# --8<-- [start:gelu_fast]
@CustomOp.register("gelu_fast")
class FastGELU(CustomOp):
# --8<-- [end:gelu_fast]
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
@ -335,9 +359,12 @@ class FastGELU(CustomOp):
return self.op(x)
# --8<-- [start:quick_gelu]
@CustomOp.register("quick_gelu")
class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
# --8<-- [end:quick_gelu]
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
@ -365,12 +392,15 @@ class QuickGELU(CustomOp):
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
# --8<-- [start:relu2]
@CustomOp.register("relu2")
class ReLUSquaredActivation(CustomOp):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
# --8<-- [end:relu2]
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return torch.square(F.relu(x))
@ -380,6 +410,7 @@ class ReLUSquaredActivation(CustomOp):
return self.forward_native(x)
# --8<-- [start:xielu]
@CustomOp.register("xielu")
class XIELU(CustomOp):
"""
@ -388,6 +419,8 @@ class XIELU(CustomOp):
Otherwise, we emit a single warning and use xIELU Python
"""
# --8<-- [end:xielu]
def __init__(
self,
alpha_p_init: float = 0.8,

View File

@ -105,10 +105,13 @@ class ConvLayerBase(CustomOp):
return s
# --8<-- [start:conv2d]
@CustomOp.register("conv2d")
class Conv2dLayer(ConvLayerBase):
"""Conv layer with Conv2d."""
# --8<-- [end:conv2d]
num_dim = 2
def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
@ -204,10 +207,13 @@ class CausalConv2dLayer(Conv2dLayer):
return x
# --8<-- [start:conv3d]
@CustomOp.register("conv3d")
class Conv3dLayer(ConvLayerBase):
"""Conv layer with Conv3d."""
# --8<-- [end:conv3d]
num_dim = 3
def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:

View File

@ -1283,10 +1283,13 @@ def grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
# --8<-- [start:grouped_topk]
@CustomOp.register("grouped_topk")
class GroupedTopk(CustomOp):
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
# --8<-- [end:grouped_topk]
def __init__(
self,
topk: int,

View File

@ -20,8 +20,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
logger = init_logger(__name__)
# --8<-- [start:modular_fused_moe]
@CustomOp.register("modular_fused_moe")
class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
# --8<-- [end:modular_fused_moe]
def __init__(
self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel
):

View File

@ -297,6 +297,7 @@ def maybe_roundup_hidden_size(
return hidden_size
# --8<-- [start:fused_moe]
@CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
"""FusedMoE layer for MoE models.
@ -320,6 +321,8 @@ class FusedMoE(CustomOp):
enable_eplb: Whether to enable expert parallelism load balancer.
"""
# --8<-- [end:fused_moe]
def __init__(
self,
num_experts: int, # Global number of experts

View File

@ -46,10 +46,13 @@ else:
logger = init_logger(__name__)
# --8<-- [start:unquantized_fused_moe]
@CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
# --8<-- [end:unquantized_fused_moe]
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)

View File

@ -88,6 +88,7 @@ def dispatch_rocm_rmsnorm_func(
return rms_norm
# --8<-- [start:rms_norm]
@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
"""Root mean square normalization.
@ -96,6 +97,8 @@ class RMSNorm(CustomOp):
Refer to https://arxiv.org/abs/1910.07467
"""
# --8<-- [end:rms_norm]
def __init__(
self,
hidden_size: int,
@ -253,6 +256,7 @@ class RMSNorm(CustomOp):
return s
# --8<-- [start:gemma_rms_norm]
@CustomOp.register("gemma_rms_norm")
class GemmaRMSNorm(CustomOp):
"""RMS normalization for Gemma.
@ -262,6 +266,8 @@ class GemmaRMSNorm(CustomOp):
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
"""
# --8<-- [end:gemma_rms_norm]
def __init__(
self,
hidden_size: int,
@ -321,6 +327,7 @@ class GemmaRMSNorm(CustomOp):
return self.forward_native(x, residual)
# --8<-- [start:rms_norm_gated]
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
"""RMS Normalization with optional gating.
@ -331,6 +338,8 @@ class RMSNormGated(CustomOp):
- Optional gating with SiLU activation
"""
# --8<-- [end:rms_norm_gated]
def __init__(
self,
hidden_size: int,

View File

@ -296,6 +296,7 @@ class LinearBase(CustomOp):
param.tp_size = self.tp_size
# --8<-- [start:replicated_linear]
@CustomOp.register("replicated_linear")
class ReplicatedLinear(LinearBase):
"""Replicated linear layer.
@ -313,6 +314,8 @@ class ReplicatedLinear(LinearBase):
disable_tp: Take no effect for replicated linear layers.
"""
# --8<-- [end:replicated_linear]
def __init__(
self,
input_size: int,
@ -413,6 +416,7 @@ class ReplicatedLinear(LinearBase):
return s
# --8<-- [start:column_parallel_linear]
@CustomOp.register("column_parallel_linear")
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
@ -440,6 +444,8 @@ class ColumnParallelLinear(LinearBase):
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
# --8<-- [end:column_parallel_linear]
def __init__(
self,
input_size: int,
@ -1276,6 +1282,7 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight)
# --8<-- [start:row_parallel_linear]
@CustomOp.register("row_parallel_linear")
class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.
@ -1310,6 +1317,8 @@ class RowParallelLinear(LinearBase):
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
# --8<-- [end:row_parallel_linear]
def __init__(
self,
input_size: int,

View File

@ -13,6 +13,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
from vllm.platforms import current_platform
# --8<-- [start:logits_processor]
@CustomOp.register("logits_processor")
class LogitsProcessor(CustomOp):
"""Process logits and apply logits processors from sampling metadata.
@ -23,6 +24,8 @@ class LogitsProcessor(CustomOp):
3. Apply logits processors (if any).
"""
# --8<-- [end:logits_processor]
def __init__(
self,
vocab_size: int,

View File

@ -39,6 +39,7 @@ from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
# --8<-- [start:mamba_mixer]
@CustomOp.register("mamba_mixer")
class MambaMixer(MambaBase, CustomOp):
"""
@ -51,6 +52,8 @@ class MambaMixer(MambaBase, CustomOp):
**selective** state spaces)
"""
# --8<-- [end:mamba_mixer]
def __init__(
self,
hidden_size: int,

View File

@ -49,8 +49,11 @@ from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
# --8<-- [start:mixer2_gated_rms_norm]
@CustomOp.register("mixer2_gated_rms_norm")
class Mixer2RMSNormGated(CustomOp):
# --8<-- [end:mixer2_gated_rms_norm]
def __init__(
self,
full_hidden_size: int,
@ -214,6 +217,7 @@ def mamba_v2_sharded_weight_loader(
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
# --8<-- [start:mamba_mixer2]
@CustomOp.register("mamba_mixer2")
class MambaMixer2(MambaBase, CustomOp):
"""
@ -226,6 +230,8 @@ class MambaMixer2(MambaBase, CustomOp):
**selective** state spaces)
"""
# --8<-- [end:mamba_mixer2]
def __init__(
self,
hidden_size: int,

View File

@ -27,8 +27,11 @@ from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata
# --8<-- [start:short_conv]
@CustomOp.register("short_conv")
class ShortConv(MambaBase, CustomOp):
# --8<-- [end:short_conv]
def __init__(
self,
config,

View File

@ -29,6 +29,7 @@ class MLAModules:
indexer_rotary_emb: torch.nn.Module | None = None
# --8<-- [start:multi_head_latent_attention]
@CustomOp.register("multi_head_latent_attention")
class MultiHeadLatentAttentionWrapper(CustomOp):
"""MLA layer registered as CustomOp to allow OOT backends to add
@ -47,6 +48,8 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
3. Return the output tensor.
"""
# --8<-- [end:multi_head_latent_attention]
def __init__(
self,
hidden_size: int,

View File

@ -19,6 +19,7 @@ _FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min
_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
# --8<-- [start:quant_fp8]
@CustomOp.register("quant_fp8")
class QuantFP8(CustomOp):
"""
@ -26,6 +27,8 @@ class QuantFP8(CustomOp):
This CustomOp supports both static and dynamic quantization.
"""
# --8<-- [end:quant_fp8]
def __init__(
self,
static: bool,

View File

@ -10,10 +10,13 @@ from vllm.model_executor.custom_op import CustomOp
from .common import ApplyRotaryEmb
# --8<-- [start:rotary_embedding]
@CustomOp.register("rotary_embedding")
class RotaryEmbeddingBase(CustomOp):
"""Original rotary positional embedding."""
# --8<-- [end:rotary_embedding]
def __init__(
self,
head_size: int,

View File

@ -118,8 +118,11 @@ direct_register_custom_op(
)
# --8<-- [start:apply_rotary_emb]
@CustomOp.register("apply_rotary_emb")
class ApplyRotaryEmb(CustomOp):
# --8<-- [end:apply_rotary_emb]
def __init__(
self,
enforce_enable: bool = False,

View File

@ -9,10 +9,13 @@ from vllm.model_executor.custom_op import CustomOp
from .common import rotate_gptj, rotate_neox
# --8<-- [start:dual_chunk_rotary_embedding]
@CustomOp.register("dual_chunk_rotary_embedding")
class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention."""
# --8<-- [end:dual_chunk_rotary_embedding]
def __init__(
self,
head_size: int,

View File

@ -181,6 +181,7 @@ def get_masked_input_and_mask(
return input_, ~vocab_mask
# --8<-- [start:vocab_parallel_embedding]
@CustomOp.register("vocab_parallel_embedding")
class VocabParallelEmbedding(CustomOp):
"""Embedding parallelized in the vocabulary dimension.
@ -221,6 +222,8 @@ class VocabParallelEmbedding(CustomOp):
prefix: full name of the layer in the state dict
""" # noqa: E501
# --8<-- [end:vocab_parallel_embedding]
def __init__(
self,
num_embeddings: int,
@ -492,6 +495,7 @@ class VocabParallelEmbedding(CustomOp):
return s
# --8<-- [start:parallel_lm_head]
@CustomOp.register("parallel_lm_head")
class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.
@ -509,6 +513,8 @@ class ParallelLMHead(VocabParallelEmbedding):
padding_size: padding size for the vocabulary.
"""
# --8<-- [end:parallel_lm_head]
def __init__(
self,
num_embeddings: int,

View File

@ -97,8 +97,11 @@ def is_mamba(config: Plamo2Config, i: int) -> bool:
# Adapted from:
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
# transformers.models.mamba.modeling_mamba.MambaMixer
@CustomOp.register(name="plamo2_mamba_mixer")
# --8<-- [start:plamo2_mamba_mixer]
@CustomOp.register("plamo2_mamba_mixer")
class Plamo2MambaMixer(MambaBase, CustomOp):
# --8<-- [end:plamo2_mamba_mixer]
def __init__(self, vllm_config: VllmConfig, *, prefix: str = "", **kwargs) -> None:
super().__init__()
self.config = vllm_config.model_config.hf_config

View File

@ -37,10 +37,13 @@ if TYPE_CHECKING:
from vllm.config import VllmConfig
# --8<-- [start:transformers_fused_moe]
@CustomOp.register("transformers_fused_moe")
class TransformersFusedMoE(FusedMoE):
"""Custom FusedMoE for the Transformers modeling backend."""
# --8<-- [end:transformers_fused_moe]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._topk_ids: torch.Tensor = None