mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 14:55:49 +08:00
[ROCm][BugFix]Fix get_cu_count in rocm_aiter_fa.py (#28618)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
parent
86d15bfd8d
commit
8da2f28f53
@ -18,6 +18,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
|
from vllm.utils.platform_utils import get_cu_count
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
@ -38,7 +39,7 @@ if current_platform.is_rocm():
|
|||||||
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
|
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
|
||||||
|
|
||||||
def num_programs(total_tokens):
|
def num_programs(total_tokens):
|
||||||
return min(total_tokens, current_platform.get_cu_count())
|
return min(total_tokens, get_cu_count())
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def cp_mha_gather_cache_kernel(
|
def cp_mha_gather_cache_kernel(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user