[Model Runner V2] Change Numba AoT to JIT (#29328)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-11-24 09:34:37 -08:00 committed by GitHub
parent cc313cb73d
commit cec418b5df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 63 deletions

View File

@ -4,7 +4,6 @@ from dataclasses import dataclass
from typing import Any from typing import Any
import numba import numba
import numba.types as types
import numpy as np import numpy as np
import torch import torch
@ -147,80 +146,42 @@ class InputBatch:
) )
# NOTE: With the type annotations, this function is pre-compiled @numba.njit(cache=True)
# before the first call.
@numba.jit(
[
types.none(
types.int32[:], # idx_mapping
types.int32[:], # num_scheduled_tokens
types.int32[:, :], # prefill_token_ids
types.int32[:], # num_computed_prefill_tokens
types.int32[:], # prefill_len
types.int32[:], # input_ids
types.int32[:], # query_start_loc
)
],
nopython=True,
cache=True,
)
def _prepare_prefill_inputs( def _prepare_prefill_inputs(
idx_mapping: np.ndarray, # batch_idx -> req_idx idx_mapping: np.ndarray, # [B]
num_scheduled_tokens: np.ndarray, # [B] query_lens: np.ndarray, # [B]
query_start_loc: np.ndarray, # [B + 1]
prefill_token_ids: np.ndarray, # [N, max_model_len] prefill_token_ids: np.ndarray, # [N, max_model_len]
num_computed_prefill_tokens: np.ndarray, # [N] num_computed_prefill_tokens: np.ndarray, # [N]
prefill_len: np.ndarray, # [N]
input_ids: np.ndarray, # [num_input_tokens] input_ids: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
) -> None: ) -> None:
num_reqs = num_scheduled_tokens.shape[0] num_reqs = idx_mapping.shape[0]
query_start_loc[0] = 0 query_starts = query_start_loc[:num_reqs]
query_ends = query_start_loc[1 : num_reqs + 1]
cu_num_tokens = 0 starts = num_computed_prefill_tokens[idx_mapping]
ends = starts + query_lens
for i in range(num_reqs): for i in range(num_reqs):
req_idx = idx_mapping[i] input_ids[query_starts[i] : query_ends[i]] = prefill_token_ids[
query_len = num_scheduled_tokens[i] idx_mapping[i], starts[i] : ends[i]
]
start = num_computed_prefill_tokens[req_idx]
end = min(start + query_len, prefill_len[req_idx])
n = end - start
start_idx = cu_num_tokens
input_ids[start_idx : start_idx + n] = prefill_token_ids[req_idx, start:end]
cu_num_tokens = start_idx + query_len
query_start_loc[i + 1] = cu_num_tokens
# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1 :].fill(cu_num_tokens)
def prepare_prefill_inputs( def prepare_prefill_inputs(
idx_mapping: np.ndarray, idx_mapping: np.ndarray,
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
total_num_tokens: int, query_start_loc: np.ndarray,
prefill_token_ids: np.ndarray, prefill_token_ids: np.ndarray,
num_computed_prefill_tokens: np.ndarray, num_computed_prefill_tokens: np.ndarray,
prefill_len: np.ndarray, input_ids: np.ndarray,
input_ids: CpuGpuBuffer,
query_start_loc: CpuGpuBuffer,
) -> None: ) -> None:
_prepare_prefill_inputs( _prepare_prefill_inputs(
idx_mapping, idx_mapping,
num_scheduled_tokens, num_scheduled_tokens,
query_start_loc,
prefill_token_ids, prefill_token_ids,
num_computed_prefill_tokens, num_computed_prefill_tokens,
prefill_len, input_ids,
input_ids.np,
query_start_loc.np,
) )
input_ids.copy_to_gpu(total_num_tokens)
# NOTE(woosuk): We should copy the whole query_start_loc and seq_lens
# tensors from CPU to GPU, because they may include paddings needed
# for full CUDA graph mode.
query_start_loc.copy_to_gpu()
@triton.jit @triton.jit

View File

@ -502,20 +502,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping) block_tables = self.block_tables.gather_block_tables(idx_mapping)
# Copy prefill tokens from CPU to GPU and get query_start_loc. # Get query_start_loc.
np.cumsum(
num_scheduled_tokens,
out=self.input_buffers.query_start_loc.np[1 : num_reqs + 1],
)
# Pad for full CUDA graph mode.
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
self.input_buffers.query_start_loc.copy_to_gpu()
query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
# Copy prefill tokens from CPU to GPU.
prepare_prefill_inputs( prepare_prefill_inputs(
idx_mapping_np, idx_mapping_np,
num_scheduled_tokens, num_scheduled_tokens,
num_tokens, query_start_loc_np,
self.req_states.prefill_token_ids, self.req_states.prefill_token_ids,
self.req_states.num_computed_prefill_tokens, self.req_states.num_computed_prefill_tokens,
self.req_states.prefill_len.np, self.input_buffers.input_ids.np,
self.input_buffers.input_ids,
self.input_buffers.query_start_loc,
) )
query_start_loc = self.input_buffers.query_start_loc self.input_buffers.input_ids.copy_to_gpu(num_tokens)
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_np = query_start_loc.np[: num_reqs + 1]
# Prepare positions and seq_lens. # Prepare positions and seq_lens.
prepare_pos_seq_lens( prepare_pos_seq_lens(