mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 07:15:01 +08:00
Support Llama 4 for cutlass_moe_fp4 (#20453)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
e59ba9e142
commit
31b96d1c64
@ -411,13 +411,23 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
|
||||
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor,
|
||||
w1_alphas: torch.Tensor, a2_gscale: torch.Tensor,
|
||||
w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor,
|
||||
w2_alphas: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, m: int, n: int, k: int, e: int,
|
||||
device: torch.device):
|
||||
def cutlass_moe_fp4(a: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
w1_fp4: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w1_alphas: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
w2_fp4: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
w2_alphas: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
device: torch.device,
|
||||
apply_router_weight_on_input: bool = False):
|
||||
"""
|
||||
MoE implementation for FP4 Inputs
|
||||
|
||||
@ -480,6 +490,12 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert num_topk == 1, \
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
a.mul_(topk_weights.to(out_dtype))
|
||||
|
||||
# problem shapes should have [m, n, k]
|
||||
# Note that problem sizes are based on logical number of elements.
|
||||
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
|
||||
@ -517,8 +533,11 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
del int_fp4, int_blockscale
|
||||
|
||||
c2 = ops.shuffle_rows(c2, c_map)
|
||||
out = (c2.view(m, num_topk, k) *
|
||||
topk_weights.view(m, num_topk, 1).half()).sum(dim=1)
|
||||
if not apply_router_weight_on_input:
|
||||
out = (c2.view(m, num_topk, k) *
|
||||
topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1)
|
||||
else:
|
||||
out = c2.view(m, num_topk, k).sum(dim=1)
|
||||
return out.to(dtype=out_dtype)
|
||||
|
||||
|
||||
|
||||
@ -295,6 +295,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for "
|
||||
"`CompressedTensorsW4A4MoeMethod` yet.")
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
@ -326,10 +327,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert not apply_router_weight_on_input, (
|
||||
"Router weight on input is not "
|
||||
"supported for CompressedTensorsW4A4MoeMethod.")
|
||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"CompressedTensorsW4A4MoeMethod.")
|
||||
@ -339,22 +336,25 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
# Cutlass moe takes in activations in BF16/Half precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
return cutlass_moe_fp4(a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w1_alphas=layer.g1_alphas,
|
||||
w2_fp4=layer.w2_weight,
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alphas=layer.g2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
device=x.device).to(x.dtype)
|
||||
return cutlass_moe_fp4(
|
||||
a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w1_alphas=layer.g1_alphas,
|
||||
w2_fp4=layer.w2_weight,
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alphas=layer.g2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
device=x.device,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input).to(
|
||||
x.dtype)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
@ -673,21 +673,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
if self.use_marlin:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -704,44 +704,31 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert not apply_router_weight_on_input, (
|
||||
"Router weight on input is not "
|
||||
"supported for ModelOptNvFp4FusedMoE.")
|
||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"ModelOptNvFp4FusedMoE.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp4)
|
||||
|
||||
# Cutlass moe takes in activations in BF16/Half precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
return cutlass_moe_fp4(a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w1_alphas=layer.g1_alphas,
|
||||
w2_fp4=layer.w2_weight,
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alphas=layer.g2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
device=x.device).to(x.dtype)
|
||||
return cutlass_moe_fp4(
|
||||
a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w1_alphas=layer.g1_alphas,
|
||||
w2_fp4=layer.w2_weight,
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alphas=layer.g2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
device=x.device,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input).to(
|
||||
x.dtype)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user