mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 02:47:03 +08:00
hack fix MoEConfig.quant_dtype
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
e69879996f
commit
a0efd3106c
@ -195,6 +195,7 @@ class MoEConfig:
|
||||
moe_parallel_config: FusedMoEParallelConfig
|
||||
|
||||
in_dtype: torch.dtype # The post quantization activation type.
|
||||
quant_dtype: Optional[torch.dtype] = None
|
||||
|
||||
# TODO: add more quantization params, blocked, per-token, etc.
|
||||
block_size: int = 128
|
||||
@ -271,11 +272,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
# For blocked per token: set to
|
||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||
# For per-token: set to sizeof(float32)
|
||||
if moe.quant_dtype.itemsize == 1:
|
||||
scale_bytes = (cdiv(moe.hidden_dim, moe.block_size) *
|
||||
if moe.quant_dtype is not None and moe.quant_dtype.itemsize == 1:
|
||||
hidden_dim_bytes = moe.hidden_dim * moe.quant_dtype.itemsize
|
||||
hidden_scale_bytes = (cdiv(moe.hidden_dim, moe.block_size) *
|
||||
torch.float32.itemsize)
|
||||
else:
|
||||
scale_bytes = 0
|
||||
hidden_dim_bytes = moe.hidden_dim * moe.in_dtype.itemsize
|
||||
hidden_scale_bytes = 0
|
||||
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
@ -286,8 +289,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
hidden_dim=moe.hidden_dim,
|
||||
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
|
||||
hidden_dim_scale_bytes=scale_bytes,
|
||||
hidden_dim_bytes=hidden_dim_bytes,
|
||||
hidden_dim_scale_bytes=hidden_scale_bytes,
|
||||
)
|
||||
|
||||
if not all2all_manager.internode:
|
||||
@ -793,7 +796,8 @@ class FusedMoE(torch.nn.Module):
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||
|
||||
quant_dtype = vllm_config.model_config.dtype
|
||||
logger.debug("MODEL DTYPE %s", vllm_config.model_config.dtype)
|
||||
quant_dtype: Optional[torch.dtype] = None
|
||||
if quant_config is not None:
|
||||
input_activations = get_quant_config_input_activations(
|
||||
quant_config)
|
||||
@ -804,6 +808,12 @@ class FusedMoE(torch.nn.Module):
|
||||
elif input_activations.type == QuantizationType.INT:
|
||||
quant_dtype = torch.int8
|
||||
|
||||
# Total hack
|
||||
if quant_config.__class__.__name__ == "Fp8Config":
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
logger.info("QUANT_DTYPE %s", quant_dtype)
|
||||
|
||||
moe = MoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user