mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:28:42 +08:00
[CI/Build] Fix pre-commit failure in docs (#21897)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
e3bc17ceea
commit
16f3250527
@ -1,6 +1,7 @@
|
||||
# Fused MoE Modular Kernel
|
||||
|
||||
## Introduction
|
||||
|
||||
FusedMoEModularKernel is implemented [here](gh-file:/vllm/model_executor/layers/fused_moe/modular_kernel.py)
|
||||
|
||||
Based on the format of the input activations, FusedMoE implementations are broadly classified into 2 types.
|
||||
@ -31,7 +32,8 @@ As can be seen from the diagrams, there are a lot of operations and there can be
|
||||
|
||||
The rest of the document will focus on the Contiguous / Non-Batched case. Extrapolating to the Batched case should be straight-forward.
|
||||
|
||||
## ModularKernel Components:
|
||||
## ModularKernel Components
|
||||
|
||||
FusedMoEModularKernel splits the FusedMoE operation into 3 parts,
|
||||
|
||||
1. TopKWeightAndReduce
|
||||
@ -39,6 +41,7 @@ FusedMoEModularKernel splits the FusedMoE operation into 3 parts,
|
||||
3. FusedMoEPermuteExpertsUnpermute
|
||||
|
||||
### TopKWeightAndReduce
|
||||
|
||||
The TopK Weight Application and Reduction components happen right after the Unpermute operation and before the All2All Combine. Note that the `FusedMoEPermuteExpertsUnpermute` is responsible for the Unpermute and `FusedMoEPrepareAndFinalize` is responsible for the All2All Combine. There is value in doing the TopK Weight Application and Reduction in the `FusedMoEPermuteExpertsUnpermute`. But some implementations choose to do it `FusedMoEPrepareAndFinalize`. In order to enable this flexibility, we have a TopKWeightAndReduce abstract class.
|
||||
|
||||
Please find the implementations of TopKWeightAndReduce [here](gh-file:vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py).
|
||||
@ -50,12 +53,14 @@ The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExperts
|
||||
* `FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl` method returns `TopKWeightAndReduceContiguous` / `TopKWeightAndReduceNaiveBatched` / `TopKWeightAndReduceDelegate` if the `FusedMoEPermuteExpertsUnpermute` implementation needs the `FusedMoEPrepareAndFinalize::finalize()` to do the weight application and reduction.
|
||||
|
||||
### FusedMoEPrepareAndFinalize
|
||||
|
||||
The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare` and `finalize` functions.
|
||||
The `prepare` function is responsible for input activation Quantization and All2All Dispatch. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section)
|
||||
|
||||

|
||||
|
||||
### FusedMoEPermuteExpertsUnpermute
|
||||
|
||||
The `FusedMoEPermuteExpertsUnpermute` class is where the crux of the MoE operations happen. The `FusedMoEPermuteExpertsUnpermute` abstract class exposes a few important functions,
|
||||
|
||||
* apply()
|
||||
@ -63,6 +68,7 @@ The `FusedMoEPermuteExpertsUnpermute` class is where the crux of the MoE operati
|
||||
* finalize_weight_and_reduce_impl()
|
||||
|
||||
#### apply()
|
||||
|
||||
The `apply` method is where the implementations perform
|
||||
|
||||
* Permute
|
||||
@ -74,50 +80,56 @@ The `apply` method is where the implementations perform
|
||||
* Maybe TopK Weight Application + Reduction
|
||||
|
||||
#### workspace_shapes()
|
||||
|
||||
The core FusedMoE implementation performs a series of operations. It would be inefficient to create output memory for each of these operations separately. To that effect, implementations are required to declare 2 workspace shapes, the workspace datatype and the FusedMoE output shape as outputs of the workspace_shapes() method. This information is used to allocate the workspace tensors and the output tensor in `FusedMoEModularKernel::forward()` and passed on to the `FusedMoEPermuteExpertsUnpermute::apply()` method. The workspaces could then be used as intermediate buffers in the FusedMoE implementation.
|
||||
|
||||
#### finalize_weight_and_reduce_impl()
|
||||
|
||||
It is sometimes efficient to perform TopK weight application and Reduction inside the `FusedMoEPermuteExpertsUnpermute::apply()`. Find an example [here](https://github.com/vllm-project/vllm/pull/20228). We have a `TopKWeightAndReduce` abstract class to facilitate such implementations. Please refer to the TopKWeightAndReduce section.
|
||||
`FusedMoEPermuteExpertsUnpermute::finalize_weight_and_reduce_impl()` returns the `TopKWeightAndReduce` object that the implementation wants the `FusedMoEPrepareAndFinalize::finalize()` to use.
|
||||
|
||||

|
||||
|
||||
### FusedMoEModularKernel
|
||||
|
||||
`FusedMoEModularKernel` is composed of the `FusedMoEPrepareAndFinalize` and `FusedMoEPermuteExpertsUnpermute` objects.
|
||||
`FusedMoEModularKernel` pseudocode/sketch,
|
||||
|
||||
```
|
||||
FusedMoEModularKernel::__init__(self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute):
|
||||
```py
|
||||
class FusedMoEModularKernel:
|
||||
def __init__(self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
|
||||
FusedMoEModularKernel::forward(self, DP_A):
|
||||
def forward(self, DP_A):
|
||||
|
||||
Aq, A_scale, _, _, _ = self.prepare_finalize.prepare(DP_A, ...)
|
||||
Aq, A_scale, _, _, _ = self.prepare_finalize.prepare(DP_A, ...)
|
||||
|
||||
workspace13_shape, workspace2_shape, _, _ = self.fused_experts.workspace_shapes(...)
|
||||
workspace13_shape, workspace2_shape, _, _ = self.fused_experts.workspace_shapes(...)
|
||||
|
||||
# allocate workspaces
|
||||
workspace_13 = torch.empty(workspace13_shape, ...)
|
||||
workspace_2 = torch.empty(workspace2_shape, ...)
|
||||
# allocate workspaces
|
||||
workspace_13 = torch.empty(workspace13_shape, ...)
|
||||
workspace_2 = torch.empty(workspace2_shape, ...)
|
||||
|
||||
# execute fused_experts
|
||||
fe_out = self.fused_experts.apply(Aq, A_scale, workspace13, workspace2, ...)
|
||||
# execute fused_experts
|
||||
fe_out = self.fused_experts.apply(Aq, A_scale, workspace13, workspace2, ...)
|
||||
|
||||
# war_impl is an object of type TopKWeightAndReduceNoOp if the fused_experts implementations performs the TopK Weight Application and Reduction.
|
||||
war_impl = self.fused_experts.finalize_weight_and_reduce_impl()
|
||||
# war_impl is an object of type TopKWeightAndReduceNoOp if the fused_experts implementations
|
||||
# performs the TopK Weight Application and Reduction.
|
||||
war_impl = self.fused_experts.finalize_weight_and_reduce_impl()
|
||||
|
||||
output = self.prepare_finalize.finalize(fe_out, war_impl,...)
|
||||
|
||||
return output
|
||||
output = self.prepare_finalize.finalize(fe_out, war_impl,...)
|
||||
|
||||
return output
|
||||
```
|
||||
|
||||
## How-To
|
||||
|
||||
### How To Add a FusedMoEPrepareAndFinalize Type
|
||||
|
||||
Typically a FusedMoEPrepareAndFinalize type is backed by an All2All Dispatch & Combine implementation / kernel. For example,
|
||||
|
||||
* PplxPrepareAndFinalize type is backed by Pplx All2All kernels,
|
||||
@ -125,9 +137,11 @@ Typically a FusedMoEPrepareAndFinalize type is backed by an All2All Dispatch & C
|
||||
* DeepEPLLPrepareAndFinalize type is backed by DeepEP Low-Latency All2All kernels.
|
||||
|
||||
#### Step 1: Add an All2All manager
|
||||
|
||||
The purpose of the All2All Manager is to setup the All2All kernel implementations. The `FusedMoEPrepareAndFinalize` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](gh-file:vllm/distributed/device_communicators/all2all.py).
|
||||
|
||||
#### Step 2: Add a FusedMoEPrepareAndFinalize Type
|
||||
|
||||
This section describes the significance of the various functions exposed by the `FusedMoEPrepareAndFinalize` abstract class.
|
||||
|
||||
`FusedMoEPrepareAndFinalize::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked.
|
||||
@ -145,6 +159,7 @@ This section describes the significance of the various functions exposed by the
|
||||
We suggest picking an already existing `FusedMoEPrepareAndFinalize` implementation that matches your All2All implementation closely and using it as a reference.
|
||||
|
||||
### How To Add a FusedMoEPermuteExpertsUnpermute Type
|
||||
|
||||
FusedMoEPermuteExpertsUnpermute performs the core of the FusedMoE operations. The various functions exposed by the abstract class and their significance is as follows,
|
||||
|
||||
`FusedMoEPermuteExpertsUnpermute::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format.
|
||||
@ -159,12 +174,14 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking
|
||||
`FusedMoEPermuteExpertsUnpermute::apply`: Refer to `FusedMoEPermuteExpertsUnpermute` section above.
|
||||
|
||||
### FusedMoEModularKernel Initialization
|
||||
|
||||
`FusedMoEMethodBase` class has 2 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are,
|
||||
|
||||
* select_gemm_impl, and
|
||||
* init_prepare_finalize
|
||||
|
||||
#### select_gemm_impl
|
||||
|
||||
The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object.
|
||||
Please refer to the implementations in,
|
||||
|
||||
@ -176,12 +193,14 @@ Please refer to the implementations in,
|
||||
dervied classes.
|
||||
|
||||
#### init_prepare_finalize
|
||||
|
||||
Based on the input and env settings, the `init_prepare_finalize` method creates the appropriate `FusedMoEPrepareAndFinalize` object. The method then queries `select_gemm_impl` for the appropriate `FusedMoEPermuteExpertsUnpermute` object and builds the `FusedMoEModularKernel` object
|
||||
|
||||
Please take a look at [init_prepare_finalize](https://github.com/vllm-project/vllm/blob/1cbf951ba272c230823b947631065b826409fa62/vllm/model_executor/layers/fused_moe/layer.py#L188).
|
||||
**Important**: The `FusedMoEMethodBase` derived classes use the `FusedMoEMethodBase::fused_experts` object in their `apply` methods. When settings permit the construction of a valid `FusedMoEModularKernel` object, we override `FusedMoEMethodBase::fused_experts` with it. This essentially makes the derived classes agnostic to what FusedMoE implementation is used.
|
||||
|
||||
### How To Unit Test
|
||||
|
||||
We have `FusedMoEModularKernel` unit tests at [test_modular_kernel_combinations.py](gh-file:tests/kernels/moe/test_modular_kernel_combinations.py).
|
||||
|
||||
The unit test iterates through all combinations of `FusedMoEPrepareAndFinalize` and `FusedMoEPremuteExpertsUnpermute` types and if they are
|
||||
@ -196,18 +215,21 @@ If you are adding some `FusedMoEPrepareAndFinalize` / `FusedMoEPermuteExpertsUnp
|
||||
Doing this will add the new implementation to the test suite.
|
||||
|
||||
### How To Check `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` Compatibility
|
||||
|
||||
The unit test file [test_modular_kernel_combinations.py](gh-file:tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script.
|
||||
Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts`
|
||||
As a side-effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked
|
||||
with incompatible types, the script will error.
|
||||
|
||||
### How To Profile
|
||||
|
||||
Please take a look at [profile_modular_kernel.py](gh-file:tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py)
|
||||
The script can be used to generate Torch traces for a single `FusedMoEModularKernel::forward()` call for any compatible
|
||||
`FusedMoEPrepareAndFinalize` and `FusedMoEPermuteExpertsUnpermute` types.
|
||||
Example: `python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts`
|
||||
|
||||
## FusedMoEPrepareAndFinalize Implementations
|
||||
|
||||
The following table lists the `FusedMoEPrepareAndFinalize` implementations at the time of writing,
|
||||
|
||||
| Implementation | Type | Comments |
|
||||
@ -220,6 +242,7 @@ The following table lists the `FusedMoEPrepareAndFinalize` implementations at th
|
||||
| BatchedPrepareAndFinalize | Batched | A reference prepare/finalize class that reorganizes the tokens into expert batched format, i.e. E x max_num_tokens x K. (Doesn’t use any all2all kernels. This is primarily used in unit testing) |
|
||||
|
||||
## FusedMoEPermuteExpertsUnpermute
|
||||
|
||||
The following table lists the `FusedMoEPermuteExpertsUnpermute` implementations at the time of writing,
|
||||
|
||||
| Implementation | Type | Comment |
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user