diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py index cc9e1bb23b9b3..0d82935bb4185 100644 --- a/vllm/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -137,7 +137,7 @@ def triton_reshape_and_cache_flash( # heuristics instead of autotuning TILE_SIZE = min(2048, triton.next_power_of_2(n)) - if torch.version.hip or torch.version.xpu: + if current_platform.is_rocm() or current_platform.is_xpu(): num_stages = 4 num_warps = 8 else: # cuda