mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 07:54:57 +08:00
[NVIDIA] Add SM100 Flashinfer MoE per tensor scale fp8 backend (#21458)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
5daffe7cf6
commit
207b750e19
@ -30,6 +30,8 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
|||||||
TopKWeightAndReduceNoOP)
|
TopKWeightAndReduceNoOP)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
_resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8)
|
_resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
|
calculate_tile_tokens_dim)
|
||||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||||
dequant_mxfp4)
|
dequant_mxfp4)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -1065,22 +1067,6 @@ direct_register_custom_op(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def next_positive_power_of_2(x: int) -> int:
|
|
||||||
if x < 1:
|
|
||||||
return 1
|
|
||||||
return 1 << (x - 1).bit_length()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
|
||||||
# Guess tokens per expert assuming perfect expert distribution first.
|
|
||||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
|
||||||
# And pad the number to the next power of 2.
|
|
||||||
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
|
||||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
|
||||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
|
||||||
return tile_tokens_dim
|
|
||||||
|
|
||||||
|
|
||||||
def flashinfer_fused_moe_blockscale_fp8(
|
def flashinfer_fused_moe_blockscale_fp8(
|
||||||
routing_logits: torch.Tensor,
|
routing_logits: torch.Tensor,
|
||||||
routing_bias: torch.Tensor,
|
routing_bias: torch.Tensor,
|
||||||
@ -1128,8 +1114,8 @@ def flashinfer_fused_moe_blockscale_fp8(
|
|||||||
local_expert_offset=expert_offset,
|
local_expert_offset=expert_offset,
|
||||||
local_num_experts=local_num_experts,
|
local_num_experts=local_num_experts,
|
||||||
routed_scaling_factor=routed_scaling,
|
routed_scaling_factor=routed_scaling,
|
||||||
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
|
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
|
||||||
global_num_experts),
|
global_num_experts),
|
||||||
routing_method_type=2, # DeepSeek-styled routing method
|
routing_method_type=2, # DeepSeek-styled routing method
|
||||||
use_shuffled_weight=False,
|
use_shuffled_weight=False,
|
||||||
)
|
)
|
||||||
@ -1164,6 +1150,97 @@ direct_register_custom_op(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||||
|
routing_logits: torch.Tensor,
|
||||||
|
routing_bias: Optional[torch.Tensor],
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
input_scale: torch.Tensor,
|
||||||
|
gemm1_weights: torch.Tensor,
|
||||||
|
gemm1_weights_scale: torch.Tensor,
|
||||||
|
activation_scale: torch.Tensor,
|
||||||
|
gemm2_weights: torch.Tensor,
|
||||||
|
gemm2_weights_scale: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
top_k: int,
|
||||||
|
num_expert_group: Optional[int],
|
||||||
|
topk_group: Optional[int],
|
||||||
|
intermediate_size: int,
|
||||||
|
local_expert_offset: int,
|
||||||
|
local_num_experts: int,
|
||||||
|
use_routing_scales_on_input: bool,
|
||||||
|
routing_method_type: int,
|
||||||
|
routed_scaling_factor: float = 1.0) -> torch.Tensor:
|
||||||
|
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||||
|
topk_group = topk_group if topk_group is not None else 0
|
||||||
|
|
||||||
|
quant_hidden_states, input_scale = moe_kernel_quantize_input(
|
||||||
|
hidden_states,
|
||||||
|
input_scale,
|
||||||
|
quant_dtype=torch.float8_e4m3fn,
|
||||||
|
per_act_token_quant=False)
|
||||||
|
|
||||||
|
output1_scales_scalar = gemm1_weights_scale * input_scale * (
|
||||||
|
1.0 / activation_scale)
|
||||||
|
output1_scales_gate_scalar = gemm1_weights_scale * input_scale
|
||||||
|
output2_scales_scalar = activation_scale * gemm2_weights_scale
|
||||||
|
|
||||||
|
from vllm.utils.flashinfer import (
|
||||||
|
flashinfer_trtllm_fp8_per_tensor_scale_moe)
|
||||||
|
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
||||||
|
routing_logits=routing_logits,
|
||||||
|
routing_bias=routing_bias,
|
||||||
|
hidden_states=quant_hidden_states,
|
||||||
|
gemm1_weights=gemm1_weights,
|
||||||
|
output1_scales_scalar=output1_scales_scalar,
|
||||||
|
output1_scales_gate_scalar=output1_scales_gate_scalar,
|
||||||
|
gemm2_weights=gemm2_weights,
|
||||||
|
output2_scales_scalar=output2_scales_scalar,
|
||||||
|
num_experts=num_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
n_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
local_expert_offset=local_expert_offset,
|
||||||
|
local_num_experts=local_num_experts,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
use_routing_scales_on_input=use_routing_scales_on_input,
|
||||||
|
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
|
||||||
|
top_k, num_experts),
|
||||||
|
routing_method_type=routing_method_type)
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
||||||
|
routing_logits: torch.Tensor,
|
||||||
|
routing_bias: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gemm1_weights: torch.Tensor,
|
||||||
|
output1_scales_scalar: torch.Tensor,
|
||||||
|
output1_scales_gate_scalar: torch.Tensor,
|
||||||
|
gemm2_weights: torch.Tensor,
|
||||||
|
output2_scales_scalar: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
top_k: int,
|
||||||
|
num_expert_group: int,
|
||||||
|
topk_group: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
local_expert_offset: int,
|
||||||
|
local_num_experts: int,
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
use_routing_scales_on_input: bool = False,
|
||||||
|
tile_tokens_dim: int = 8,
|
||||||
|
routing_method_type: int = 0) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
|
||||||
|
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
|
||||||
|
mutates_args=["hidden_states"],
|
||||||
|
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
|
||||||
|
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def outplace_fused_experts(
|
def outplace_fused_experts(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
|
|||||||
@ -23,6 +23,9 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
|
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
|
||||||
|
swap_w13_to_w31)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
|
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
@ -53,11 +56,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return x.reshape(-1, 2, x.shape[-2] // 2,
|
|
||||||
x.shape[-1]).flip(dims=[1]).reshape(x.shape)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_col_major(x: torch.Tensor) -> bool:
|
def _is_col_major(x: torch.Tensor) -> bool:
|
||||||
assert x.dim() == 3
|
assert x.dim() == 3
|
||||||
b, m, n = x.shape
|
b, m, n = x.shape
|
||||||
@ -695,11 +693,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
elif self.flashinfer_moe_enabled:
|
elif self.flashinfer_moe_enabled:
|
||||||
# NOTE: weights have to be swapped since the activation is
|
# NOTE: weights have to be swapped since the activation is
|
||||||
# applied on different half for flashinfer vs vllm
|
# applied on different half for flashinfer vs vllm
|
||||||
w13_weight = _swap_w13_to_w31(layer.w13_weight.data)
|
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
|
||||||
w13_weight_scale_inv = _swap_w13_to_w31(
|
w13_weight_scale_inv = swap_w13_to_w31(
|
||||||
layer.w13_weight_scale_inv.data)
|
layer.w13_weight_scale_inv.data)
|
||||||
w2_weight = layer.w2_weight.data
|
w2_weight = layer.w2_weight.data
|
||||||
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
|
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
|
||||||
|
if not self.block_quant:
|
||||||
|
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
||||||
else:
|
else:
|
||||||
w13_weight = layer.w13_weight.data
|
w13_weight = layer.w13_weight.data
|
||||||
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
|
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
|
||||||
@ -998,30 +998,43 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
elif self.flashinfer_moe_enabled:
|
elif self.flashinfer_moe_enabled:
|
||||||
# Currently only work with DS models
|
assert activation == 'silu'
|
||||||
assert self.block_quant
|
assert scoring_func == 'sigmoid'
|
||||||
assert (renormalize and use_grouped_topk
|
if self.block_quant:
|
||||||
and scoring_func == 'sigmoid'
|
assert (renormalize and use_grouped_topk
|
||||||
and custom_routing_function is None)
|
and custom_routing_function is None)
|
||||||
assert activation == "silu"
|
|
||||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||||
routing_logits=router_logits.to(torch.float32),
|
routing_logits=router_logits.to(torch.float32),
|
||||||
routing_bias=e_score_correction_bias,
|
routing_bias=e_score_correction_bias,
|
||||||
x=x,
|
x=x,
|
||||||
w13_weight=layer.w13_weight,
|
w13_weight=layer.w13_weight,
|
||||||
w13_weight_scale_inv=layer.w13_weight_scale_inv,
|
w13_weight_scale_inv=layer.w13_weight_scale_inv,
|
||||||
w2_weight=layer.w2_weight,
|
w2_weight=layer.w2_weight,
|
||||||
w2_weight_scale_inv=layer.w2_weight_scale_inv,
|
w2_weight_scale_inv=layer.w2_weight_scale_inv,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
intermediate_size=layer.intermediate_size_per_partition,
|
intermediate_size=layer.intermediate_size_per_partition,
|
||||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||||
local_num_experts=layer.local_num_experts,
|
local_num_experts=layer.local_num_experts,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
block_shape=self.quant_config.weight_block_size,
|
||||||
routed_scaling=1.0,
|
routed_scaling=1.0,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
assert (not renormalize
|
||||||
|
and custom_routing_function is not None)
|
||||||
|
return apply_flashinfer_per_tensor_scale_fp8(
|
||||||
|
layer=layer,
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
routing_bias=e_score_correction_bias,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
else:
|
else:
|
||||||
return self.fused_experts(
|
return self.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
|||||||
@ -23,6 +23,9 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
|
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
|
||||||
|
swap_w13_to_w31)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||||
apply_fp4_marlin_linear, is_fp4_marlin_supported,
|
apply_fp4_marlin_linear, is_fp4_marlin_supported,
|
||||||
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
|
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
|
||||||
@ -34,6 +37,7 @@ from vllm.model_executor.parameter import (ModelWeightParameter,
|
|||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -267,6 +271,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
cutlass_fp8_supported)
|
cutlass_fp8_supported)
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
|
self.flashinfer_moe_enabled = False
|
||||||
|
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
||||||
|
logger.info_once(
|
||||||
|
"Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.")
|
||||||
|
self.flashinfer_moe_enabled = True
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@ -410,6 +419,11 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
|
layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
if self.flashinfer_moe_enabled:
|
||||||
|
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||||
|
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
||||||
|
layer.w2_weight)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -436,6 +450,20 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
|
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
|
||||||
|
|
||||||
|
if self.flashinfer_moe_enabled:
|
||||||
|
assert activation == 'silu'
|
||||||
|
assert not renormalize
|
||||||
|
return apply_flashinfer_per_tensor_scale_fp8(
|
||||||
|
layer=layer,
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
routing_bias=e_score_correction_bias,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
|
|
||||||
# Expert selection
|
# Expert selection
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
|||||||
@ -0,0 +1,100 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||||
|
from flashinfer import next_positive_power_of_2
|
||||||
|
|
||||||
|
# Guess tokens per expert assuming perfect expert distribution first.
|
||||||
|
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||||
|
# And pad the number to the next power of 2.
|
||||||
|
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
||||||
|
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
||||||
|
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||||
|
return tile_tokens_dim
|
||||||
|
|
||||||
|
|
||||||
|
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x.reshape(-1, 2, x.shape[-2] // 2,
|
||||||
|
x.shape[-1]).flip(dims=[1]).reshape(x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor,
|
||||||
|
gemm2_weights: torch.Tensor):
|
||||||
|
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
|
||||||
|
epilogue_tile_m = 128
|
||||||
|
num_experts = gemm1_weights.shape[0]
|
||||||
|
hidden_size = gemm1_weights.shape[-1]
|
||||||
|
intermediate_size = gemm1_weights.shape[1] // 2
|
||||||
|
|
||||||
|
# Reorder rows of W1 for fused gated activation
|
||||||
|
gemm1_weights_fp8_interleaved = []
|
||||||
|
for i in range(num_experts):
|
||||||
|
gemm1_weights_fp8_interleaved.append(
|
||||||
|
reorder_rows_for_gated_act_gemm(gemm1_weights[i]))
|
||||||
|
|
||||||
|
# Stack weights and scales for all experts
|
||||||
|
gemm1_weights_fp8_interleaved = torch.stack(
|
||||||
|
gemm1_weights_fp8_interleaved).reshape(num_experts,
|
||||||
|
2 * intermediate_size,
|
||||||
|
hidden_size)
|
||||||
|
|
||||||
|
# Shuffle weights and scaling factors for transposed mma output
|
||||||
|
gemm1_weights_fp8_shuffled = []
|
||||||
|
gemm2_weights_fp8_shuffled = []
|
||||||
|
for i in range(num_experts):
|
||||||
|
gemm1_weights_fp8_shuffled.append(
|
||||||
|
shuffle_matrix_a(
|
||||||
|
gemm1_weights_fp8_interleaved[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m))
|
||||||
|
|
||||||
|
gemm2_weights_fp8_shuffled.append(
|
||||||
|
shuffle_matrix_a(gemm2_weights[i].view(torch.uint8),
|
||||||
|
epilogue_tile_m))
|
||||||
|
|
||||||
|
# Stack weights for all experts
|
||||||
|
gemm1_weights.data = torch.stack(gemm1_weights_fp8_shuffled).view(
|
||||||
|
torch.float8_e4m3fn)
|
||||||
|
gemm2_weights.data = torch.stack(gemm2_weights_fp8_shuffled).view(
|
||||||
|
torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_flashinfer_per_tensor_scale_fp8(
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
routing_bias: Optional[torch.Tensor],
|
||||||
|
top_k: int,
|
||||||
|
num_expert_group: Optional[int],
|
||||||
|
topk_group: Optional[int],
|
||||||
|
global_num_experts: int,
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from flashinfer.fused_moe import RoutingMethodType
|
||||||
|
|
||||||
|
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||||
|
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
|
||||||
|
"FusedMoE flashinfer kernels are only supported for Llama4"
|
||||||
|
return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||||
|
routing_logits=router_logits,
|
||||||
|
routing_bias=routing_bias,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
input_scale=layer.w13_input_scale,
|
||||||
|
gemm1_weights=layer.w13_weight,
|
||||||
|
gemm1_weights_scale=layer.w13_weight_scale,
|
||||||
|
gemm2_weights=layer.w2_weight,
|
||||||
|
gemm2_weights_scale=layer.w2_weight_scale,
|
||||||
|
activation_scale=layer.w2_input_scale,
|
||||||
|
num_experts=global_num_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
intermediate_size=layer.intermediate_size_per_partition,
|
||||||
|
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||||
|
local_num_experts=layer.local_num_experts,
|
||||||
|
use_routing_scales_on_input=apply_router_weight_on_input,
|
||||||
|
routing_method_type=RoutingMethodType.Llama4,
|
||||||
|
)
|
||||||
@ -66,6 +66,8 @@ def _lazy_import_wrapper(module_name: str,
|
|||||||
# Create lazy wrappers for each function
|
# Create lazy wrappers for each function
|
||||||
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
|
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
|
||||||
"flashinfer.fused_moe", "trtllm_fp8_block_scale_moe")
|
"flashinfer.fused_moe", "trtllm_fp8_block_scale_moe")
|
||||||
|
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
|
||||||
|
"flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe")
|
||||||
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
|
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
|
||||||
"cutlass_fused_moe")
|
"cutlass_fused_moe")
|
||||||
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
|
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user