hack fix MoEConfig.quant_dtype

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
Bill Nell 2025-05-30 02:08:21 +00:00
parent e69879996f
commit a0efd3106c

View File

@ -195,6 +195,7 @@ class MoEConfig:
moe_parallel_config: FusedMoEParallelConfig
in_dtype: torch.dtype # The post quantization activation type.
quant_dtype: Optional[torch.dtype] = None
# TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128
@ -271,11 +272,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
if moe.quant_dtype.itemsize == 1:
scale_bytes = (cdiv(moe.hidden_dim, moe.block_size) *
if moe.quant_dtype is not None and moe.quant_dtype.itemsize == 1:
hidden_dim_bytes = moe.hidden_dim * moe.quant_dtype.itemsize
hidden_scale_bytes = (cdiv(moe.hidden_dim, moe.block_size) *
torch.float32.itemsize)
else:
scale_bytes = 0
hidden_dim_bytes = moe.hidden_dim * moe.in_dtype.itemsize
hidden_scale_bytes = 0
all_to_all_args = dict(
max_num_tokens=moe.max_num_tokens,
@ -286,8 +289,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
hidden_dim_scale_bytes=scale_bytes,
hidden_dim_bytes=hidden_dim_bytes,
hidden_dim_scale_bytes=hidden_scale_bytes,
)
if not all2all_manager.internode:
@ -793,7 +796,8 @@ class FusedMoE(torch.nn.Module):
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
quant_dtype = vllm_config.model_config.dtype
logger.debug("MODEL DTYPE %s", vllm_config.model_config.dtype)
quant_dtype: Optional[torch.dtype] = None
if quant_config is not None:
input_activations = get_quant_config_input_activations(
quant_config)
@ -804,6 +808,12 @@ class FusedMoE(torch.nn.Module):
elif input_activations.type == QuantizationType.INT:
quant_dtype = torch.int8
# Total hack
if quant_config.__class__.__name__ == "Fp8Config":
quant_dtype = torch.float8_e4m3fn
logger.info("QUANT_DTYPE %s", quant_dtype)
moe = MoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,