[ROCm][Bugfix] Use platform specific FP8 dtype (#15717)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg 2025-04-04 12:40:20 -04:00 committed by GitHub
parent ef608c37a7
commit 40a36ccfeb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -753,7 +753,7 @@ if triton.__version__ >= "2.1.0":
assert (v_cache.dtype == torch.uint8)
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = torch.float8_e4m3fn
target_dtype = current_platform.fp8_dtype()
elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2
else: