mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 03:55:01 +08:00
[Bug] Fix DeepGemm Init Error (#21554)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
b57296bb9a
commit
633f6e804b
@ -366,7 +366,7 @@ def per_token_group_quant_fp8(
|
|||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
column_major_scales: bool = False,
|
column_major_scales: bool = False,
|
||||||
out_q: Optional[torch.Tensor] = None,
|
out_q: Optional[torch.Tensor] = None,
|
||||||
use_ue8m0: bool = is_blackwell_deep_gemm_used(),
|
use_ue8m0: Optional[bool] = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||||
It converts the tensor values into signed float8 values and returns the
|
It converts the tensor values into signed float8 values and returns the
|
||||||
@ -383,6 +383,10 @@ def per_token_group_quant_fp8(
|
|||||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||||
scaling factor.
|
scaling factor.
|
||||||
"""
|
"""
|
||||||
|
# TODO(wentao): refactor this
|
||||||
|
# use_ue8m0 should be a global flag that could be set by user
|
||||||
|
if use_ue8m0 is None:
|
||||||
|
use_ue8m0 = is_blackwell_deep_gemm_used()
|
||||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||||
assert (x.shape[-1] % group_size == 0), (
|
assert (x.shape[-1] % group_size == 0), (
|
||||||
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user