mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-27 05:31:19 +08:00
Merge b0846c1c2375a4e40ac171700f89fffd88fbec75 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
a5f89df8c1
318
docs/design/custom_op.md
Normal file
318
docs/design/custom_op.md
Normal file
@ -0,0 +1,318 @@
|
||||
# CustomOp
|
||||
|
||||
`CustomOp` is an abstract class used for dispatching the forward method of various operations to the appropriate backend. It also offers a mechanism for both vLLM and OOT (Out-Of-Tree) plugins to register their custom operations.
|
||||
|
||||
This document will introduce how CustomOp works in vLLM and how to implement a new `CustomOp`.
|
||||
|
||||
## How CustomOp Works in vLLM
|
||||
|
||||
`CustomOp` manages two dictionaries of all custom ops (i.e., op classes, indexed by registered name) in its class, for vLLM and OOT plugins respectively.
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
class CustomOp(nn.Module):
|
||||
|
||||
op_registry: dict[str, type["CustomOp"]] = {}
|
||||
op_registry_oot: dict[str, type["CustomOp"]] = {}
|
||||
```
|
||||
|
||||
We can use `@CustomOp.register("op_name")` to register an op class to the `CustomOp` system. After this, the `op_name` and its class will be added into the `op_registry` dictionary. In addition, We can also register an OOT op by `@CustomOp.register_oot("op_name")`. We will introduce this mechanism in detail later.
|
||||
|
||||
When a `CustomOp` is called (i.e., call its `forward()` method), if it is enabled, it will automatically dispatch the forward method to the appropriate backend according to `current_platform`. Otherwise (i.e., it is disabled), it will only call the `forward_native()` method to use PyTorch-native implementation of this forward method.
|
||||
|
||||
- **CPU platform:** dispatch to `forward_cpu()`.
|
||||
- **CUDA platform:** dispatch to `forward_cuda()`.
|
||||
- **ROCm platform:** dispatch to `forward_hip()`. If `forward_hip()` is not implemented, it will use `forward_cuda()` as a fallback.
|
||||
- **XPU platform:** dispatch to `forward_xpu()`.
|
||||
- **TPU platform:** dispatch to `forward_tpu()`.
|
||||
- **OOT platform:** dispatch to `forward_oot()`. This will only be called on OOT platforms.
|
||||
- **Default:** dispatch to `forward_native()` as a final fallback for all platforms.
|
||||
|
||||
!!! note
|
||||
Note that the dispatching logic might not be absolute because of class inheritance. Derived class might override the behavior.
|
||||
|
||||
Furthur more, vLLM decides whether enable or disable a `CustomOp` by `compilation_config.custom_ops`. To be specific, if a `CustomOp` is not registered (i.e., use default config), it will be enabled if there is a `all` in `compilation_config.custom_ops` or will be disabled if there is a `none`.
|
||||
|
||||
!!! note
|
||||
Note that `all` and `none` cannot coexist in `compilation_config.custom_ops`.
|
||||
|
||||
By default, if `compilation_config.backend == "inductor"` and `compilation_config.mode != CompilationMode.NONE`, a `none` will be appended into `compilation_config.custom_ops`, otherwise a `all` will be appended. In other words, this means `CustomOp` will be disabled in some platforms (i.e., those use `inductor` as dafault backend for `torch.compile`) when running with torch compile mode. In this case, Inductor generates (fused) Triton kernels for those disabled custom ops.
|
||||
|
||||
!!! note
|
||||
For multi-modal models, vLLM has enforece enabled some custom ops to use device-specific deep-optimized kernels for better performance in ViT part, such as `MMEncoderAttention` and `ApplyRotaryEmb`. We can also pass a `enforce_enable=True` param to the `__init__()` method of the `CustomOp` to enforce enable itself at object-level.
|
||||
|
||||
Note that this `enforce_enable` mechanism will be removed after we adding a separate `compilation_config` for multi-modal part.
|
||||
|
||||
## How to Customise Your Configuration for CustomOp
|
||||
|
||||
vLLM also offers fine-grained control over which custom ops to enable or disable for users, by manually passing a `--compilation_config.custom_ops '["..."]'` when launching a server.
|
||||
|
||||
For example:
|
||||
|
||||
- Use `--compilation_config.custom_ops '["all"]'` to enable all custom ops.
|
||||
- Use `--compilation_config.custom_ops '["none"]'` to disable all custom ops.
|
||||
- Use `--compilation_config.custom_ops '["all,-op1"]'` to enable all custom ops except op1 (i.e., prefixed with a `-` means "disable").
|
||||
- Use `--compilation_config.custom_ops '["none,+op1,+op2"]'` to only enable op1 and op2 (i.e., prefixed with a `+` means "enable").
|
||||
|
||||
## Types of Supported CustomOp in vLLM
|
||||
|
||||
**1. Attention:**
|
||||
|
||||
```python
|
||||
--8<-- "vllm/attention/layers/mm_encoder_attention.py:mm_encoder_attn"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/mla.py:multi_head_latent_attention"
|
||||
```
|
||||
|
||||
**2. Activation:**
|
||||
|
||||
```python
|
||||
--8<-- "vllm/model_executor/layers/activation.py:silu_and_mul"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/activation.py:mul_and_silu"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/activation.py:gelu_new"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/activation.py:gelu_fast"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/activation.py:quick_gelu"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/activation.py:gelu_and_mul"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/activation.py:gelu_and_mul_sparse"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/activation.py:relu2"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/activation.py:xielu"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/activation.py:swigluoai_and_mul"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/activation.py:fatrelu_and_mul"
|
||||
```
|
||||
|
||||
**3. MM-Conv:**
|
||||
|
||||
```python
|
||||
--8<-- "vllm/model_executor/layers/conv.py:conv2d"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/conv.py:conv3d"
|
||||
```
|
||||
|
||||
**4. Embedding:**
|
||||
|
||||
```python
|
||||
--8<-- "vllm/model_executor/layers/vocab_parallel_embedding.py:vocab_parallel_embedding"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/vocab_parallel_embedding.py:parallel_lm_head"
|
||||
```
|
||||
|
||||
**5. Linear:**
|
||||
|
||||
```python
|
||||
--8<-- "vllm/model_executor/layers/linear.py:row_parallel_linear"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/linear.py:column_parallel_linear"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/linear.py:replicated_linear"
|
||||
```
|
||||
|
||||
**6. Logits Processor:**
|
||||
|
||||
```python
|
||||
--8<-- "vllm/model_executor/layers/logits_processor.py:logits_processor"
|
||||
```
|
||||
|
||||
**7. Mamba:**
|
||||
|
||||
```python
|
||||
--8<-- "vllm/model_executor/layers/mamba/mamba_mixer.py:mamba_mixer"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/mamba/mamba_mixer2.py:mamba_mixer2"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/mamba/mamba_mixer2.py:mixer2_gated_rms_norm"
|
||||
|
||||
--8<-- "vllm/model_executor/models/plamo2.py:plamo2_mamba_mixer"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/mamba/short_conv.py:short_conv"
|
||||
```
|
||||
|
||||
**8. MoE:**
|
||||
|
||||
```python
|
||||
--8<-- "vllm/model_executor/layers/fused_moe/layer.py:fused_moe"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py:modular_fused_moe"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py:unquantized_fused_moe"
|
||||
|
||||
--8<-- "vllm/model_executor/models/transformers/moe.py:transformers_fused_moe"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/fused_moe/fused_moe.py:grouped_topk"
|
||||
```
|
||||
|
||||
**9. Norm:**
|
||||
|
||||
```python
|
||||
--8<-- "vllm/model_executor/layers/layernorm.py:rms_norm"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/layernorm.py:rms_norm_gated"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/layernorm.py:gemma_rms_norm"
|
||||
```
|
||||
|
||||
**10. Quantization:**
|
||||
|
||||
```python
|
||||
--8<-- "vllm/model_executor/layers/quantization/input_quant_fp8.py:quant_fp8"
|
||||
```
|
||||
|
||||
**11. Rope:**
|
||||
|
||||
```python
|
||||
--8<-- "vllm/model_executor/layers/rotary_embedding/base.py:rotary_embedding"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py:dual_chunk_rotary_embedding"
|
||||
|
||||
--8<-- "vllm/model_executor/layers/rotary_embedding/common.py:apply_rotary_emb"
|
||||
```
|
||||
|
||||
## Guidelines for Implementing a New CustomOp
|
||||
|
||||
### Implement a New CustomOp in vLLM
|
||||
|
||||
This part is a tutorial of how to implement a New `CustomOp` in vLLM.
|
||||
|
||||
Steps:
|
||||
|
||||
1. Implement a new op class, which extends from `CustomOp` base class.
|
||||
2. Add the `@CustomOp.register("op_name")` decorator on this op class to register it into `CustomOp` system.
|
||||
3. Implement different `forward_xxx()` method according to your needs.
|
||||
|
||||
Taking `MMEncoderAttention` as an example:
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
@CustomOp.register("mm_encoder_attn")
|
||||
class MMEncoderAttention(CustomOp):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float | None = None,
|
||||
num_kv_heads: int | None = None,
|
||||
prefix: str = "",
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Init...
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
# Call TORCH_SDPA implementation...
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
# Call FA or TORCH_SDPA implementation...
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
# Call TORCH_SDPA implementation...
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
# Call FA implementation...
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
# Call PALLAS implementation...
|
||||
```
|
||||
|
||||
### Register a New CustomOp in OOT Device Plugins
|
||||
|
||||
Currently, thanks to [vLLM's hardware-plugin mechanism](./plugin_system.md), there are various OOT device plugins emerging out to enable vLLM seamlessly runs on different hardwares. You can also find more details about this mechanism at [Introducing vLLM Hardware Plugin, Best Practice from Ascend NPU](https://blog.vllm.ai/2025/05/12/hardware-plugin.html).
|
||||
|
||||
- **Official device plugins:** [vllm-ascend](https://github.com/vllm-project/vllm-ascend) (for Huawei Ascend NPU), [vllm-spyre](https://github.com/vllm-project/vllm-spyre)
|
||||
(for Spyre), [vllm-gaudi](https://github.com/vllm-project/vllm-gaudi) (for Intel Gaudi), [vllm-neuron](https://github.com/vllm-project/vllm-neuron) (for AWS Neuron), [vllm-meta](https://github.com/vllm-project/vllm-metal) (for Apple Silicon), etc.
|
||||
- **Non-official device plugins:** [vllm-metax](https://github.com/MetaX-MACA/vLLM-metax) (for MetaX GPU), [vllm-kunlun](https://github.com/baidu/vLLM-Kunlun) (for Baidu Kunlun XPU), etc.
|
||||
|
||||
In this case, `CustomOp` can enable these hardware manufacturers to seamlessly replace vLLM's operations with their deep-optimized kernels for specific devices at runtime, by just registering an OOT `CustomOp` and implementing the `forward_oot()` method.
|
||||
|
||||
Now, this part will show you how to register an OOT `CustomOp` for a device plugin.
|
||||
|
||||
Taking `MMEncoderAttention` as an example:
|
||||
|
||||
1. Implement a `CustomMMEncoderAttention` class which extends from `MMEncoderAttention` and implement its `forward_oot()` method.
|
||||
2. Register your `CustomMMEncoderAttention` into vLLM to replace `MMEncoderAttention`.
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
|
||||
@CustomOp.register_oot("MMEncoderAttention")
|
||||
class CustomMMEncoderAttention(MMEncoderAttention):
|
||||
|
||||
def __init__(...):
|
||||
super().__init__(...)
|
||||
|
||||
def forward_oot(...):
|
||||
# Call optimized device-specific kernels.
|
||||
...
|
||||
```
|
||||
|
||||
In this case, a new item `{"MMEncoderAttention": CustomMMEncoderAttention}` will be added into `op_registry_oot`. When initializing a `MMEncoderAttention` op object, if the class name (i.e., `MMEncoderAttention`) is contained in the keys of `op_registry_oot`, vLLM will replace it with our registered class (i.e., `CustomMMEncoderAttention`) and instantiate it.
|
||||
|
||||
After that, when this `MMEncoderAttention` op is called, your `forward_oot()` will be called if it is enabled. Thus, you will get expected performance on your hardwares without directly modify vLLM.
|
||||
|
||||
In addition, you can also register all your `CustomOp` at one place for better management.
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
|
||||
REGISTERED_CUSTOM_OPS = {
|
||||
"CustomOP1": YourCustomOp1,
|
||||
"CustomOP2": YourCustomOp2,
|
||||
"CustomOP3": YourCustomOp3,
|
||||
}
|
||||
|
||||
for op_name, op_cls in REGISTERED_CUSTOM_OPS.items():
|
||||
CustomOp.register_oot(_decorated_op_cls=op_cls, name=op_name)
|
||||
```
|
||||
@ -18,10 +18,13 @@ from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# --8<-- [start:mm_encoder_attn]
|
||||
@CustomOp.register("mm_encoder_attn")
|
||||
class MMEncoderAttention(CustomOp):
|
||||
"""Multi-headed attention without any cache, used for multimodal encoder."""
|
||||
|
||||
# --8<-- [end:mm_encoder_attn]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
|
||||
@ -404,7 +404,8 @@ class CompilationConfig:
|
||||
- 'none,+op1,+op2' to enable only op1 and op2
|
||||
|
||||
By default, all custom ops are enabled when running without Inductor and
|
||||
disabled when running with Inductor: mode>=VLLM_COMPILE and backend="inductor".
|
||||
disabled when running with Inductor: mode>CompilationMode.NONE and
|
||||
backend="inductor".
|
||||
Inductor generates (fused) Triton kernels for disabled custom ops."""
|
||||
splitting_ops: list[str] | None = None
|
||||
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
|
||||
|
||||
@ -86,9 +86,12 @@ class CustomOp(nn.Module):
|
||||
# specific backend. Currently, we do not support dynamic dispatching.
|
||||
compilation_config = get_cached_compilation_config()
|
||||
|
||||
# CustomOp object can be enforce enabled, e.g., enable device-specific
|
||||
# kernels in ViT models when enabling graph mode. By default, it will
|
||||
# follow the compilation_config to determine whether enable itself.
|
||||
# NOTE(shen-shanshan): CustomOp object can be enforce enabled, e.g.,
|
||||
# enable device-specific kernels in ViT models when enabling graph
|
||||
# mode. By default, it will follow the compilation_config to determine
|
||||
# whether enable itself.
|
||||
# This enforce_enable mechanism will be removed after we adding a
|
||||
# separate compilation_config for multi-modal part.
|
||||
enabled = self._enforce_enable or self.enabled()
|
||||
if enabled:
|
||||
compilation_config.enabled_custom_ops.update([self.__class__.name])
|
||||
|
||||
@ -22,6 +22,7 @@ from vllm.utils.collection_utils import LazyDict
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# --8<-- [start:fatrelu_and_mul]
|
||||
@CustomOp.register("fatrelu_and_mul")
|
||||
class FatreluAndMul(CustomOp):
|
||||
"""An activation function for FATReLU.
|
||||
@ -35,6 +36,8 @@ class FatreluAndMul(CustomOp):
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
# --8<-- [end:fatrelu_and_mul]
|
||||
|
||||
def __init__(self, threshold: float = 0.0):
|
||||
super().__init__()
|
||||
self.threshold = threshold
|
||||
@ -58,6 +61,7 @@ class FatreluAndMul(CustomOp):
|
||||
return out
|
||||
|
||||
|
||||
# --8<-- [start:silu_and_mul]
|
||||
@CustomOp.register("silu_and_mul")
|
||||
class SiluAndMul(CustomOp):
|
||||
"""An activation function for SwiGLU.
|
||||
@ -69,6 +73,8 @@ class SiluAndMul(CustomOp):
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
# --8<-- [end:silu_and_mul]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike():
|
||||
@ -101,6 +107,7 @@ class SiluAndMul(CustomOp):
|
||||
return out
|
||||
|
||||
|
||||
# --8<-- [start:mul_and_silu]
|
||||
@CustomOp.register("mul_and_silu")
|
||||
class MulAndSilu(CustomOp):
|
||||
"""An activation function for SwiGLU.
|
||||
@ -112,6 +119,8 @@ class MulAndSilu(CustomOp):
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
# --8<-- [end:mul_and_silu]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike():
|
||||
@ -139,6 +148,7 @@ class MulAndSilu(CustomOp):
|
||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
# --8<-- [start:gelu_and_mul_sparse]
|
||||
@CustomOp.register("gelu_and_mul_sparse")
|
||||
class GeluAndMulSparse(CustomOp):
|
||||
"""An activation function for GeluAndMulSparse.
|
||||
@ -153,6 +163,8 @@ class GeluAndMulSparse(CustomOp):
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
# --8<-- [end:gelu_and_mul_sparse]
|
||||
|
||||
def __init__(self, activation_sparsity: float, approximate: str = "none"):
|
||||
super().__init__()
|
||||
# Gelu.
|
||||
@ -195,6 +207,7 @@ class GeluAndMulSparse(CustomOp):
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
# --8<-- [start:gelu_and_mul]
|
||||
@CustomOp.register("gelu_and_mul")
|
||||
class GeluAndMul(CustomOp):
|
||||
"""An activation function for GeGLU.
|
||||
@ -206,6 +219,8 @@ class GeluAndMul(CustomOp):
|
||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||
"""
|
||||
|
||||
# --8<-- [end:gelu_and_mul]
|
||||
|
||||
def __init__(self, approximate: str = "none"):
|
||||
super().__init__()
|
||||
self.approximate = approximate
|
||||
@ -257,9 +272,12 @@ class GeluAndMul(CustomOp):
|
||||
return f"approximate={repr(self.approximate)}"
|
||||
|
||||
|
||||
# --8<-- [start:swigluoai_and_mul]
|
||||
@CustomOp.register("swigluoai_and_mul")
|
||||
class SwigluOAIAndMul(CustomOp):
|
||||
# https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
|
||||
# --8<-- [end:swigluoai_and_mul]
|
||||
|
||||
def __init__(self, alpha: float = 1.702, limit: float = 7.0):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
@ -286,8 +304,11 @@ class SwigluOAIAndMul(CustomOp):
|
||||
return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"
|
||||
|
||||
|
||||
# --8<-- [start:gelu_new]
|
||||
@CustomOp.register("gelu_new")
|
||||
class NewGELU(CustomOp):
|
||||
# --8<-- [end:gelu_new]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||
@ -311,8 +332,11 @@ class NewGELU(CustomOp):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
# --8<-- [start:gelu_fast]
|
||||
@CustomOp.register("gelu_fast")
|
||||
class FastGELU(CustomOp):
|
||||
# --8<-- [end:gelu_fast]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||
@ -335,9 +359,12 @@ class FastGELU(CustomOp):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
# --8<-- [start:quick_gelu]
|
||||
@CustomOp.register("quick_gelu")
|
||||
class QuickGELU(CustomOp):
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
|
||||
# --8<-- [end:quick_gelu]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||
@ -365,12 +392,15 @@ class QuickGELU(CustomOp):
|
||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
# --8<-- [start:relu2]
|
||||
@CustomOp.register("relu2")
|
||||
class ReLUSquaredActivation(CustomOp):
|
||||
"""
|
||||
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
||||
"""
|
||||
|
||||
# --8<-- [end:relu2]
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return torch.square(F.relu(x))
|
||||
@ -380,6 +410,7 @@ class ReLUSquaredActivation(CustomOp):
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
# --8<-- [start:xielu]
|
||||
@CustomOp.register("xielu")
|
||||
class XIELU(CustomOp):
|
||||
"""
|
||||
@ -388,6 +419,8 @@ class XIELU(CustomOp):
|
||||
Otherwise, we emit a single warning and use xIELU Python
|
||||
"""
|
||||
|
||||
# --8<-- [end:xielu]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alpha_p_init: float = 0.8,
|
||||
|
||||
@ -105,10 +105,13 @@ class ConvLayerBase(CustomOp):
|
||||
return s
|
||||
|
||||
|
||||
# --8<-- [start:conv2d]
|
||||
@CustomOp.register("conv2d")
|
||||
class Conv2dLayer(ConvLayerBase):
|
||||
"""Conv layer with Conv2d."""
|
||||
|
||||
# --8<-- [end:conv2d]
|
||||
|
||||
num_dim = 2
|
||||
|
||||
def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -204,10 +207,13 @@ class CausalConv2dLayer(Conv2dLayer):
|
||||
return x
|
||||
|
||||
|
||||
# --8<-- [start:conv3d]
|
||||
@CustomOp.register("conv3d")
|
||||
class Conv3dLayer(ConvLayerBase):
|
||||
"""Conv layer with Conv3d."""
|
||||
|
||||
# --8<-- [end:conv3d]
|
||||
|
||||
num_dim = 3
|
||||
|
||||
def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -1283,10 +1283,13 @@ def grouped_topk(
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
# --8<-- [start:grouped_topk]
|
||||
@CustomOp.register("grouped_topk")
|
||||
class GroupedTopk(CustomOp):
|
||||
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
|
||||
|
||||
# --8<-- [end:grouped_topk]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
topk: int,
|
||||
|
||||
@ -20,8 +20,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# --8<-- [start:modular_fused_moe]
|
||||
@CustomOp.register("modular_fused_moe")
|
||||
class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
# --8<-- [end:modular_fused_moe]
|
||||
|
||||
def __init__(
|
||||
self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel
|
||||
):
|
||||
|
||||
@ -297,6 +297,7 @@ def maybe_roundup_hidden_size(
|
||||
return hidden_size
|
||||
|
||||
|
||||
# --8<-- [start:fused_moe]
|
||||
@CustomOp.register("fused_moe")
|
||||
class FusedMoE(CustomOp):
|
||||
"""FusedMoE layer for MoE models.
|
||||
@ -320,6 +321,8 @@ class FusedMoE(CustomOp):
|
||||
enable_eplb: Whether to enable expert parallelism load balancer.
|
||||
"""
|
||||
|
||||
# --8<-- [end:fused_moe]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int, # Global number of experts
|
||||
|
||||
@ -46,10 +46,13 @@ else:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# --8<-- [start:unquantized_fused_moe]
|
||||
@CustomOp.register("unquantized_fused_moe")
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
# --8<-- [end:unquantized_fused_moe]
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
|
||||
|
||||
@ -88,6 +88,7 @@ def dispatch_rocm_rmsnorm_func(
|
||||
return rms_norm
|
||||
|
||||
|
||||
# --8<-- [start:rms_norm]
|
||||
@CustomOp.register("rms_norm")
|
||||
class RMSNorm(CustomOp):
|
||||
"""Root mean square normalization.
|
||||
@ -96,6 +97,8 @@ class RMSNorm(CustomOp):
|
||||
Refer to https://arxiv.org/abs/1910.07467
|
||||
"""
|
||||
|
||||
# --8<-- [end:rms_norm]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@ -253,6 +256,7 @@ class RMSNorm(CustomOp):
|
||||
return s
|
||||
|
||||
|
||||
# --8<-- [start:gemma_rms_norm]
|
||||
@CustomOp.register("gemma_rms_norm")
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
"""RMS normalization for Gemma.
|
||||
@ -262,6 +266,8 @@ class GemmaRMSNorm(CustomOp):
|
||||
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
|
||||
"""
|
||||
|
||||
# --8<-- [end:gemma_rms_norm]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@ -321,6 +327,7 @@ class GemmaRMSNorm(CustomOp):
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
# --8<-- [start:rms_norm_gated]
|
||||
@CustomOp.register("rms_norm_gated")
|
||||
class RMSNormGated(CustomOp):
|
||||
"""RMS Normalization with optional gating.
|
||||
@ -331,6 +338,8 @@ class RMSNormGated(CustomOp):
|
||||
- Optional gating with SiLU activation
|
||||
"""
|
||||
|
||||
# --8<-- [end:rms_norm_gated]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
|
||||
@ -296,6 +296,7 @@ class LinearBase(CustomOp):
|
||||
param.tp_size = self.tp_size
|
||||
|
||||
|
||||
# --8<-- [start:replicated_linear]
|
||||
@CustomOp.register("replicated_linear")
|
||||
class ReplicatedLinear(LinearBase):
|
||||
"""Replicated linear layer.
|
||||
@ -313,6 +314,8 @@ class ReplicatedLinear(LinearBase):
|
||||
disable_tp: Take no effect for replicated linear layers.
|
||||
"""
|
||||
|
||||
# --8<-- [end:replicated_linear]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
@ -413,6 +416,7 @@ class ReplicatedLinear(LinearBase):
|
||||
return s
|
||||
|
||||
|
||||
# --8<-- [start:column_parallel_linear]
|
||||
@CustomOp.register("column_parallel_linear")
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""Linear layer with column parallelism.
|
||||
@ -440,6 +444,8 @@ class ColumnParallelLinear(LinearBase):
|
||||
disable_tp: If true, weights matrix won't be sharded through tp rank.
|
||||
"""
|
||||
|
||||
# --8<-- [end:column_parallel_linear]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
@ -1276,6 +1282,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
# --8<-- [start:row_parallel_linear]
|
||||
@CustomOp.register("row_parallel_linear")
|
||||
class RowParallelLinear(LinearBase):
|
||||
"""Linear layer with row parallelism.
|
||||
@ -1310,6 +1317,8 @@ class RowParallelLinear(LinearBase):
|
||||
disable_tp: If true, weights matrix won't be sharded through tp rank.
|
||||
"""
|
||||
|
||||
# --8<-- [end:row_parallel_linear]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
# --8<-- [start:logits_processor]
|
||||
@CustomOp.register("logits_processor")
|
||||
class LogitsProcessor(CustomOp):
|
||||
"""Process logits and apply logits processors from sampling metadata.
|
||||
@ -23,6 +24,8 @@ class LogitsProcessor(CustomOp):
|
||||
3. Apply logits processors (if any).
|
||||
"""
|
||||
|
||||
# --8<-- [end:logits_processor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
|
||||
@ -39,6 +39,7 @@ from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
# --8<-- [start:mamba_mixer]
|
||||
@CustomOp.register("mamba_mixer")
|
||||
class MambaMixer(MambaBase, CustomOp):
|
||||
"""
|
||||
@ -51,6 +52,8 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
**selective** state spaces)
|
||||
"""
|
||||
|
||||
# --8<-- [end:mamba_mixer]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
|
||||
@ -49,8 +49,11 @@ from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
|
||||
# --8<-- [start:mixer2_gated_rms_norm]
|
||||
@CustomOp.register("mixer2_gated_rms_norm")
|
||||
class Mixer2RMSNormGated(CustomOp):
|
||||
# --8<-- [end:mixer2_gated_rms_norm]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
full_hidden_size: int,
|
||||
@ -214,6 +217,7 @@ def mamba_v2_sharded_weight_loader(
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
# --8<-- [start:mamba_mixer2]
|
||||
@CustomOp.register("mamba_mixer2")
|
||||
class MambaMixer2(MambaBase, CustomOp):
|
||||
"""
|
||||
@ -226,6 +230,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
**selective** state spaces)
|
||||
"""
|
||||
|
||||
# --8<-- [end:mamba_mixer2]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
|
||||
@ -27,8 +27,11 @@ from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata
|
||||
|
||||
|
||||
# --8<-- [start:short_conv]
|
||||
@CustomOp.register("short_conv")
|
||||
class ShortConv(MambaBase, CustomOp):
|
||||
# --8<-- [end:short_conv]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
|
||||
@ -29,6 +29,7 @@ class MLAModules:
|
||||
indexer_rotary_emb: torch.nn.Module | None = None
|
||||
|
||||
|
||||
# --8<-- [start:multi_head_latent_attention]
|
||||
@CustomOp.register("multi_head_latent_attention")
|
||||
class MultiHeadLatentAttentionWrapper(CustomOp):
|
||||
"""MLA layer registered as CustomOp to allow OOT backends to add
|
||||
@ -47,6 +48,8 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
# --8<-- [end:multi_head_latent_attention]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
|
||||
@ -19,6 +19,7 @@ _FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min
|
||||
_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
|
||||
|
||||
|
||||
# --8<-- [start:quant_fp8]
|
||||
@CustomOp.register("quant_fp8")
|
||||
class QuantFP8(CustomOp):
|
||||
"""
|
||||
@ -26,6 +27,8 @@ class QuantFP8(CustomOp):
|
||||
This CustomOp supports both static and dynamic quantization.
|
||||
"""
|
||||
|
||||
# --8<-- [end:quant_fp8]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
static: bool,
|
||||
|
||||
@ -10,10 +10,13 @@ from vllm.model_executor.custom_op import CustomOp
|
||||
from .common import ApplyRotaryEmb
|
||||
|
||||
|
||||
# --8<-- [start:rotary_embedding]
|
||||
@CustomOp.register("rotary_embedding")
|
||||
class RotaryEmbeddingBase(CustomOp):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
# --8<-- [end:rotary_embedding]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
|
||||
@ -118,8 +118,11 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
# --8<-- [start:apply_rotary_emb]
|
||||
@CustomOp.register("apply_rotary_emb")
|
||||
class ApplyRotaryEmb(CustomOp):
|
||||
# --8<-- [end:apply_rotary_emb]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enforce_enable: bool = False,
|
||||
|
||||
@ -9,10 +9,13 @@ from vllm.model_executor.custom_op import CustomOp
|
||||
from .common import rotate_gptj, rotate_neox
|
||||
|
||||
|
||||
# --8<-- [start:dual_chunk_rotary_embedding]
|
||||
@CustomOp.register("dual_chunk_rotary_embedding")
|
||||
class DualChunkRotaryEmbedding(CustomOp):
|
||||
"""Rotary positional embedding for Dual Chunk Attention."""
|
||||
|
||||
# --8<-- [end:dual_chunk_rotary_embedding]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
|
||||
@ -181,6 +181,7 @@ def get_masked_input_and_mask(
|
||||
return input_, ~vocab_mask
|
||||
|
||||
|
||||
# --8<-- [start:vocab_parallel_embedding]
|
||||
@CustomOp.register("vocab_parallel_embedding")
|
||||
class VocabParallelEmbedding(CustomOp):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
@ -221,6 +222,8 @@ class VocabParallelEmbedding(CustomOp):
|
||||
prefix: full name of the layer in the state dict
|
||||
""" # noqa: E501
|
||||
|
||||
# --8<-- [end:vocab_parallel_embedding]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
@ -492,6 +495,7 @@ class VocabParallelEmbedding(CustomOp):
|
||||
return s
|
||||
|
||||
|
||||
# --8<-- [start:parallel_lm_head]
|
||||
@CustomOp.register("parallel_lm_head")
|
||||
class ParallelLMHead(VocabParallelEmbedding):
|
||||
"""Parallelized LM head.
|
||||
@ -509,6 +513,8 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
padding_size: padding size for the vocabulary.
|
||||
"""
|
||||
|
||||
# --8<-- [end:parallel_lm_head]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
|
||||
@ -97,8 +97,11 @@ def is_mamba(config: Plamo2Config, i: int) -> bool:
|
||||
# Adapted from:
|
||||
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
|
||||
# transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
@CustomOp.register(name="plamo2_mamba_mixer")
|
||||
# --8<-- [start:plamo2_mamba_mixer]
|
||||
@CustomOp.register("plamo2_mamba_mixer")
|
||||
class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
# --8<-- [end:plamo2_mamba_mixer]
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, *, prefix: str = "", **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
|
||||
@ -37,10 +37,13 @@ if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
# --8<-- [start:transformers_fused_moe]
|
||||
@CustomOp.register("transformers_fused_moe")
|
||||
class TransformersFusedMoE(FusedMoE):
|
||||
"""Custom FusedMoE for the Transformers modeling backend."""
|
||||
|
||||
# --8<-- [end:transformers_fused_moe]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._topk_ids: torch.Tensor = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user