Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-31 23:50:20 -07:00
parent 62d23b3006
commit af7b6c5dd4
4 changed files with 77 additions and 54 deletions

View File

@ -16,8 +16,8 @@ class SamplingMetadata:
all_greedy: bool
all_random: bool
top_p: torch.Tensor
top_k: torch.Tensor
top_p: Optional[torch.Tensor]
top_k: Optional[torch.Tensor]
generators: dict[int, torch.Generator]

View File

@ -67,7 +67,7 @@ class BlockTables:
dtype=torch.int32,
device=self.device)
self.num_blocks = torch.zeros(self.num_kv_cache_groups,
self.max_num_reqs,
self.max_num_cached_reqs,
dtype=torch.int32,
device=self.device)
self.slot_mappings = torch.zeros(self.num_kv_cache_groups,
@ -119,8 +119,7 @@ class BlockTables:
self.overwrite.np[:num_reqs] = overwrite
for i in range(self.num_kv_cache_groups):
self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i]
n = len(new_block_ids[i])
self.new_block_ids.np[i, :n] = new_block_ids[i]
self.new_block_ids.np[i, :len(new_block_ids[i])] = new_block_ids[i]
_append_block_ids_kernel[(num_reqs, self.num_kv_cache_groups)](
self.req_indices.copy_to_gpu(num_reqs),
@ -154,17 +153,17 @@ class BlockTables:
def compute_slot_mappings(
self,
cu_num_tokens: torch.Tensor,
pos: torch.Tensor,
num_tokens: int,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
num_reqs = cu_num_tokens.shape[0] - 1
num_reqs = query_start_loc.shape[0] - 1
num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_reqs + 1, num_groups)](
num_tokens,
self.max_num_batched_tokens,
cu_num_tokens,
pos,
query_start_loc,
positions,
self.block_table_ptrs,
self.block_table_strides,
self.block_sizes_tensor,
@ -188,7 +187,7 @@ def _append_block_ids_kernel(
block_table_strides, # [num_kv_cache_groups]
# Outputs
block_table_buffer_ptrs, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_ptr, # [num_kv_cache_groups, max_num_cached_reqs]
num_blocks_stride,
# Constants
BLOCK_SIZE: tl.constexpr,
@ -235,7 +234,7 @@ def _compute_block_tables_kernel(
src_block_table_ptrs, # [num_kv_cache_groups]
dst_block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_ptr, # [num_kv_cache_groups, max_num_cached_reqs]
num_blocks_stride,
BLOCK_SIZE: tl.constexpr,
):

View File

@ -262,15 +262,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: Optional[IntermediateTensors] = None
self.block_tables = BlockTables(
block_sizes=[self.cache_config.block_size],
max_num_reqs=self.max_num_reqs,
max_num_cached_reqs=2 * self.max_num_reqs,
max_num_batched_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
pin_memory=self.pin_memory,
)
self.idx_mapping = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
@ -581,7 +572,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Compute the slot mappings on GPUs.
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc, self.positions.gpu, total_num_scheduled_tokens)
query_start_loc, self.positions.gpu[:total_num_scheduled_tokens])
if self.uses_mrope:
self._calc_mrope_positions(req_ids, num_scheduled_tokens)
@ -635,8 +626,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Used in the below loop.
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
num_computed_tokens_cpu = (
self.requests.num_computed_tokens.cpu[:num_reqs])
num_computed_tokens_np = self.requests.num_computed_tokens.np[
idx_mapping_np]
num_computed_tokens_cpu = torch.from_numpy(num_computed_tokens_np)
spec_decode_common_attn_metadata = None
attn_metadata: dict[str, Any] = {}
@ -1444,16 +1436,29 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
sampler_output.sampled_token_ids = output_token_ids
for i in range(input_batch.num_reqs):
req_idx = input_batch.idx_mapping_np[i]
num_tokens = input_batch.num_scheduled_tokens[i]
self.requests.num_computed_tokens.np[req_idx] += num_tokens
num_nans_in_logits = {}
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
num_nans_in_logits = self._get_nans_in_logits(
input_batch.req_ids, logits)
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
discard_sampled_tokens_req_indices: list[int] = []
for i, req_id in enumerate(input_batch.req_ids):
req_idx = self.requests.req_id_to_index[req_id]
seq_len = (self.requests.num_computed_tokens.np[req_idx] +
input_batch.num_scheduled_tokens[i])
if seq_len < self.requests.num_tokens.np[req_idx]:
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator = self.requests.generators.get(req_idx)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors = sampler_output.logprobs_tensors
@ -1471,23 +1476,36 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids_np = sampled_token_ids.cpu().numpy()
valid_sampled_token_ids = valid_sampled_token_ids_np.tolist()
# valid_sampled_token_ids = self._to_list(sampled_token_ids)
valid_sampled_token_ids = self._to_list(sampled_token_ids)
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids, self.vocab_size)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
self.requests.append_sampled_token_ids(
input_batch.idx_mapping_np,
valid_sampled_token_ids,
)
for i, req_id in enumerate(input_batch.req_ids):
sampled_ids = valid_sampled_token_ids[i]
if not sampled_ids:
continue
req_idx = self.requests.req_id_to_index[req_id]
start_idx = self.requests.num_tokens.np[req_idx]
end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}")
self.requests.token_ids.np[req_idx,
start_idx:end_idx] = sampled_ids
self.requests.num_tokens.np[req_idx] = end_idx
if self.speculative_config:
assert input_batch.spec_decode_common_attn_metadata is not None
@ -1648,14 +1666,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
draft_token_ids.append([])
continue
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = input_batch.req_ids[i]
if req_id in self.requests.spec_decode_unsupported_reqs:
draft_token_ids.append([])
continue
# # Skip requests that require sampling parameters that are not
# # supported with speculative decoding.
# req_id = input_batch.req_ids[i]
# if req_id in self.requests.spec_decode_unsupported_reqs:
# draft_token_ids.append([])
# continue
num_tokens = self.requests.num_tokens_no_spec[i]
num_tokens = self.requests.num_tokens.np[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
@ -2824,6 +2842,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
break
def init_block_tables(self, kv_cache_config: KVCacheConfig) -> None:
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
]
self.block_tables = BlockTables(
block_sizes=block_sizes,
max_num_reqs=self.max_num_reqs,
max_num_cached_reqs=2 * self.max_num_reqs,
max_num_batched_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
pin_memory=self.pin_memory,
)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
@ -2833,6 +2866,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"""
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
self.init_block_tables(kv_cache_config)
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)

View File

@ -227,16 +227,6 @@ class RequestState:
self.token_ids.np[req_idx, start_idx:end_idx] = token_ids
self.num_tokens.np[req_idx] = end_idx
def append_sampled_token_ids(
self,
idx_mapping: np.ndarray,
sampled_token_ids: np.ndarray,
) -> None:
num_reqs = idx_mapping.shape[0]
for i in range(num_reqs):
req_idx = idx_mapping[i]
self.append_token_ids(req_idx, sampled_token_ids[i])
def remove_request(self, req_id: str) -> None:
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None: