[ROCm][BugFix]Fix get_cu_count in rocm_aiter_fa.py (#28618)

Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
Pleaplusone 2025-11-13 22:18:20 +08:00 committed by GitHub
parent 86d15bfd8d
commit 8da2f28f53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18,6 +18,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import get_cu_count
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
@ -38,7 +39,7 @@ if current_platform.is_rocm():
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
def num_programs(total_tokens):
return min(total_tokens, current_platform.get_cu_count())
return min(total_tokens, get_cu_count())
@triton.jit
def cp_mha_gather_cache_kernel(