From efba25e21a55f8f0641fb7dec62233a990ff4b92 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 28 Aug 2025 12:39:15 -0700 Subject: [PATCH] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/metadata.py | 2 + vllm/v1/worker/gpu_block_table.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 82 +++++------------------------ vllm/v1/worker/gpu_worker_states.py | 2 + 4 files changed, 17 insertions(+), 71 deletions(-) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index b62b6e5c331ce..2059eac4bad20 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -30,6 +30,8 @@ class SamplingMetadata: repetition_penalties: torch.Tensor token_ids: Optional[torch.Tensor] + num_tokens: Optional[torch.Tensor] + num_prompt_tokens: Optional[torch.Tensor] # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size, # vocab size). diff --git a/vllm/v1/worker/gpu_block_table.py b/vllm/v1/worker/gpu_block_table.py index a0d39c511ace6..f828ea5009c59 100644 --- a/vllm/v1/worker/gpu_block_table.py +++ b/vllm/v1/worker/gpu_block_table.py @@ -7,7 +7,7 @@ import triton import triton.language as tl from vllm.utils import cdiv -from vllm.v1.worker.utils import CpuGpuBuffer +from vllm.v1.utils import CpuGpuBuffer PAD_SLOT_ID = -1 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bdf66bc14dd54..14f78f927df43 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -446,12 +446,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): new_token_ids = cached_reqs.new_token_ids[i] self.requests.append_token_ids(req_index, new_token_ids) - if cached_reqs.new_block_ids[i] is not None: + req_new_block_ids = cached_reqs.new_block_ids[i] + if req_new_block_ids is not None: req_indices.append(req_index) - for i, block_ids in enumerate(cached_reqs.new_block_ids[i]): - x = cu_num_new_blocks[i][-1] - cu_num_new_blocks[i].append(x + len(block_ids)) - new_block_ids[i].extend(block_ids) + for group_id, block_ids in enumerate(req_new_block_ids): + x = cu_num_new_blocks[group_id][-1] + cu_num_new_blocks[group_id].append(x + len(block_ids)) + new_block_ids[group_id].extend(block_ids) # If the request is resumed from preemption, we need to # overwrite the existing block IDs. overwrite.append(cached_reqs.resumed_from_preemption[i]) @@ -1686,7 +1687,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtype=torch.int32) common_attn_metadata, token_indices =\ self.drafter.prepare_inputs( - common_attn_metadata, num_rejected_tokens_cpu) + input_batch.spec_decode_common_attn_metadata, + num_rejected_tokens_cpu) target_token_ids = self.input_ids.gpu[token_indices] # TODO(woosuk): Support M-RoPE. @@ -2142,10 +2144,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_actual_tokens=num_tokens, max_query_len=max_query_len, max_seq_len=self.max_model_len, - block_table_tensor=self.requests. - block_tables[kv_cache_group_id].gpu[:num_reqs], - slot_mapping=self.requests.slot_mappings[kv_cache_group_id] - [:num_tokens], + block_table_tensor=self.block_tables. + block_tables[kv_cache_group_id][:num_reqs], + slot_mapping=self.block_tables. + slot_mappings[kv_cache_group_id][:num_tokens], causal=True, ) @@ -2607,9 +2609,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.attn_groups.append( create_attn_groups(attn_backends, kv_cache_spec)) - # Calculate reorder batch threshold (if neeeded) - self.calculate_reorder_batch_threshold() - def initialize_cudagraph_capture(self) -> None: min_cg_support = AttentionCGSupport.ALWAYS min_cg_builder_name = None @@ -2679,62 +2678,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.compilation_config.cudagraph_mode, self.uniform_decode_query_len) - def calculate_reorder_batch_threshold(self) -> None: - """ - Check that if any backends reorder batches; that the reordering - is compatible (e.g., decode threshold is the same) - """ - for group in self._attn_group_iterator(): - attn_metadata_builder_i = group.metadata_builder - - # check that if any backends reorder batches; that the reordering - # is compatible (e.g., decode threshold is the same) - reorder_batch_threshold_i = ( - attn_metadata_builder_i.reorder_batch_threshold) - if reorder_batch_threshold_i is not None: - if self.reorder_batch_threshold is not None: - if reorder_batch_threshold_i != \ - self.reorder_batch_threshold: - raise ValueError( - f"Attention backend reorders decodes with " - f"threshold {reorder_batch_threshold_i} but other " - f"backend uses threshold " - f"{self.reorder_batch_threshold}") - else: - self.reorder_batch_threshold = reorder_batch_threshold_i - - def may_reinitialize_input_batch(self, - kv_cache_config: KVCacheConfig) -> None: - """ - Re-initialize the input batch if the block sizes are different from - `[self.cache_config.block_size]`. This usually happens when there - are multiple KV cache groups. - - Args: - kv_cache_config: The KV cache configuration. - """ - block_sizes = [ - kv_cache_group.kv_cache_spec.block_size - for kv_cache_group in kv_cache_config.kv_cache_groups - ] - if block_sizes != [self.cache_config.block_size]: - assert self.cache_config.cpu_offload_gb == 0, ( - "Cannot re-initialize the input batch when CPU weight " - "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 - "for more details.") - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - block_sizes=block_sizes, - is_spec_decode=bool(self.vllm_config.speculative_config), - logitsprocs=self.input_batch.logitsprocs, - is_pooling_model=self.is_pooling_model, - ) - def _allocate_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ @@ -2941,7 +2884,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config - self.may_reinitialize_input_batch(kv_cache_config) self.may_add_encoder_only_layers_to_kv_cache_config() self.initialize_attn_backend(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) diff --git a/vllm/v1/worker/gpu_worker_states.py b/vllm/v1/worker/gpu_worker_states.py index aad6cd5e5345a..9e7189e3bd40a 100644 --- a/vllm/v1/worker/gpu_worker_states.py +++ b/vllm/v1/worker/gpu_worker_states.py @@ -282,6 +282,8 @@ class RequestState: # TODO generators={}, token_ids=None, + num_tokens=None, + num_prompt_tokens=None, max_num_logprobs=None, allowed_token_ids_mask=None, bad_words_token_ids={},