vllm/vllm/model_executor/layers/fused_moe/prepare_finalize.py

61 lines
2.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_unpermute_and_reduce)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def __init__(
self,
quant_dtype: Optional[torch.dtype] = None,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None,
):
super().__init__()
self.per_channel_quant = per_channel_quant
self.block_shape = block_shape
self.quant_dtype = quant_dtype
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
self.quant_dtype,
self.per_channel_quant,
self.block_shape)
return a1q, a1q_scale, None
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
_moe_unpermute_and_reduce(output, fused_expert_output, None,
topk_weights, apply_router_weight_on_input)