mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:45:01 +08:00
Remove cuda hard-code in compute_causal_conv1d_metadata (#25555)
Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
parent
99b3a504c5
commit
dd70437a4f
@ -947,6 +947,7 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
|
|||||||
nums_dict = {} # type: ignore
|
nums_dict = {} # type: ignore
|
||||||
batch_ptr = None
|
batch_ptr = None
|
||||||
token_chunk_offset_ptr = None
|
token_chunk_offset_ptr = None
|
||||||
|
device = query_start_loc_p.device
|
||||||
for BLOCK_M in [8]: # cover all BLOCK_M values
|
for BLOCK_M in [8]: # cover all BLOCK_M values
|
||||||
nums = -(-seqlens // BLOCK_M)
|
nums = -(-seqlens // BLOCK_M)
|
||||||
nums_dict[BLOCK_M] = {}
|
nums_dict[BLOCK_M] = {}
|
||||||
@ -968,11 +969,11 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
|
|||||||
batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
batch_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
||||||
PAD_SLOT_ID,
|
PAD_SLOT_ID,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device='cuda')
|
device=device)
|
||||||
token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ),
|
||||||
PAD_SLOT_ID,
|
PAD_SLOT_ID,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device='cuda')
|
device=device)
|
||||||
else:
|
else:
|
||||||
if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
|
if batch_ptr.nelement() < MAX_NUM_PROGRAMS:
|
||||||
batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
|
batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user