[Perf] Fix DeepGEMM Contiguous Layout Issue, 5.5% Throughput Improvement (#24783)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Wentao Ye 2025-09-14 11:20:17 -04:00 committed by GitHub
parent fec347dee1
commit fc2dbcda8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -772,10 +772,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
if _is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = \
get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv)
if _is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = \
get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv)
# If checkpoint is fp16, quantize in place.
elif not self.quant_config.is_checkpoint_fp8_serialized:
@ -923,10 +923,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Ensure column-major TMA alignment expected by DeepGEMM.
if _is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale_inv).contiguous()
layer.w13_weight_scale_inv)
if _is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv).contiguous()
layer.w2_weight_scale_inv)
def select_gemm_impl(
self,