mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 07:15:01 +08:00
[ROCm] Add attention sink to use_rocm_custom_paged_attention (#22329)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
parent
de98252f49
commit
98a3a81024
@ -127,7 +127,8 @@ def use_rocm_custom_paged_attention(
|
|||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
sliding_window: int,
|
sliding_window: int,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
alibi_slopes: Optional[torch.Tensor] = None) -> bool:
|
alibi_slopes: Optional[torch.Tensor] = None,
|
||||||
|
sinks: Optional[torch.Tensor] = None) -> bool:
|
||||||
|
|
||||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||||
@ -145,7 +146,7 @@ def use_rocm_custom_paged_attention(
|
|||||||
and max_seq_len <= 128 * 1024
|
and max_seq_len <= 128 * 1024
|
||||||
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||||
and envs.VLLM_ROCM_USE_AITER))
|
and envs.VLLM_ROCM_USE_AITER) and sinks is None)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
|
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||||
@ -155,7 +156,7 @@ def use_rocm_custom_paged_attention(
|
|||||||
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
||||||
and max_seq_len <= 128 * 1024 and alibi_slopes is None
|
and max_seq_len <= 128 * 1024 and alibi_slopes is None
|
||||||
and kv_cache_dtype == "auto"
|
and kv_cache_dtype == "auto"
|
||||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
|
||||||
|
|
||||||
|
|
||||||
class RocmPlatform(Platform):
|
class RocmPlatform(Platform):
|
||||||
@ -170,7 +171,7 @@ class RocmPlatform(Platform):
|
|||||||
|
|
||||||
supported_quantization: list[str] = [
|
supported_quantization: list[str] = [
|
||||||
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
|
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
|
||||||
"quark", "ptpc_fp8"
|
"quark", "ptpc_fp8", "mxfp4"
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user