diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index ee5f2b51564d1..8a7e809d082b1 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -366,7 +366,7 @@ def per_token_group_quant_fp8( dtype: Optional[torch.dtype] = None, column_major_scales: bool = False, out_q: Optional[torch.Tensor] = None, - use_ue8m0: bool = is_blackwell_deep_gemm_used(), + use_ue8m0: Optional[bool] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. It converts the tensor values into signed float8 values and returns the @@ -383,6 +383,10 @@ def per_token_group_quant_fp8( tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor. """ + # TODO(wentao): refactor this + # use_ue8m0 should be a global flag that could be set by user + if use_ue8m0 is None: + use_ue8m0 = is_blackwell_deep_gemm_used() dtype = current_platform.fp8_dtype() if dtype is None else dtype assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible "