[Model Runner V2] Change Numba AoT to JIT (#29328)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-11-24 09:34:37 -08:00 committed by GitHub
parent cc313cb73d
commit cec418b5df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 63 deletions

View File

@ -4,7 +4,6 @@ from dataclasses import dataclass
from typing import Any
import numba
import numba.types as types
import numpy as np
import torch
@ -147,80 +146,42 @@ class InputBatch:
)
# NOTE: With the type annotations, this function is pre-compiled
# before the first call.
@numba.jit(
[
types.none(
types.int32[:], # idx_mapping
types.int32[:], # num_scheduled_tokens
types.int32[:, :], # prefill_token_ids
types.int32[:], # num_computed_prefill_tokens
types.int32[:], # prefill_len
types.int32[:], # input_ids
types.int32[:], # query_start_loc
)
],
nopython=True,
cache=True,
)
@numba.njit(cache=True)
def _prepare_prefill_inputs(
idx_mapping: np.ndarray, # batch_idx -> req_idx
num_scheduled_tokens: np.ndarray, # [B]
idx_mapping: np.ndarray, # [B]
query_lens: np.ndarray, # [B]
query_start_loc: np.ndarray, # [B + 1]
prefill_token_ids: np.ndarray, # [N, max_model_len]
num_computed_prefill_tokens: np.ndarray, # [N]
prefill_len: np.ndarray, # [N]
input_ids: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0
cu_num_tokens = 0
num_reqs = idx_mapping.shape[0]
query_starts = query_start_loc[:num_reqs]
query_ends = query_start_loc[1 : num_reqs + 1]
starts = num_computed_prefill_tokens[idx_mapping]
ends = starts + query_lens
for i in range(num_reqs):
req_idx = idx_mapping[i]
query_len = num_scheduled_tokens[i]
start = num_computed_prefill_tokens[req_idx]
end = min(start + query_len, prefill_len[req_idx])
n = end - start
start_idx = cu_num_tokens
input_ids[start_idx : start_idx + n] = prefill_token_ids[req_idx, start:end]
cu_num_tokens = start_idx + query_len
query_start_loc[i + 1] = cu_num_tokens
# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1 :].fill(cu_num_tokens)
input_ids[query_starts[i] : query_ends[i]] = prefill_token_ids[
idx_mapping[i], starts[i] : ends[i]
]
def prepare_prefill_inputs(
idx_mapping: np.ndarray,
num_scheduled_tokens: np.ndarray,
total_num_tokens: int,
query_start_loc: np.ndarray,
prefill_token_ids: np.ndarray,
num_computed_prefill_tokens: np.ndarray,
prefill_len: np.ndarray,
input_ids: CpuGpuBuffer,
query_start_loc: CpuGpuBuffer,
input_ids: np.ndarray,
) -> None:
_prepare_prefill_inputs(
idx_mapping,
num_scheduled_tokens,
query_start_loc,
prefill_token_ids,
num_computed_prefill_tokens,
prefill_len,
input_ids.np,
query_start_loc.np,
input_ids,
)
input_ids.copy_to_gpu(total_num_tokens)
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
# tensors from CPU to GPU, because they may include paddings needed
# for full CUDA graph mode.
query_start_loc.copy_to_gpu()
@triton.jit

View File

@ -502,20 +502,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
# Copy prefill tokens from CPU to GPU and get query_start_loc.
# Get query_start_loc.
np.cumsum(
num_scheduled_tokens,
out=self.input_buffers.query_start_loc.np[1 : num_reqs + 1],
)
# Pad for full CUDA graph mode.
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
self.input_buffers.query_start_loc.copy_to_gpu()
query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
# Copy prefill tokens from CPU to GPU.
prepare_prefill_inputs(
idx_mapping_np,
num_scheduled_tokens,
num_tokens,
query_start_loc_np,
self.req_states.prefill_token_ids,
self.req_states.num_computed_prefill_tokens,
self.req_states.prefill_len.np,
self.input_buffers.input_ids,
self.input_buffers.query_start_loc,
self.input_buffers.input_ids.np,
)
query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_np = query_start_loc.np[: num_reqs + 1]
self.input_buffers.input_ids.copy_to_gpu(num_tokens)
# Prepare positions and seq_lens.
prepare_pos_seq_lens(