mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:15:31 +08:00
[ROCm][BugFix] Remove the usage of device_info from aiter (#28383)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
parent
d44fbbab0e
commit
ca00b1bfc6
@ -31,15 +31,14 @@ _CP_TOKENS_PER_ITER_ROCM = 32 * 1024
|
|||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
import aiter
|
import aiter
|
||||||
from aiter.ops.triton.utils.device_info import get_num_sms
|
|
||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
def block_size(x, head_dim):
|
def block_size(x, head_dim):
|
||||||
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
|
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
|
||||||
|
|
||||||
def num_programs(head_dim):
|
def num_programs(total_tokens):
|
||||||
return min(head_dim, get_num_sms())
|
return min(total_tokens, current_platform.get_cu_count())
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def cp_mha_gather_cache_kernel(
|
def cp_mha_gather_cache_kernel(
|
||||||
@ -58,11 +57,11 @@ if current_platform.is_rocm():
|
|||||||
x,
|
x,
|
||||||
max_block_num,
|
max_block_num,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
|
num_programs,
|
||||||
DEQUANT: tl.constexpr,
|
DEQUANT: tl.constexpr,
|
||||||
PAGE_SIZE: tl.constexpr,
|
PAGE_SIZE: tl.constexpr,
|
||||||
CACHE_FORMAT: tl.constexpr,
|
CACHE_FORMAT: tl.constexpr,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
NUM_PRGMS: tl.constexpr,
|
|
||||||
):
|
):
|
||||||
bid = tl.program_id(0)
|
bid = tl.program_id(0)
|
||||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
@ -70,7 +69,7 @@ if current_platform.is_rocm():
|
|||||||
k_scale = tl.load(k_scale_ptr)
|
k_scale = tl.load(k_scale_ptr)
|
||||||
v_scale = tl.load(v_scale_ptr)
|
v_scale = tl.load(v_scale_ptr)
|
||||||
|
|
||||||
for token_id in tl.range(bid, num_tokens, NUM_PRGMS):
|
for token_id in tl.range(bid, num_tokens, num_programs):
|
||||||
key_ptr_offset = key_ptr + token_id * head_size * num_heads
|
key_ptr_offset = key_ptr + token_id * head_size * num_heads
|
||||||
value_ptr_offset = value_ptr + token_id * head_size * num_heads
|
value_ptr_offset = value_ptr + token_id * head_size * num_heads
|
||||||
batch_idx = tl.load(token_to_batch_ptr + token_id)
|
batch_idx = tl.load(token_to_batch_ptr + token_id)
|
||||||
@ -162,11 +161,11 @@ if current_platform.is_rocm():
|
|||||||
x,
|
x,
|
||||||
block_tables.size(1),
|
block_tables.size(1),
|
||||||
total_tokens,
|
total_tokens,
|
||||||
|
NUM_PRGMS,
|
||||||
DEQUANT=dequant,
|
DEQUANT=dequant,
|
||||||
PAGE_SIZE=page_size,
|
PAGE_SIZE=page_size,
|
||||||
CACHE_FORMAT=kv_cache_layout,
|
CACHE_FORMAT=kv_cache_layout,
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
NUM_PRGMS=NUM_PRGMS,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user