mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 13:44:28 +08:00
[Misc] MoE ModularKernel : Introduce TopKWeightAndReduce (#20648)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
parent
574ad60db9
commit
f0c98cae27
@ -32,6 +32,8 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
@ -371,6 +373,7 @@ def pplx_prepare_finalize(
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
False,
|
||||
weight_and_reduce_impl=TopKWeightAndReduceDelegate(),
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -7,6 +7,8 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
@ -217,6 +219,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
|
||||
@ -88,6 +88,25 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return ((bdge is None or bdge.supports_expert_map())
|
||||
and (bte is None or bte.supports_expert_map()))
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
bdge = self.batched_deep_gemm_experts
|
||||
bte = self.batched_triton_experts
|
||||
bdge_war = bdge.finalize_weight_and_reduce_impl() if bdge else None
|
||||
bte_war = bte.finalize_weight_and_reduce_impl() if bte else None
|
||||
is_bdge_war = bdge_war is not None
|
||||
is_bte_war = bte_war is not None
|
||||
|
||||
if is_bdge_war and is_bte_war:
|
||||
assert bdge_war == bte_war, (
|
||||
"Both implementations should agree on WeightAndReduce impls. "
|
||||
f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}")
|
||||
|
||||
if bdge_war is not None:
|
||||
return bdge_war
|
||||
|
||||
assert bte_war is not None
|
||||
return bte_war
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
|
||||
@ -11,6 +11,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
|
||||
_fp8_quantize,
|
||||
_resize_cache)
|
||||
@ -255,6 +257,10 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return not self.use_batched_format
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
|
||||
@ -12,6 +12,8 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
_moe_permute)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, per_token_group_quant_fp8)
|
||||
from vllm.utils import has_deep_gemm, round_up
|
||||
@ -85,6 +87,10 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
|
||||
topk: int, global_num_experts: int, local_num_experts: int
|
||||
|
||||
@ -6,8 +6,9 @@ import deep_ep
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
|
||||
@ -187,45 +188,25 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
|
||||
expert_topk_weights)
|
||||
|
||||
def _apply_weights_and_reduce(self, num_tokens: int,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
output_dtype: torch.dtype):
|
||||
|
||||
hidden_dim = fused_expert_output.size(-1)
|
||||
if fused_expert_output.ndim == 2:
|
||||
fused_expert_output = fused_expert_output.view(
|
||||
num_tokens, -1, hidden_dim)
|
||||
|
||||
if not apply_router_weight_on_input:
|
||||
# The DeepEP combine kernels don't do the topk weight
|
||||
# multiplication. We multiply the weights locally.
|
||||
m_x_topk = fused_expert_output.size(0)
|
||||
fused_expert_output.mul_(topk_weights.view(m_x_topk, -1, 1))
|
||||
|
||||
out = torch.empty((num_tokens, hidden_dim),
|
||||
device=fused_expert_output.device,
|
||||
dtype=output_dtype)
|
||||
ops.moe_sum(fused_expert_output, out)
|
||||
|
||||
return out
|
||||
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> None:
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
|
||||
|
||||
assert self.handle is not None
|
||||
|
||||
# fused_expert_output can have 0 tokens - This happens when none of the
|
||||
# tokens from the all2all reach this EP rank.
|
||||
if fused_expert_output.numel() != 0:
|
||||
fused_expert_output = self._apply_weights_and_reduce(
|
||||
num_tokens=topk_ids.size(0),
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
fused_expert_output = weight_and_reduce_impl.apply(
|
||||
output=None,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
output_dtype=output.dtype)
|
||||
)
|
||||
|
||||
combined_x, _, event = self.buffer.combine(
|
||||
x=fused_expert_output,
|
||||
|
||||
@ -7,6 +7,8 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input, normalize_batched_scales_shape)
|
||||
|
||||
@ -166,8 +168,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> None:
|
||||
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
|
||||
assert isinstance(
|
||||
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
||||
), ("Weight application and reduction happens in the combine kernel.")
|
||||
assert self.handle is not None
|
||||
|
||||
combine_topk_weights = topk_weights
|
||||
|
||||
@ -11,6 +11,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
get_config_dtype_str, try_get_optimal_moe_config)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, moe_kernel_quantize_input, normalize_batched_scales_shape,
|
||||
normalize_scales_shape)
|
||||
@ -600,25 +602,17 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
num_tokens = topk_ids.size(0)
|
||||
num_local_experts = fused_expert_output.size(0)
|
||||
K = fused_expert_output.size(-1)
|
||||
assert output.size(0) == num_tokens and output.size(1) == K
|
||||
|
||||
output.fill_(0)
|
||||
|
||||
first_expert = num_local_experts * self.rank
|
||||
last_expert = first_expert + num_local_experts
|
||||
|
||||
for expert_id in range(first_expert, last_expert):
|
||||
matching_tokens = topk_ids == expert_id
|
||||
topks = torch.any(matching_tokens, dim=1).flatten()
|
||||
rows = torch.count_nonzero(topks)
|
||||
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
|
||||
if not apply_router_weight_on_input:
|
||||
rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1))
|
||||
output[topks] = output[topks] + rhs
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
|
||||
weight_and_reduce_impl.apply(
|
||||
output=output,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@ -670,6 +664,10 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -877,6 +875,10 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
|
||||
@ -25,6 +25,8 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, moe_kernel_quantize_input)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
@ -1596,6 +1598,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
|
||||
@ -23,7 +23,7 @@ from vllm.utils import cdiv
|
||||
#
|
||||
# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]
|
||||
#
|
||||
# Each component will be independent of the others except for
|
||||
# Each component will be independent of (but may inform) the others except for
|
||||
# [Quantize-Dispatch] and `[Combine] (see below). The components can then be
|
||||
# mixed and matched with so that DP+EP can be supported easily for multiple
|
||||
# MoE kernel implementations.
|
||||
@ -32,13 +32,19 @@ from vllm.utils import cdiv
|
||||
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
|
||||
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
|
||||
# The prepare method must take care of any needed quantization and the
|
||||
# finalize method must apply weights and do the final reduction of the output.
|
||||
# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method,
|
||||
# may apply weights and/or do the final reduction of the output.
|
||||
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
|
||||
# MoE operation. One important feature to note is that this class does not
|
||||
# apply topk weights or reduce the final output.
|
||||
# MoE operation, i.e matmul + act_mul + optionally quant + matmul.
|
||||
# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do
|
||||
# the weight application and/or reduction. The class communicates this
|
||||
# to [Finalize] via a TopKWeightAndReduce object.
|
||||
# * FusedMoEModularKernel - an interface class that combines a
|
||||
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
|
||||
# provide the standard fused MoE kernel interface.
|
||||
# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen
|
||||
# by the FusedMoEPermuteExpertsUnpermute implementation that is passed
|
||||
# on to [Finalize].
|
||||
#
|
||||
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
|
||||
# class `FusedMoEPrepareAndFinalize` since they could use collective
|
||||
@ -117,6 +123,24 @@ class ExpertTokensMetadata:
|
||||
expert_num_tokens_cpu=expert_num_tokens_cpu)
|
||||
|
||||
|
||||
class TopKWeightAndReduce(ABC):
|
||||
"""
|
||||
An abstract base class for weight application and reduction implementations.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
"""
|
||||
Apply topk_weights to the fused_experts_outputs and/or reduce.
|
||||
If an output tensor is not passed, it will be created in the
|
||||
function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
@ -173,6 +197,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||
) -> None:
|
||||
"""
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
@ -184,6 +209,8 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- topk_ids: The topk_ids.
|
||||
- apply_router_weight_on_input: When False, apply the weights to
|
||||
fused_expert_output.
|
||||
- weight_and_reduce_impl: An optional TopKWeightAndReduce
|
||||
implementation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -323,6 +350,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \
|
||||
self.supports_chunking()
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
@ -702,7 +732,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta)
|
||||
|
||||
self.prepare_finalize.finalize(output, fused_out, topk_weights,
|
||||
topk_ids, apply_router_weight_on_input)
|
||||
self.prepare_finalize.finalize(
|
||||
output, fused_out, topk_weights, topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl())
|
||||
|
||||
return output
|
||||
|
||||
@ -8,6 +8,8 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_validate_scale_shape, moe_kernel_quantize_input)
|
||||
from vllm.utils import cdiv, round_up
|
||||
@ -222,7 +224,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
assert isinstance(
|
||||
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
||||
), ("Weight application and reduction happens in the combine kernel.")
|
||||
|
||||
# This argument is optional
|
||||
# There's not much point setting this unless it is != topk_ids.size(0)
|
||||
bound_m: Optional[torch.Tensor] = None
|
||||
|
||||
@ -6,8 +6,8 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
_moe_unpermute_and_reduce)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
|
||||
@ -62,6 +62,13 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
_moe_unpermute_and_reduce(output, fused_expert_output, None,
|
||||
topk_weights, apply_router_weight_on_input)
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
weight_and_reduce_impl.apply(
|
||||
output=output,
|
||||
fused_expert_output=fused_expert_output,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
139
vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py
Normal file
139
vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py
Normal file
@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
|
||||
|
||||
class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
|
||||
"""
|
||||
Useful in the case when some FusedMoEPermuteExpertsUnpermute
|
||||
implementation does not perform weight application and reduction
|
||||
but cannot address the needs of all the compatible PrepareAndFinalize
|
||||
implementations.
|
||||
For example, BatchedTritonExperts is compatible with both
|
||||
PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize
|
||||
does the weight-application + reduction as part of the pplx combine kernel.
|
||||
But the BatchedPrepareAndFinalize needs an implementation. To facilitate
|
||||
this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate
|
||||
so the PrepareAndFinalize implementations could choose how to
|
||||
weight + reduce.
|
||||
"""
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TopKWeightAndReduceDelegate)
|
||||
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
raise RuntimeError("The caller is expected to choose an appropriate "
|
||||
"TopKWeightAndReduce implementation.")
|
||||
|
||||
|
||||
class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
|
||||
"""
|
||||
The fused_experts outputs have already been weight applied and reduced.
|
||||
This implementation is a no-op.
|
||||
"""
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TopKWeightAndReduceNoOP)
|
||||
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
# Relax this if an explicit copy is necessary. Note that,
|
||||
# if a copy is employed we have to make sure that the
|
||||
# tensors don't overlap
|
||||
assert output is None
|
||||
return fused_expert_output
|
||||
|
||||
|
||||
class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce):
|
||||
"""
|
||||
TopKWeightAndReduce implementation for a fused_experts output
|
||||
of shape (m, topk, K)
|
||||
"""
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TopKWeightAndReduceContiguous)
|
||||
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
|
||||
m, num_topk = topk_ids.size()
|
||||
k = fused_expert_output.size(-1)
|
||||
if fused_expert_output.ndim == 2:
|
||||
fused_expert_output = fused_expert_output.view(m, num_topk, k)
|
||||
|
||||
assert fused_expert_output.size() == (m, num_topk, k), (
|
||||
f"Expected fused_expert_output size {(m, num_topk, k)}. But got "
|
||||
f"{fused_expert_output.size()}")
|
||||
|
||||
if not apply_router_weight_on_input:
|
||||
fused_expert_output.mul_(topk_weights.view(m, -1, 1))
|
||||
|
||||
if output is None:
|
||||
output = torch.empty((m, k),
|
||||
device=fused_expert_output.device,
|
||||
dtype=fused_expert_output.dtype)
|
||||
assert output.size() == (m, k), (
|
||||
f"Expected output size {(m, k)}. But got {output.size()}")
|
||||
|
||||
ops.moe_sum(fused_expert_output, output)
|
||||
return output
|
||||
|
||||
|
||||
class TopKWeightAndReduceNaiveBatched(mk.TopKWeightAndReduce):
|
||||
"""
|
||||
TopKWeightAndReduce implementation for a fused_experts output
|
||||
of shape (num_experts, batch_size, K)
|
||||
"""
|
||||
|
||||
def __init__(self, rank: int):
|
||||
self.rank = rank
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, TopKWeightAndReduceNaiveBatched)
|
||||
and (other.rank == self.rank))
|
||||
|
||||
def apply(self, output: Optional[torch.Tensor],
|
||||
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
assert fused_expert_output.ndim == 3
|
||||
num_tokens = topk_ids.size(0)
|
||||
num_local_experts = fused_expert_output.size(0)
|
||||
K = fused_expert_output.size(-1)
|
||||
|
||||
if output is None:
|
||||
output = torch.zeros((num_tokens, K),
|
||||
device=fused_expert_output.device,
|
||||
dtype=fused_expert_output.dtype)
|
||||
else:
|
||||
output.fill_(0)
|
||||
|
||||
assert output.size() == (num_tokens, K), (
|
||||
f"Expected output size {(num_tokens, K)}, but got {output.size()}")
|
||||
|
||||
first_expert = num_local_experts * self.rank
|
||||
last_expert = first_expert + num_local_experts
|
||||
|
||||
for expert_id in range(first_expert, last_expert):
|
||||
matching_tokens = topk_ids == expert_id
|
||||
topks = torch.any(matching_tokens, dim=1).flatten()
|
||||
rows = torch.count_nonzero(topks)
|
||||
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
|
||||
if not apply_router_weight_on_input:
|
||||
rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1))
|
||||
output[topks] = output[topks] + rhs
|
||||
|
||||
return output
|
||||
@ -69,6 +69,25 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
return ((dge is None or dge.supports_expert_map())
|
||||
and (te is None or te.supports_expert_map()))
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
dge = self.deep_gemm_expert
|
||||
te = self.triton_expert
|
||||
dge_war = dge.finalize_weight_and_reduce_impl() if dge else None
|
||||
te_war = te.finalize_weight_and_reduce_impl() if te else None
|
||||
is_dge_war = dge_war is not None
|
||||
is_te_war = te_war is not None
|
||||
|
||||
if is_dge_war and is_te_war:
|
||||
assert dge_war == te_war, (
|
||||
"Both implementations should agree on WeightAndReduce impls. "
|
||||
f"Got dge_war: {dge_war}, and te_war: {te_war}")
|
||||
|
||||
if dge_war is not None:
|
||||
return dge_war
|
||||
|
||||
assert te_war is not None
|
||||
return te_war
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user