From a0efd3106cdd75c29f6463193968a60daa3102db Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 30 May 2025 02:08:21 +0000 Subject: [PATCH] hack fix MoEConfig.quant_dtype Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8eef20c75c432..ae72826ee9765 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -195,6 +195,7 @@ class MoEConfig: moe_parallel_config: FusedMoEParallelConfig in_dtype: torch.dtype # The post quantization activation type. + quant_dtype: Optional[torch.dtype] = None # TODO: add more quantization params, blocked, per-token, etc. block_size: int = 128 @@ -271,11 +272,13 @@ class FusedMoEMethodBase(QuantizeMethodBase): # For blocked per token: set to # ceil_div(hidden_dim, block_size) * sizeof(float32) # For per-token: set to sizeof(float32) - if moe.quant_dtype.itemsize == 1: - scale_bytes = (cdiv(moe.hidden_dim, moe.block_size) * + if moe.quant_dtype is not None and moe.quant_dtype.itemsize == 1: + hidden_dim_bytes = moe.hidden_dim * moe.quant_dtype.itemsize + hidden_scale_bytes = (cdiv(moe.hidden_dim, moe.block_size) * torch.float32.itemsize) else: - scale_bytes = 0 + hidden_dim_bytes = moe.hidden_dim * moe.in_dtype.itemsize + hidden_scale_bytes = 0 all_to_all_args = dict( max_num_tokens=moe.max_num_tokens, @@ -286,8 +289,8 @@ class FusedMoEMethodBase(QuantizeMethodBase): # dp_size actually means tp_size, bug in pplx kernels dp_size=all2all_manager.tp_group.world_size, hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize, - hidden_dim_scale_bytes=scale_bytes, + hidden_dim_bytes=hidden_dim_bytes, + hidden_dim_scale_bytes=hidden_scale_bytes, ) if not all2all_manager.internode: @@ -793,7 +796,8 @@ class FusedMoE(torch.nn.Module): from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) - quant_dtype = vllm_config.model_config.dtype + logger.debug("MODEL DTYPE %s", vllm_config.model_config.dtype) + quant_dtype: Optional[torch.dtype] = None if quant_config is not None: input_activations = get_quant_config_input_activations( quant_config) @@ -804,6 +808,12 @@ class FusedMoE(torch.nn.Module): elif input_activations.type == QuantizationType.INT: quant_dtype = torch.int8 + # Total hack + if quant_config.__class__.__name__ == "Fp8Config": + quant_dtype = torch.float8_e4m3fn + + logger.info("QUANT_DTYPE %s", quant_dtype) + moe = MoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k,