mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-05 13:51:24 +08:00
[Bug Fix] Fix the support check for FP8 CUTLASS (#5352)
Bug description: With torch 2.4.0.dev20240603+cu121, cutlass_fp8_supported outputs False, and the (capability, version) before the comparison is (90, 11111111112) This PR fixes the support check for FP8 CUTLASS ( cutlass_fp8_supported) which was introduced in https://github.com/vllm-project/vllm/pull/5183.
This commit is contained in:
parent
767c727a81
commit
e69ded7d1c
@ -20,16 +20,16 @@ logger = init_logger(__name__)
|
|||||||
def cutlass_fp8_supported() -> bool:
|
def cutlass_fp8_supported() -> bool:
|
||||||
capability = torch.cuda.get_device_capability()
|
capability = torch.cuda.get_device_capability()
|
||||||
capability = capability[0] * 10 + capability[1]
|
capability = capability[0] * 10 + capability[1]
|
||||||
version = torch.version.cuda
|
major, minor = torch.version.cuda.split(".")
|
||||||
version = version[0] * 10 + version[1]
|
version = int(major) * 10 + int(minor)
|
||||||
|
|
||||||
# CUTLASS FP8 kernels need at least
|
# CUTLASS FP8 kernels need at least
|
||||||
# CUDA 12.0 on SM90 systems (Hopper)
|
# CUDA 12.0 on SM90 systems (Hopper)
|
||||||
# CUDA 12.4 on SM89 systems (Lovelace)
|
# CUDA 12.4 on SM89 systems (Lovelace)
|
||||||
gpu_is_supported = False
|
gpu_is_supported = False
|
||||||
if capability >= 900:
|
if capability >= 90:
|
||||||
gpu_is_supported = version > 120
|
gpu_is_supported = version > 120
|
||||||
elif capability >= 890:
|
elif capability >= 89:
|
||||||
gpu_is_supported = version > 124
|
gpu_is_supported = version > 124
|
||||||
|
|
||||||
return gpu_is_supported
|
return gpu_is_supported
|
||||||
@ -103,7 +103,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
||||||
2. Only support float8_e4m3fn data type due to the limitation of
|
2. Only support float8_e4m3fn data type due to the limitation of
|
||||||
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
|
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
quant_config: The quantization config.
|
quant_config: The quantization config.
|
||||||
"""
|
"""
|
||||||
@ -298,8 +298,8 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module):
|
def create_weights(self, layer: torch.nn.Module):
|
||||||
"""Create "weight" (aka kv_scale) for an attention layer.
|
"""Create "weight" (aka kv_scale) for an attention layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer: The layer that is using the QuantizeMethodBase factory.
|
layer: The layer that is using the QuantizeMethodBase factory.
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user