[Bug] Fix Weight Loading for Block FP8 Cutlass SM90 (#25909)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Wentao Ye 2025-09-29 21:15:19 -04:00 committed by GitHub
parent 78a47f87ce
commit 89e4050af4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -911,15 +911,15 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
# requantize the weight and input to the specific scale
# at the same time.
if is_deep_gemm_e8m0_used():
should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
layer.orig_dtype, layer.weight)
if is_deep_gemm_e8m0_used() and should_use_deepgemm:
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(layer.weight.data,
layer.weight_scale.data, block_sz)
# SM90 Block FP8 CUTLASS requires row-major weight scales
elif (current_platform.is_device_capability(90)
and cutlass_block_fp8_supported
and not should_use_deepgemm_for_fp8_linear(torch.bfloat16,
layer.weight)):
and cutlass_block_fp8_supported and not should_use_deepgemm):
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data.T.contiguous(), requires_grad=False)