mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-11 18:32:31 +08:00
[Bugfix] Fixes for new marlin moe usage (#18017)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
d8487ef557
commit
1df491c522
@ -57,9 +57,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
"input_activations")
|
"input_activations")
|
||||||
|
|
||||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||||
|
# group_size=None means channelwise
|
||||||
|
group_size = weight_quant.group_size or -1
|
||||||
# Prefer to use the MarlinMoE kernel when it is supported.
|
# Prefer to use the MarlinMoE kernel when it is supported.
|
||||||
if not check_moe_marlin_supports_layer(layer,
|
if not check_moe_marlin_supports_layer(layer, group_size):
|
||||||
weight_quant.group_size):
|
|
||||||
if (weight_quant.strategy in QuantizationStrategy.GROUP and
|
if (weight_quant.strategy in QuantizationStrategy.GROUP and
|
||||||
weight_quant.actorder in (ActivationOrdering.GROUP,
|
weight_quant.actorder in (ActivationOrdering.GROUP,
|
||||||
ActivationOrdering.DYNAMIC)):
|
ActivationOrdering.DYNAMIC)):
|
||||||
|
|||||||
@ -610,9 +610,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
assert activation == "silu", "Only SiLU activation is supported."
|
||||||
if apply_router_weight_on_input is not None:
|
if apply_router_weight_on_input:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Apply router weight on input is not supported for"
|
"Apply router weight on input is not supported for "
|
||||||
"fused Marlin MoE method.")
|
"fused Marlin MoE method.")
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user