From f0d738f0cc460b14981aab5350b86130c6e7c5ac Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Tue, 23 Dec 2025 09:07:23 +0000 Subject: [PATCH] add custom op doc Signed-off-by: shen-shanshan <467638484@qq.com> --- docs/design/custom_op.md | 236 +++++++++++++++++++++++++++++++ vllm/config/compilation.py | 3 +- vllm/model_executor/custom_op.py | 9 +- 3 files changed, 244 insertions(+), 4 deletions(-) create mode 100644 docs/design/custom_op.md diff --git a/docs/design/custom_op.md b/docs/design/custom_op.md new file mode 100644 index 0000000000000..fee7f89171b23 --- /dev/null +++ b/docs/design/custom_op.md @@ -0,0 +1,236 @@ +# CustomOp + +`CustomOp` is an abstract class used for dispatching the forward method of various operations to the appropriate backend. It also offers a mechanism for both vLLM and OOT (Out-Of-Tree) plugins to register their custom operations. + +This document will introduce how CustomOp works in vLLM and how to implement a new `CustomOp`. + +## How CustomOp Works in vLLM + +`CustomOp` manages two dictionaries of all custom ops (i.e., op classes, indexed by registered name) in its class, for vLLM and OOT plugins respectively. + +??? code + + ```python + class CustomOp(nn.Module): + + op_registry: dict[str, type["CustomOp"]] = {} + op_registry_oot: dict[str, type["CustomOp"]] = {} + ``` + +We can use `@CustomOp.register("op_name")` to register an op class to the `CustomOp` system. After this, the `op_name` and its class will be added into the `op_registry` dictionary. In addition, We can also register an OOT op by `@CustomOp.register_oot("op_name")`. We will introduce this mechanism in detail later. + +When a `CustomOp` is called (i.e., call its `forward()` method), if it is enabled, it will automatically dispatch the forward method to the appropriate backend according to `current_platform`. Otherwise (i.e., it is disabled), it will only call the `forward_native()` method to use PyTorch-native implementation of this forward method. + +- **CPU platform:** dispatch to `forward_cpu()`. +- **CUDA platform:** dispatch to `forward_cuda()`. +- **ROCm platform:** dispatch to `forward_hip()`. If `forward_hip()` is not implemented, it will use `forward_cuda()` as a fallback. +- **XPU platform:** dispatch to `forward_xpu()`. +- **TPU platform:** dispatch to `forward_tpu()`. +- **OOT platform:** dispatch to `forward_oot()`. This will only be called on OOT platforms. +- **Default:** dispatch to `forward_native()` as a final fallback for all platforms. + +Furthur more, vLLM decides whether enable or disable a `CustomOp` by `compilation_config.custom_ops`. To be specific, if a `CustomOp` is not registered (i.e., use default config), it will be enabled if there is a `all` in `compilation_config.custom_ops` or will be disabled if there is a `none`. + +!!! note + Note that `all` and `none` cannot coexist in `compilation_config.custom_ops`. + +By default, if `compilation_config.backend == "inductor"` and `compilation_config.mode != CompilationMode.NONE`, a `none` will be appended into `compilation_config.custom_ops`, otherwise a `all` will be appended. In other words, this means `CustomOp` will be disabled in some platforms (i.e., those use `inductor` as dafault backend for `torch.compile`) when running with graph mode. In this case, Inductor generates (fused) Triton kernels for those disabled custom ops. + +!!! note + For multi-modal models, vLLM has enforece enabled some custom ops to use device-specific deep-optimized kernels for better performance in ViT part, such as `MMEncoderAttention` and `ApplyRotaryEmb`. We can also pass a `enforce_enable=True` param to the `__init__()` method of the `CustomOp` to enforce enable itself at object-level. + + Note that this `enforce_enable` mechanism will be removed after we adding a separate `compilation_config` for multi-modal part. + +## How to Customise Your Configuration for CustomOp + +vLLM also offers fine-grained control over which custom ops to enable or disable for users, by manually passing a `--compilation_config.custom_ops '["..."]'` when launching a server. + +For example: + +- Use `--compilation_config.custom_ops '["all"]'` to enable all custom ops. +- Use `--compilation_config.custom_ops '["none"]'` to disable all custom ops. +- Use `--compilation_config.custom_ops '["all,-op1"]'` to enable all custom ops except op1 (i.e., prefixed with a `-` means "disable"). +- Use `--compilation_config.custom_ops '["none,+op1,+op2"]'` to only enable op1 and op2 (i.e., prefixed with a `+` means "enable"). + +## Types of Supported CustomOp in vLLM + +| Category | OP Name | OP Class | +|----------|---------|----------| +| Attention | `mm_encoder_attn` | `MMEncoderAttention` | +| Attention | `multi_head_latent_attention` | `MultiHeadLatentAttentionWrapper` | +| Activation | `fatrelu_and_mul` | `FatreluAndMul` | +| Activation | `silu_and_mul` | `SiluAndMul` | +| Activation | `mul_and_silu` | `MulAndSilu` | +| Activation | `gelu_and_mul_sparse` | `GeluAndMulSparse` | +| Activation | `gelu_and_mul` | `GeluAndMul` | +| Activation | `swigluoai_and_mul` | `SwigluOAIAndMul` | +| Activation | `gelu_new` | `NewGELU` | +| Activation | `gelu_fast` | `FastGELU` | +| Activation | `quick_gelu` | `QuickGELU` | +| Activation | `relu2` | `ReLUSquaredActivation` | +| Activation | `xielu` | `XIELU` | +| Conv | `conv2d` | `Conv2dLayer` | +| Conv | `conv3d` | `Conv3dLayer` | +| Conv | `short_conv` | `ShortConv` | +| Embedding | `vocab_parallel_embedding` | `VocabParallelEmbedding` | +| Embedding | `parallel_lm_head` | `ParallelLMHead` | +| Linear | `row_parallel_linear` | `RowParallelLinear` | +| Linear | `column_parallel_linear` | `ColumnParallelLinear` | +| Linear | `replicated_linear` | `ReplicatedLinear` | +| Logits Processor | `logits_processor` | `LogitsProcessor` | +| Mamba | `mamba_mixer` | `MambaMixer` | +| Mamba | `mamba_mixer2` | `MambaMixer2` | +| Mamba | `plamo2_mamba_mixer` | `Plamo2MambaMixer` | +| Mamba | `mixer2_gated_rms_norm` | `Mixer2RMSNormGated` | +| MoE | `fused_moe` | `FusedMoE` | +| MoE | `modular_fused_moe` | `FusedMoEModularMethod` | +| MoE | `unquantized_fused_moe` | `UnquantizedFusedMoEMethod` | +| MoE | `transformers_fused_moe` | `TransformersFusedMoE` | +| MoE | `grouped_topk` | `GroupedTopk` | +| Norm | `rms_norm` | `RMSNorm` | +| Norm | `gemma_rms_norm` | `GemmaRMSNorm` | +| Norm | `rms_norm_gated` | `RMSNormGated` | +| Quantization | `quant_fp8` | `QuantFP8` | +| Rope | `rotary_embedding` | `RotaryEmbeddingBase` | +| Rope | `dual_chunk_rotary_embedding` | `DualChunkRotaryEmbedding` | +| Rope | `apply_rotary_emb` | `ApplyRotaryEmb` | + +## Guidelines for Implementing a New CustomOp + +### Implement a New CustomOp in vLLM + +This part is a tutorial of how to implement a New `CustomOp` in vLLM. + +Steps: + +1. Implement a new op class, which extends from `CustomOp` base class. +2. Add the `@CustomOp.register("op_name")` decorator on this op class to register it into `CustomOp` system. +3. Implement different `forward_xxx()` method according to your needs. + +Taking `MMEncoderAttention` as an example: + +??? code + + ```python + @CustomOp.register("mm_encoder_attn") + class MMEncoderAttention(CustomOp): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float | None = None, + num_kv_heads: int | None = None, + prefix: str = "", + multimodal_config: MultiModalConfig | None = None, + ) -> None: + super().__init__() + # Init... + + def forward_native( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + # Call TORCH_SDPA implementation... + + def forward_cuda( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + # Call FA or TORCH_SDPA implementation... + + def forward_cpu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + # Call TORCH_SDPA implementation... + + def forward_xpu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + # Call FA implementation... + + def forward_tpu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + ) -> torch.Tensor: + # Call PALLAS implementation... + ``` + +### Register a New CustomOp in OOT Device Plugins + +Currently, thanks to [vLLM's hardware-plugin mechanism](./plugin_system.md), there are various OOT device plugins emerging out to enable vLLM seamlessly runs on different hardwares. You can also find more details about this mechanism at [Introducing vLLM Hardware Plugin, Best Practice from Ascend NPU](https://blog.vllm.ai/2025/05/12/hardware-plugin.html). + +- **Official device plugins:** [vllm-ascend](https://github.com/vllm-project/vllm-ascend) (for Huawei Ascend NPU), [vllm-spyre](https://github.com/vllm-project/vllm-spyre) +(for Spyre), [vllm-gaudi](https://github.com/vllm-project/vllm-gaudi) (for Intel Gaudi), [vllm-neuron](https://github.com/vllm-project/vllm-neuron) (for AWS Neuron), [vllm-meta](https://github.com/vllm-project/vllm-metal) (for Apple Silicon), etc. +- **Non-official device plugins:** [vllm-metax](https://github.com/MetaX-MACA/vLLM-metax) (for MetaX GPU), [vllm-kunlun](https://github.com/baidu/vLLM-Kunlun) (for Baidu Kunlun XPU), etc. + +In this case, `CustomOp` can enable these hardware manufacturers to seamlessly replace vLLM's operations with their deep-optimized kernels for specific devices at runtime, by just registering an OOT `CustomOp` and implementing the `forward_oot()` method. + +Now, this part will show you how to register an OOT `CustomOp` for a device plugin. + +Taking `MMEncoderAttention` as an example: + +1. Implement a `CustomMMEncoderAttention` class which extends from `MMEncoderAttention` and implement its `forward_oot()` method. +2. Register your `CustomMMEncoderAttention` into vLLM to replace `MMEncoderAttention`. + +??? code + + ```python + from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention + from vllm.model_executor.custom_op import CustomOp + + + @CustomOp.register_oot("MMEncoderAttention") + class CustomMMEncoderAttention(MMEncoderAttention): + + def __init__(...): + super().__init__(...) + + def forward_oot(...): + # Call optimized device-specific kernels. + ... + ``` + +In this case, a new item `{"MMEncoderAttention": CustomMMEncoderAttention}` will be added into `op_registry_oot`. When initializing a `MMEncoderAttention` op object, if the class name (i.e., `MMEncoderAttention`) is contained in the keys of `op_registry_oot`, vLLM will replace it with our registered class (i.e., `CustomMMEncoderAttention`) and instantiate it. + +After that, when this `MMEncoderAttention` op is called, your `forward_oot()` will be called if it is enabled. Thus, you will get expected performance on your hardwares without directly modify vLLM. + +In addition, you can also register all your `CustomOp` at one place for better management. + +??? code + + ```python + from vllm.model_executor.custom_op import CustomOp + + + REGISTERED_CUSTOM_OPS = { + "CustomOP1": YourCustomOp1, + "CustomOP2": YourCustomOp2, + "CustomOP3": YourCustomOp3, + } + + for op_name, op_cls in REGISTERED_CUSTOM_OPS.items(): + CustomOp.register_oot(_decorated_op_cls=op_cls, name=op_name) + ``` diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index cd527e4198557..56e69541e6b81 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -404,7 +404,8 @@ class CompilationConfig: - 'none,+op1,+op2' to enable only op1 and op2 By default, all custom ops are enabled when running without Inductor and - disabled when running with Inductor: mode>=VLLM_COMPILE and backend="inductor". + disabled when running with Inductor: mode>=CompilationMode.NONE and + backend="inductor". Inductor generates (fused) Triton kernels for disabled custom ops.""" splitting_ops: list[str] | None = None """A list of ops to exclude from cudagraphs, used in piecewise compilation. diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 66250f816f459..371b691759348 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -86,9 +86,12 @@ class CustomOp(nn.Module): # specific backend. Currently, we do not support dynamic dispatching. compilation_config = get_cached_compilation_config() - # CustomOp object can be enforce enabled, e.g., enable device-specific - # kernels in ViT models when enabling graph mode. By default, it will - # follow the compilation_config to determine whether enable itself. + # NOTE(shen-shanshan): CustomOp object can be enforce enabled, e.g., + # enable device-specific kernels in ViT models when enabling graph + # mode. By default, it will follow the compilation_config to determine + # whether enable itself. + # This enforce_enable mechanism will be removed after we adding a + # separate compilation_config for multi-modal part. enabled = self._enforce_enable or self.enabled() if enabled: compilation_config.enabled_custom_ops.update([self.__class__.name])