mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 09:45:58 +08:00
refactor: Change scaling factors calculation for flashinfer FusedMoE (#22812)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
0fe85087a9
commit
b4cef5e6c7
@ -1189,10 +1189,10 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_scale: torch.Tensor,
|
input_scale: torch.Tensor,
|
||||||
gemm1_weights: torch.Tensor,
|
gemm1_weights: torch.Tensor,
|
||||||
gemm1_weights_scale: torch.Tensor,
|
|
||||||
activation_scale: torch.Tensor,
|
|
||||||
gemm2_weights: torch.Tensor,
|
gemm2_weights: torch.Tensor,
|
||||||
gemm2_weights_scale: torch.Tensor,
|
output1_scales_scalar: torch.Tensor,
|
||||||
|
output1_scales_gate_scalar: torch.Tensor,
|
||||||
|
output2_scales_scalar: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
num_expert_group: Optional[int],
|
num_expert_group: Optional[int],
|
||||||
@ -1206,17 +1206,12 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
|
|||||||
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||||
topk_group = topk_group if topk_group is not None else 0
|
topk_group = topk_group if topk_group is not None else 0
|
||||||
|
|
||||||
quant_hidden_states, input_scale = moe_kernel_quantize_input(
|
quant_hidden_states, _ = moe_kernel_quantize_input(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
input_scale,
|
input_scale,
|
||||||
quant_dtype=torch.float8_e4m3fn,
|
quant_dtype=torch.float8_e4m3fn,
|
||||||
per_act_token_quant=False)
|
per_act_token_quant=False)
|
||||||
|
|
||||||
output1_scales_scalar = gemm1_weights_scale * input_scale * (
|
|
||||||
1.0 / activation_scale)
|
|
||||||
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
|
|
||||||
output2_scales_scalar = activation_scale * gemm2_weights_scale
|
|
||||||
|
|
||||||
from vllm.utils.flashinfer import (
|
from vllm.utils.flashinfer import (
|
||||||
flashinfer_trtllm_fp8_per_tensor_scale_moe)
|
flashinfer_trtllm_fp8_per_tensor_scale_moe)
|
||||||
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
||||||
@ -1244,24 +1239,24 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
|
|||||||
|
|
||||||
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
||||||
routing_logits: torch.Tensor,
|
routing_logits: torch.Tensor,
|
||||||
routing_bias: torch.Tensor,
|
routing_bias: Optional[torch.Tensor],
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
input_scale: torch.Tensor,
|
||||||
gemm1_weights: torch.Tensor,
|
gemm1_weights: torch.Tensor,
|
||||||
|
gemm2_weights: torch.Tensor,
|
||||||
output1_scales_scalar: torch.Tensor,
|
output1_scales_scalar: torch.Tensor,
|
||||||
output1_scales_gate_scalar: torch.Tensor,
|
output1_scales_gate_scalar: torch.Tensor,
|
||||||
gemm2_weights: torch.Tensor,
|
|
||||||
output2_scales_scalar: torch.Tensor,
|
output2_scales_scalar: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
num_expert_group: int,
|
num_expert_group: Optional[int],
|
||||||
topk_group: int,
|
topk_group: Optional[int],
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
local_expert_offset: int,
|
local_expert_offset: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
routed_scaling_factor: float = 1.0,
|
use_routing_scales_on_input: bool,
|
||||||
use_routing_scales_on_input: bool = False,
|
routing_method_type: int,
|
||||||
tile_tokens_dim: int = 8,
|
routed_scaling_factor: float = 1.0) -> torch.Tensor:
|
||||||
routing_method_type: int = 0) -> torch.Tensor:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -24,8 +24,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
|
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
|
||||||
swap_w13_to_w31)
|
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
|
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
@ -694,6 +694,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
w2_weight = layer.w2_weight.data
|
w2_weight = layer.w2_weight.data
|
||||||
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
|
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
|
||||||
if not self.block_quant:
|
if not self.block_quant:
|
||||||
|
register_moe_scaling_factors(layer)
|
||||||
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
||||||
else:
|
else:
|
||||||
w13_weight = layer.w13_weight.data
|
w13_weight = layer.w13_weight.data
|
||||||
|
|||||||
@ -25,8 +25,8 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
|||||||
build_flashinfer_fp4_cutlass_moe_kernel,
|
build_flashinfer_fp4_cutlass_moe_kernel,
|
||||||
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
|
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
|
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
|
||||||
swap_w13_to_w31)
|
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||||
apply_fp4_marlin_linear, is_fp4_marlin_supported,
|
apply_fp4_marlin_linear, is_fp4_marlin_supported,
|
||||||
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
|
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
|
||||||
@ -430,6 +430,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
||||||
layer.w2_weight)
|
layer.w2_weight)
|
||||||
|
register_moe_scaling_factors(layer)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -82,6 +82,12 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
|||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from flashinfer.fused_moe import RoutingMethodType
|
from flashinfer.fused_moe import RoutingMethodType
|
||||||
|
assert layer.output1_scales_scalar is not None, (
|
||||||
|
"Expected output1_scales_scalar to be initialized")
|
||||||
|
assert layer.output1_scales_scalar is not None, (
|
||||||
|
"Expected output1_scales_gate_scalar to be initialized")
|
||||||
|
assert layer.output1_scales_scalar is not None, (
|
||||||
|
"Expected output2_scales_scalar to be initialized")
|
||||||
|
|
||||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||||
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
|
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
|
||||||
@ -92,10 +98,10 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
input_scale=layer.w13_input_scale,
|
input_scale=layer.w13_input_scale,
|
||||||
gemm1_weights=layer.w13_weight,
|
gemm1_weights=layer.w13_weight,
|
||||||
gemm1_weights_scale=layer.w13_weight_scale,
|
|
||||||
gemm2_weights=layer.w2_weight,
|
gemm2_weights=layer.w2_weight,
|
||||||
gemm2_weights_scale=layer.w2_weight_scale,
|
output1_scales_scalar=layer.output1_scales_scalar,
|
||||||
activation_scale=layer.w2_input_scale,
|
output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
|
||||||
|
output2_scales_scalar=layer.output2_scales_scalar,
|
||||||
num_experts=global_num_experts,
|
num_experts=global_num_experts,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
@ -105,4 +111,36 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
|||||||
local_num_experts=layer.local_num_experts,
|
local_num_experts=layer.local_num_experts,
|
||||||
use_routing_scales_on_input=apply_router_weight_on_input,
|
use_routing_scales_on_input=apply_router_weight_on_input,
|
||||||
routing_method_type=RoutingMethodType.Llama4,
|
routing_method_type=RoutingMethodType.Llama4,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_moe_scaling_factors(
|
||||||
|
input_scale: torch.Tensor,
|
||||||
|
gemm1_weights_scale: torch.Tensor,
|
||||||
|
activation_scale: torch.Tensor,
|
||||||
|
gemm2_weights_scale: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
output1_scales_scalar = gemm1_weights_scale * input_scale * (
|
||||||
|
1.0 / activation_scale)
|
||||||
|
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
|
||||||
|
output2_scales_scalar = activation_scale * gemm2_weights_scale
|
||||||
|
|
||||||
|
return output1_scales_scalar, output1_scales_gate_scalar, \
|
||||||
|
output2_scales_scalar
|
||||||
|
|
||||||
|
|
||||||
|
def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
|
||||||
|
output1_scales, output1_gate_scales, output2_scales = \
|
||||||
|
get_moe_scaling_factors(
|
||||||
|
layer.w13_input_scale, layer.w13_weight_scale,
|
||||||
|
layer.w2_input_scale, layer.w2_weight_scale
|
||||||
|
)
|
||||||
|
layer.register_parameter(
|
||||||
|
'output1_scales_scalar',
|
||||||
|
torch.nn.Parameter(output1_scales, requires_grad=False))
|
||||||
|
layer.register_parameter(
|
||||||
|
'output1_scales_gate_scalar',
|
||||||
|
torch.nn.Parameter(output1_gate_scales, requires_grad=False))
|
||||||
|
layer.register_parameter(
|
||||||
|
'output2_scales_scalar',
|
||||||
|
torch.nn.Parameter(output2_scales, requires_grad=False))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user