[Bug] Fix DeepGemm Init Error (#21554)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-07-24 23:07:22 -04:00 committed by GitHub
parent b57296bb9a
commit 633f6e804b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -366,7 +366,7 @@ def per_token_group_quant_fp8(
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
out_q: Optional[torch.Tensor] = None,
use_ue8m0: bool = is_blackwell_deep_gemm_used(),
use_ue8m0: Optional[bool] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
@ -383,6 +383,10 @@ def per_token_group_quant_fp8(
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor.
"""
# TODO(wentao): refactor this
# use_ue8m0 should be a global flag that could be set by user
if use_ue8m0 is None:
use_ue8m0 = is_blackwell_deep_gemm_used()
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "