[ROCm][BugFix] Remove the usage of device_info from aiter (#28383)

Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
Pleaplusone 2025-11-13 13:43:42 +08:00 committed by GitHub
parent d44fbbab0e
commit ca00b1bfc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -31,15 +31,14 @@ _CP_TOKENS_PER_ITER_ROCM = 32 * 1024
if current_platform.is_rocm():
import aiter
from aiter.ops.triton.utils.device_info import get_num_sms
from vllm.triton_utils import tl, triton
def block_size(x, head_dim):
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
def num_programs(head_dim):
return min(head_dim, get_num_sms())
def num_programs(total_tokens):
return min(total_tokens, current_platform.get_cu_count())
@triton.jit
def cp_mha_gather_cache_kernel(
@ -58,11 +57,11 @@ if current_platform.is_rocm():
x,
max_block_num,
num_tokens,
num_programs,
DEQUANT: tl.constexpr,
PAGE_SIZE: tl.constexpr,
CACHE_FORMAT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
NUM_PRGMS: tl.constexpr,
):
bid = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
@ -70,7 +69,7 @@ if current_platform.is_rocm():
k_scale = tl.load(k_scale_ptr)
v_scale = tl.load(v_scale_ptr)
for token_id in tl.range(bid, num_tokens, NUM_PRGMS):
for token_id in tl.range(bid, num_tokens, num_programs):
key_ptr_offset = key_ptr + token_id * head_size * num_heads
value_ptr_offset = value_ptr + token_id * head_size * num_heads
batch_idx = tl.load(token_to_batch_ptr + token_id)
@ -162,11 +161,11 @@ if current_platform.is_rocm():
x,
block_tables.size(1),
total_tokens,
NUM_PRGMS,
DEQUANT=dequant,
PAGE_SIZE=page_size,
CACHE_FORMAT=kv_cache_layout,
BLOCK_SIZE=BLOCK_SIZE,
NUM_PRGMS=NUM_PRGMS,
)