[Misc] MoE ModularKernel : Introduce TopKWeightAndReduce (#20648)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-07-10 17:40:38 -04:00 committed by GitHub
parent 574ad60db9
commit f0c98cae27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 297 additions and 59 deletions

View File

@ -32,6 +32,8 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.platforms import current_platform
from vllm.utils import round_up
@ -371,6 +373,7 @@ def pplx_prepare_finalize(
chunk_topk_weight,
chunk_topk_ids,
False,
weight_and_reduce_impl=TopKWeightAndReduceDelegate(),
)
torch.cuda.synchronize()

View File

@ -7,6 +7,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton
@ -217,6 +219,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self,
a: torch.Tensor,

View File

@ -88,6 +88,25 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return ((bdge is None or bdge.supports_expert_map())
and (bte is None or bte.supports_expert_map()))
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
bdge = self.batched_deep_gemm_experts
bte = self.batched_triton_experts
bdge_war = bdge.finalize_weight_and_reduce_impl() if bdge else None
bte_war = bte.finalize_weight_and_reduce_impl() if bte else None
is_bdge_war = bdge_war is not None
is_bte_war = bte_war is not None
if is_bdge_war and is_bte_war:
assert bdge_war == bte_war, (
"Both implementations should agree on WeightAndReduce impls. "
f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}")
if bdge_war is not None:
return bdge_war
assert bte_war is not None
return bte_war
def workspace_shapes(
self,
a: torch.Tensor,

View File

@ -11,6 +11,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
_resize_cache)
@ -255,6 +257,10 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool:
return not self.use_batched_format
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self,
a: torch.Tensor,

View File

@ -12,6 +12,8 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, per_token_group_quant_fp8)
from vllm.utils import has_deep_gemm, round_up
@ -85,6 +87,10 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int

View File

@ -6,8 +6,9 @@ import deep_ep
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
@ -187,45 +188,25 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights)
def _apply_weights_and_reduce(self, num_tokens: int,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
apply_router_weight_on_input: bool,
output_dtype: torch.dtype):
hidden_dim = fused_expert_output.size(-1)
if fused_expert_output.ndim == 2:
fused_expert_output = fused_expert_output.view(
num_tokens, -1, hidden_dim)
if not apply_router_weight_on_input:
# The DeepEP combine kernels don't do the topk weight
# multiplication. We multiply the weights locally.
m_x_topk = fused_expert_output.size(0)
fused_expert_output.mul_(topk_weights.view(m_x_topk, -1, 1))
out = torch.empty((num_tokens, hidden_dim),
device=fused_expert_output.device,
dtype=output_dtype)
ops.moe_sum(fused_expert_output, out)
return out
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> None:
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
assert self.handle is not None
# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
if fused_expert_output.numel() != 0:
fused_expert_output = self._apply_weights_and_reduce(
num_tokens=topk_ids.size(0),
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
fused_expert_output = weight_and_reduce_impl.apply(
output=None,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
output_dtype=output.dtype)
)
combined_x, _, event = self.buffer.combine(
x=fused_expert_output,

View File

@ -7,6 +7,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input, normalize_batched_scales_shape)
@ -166,8 +168,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> None:
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")
assert self.handle is not None
combine_topk_weights = topk_weights

View File

@ -11,6 +11,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input, normalize_batched_scales_shape,
normalize_scales_shape)
@ -600,25 +602,17 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
num_tokens = topk_ids.size(0)
num_local_experts = fused_expert_output.size(0)
K = fused_expert_output.size(-1)
assert output.size(0) == num_tokens and output.size(1) == K
output.fill_(0)
first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts
for expert_id in range(first_expert, last_expert):
matching_tokens = topk_ids == expert_id
topks = torch.any(matching_tokens, dim=1).flatten()
rows = torch.count_nonzero(topks)
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
if not apply_router_weight_on_input:
rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1))
output[topks] = output[topks] + rhs
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
weight_and_reduce_impl.apply(
output=output,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
@ -670,6 +664,10 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self,
a: torch.Tensor,
@ -877,6 +875,10 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self,
a: torch.Tensor,

View File

@ -25,6 +25,8 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
@ -1596,6 +1598,10 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self,
a: torch.Tensor,

View File

@ -23,7 +23,7 @@ from vllm.utils import cdiv
#
# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]
#
# Each component will be independent of the others except for
# Each component will be independent of (but may inform) the others except for
# [Quantize-Dispatch] and `[Combine] (see below). The components can then be
# mixed and matched with so that DP+EP can be supported easily for multiple
# MoE kernel implementations.
@ -32,13 +32,19 @@ from vllm.utils import cdiv
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
# The prepare method must take care of any needed quantization and the
# finalize method must apply weights and do the final reduction of the output.
# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method,
# may apply weights and/or do the final reduction of the output.
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
# MoE operation. One important feature to note is that this class does not
# apply topk weights or reduce the final output.
# MoE operation, i.e matmul + act_mul + optionally quant + matmul.
# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do
# the weight application and/or reduction. The class communicates this
# to [Finalize] via a TopKWeightAndReduce object.
# * FusedMoEModularKernel - an interface class that combines a
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
# provide the standard fused MoE kernel interface.
# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen
# by the FusedMoEPermuteExpertsUnpermute implementation that is passed
# on to [Finalize].
#
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
# class `FusedMoEPrepareAndFinalize` since they could use collective
@ -117,6 +123,24 @@ class ExpertTokensMetadata:
expert_num_tokens_cpu=expert_num_tokens_cpu)
class TopKWeightAndReduce(ABC):
"""
An abstract base class for weight application and reduction implementations.
"""
@abstractmethod
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
"""
Apply topk_weights to the fused_experts_outputs and/or reduce.
If an output tensor is not passed, it will be created in the
function.
"""
raise NotImplementedError
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC):
"""
@ -173,6 +197,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
) -> None:
"""
Perform any combine plus apply weights and perform a reduction on the
@ -184,6 +209,8 @@ class FusedMoEPrepareAndFinalize(ABC):
- topk_ids: The topk_ids.
- apply_router_weight_on_input: When False, apply the weights to
fused_expert_output.
- weight_and_reduce_impl: An optional TopKWeightAndReduce
implementation.
"""
raise NotImplementedError
@ -323,6 +350,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \
self.supports_chunking()
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
raise NotImplementedError
@abstractmethod
def apply(
self,
@ -702,7 +732,9 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta)
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input)
self.prepare_finalize.finalize(
output, fused_out, topk_weights, topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl())
return output

View File

@ -8,6 +8,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (
_validate_scale_shape, moe_kernel_quantize_input)
from vllm.utils import cdiv, round_up
@ -222,7 +224,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")
# This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0)
bound_m: Optional[torch.Tensor] = None

View File

@ -6,8 +6,8 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_unpermute_and_reduce)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
@ -62,6 +62,13 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
_moe_unpermute_and_reduce(output, fused_expert_output, None,
topk_weights, apply_router_weight_on_input)
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
weight_and_reduce_impl.apply(
output=output,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input)

View File

@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
"""
Useful in the case when some FusedMoEPermuteExpertsUnpermute
implementation does not perform weight application and reduction
but cannot address the needs of all the compatible PrepareAndFinalize
implementations.
For example, BatchedTritonExperts is compatible with both
PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize
does the weight-application + reduction as part of the pplx combine kernel.
But the BatchedPrepareAndFinalize needs an implementation. To facilitate
this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate
so the PrepareAndFinalize implementations could choose how to
weight + reduce.
"""
def __eq__(self, other):
return isinstance(other, TopKWeightAndReduceDelegate)
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
raise RuntimeError("The caller is expected to choose an appropriate "
"TopKWeightAndReduce implementation.")
class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
"""
The fused_experts outputs have already been weight applied and reduced.
This implementation is a no-op.
"""
def __eq__(self, other):
return isinstance(other, TopKWeightAndReduceNoOP)
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
# Relax this if an explicit copy is necessary. Note that,
# if a copy is employed we have to make sure that the
# tensors don't overlap
assert output is None
return fused_expert_output
class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce):
"""
TopKWeightAndReduce implementation for a fused_experts output
of shape (m, topk, K)
"""
def __eq__(self, other):
return isinstance(other, TopKWeightAndReduceContiguous)
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
m, num_topk = topk_ids.size()
k = fused_expert_output.size(-1)
if fused_expert_output.ndim == 2:
fused_expert_output = fused_expert_output.view(m, num_topk, k)
assert fused_expert_output.size() == (m, num_topk, k), (
f"Expected fused_expert_output size {(m, num_topk, k)}. But got "
f"{fused_expert_output.size()}")
if not apply_router_weight_on_input:
fused_expert_output.mul_(topk_weights.view(m, -1, 1))
if output is None:
output = torch.empty((m, k),
device=fused_expert_output.device,
dtype=fused_expert_output.dtype)
assert output.size() == (m, k), (
f"Expected output size {(m, k)}. But got {output.size()}")
ops.moe_sum(fused_expert_output, output)
return output
class TopKWeightAndReduceNaiveBatched(mk.TopKWeightAndReduce):
"""
TopKWeightAndReduce implementation for a fused_experts output
of shape (num_experts, batch_size, K)
"""
def __init__(self, rank: int):
self.rank = rank
def __eq__(self, other):
return (isinstance(other, TopKWeightAndReduceNaiveBatched)
and (other.rank == self.rank))
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
assert fused_expert_output.ndim == 3
num_tokens = topk_ids.size(0)
num_local_experts = fused_expert_output.size(0)
K = fused_expert_output.size(-1)
if output is None:
output = torch.zeros((num_tokens, K),
device=fused_expert_output.device,
dtype=fused_expert_output.dtype)
else:
output.fill_(0)
assert output.size() == (num_tokens, K), (
f"Expected output size {(num_tokens, K)}, but got {output.size()}")
first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts
for expert_id in range(first_expert, last_expert):
matching_tokens = topk_ids == expert_id
topks = torch.any(matching_tokens, dim=1).flatten()
rows = torch.count_nonzero(topks)
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
if not apply_router_weight_on_input:
rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1))
output[topks] = output[topks] + rhs
return output

View File

@ -69,6 +69,25 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
return ((dge is None or dge.supports_expert_map())
and (te is None or te.supports_expert_map()))
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
dge = self.deep_gemm_expert
te = self.triton_expert
dge_war = dge.finalize_weight_and_reduce_impl() if dge else None
te_war = te.finalize_weight_and_reduce_impl() if te else None
is_dge_war = dge_war is not None
is_te_war = te_war is not None
if is_dge_war and is_te_war:
assert dge_war == te_war, (
"Both implementations should agree on WeightAndReduce impls. "
f"Got dge_war: {dge_war}, and te_war: {te_war}")
if dge_war is not None:
return dge_war
assert te_war is not None
return te_war
def workspace_shapes(
self,
a: torch.Tensor,