mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 12:01:19 +08:00
hack fix MoEConfig.quant_dtype
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
e69879996f
commit
a0efd3106c
@ -195,6 +195,7 @@ class MoEConfig:
|
|||||||
moe_parallel_config: FusedMoEParallelConfig
|
moe_parallel_config: FusedMoEParallelConfig
|
||||||
|
|
||||||
in_dtype: torch.dtype # The post quantization activation type.
|
in_dtype: torch.dtype # The post quantization activation type.
|
||||||
|
quant_dtype: Optional[torch.dtype] = None
|
||||||
|
|
||||||
# TODO: add more quantization params, blocked, per-token, etc.
|
# TODO: add more quantization params, blocked, per-token, etc.
|
||||||
block_size: int = 128
|
block_size: int = 128
|
||||||
@ -271,11 +272,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
# For blocked per token: set to
|
# For blocked per token: set to
|
||||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||||
# For per-token: set to sizeof(float32)
|
# For per-token: set to sizeof(float32)
|
||||||
if moe.quant_dtype.itemsize == 1:
|
if moe.quant_dtype is not None and moe.quant_dtype.itemsize == 1:
|
||||||
scale_bytes = (cdiv(moe.hidden_dim, moe.block_size) *
|
hidden_dim_bytes = moe.hidden_dim * moe.quant_dtype.itemsize
|
||||||
|
hidden_scale_bytes = (cdiv(moe.hidden_dim, moe.block_size) *
|
||||||
torch.float32.itemsize)
|
torch.float32.itemsize)
|
||||||
else:
|
else:
|
||||||
scale_bytes = 0
|
hidden_dim_bytes = moe.hidden_dim * moe.in_dtype.itemsize
|
||||||
|
hidden_scale_bytes = 0
|
||||||
|
|
||||||
all_to_all_args = dict(
|
all_to_all_args = dict(
|
||||||
max_num_tokens=moe.max_num_tokens,
|
max_num_tokens=moe.max_num_tokens,
|
||||||
@ -286,8 +289,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
# dp_size actually means tp_size, bug in pplx kernels
|
# dp_size actually means tp_size, bug in pplx kernels
|
||||||
dp_size=all2all_manager.tp_group.world_size,
|
dp_size=all2all_manager.tp_group.world_size,
|
||||||
hidden_dim=moe.hidden_dim,
|
hidden_dim=moe.hidden_dim,
|
||||||
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
|
hidden_dim_bytes=hidden_dim_bytes,
|
||||||
hidden_dim_scale_bytes=scale_bytes,
|
hidden_dim_scale_bytes=hidden_scale_bytes,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not all2all_manager.internode:
|
if not all2all_manager.internode:
|
||||||
@ -793,7 +796,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||||
|
|
||||||
quant_dtype = vllm_config.model_config.dtype
|
logger.debug("MODEL DTYPE %s", vllm_config.model_config.dtype)
|
||||||
|
quant_dtype: Optional[torch.dtype] = None
|
||||||
if quant_config is not None:
|
if quant_config is not None:
|
||||||
input_activations = get_quant_config_input_activations(
|
input_activations = get_quant_config_input_activations(
|
||||||
quant_config)
|
quant_config)
|
||||||
@ -804,6 +808,12 @@ class FusedMoE(torch.nn.Module):
|
|||||||
elif input_activations.type == QuantizationType.INT:
|
elif input_activations.type == QuantizationType.INT:
|
||||||
quant_dtype = torch.int8
|
quant_dtype = torch.int8
|
||||||
|
|
||||||
|
# Total hack
|
||||||
|
if quant_config.__class__.__name__ == "Fp8Config":
|
||||||
|
quant_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
logger.info("QUANT_DTYPE %s", quant_dtype)
|
||||||
|
|
||||||
moe = MoEConfig(
|
moe = MoEConfig(
|
||||||
num_experts=self.global_num_experts,
|
num_experts=self.global_num_experts,
|
||||||
experts_per_token=top_k,
|
experts_per_token=top_k,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user