From af7b6c5dd440be67005c10bc86c94d13495fe7d4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 31 Aug 2025 23:50:20 -0700 Subject: [PATCH] fix Signed-off-by: Woosuk Kwon --- vllm/v1/sample/metadata.py | 4 +- vllm/v1/worker/gpu_block_table.py | 21 +++---- vllm/v1/worker/gpu_model_runner.py | 96 +++++++++++++++++++---------- vllm/v1/worker/gpu_worker_states.py | 10 --- 4 files changed, 77 insertions(+), 54 deletions(-) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 2059eac4bad20..4749542cb6306 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -16,8 +16,8 @@ class SamplingMetadata: all_greedy: bool all_random: bool - top_p: torch.Tensor - top_k: torch.Tensor + top_p: Optional[torch.Tensor] + top_k: Optional[torch.Tensor] generators: dict[int, torch.Generator] diff --git a/vllm/v1/worker/gpu_block_table.py b/vllm/v1/worker/gpu_block_table.py index 01700db523887..f7cb7edf43ee2 100644 --- a/vllm/v1/worker/gpu_block_table.py +++ b/vllm/v1/worker/gpu_block_table.py @@ -67,7 +67,7 @@ class BlockTables: dtype=torch.int32, device=self.device) self.num_blocks = torch.zeros(self.num_kv_cache_groups, - self.max_num_reqs, + self.max_num_cached_reqs, dtype=torch.int32, device=self.device) self.slot_mappings = torch.zeros(self.num_kv_cache_groups, @@ -119,8 +119,7 @@ class BlockTables: self.overwrite.np[:num_reqs] = overwrite for i in range(self.num_kv_cache_groups): self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i] - n = len(new_block_ids[i]) - self.new_block_ids.np[i, :n] = new_block_ids[i] + self.new_block_ids.np[i, :len(new_block_ids[i])] = new_block_ids[i] _append_block_ids_kernel[(num_reqs, self.num_kv_cache_groups)]( self.req_indices.copy_to_gpu(num_reqs), @@ -154,17 +153,17 @@ class BlockTables: def compute_slot_mappings( self, - cu_num_tokens: torch.Tensor, - pos: torch.Tensor, - num_tokens: int, + query_start_loc: torch.Tensor, + positions: torch.Tensor, ) -> tuple[torch.Tensor, ...]: - num_reqs = cu_num_tokens.shape[0] - 1 + num_reqs = query_start_loc.shape[0] - 1 + num_tokens = positions.shape[0] num_groups = self.num_kv_cache_groups _compute_slot_mappings_kernel[(num_reqs + 1, num_groups)]( num_tokens, self.max_num_batched_tokens, - cu_num_tokens, - pos, + query_start_loc, + positions, self.block_table_ptrs, self.block_table_strides, self.block_sizes_tensor, @@ -188,7 +187,7 @@ def _append_block_ids_kernel( block_table_strides, # [num_kv_cache_groups] # Outputs block_table_buffer_ptrs, # [num_kv_cache_groups] - num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs] + num_blocks_ptr, # [num_kv_cache_groups, max_num_cached_reqs] num_blocks_stride, # Constants BLOCK_SIZE: tl.constexpr, @@ -235,7 +234,7 @@ def _compute_block_tables_kernel( src_block_table_ptrs, # [num_kv_cache_groups] dst_block_table_ptrs, # [num_kv_cache_groups] block_table_strides, # [num_kv_cache_groups] - num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs] + num_blocks_ptr, # [num_kv_cache_groups, max_num_cached_reqs] num_blocks_stride, BLOCK_SIZE: tl.constexpr, ): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3892208f44ffa..489f31c50ff97 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -262,15 +262,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None - self.block_tables = BlockTables( - block_sizes=[self.cache_config.block_size], - max_num_reqs=self.max_num_reqs, - max_num_cached_reqs=2 * self.max_num_reqs, - max_num_batched_tokens=self.max_num_tokens, - max_model_len=self.max_model_len, - device=self.device, - pin_memory=self.pin_memory, - ) self.idx_mapping = self._make_buffer(self.max_num_reqs, dtype=torch.int32) @@ -581,7 +572,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Compute the slot mappings on GPUs. slot_mappings = self.block_tables.compute_slot_mappings( - query_start_loc, self.positions.gpu, total_num_scheduled_tokens) + query_start_loc, self.positions.gpu[:total_num_scheduled_tokens]) if self.uses_mrope: self._calc_mrope_positions(req_ids, num_scheduled_tokens) @@ -635,8 +626,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Used in the below loop. query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] seq_lens_cpu = self.seq_lens.cpu[:num_reqs] - num_computed_tokens_cpu = ( - self.requests.num_computed_tokens.cpu[:num_reqs]) + num_computed_tokens_np = self.requests.num_computed_tokens.np[ + idx_mapping_np] + num_computed_tokens_cpu = torch.from_numpy(num_computed_tokens_np) spec_decode_common_attn_metadata = None attn_metadata: dict[str, Any] = {} @@ -1444,16 +1436,29 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) sampler_output.sampled_token_ids = output_token_ids - for i in range(input_batch.num_reqs): - req_idx = input_batch.idx_mapping_np[i] - num_tokens = input_batch.num_scheduled_tokens[i] - self.requests.num_computed_tokens.np[req_idx] += num_tokens - num_nans_in_logits = {} if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits( input_batch.req_ids, logits) + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + discard_sampled_tokens_req_indices: list[int] = [] + for i, req_id in enumerate(input_batch.req_ids): + req_idx = self.requests.req_id_to_index[req_id] + seq_len = (self.requests.num_computed_tokens.np[req_idx] + + input_batch.num_scheduled_tokens[i]) + if seq_len < self.requests.num_tokens.np[req_idx]: + # Ignore the sampled token for partial prefills. + # Rewind the generator state as if the token was not sampled. + # This relies on cuda-specific torch-internal impl details + generator = self.requests.generators.get(req_idx) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + # Record the index of the request that should not be sampled, + # so that we could clear the sampled tokens before returning. + discard_sampled_tokens_req_indices.append(i) + # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors @@ -1471,23 +1476,36 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: # No spec decode tokens. - valid_sampled_token_ids_np = sampled_token_ids.cpu().numpy() - valid_sampled_token_ids = valid_sampled_token_ids_np.tolist() - # valid_sampled_token_ids = self._to_list(sampled_token_ids) + valid_sampled_token_ids = self._to_list(sampled_token_ids) else: # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( sampled_token_ids, self.vocab_size) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. # NOTE(woosuk): As an exception, when using PP, the scheduler sends # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. - self.requests.append_sampled_token_ids( - input_batch.idx_mapping_np, - valid_sampled_token_ids, - ) + for i, req_id in enumerate(input_batch.req_ids): + sampled_ids = valid_sampled_token_ids[i] + if not sampled_ids: + continue + req_idx = self.requests.req_id_to_index[req_id] + + start_idx = self.requests.num_tokens.np[req_idx] + end_idx = start_idx + len(sampled_ids) + assert end_idx <= self.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.max_model_len}") + + self.requests.token_ids.np[req_idx, + start_idx:end_idx] = sampled_ids + self.requests.num_tokens.np[req_idx] = end_idx if self.speculative_config: assert input_batch.spec_decode_common_attn_metadata is not None @@ -1648,14 +1666,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): draft_token_ids.append([]) continue - # Skip requests that require sampling parameters that are not - # supported with speculative decoding. - req_id = input_batch.req_ids[i] - if req_id in self.requests.spec_decode_unsupported_reqs: - draft_token_ids.append([]) - continue + # # Skip requests that require sampling parameters that are not + # # supported with speculative decoding. + # req_id = input_batch.req_ids[i] + # if req_id in self.requests.spec_decode_unsupported_reqs: + # draft_token_ids.append([]) + # continue - num_tokens = self.requests.num_tokens_no_spec[i] + num_tokens = self.requests.num_tokens.np[i] if num_tokens >= self.max_model_len: # Skip requests that have already reached the max model length. draft_token_ids.append([]) @@ -2824,6 +2842,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: break + def init_block_tables(self, kv_cache_config: KVCacheConfig) -> None: + block_sizes = [ + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in kv_cache_config.kv_cache_groups + ] + self.block_tables = BlockTables( + block_sizes=block_sizes, + max_num_reqs=self.max_num_reqs, + max_num_cached_reqs=2 * self.max_num_reqs, + max_num_batched_tokens=self.max_num_tokens, + max_model_len=self.max_model_len, + device=self.device, + pin_memory=self.pin_memory, + ) + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -2833,6 +2866,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config + self.init_block_tables(kv_cache_config) self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) diff --git a/vllm/v1/worker/gpu_worker_states.py b/vllm/v1/worker/gpu_worker_states.py index ae3829291413d..c23923423a2cb 100644 --- a/vllm/v1/worker/gpu_worker_states.py +++ b/vllm/v1/worker/gpu_worker_states.py @@ -227,16 +227,6 @@ class RequestState: self.token_ids.np[req_idx, start_idx:end_idx] = token_ids self.num_tokens.np[req_idx] = end_idx - def append_sampled_token_ids( - self, - idx_mapping: np.ndarray, - sampled_token_ids: np.ndarray, - ) -> None: - num_reqs = idx_mapping.shape[0] - for i in range(num_reqs): - req_idx = idx_mapping[i] - self.append_token_ids(req_idx, sampled_token_ids[i]) - def remove_request(self, req_id: str) -> None: req_idx = self.req_id_to_index.pop(req_id, None) if req_idx is None: