[Bug Fix] Fix the support check for FP8 CUTLASS (#5352)

Bug description:
With torch 2.4.0.dev20240603+cu121,
cutlass_fp8_supported outputs False, and the (capability, version) before the comparison is (90, 11111111112)

This PR fixes the support check for FP8 CUTLASS ( cutlass_fp8_supported) which was introduced in https://github.com/vllm-project/vllm/pull/5183.
This commit is contained in:
Cheng Li 2024-06-07 17:42:05 -07:00 committed by GitHub
parent 767c727a81
commit e69ded7d1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -20,16 +20,16 @@ logger = init_logger(__name__)
def cutlass_fp8_supported() -> bool: def cutlass_fp8_supported() -> bool:
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
version = torch.version.cuda major, minor = torch.version.cuda.split(".")
version = version[0] * 10 + version[1] version = int(major) * 10 + int(minor)
# CUTLASS FP8 kernels need at least # CUTLASS FP8 kernels need at least
# CUDA 12.0 on SM90 systems (Hopper) # CUDA 12.0 on SM90 systems (Hopper)
# CUDA 12.4 on SM89 systems (Lovelace) # CUDA 12.4 on SM89 systems (Lovelace)
gpu_is_supported = False gpu_is_supported = False
if capability >= 900: if capability >= 90:
gpu_is_supported = version > 120 gpu_is_supported = version > 120
elif capability >= 890: elif capability >= 89:
gpu_is_supported = version > 124 gpu_is_supported = version > 124
return gpu_is_supported return gpu_is_supported