mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 19:34:30 +08:00
[ROCm][Bugfix] Use platform specific FP8 dtype (#15717)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
parent
ef608c37a7
commit
40a36ccfeb
@ -753,7 +753,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
assert (v_cache.dtype == torch.uint8)
|
assert (v_cache.dtype == torch.uint8)
|
||||||
|
|
||||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||||
target_dtype = torch.float8_e4m3fn
|
target_dtype = current_platform.fp8_dtype()
|
||||||
elif kv_cache_dtype == "fp8_e5m2":
|
elif kv_cache_dtype == "fp8_e5m2":
|
||||||
target_dtype = torch.float8_e5m2
|
target_dtype = torch.float8_e5m2
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user