diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 70ac6688deb7f..70a580b9c4c70 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -11,7 +11,8 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.triton_utils import tl, triton -from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked +from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, + is_blackwell_deep_gemm_used) logger = init_logger(__name__) @@ -50,6 +51,7 @@ def _silu_mul_fp8_quant_deep_gemm( eps: tl.constexpr, fp8_min: tl.constexpr, fp8_max: tl.constexpr, + use_ue8m0: tl.constexpr, # Meta --------------------------------------------------------------- BLOCK: tl.constexpr, @@ -92,7 +94,9 @@ def _silu_mul_fp8_quant_deep_gemm( y = x * y2 _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - y_s = _absmax / fp8_max + scale_raw = _absmax / fp8_max + y_s = tl.math.exp2(tl.ceil( + tl.log2(scale_raw))) if use_ue8m0 else scale_raw y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask) @@ -174,6 +178,7 @@ def silu_mul_fp8_quant_deep_gemm( eps, fp8_min, fp8_max, + is_blackwell_deep_gemm_used(), BLOCK=group_size, num_warps=4, ) @@ -290,14 +295,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # may lead to better performance. expected_m = max_num_tokens fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale), - out=workspace1, - masked_m=expert_num_tokens, - expected_m=expected_m) + workspace1, expert_num_tokens, expected_m) a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, expert_num_tokens) - fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), - out=output, - masked_m=expert_num_tokens, - expected_m=expected_m) + fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output, + expert_num_tokens, expected_m)