mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 07:35:01 +08:00
[Bugfix] Fix a couple PPLX+CUTLASS MoE bugs (#20825)
Signed-off-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
parent
42d440c22b
commit
3b3b778d4a
@ -204,7 +204,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
out_expert_x_scale=expert_x_scale,
|
out_expert_x_scale=expert_x_scale,
|
||||||
dp_x=a1q,
|
dp_x=a1q,
|
||||||
dp_x_scale=a1q_scale,
|
dp_x_scale=a1q_scale,
|
||||||
indices=topk_ids,
|
indices=topk_ids.view(dtype=torch.uint32),
|
||||||
bound_m=bound_m,
|
bound_m=bound_m,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -249,7 +249,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
topk_weights = torch.ones_like(topk_weights)
|
topk_weights = torch.ones_like(topk_weights)
|
||||||
|
|
||||||
self.a2a.combine(out_tokens=output,
|
self.a2a.combine(out_tokens=output,
|
||||||
indices=topk_ids,
|
indices=topk_ids.view(dtype=torch.uint32),
|
||||||
weights=topk_weights,
|
weights=topk_weights,
|
||||||
expert_y=fused_expert_output,
|
expert_y=fused_expert_output,
|
||||||
bound_m=bound_m)
|
bound_m=bound_m)
|
||||||
|
|||||||
@ -737,10 +737,8 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|||||||
"For FP8 Fused MoE layer, we require either per tensor or "
|
"For FP8 Fused MoE layer, we require either per tensor or "
|
||||||
"channelwise, dynamic per token quantization.")
|
"channelwise, dynamic per token quantization.")
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
|
||||||
cutlass_moe_fp8)
|
|
||||||
self.topk_indices_dtype = None
|
self.topk_indices_dtype = None
|
||||||
self.fused_experts = cutlass_moe_fp8 # type: ignore
|
self.fused_experts = None # type: ignore
|
||||||
self.disable_expert_map = False
|
self.disable_expert_map = False
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
@ -936,21 +934,40 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|||||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||||
|
|
||||||
return self.fused_experts(
|
if self.fused_experts is None:
|
||||||
x,
|
# If no modular kernel is provided, use cutlass_moe_fp8
|
||||||
layer.w13_weight,
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
layer.w2_weight,
|
cutlass_moe_fp8)
|
||||||
topk_weights,
|
return cutlass_moe_fp8(
|
||||||
topk_ids,
|
x,
|
||||||
per_act_token=per_act_token,
|
layer.w13_weight,
|
||||||
activation=activation,
|
layer.w2_weight,
|
||||||
global_num_experts=global_num_experts,
|
topk_weights,
|
||||||
expert_map=None if self.disable_expert_map else expert_map,
|
topk_ids,
|
||||||
w1_scale=layer.w13_weight_scale,
|
per_act_token=per_act_token,
|
||||||
w2_scale=layer.w2_weight_scale,
|
activation=activation,
|
||||||
a1_scale=a1_scale,
|
global_num_experts=global_num_experts,
|
||||||
a2_scale=a2_scale,
|
expert_map=None if self.disable_expert_map else expert_map,
|
||||||
)
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.fused_experts(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=None if self.disable_expert_map else expert_map,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user