diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 49ba476d78b62..e0478c2aebdaa 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -753,7 +753,7 @@ if triton.__version__ >= "2.1.0": assert (v_cache.dtype == torch.uint8) if kv_cache_dtype in ("fp8", "fp8_e4m3"): - target_dtype = torch.float8_e4m3fn + target_dtype = current_platform.fp8_dtype() elif kv_cache_dtype == "fp8_e5m2": target_dtype = torch.float8_e5m2 else: