diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 54ffc83cd565a..d26e4b3350381 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -127,7 +127,8 @@ def use_rocm_custom_paged_attention( max_seq_len: int, sliding_window: int, kv_cache_dtype: str, - alibi_slopes: Optional[torch.Tensor] = None) -> bool: + alibi_slopes: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None) -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) @@ -145,7 +146,7 @@ def use_rocm_custom_paged_attention( and max_seq_len <= 128 * 1024 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN - and envs.VLLM_ROCM_USE_AITER)) + and envs.VLLM_ROCM_USE_AITER) and sinks is None) else: return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0 @@ -155,7 +156,7 @@ def use_rocm_custom_paged_attention( and (gqa_ratio >= 3 and gqa_ratio <= 16) and max_seq_len <= 128 * 1024 and alibi_slopes is None and kv_cache_dtype == "auto" - and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None) class RocmPlatform(Platform): @@ -170,7 +171,7 @@ class RocmPlatform(Platform): supported_quantization: list[str] = [ "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", - "quark", "ptpc_fp8" + "quark", "ptpc_fp8", "mxfp4" ] @classmethod @@ -469,4 +470,4 @@ class RocmPlatform(Platform): @classmethod def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: - return True \ No newline at end of file + return True