mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:15:20 +08:00
[ROCm][BugFix]Fix get_cu_count in rocm_aiter_fa.py (#28618)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
parent
86d15bfd8d
commit
8da2f28f53
@ -18,6 +18,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
@ -38,7 +39,7 @@ if current_platform.is_rocm():
|
||||
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
|
||||
|
||||
def num_programs(total_tokens):
|
||||
return min(total_tokens, current_platform.get_cu_count())
|
||||
return min(total_tokens, get_cu_count())
|
||||
|
||||
@triton.jit
|
||||
def cp_mha_gather_cache_kernel(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user