From c80c53a30ff7a9c074ec6a7d88021ebe8c19e2e9 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 22 Aug 2025 17:20:41 -0700 Subject: [PATCH] [BugFix] Fix batch updates for pooling models (#23398) Signed-off-by: Nick Hill --- vllm/v1/sample/logits_processor/state.py | 20 +++- vllm/v1/worker/gpu_input_batch.py | 146 ++++++++++++----------- vllm/v1/worker/gpu_model_runner.py | 8 +- 3 files changed, 95 insertions(+), 79 deletions(-) diff --git a/vllm/v1/sample/logits_processor/state.py b/vllm/v1/sample/logits_processor/state.py index 0f58b52496956..31cece58c7db5 100644 --- a/vllm/v1/sample/logits_processor/state.py +++ b/vllm/v1/sample/logits_processor/state.py @@ -50,6 +50,10 @@ class BatchUpdateBuilder: self.added = added or [] self._is_removed_sorted = False + # Used to track changes in the pooling case + # where we don't populate the added list. + self.batch_changed = False + def _ensure_removed_sorted(self) -> None: """Sort removed request indices in descending order. @@ -80,6 +84,7 @@ class BatchUpdateBuilder: raise RuntimeError("Cannot register new removed request after" " self.removed has been read.") self._removed.append(index) + self.batch_changed = True def has_removed(self) -> bool: return bool(self._removed) @@ -98,9 +103,15 @@ class BatchUpdateBuilder: return self._removed.pop() return None - def _is_update(self) -> bool: - """True if there is a batch state change""" - return any((self._removed, self.moved, self.added)) + def reset(self) -> bool: + """Returns True if there were any changes to the batch.""" + self._is_removed_sorted = False + self._removed.clear() + self.moved.clear() + self.added.clear() + batch_changed = self.batch_changed + self.batch_changed = False + return batch_changed def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: """Generate a logitsprocs batch update data structure and reset @@ -114,7 +125,8 @@ class BatchUpdateBuilder: """ # Reset removal-sorting logic self._is_removed_sorted = False - if not self._is_update(): + self.batch_changed = False + if not any((self._removed, self.moved, self.added)): # No update; short-circuit return None # Build batch state update diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index e45d1ef31f603..f48c9de2f4e1a 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -65,8 +65,7 @@ class CachedRequestState: def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: return self.prompt_token_ids[idx] - else: - return self.output_token_ids[idx - self.num_prompt_tokens] + return self.output_token_ids[idx - self.num_prompt_tokens] class InputBatch: @@ -261,30 +260,27 @@ class InputBatch: Not applicable to pooling models. """ - # Detailed added request metadata is only required for non-pooling - # models, to support logitsprocs - assert request.sampling_params - # Fill the next empty index if there is one. if (new_req_index := self.batch_update_builder.pop_removed()) is None: # Append to end otherwise. new_req_index = self.num_reqs assert new_req_index < self.max_num_reqs - self.batch_update_builder.added.append( - (new_req_index, request.sampling_params, request.prompt_token_ids, - request.output_token_ids)) + self.batch_update_builder.batch_changed = True + if request.sampling_params: + # Detailed added request metadata is only required for non-pooling + # models, to support logitsprocs. + self.batch_update_builder.added.append( + (new_req_index, request.sampling_params, + request.prompt_token_ids, request.output_token_ids)) + return new_req_index def add_request( self, request: "CachedRequestState", ) -> int: - if not self.is_pooling_model: - # New request index bookkeeping for autoregressive models. - req_index = self._register_add_request(request) - else: - req_index = self.num_reqs + req_index = self._register_add_request(request) req_id = request.req_id if req_index == len(self._req_ids): @@ -389,7 +385,7 @@ class InputBatch: self.logits_processing_needs_token_ids[req_index] = ( pooling_params.requires_token_ids) else: - raise NotImplementedError(request) + raise NotImplementedError("Unrecognized request type") # Add request lora ID if request.lora_request: @@ -419,13 +415,25 @@ class InputBatch: req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: return None - if not self.is_pooling_model: - # Autoregressive models require bookkeeping of removed requests to - # support logitsprocs. - self.batch_update_builder.removed_append(req_index) + + self.batch_update_builder.removed_append(req_index) self._req_ids[req_index] = None self.req_output_token_ids[req_index] = None + # LoRA + lora_id = self.request_lora_mapping[req_index] + if lora_id != 0: + lora_req_ids = self.lora_id_to_request_ids[lora_id] + lora_req_ids.discard(req_id) + if not lora_req_ids: + del self.lora_id_to_request_ids[lora_id] + del self.lora_id_to_lora_request[lora_id] + self.request_lora_mapping[req_index] = 0 + + if self.is_pooling_model: + self.pooling_params.pop(req_id, None) + return req_index + self.greedy_reqs.discard(req_id) self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) @@ -439,29 +447,14 @@ class InputBatch: self.num_prompt_logprobs.pop(req_id, None) self.in_progress_prompt_logprobs_cpu.pop(req_id, None) - # LoRA - lora_id = self.request_lora_mapping[req_index] - if lora_id != 0: - lora_req_ids = self.lora_id_to_request_ids[lora_id] - lora_req_ids.discard(req_id) - if not lora_req_ids: - del self.lora_id_to_request_ids[lora_id] - del self.lora_id_to_lora_request[lora_id] - self.request_lora_mapping[req_index] = 0 - self.has_allowed_token_ids.discard(req_id) if self.allowed_token_ids_mask_cpu_tensor is not None: # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) self.bad_words_token_ids.pop(req_index, None) - self.pooling_params.pop(req_id, None) return req_index def swap_states(self, i1: int, i2: int) -> None: - # For autoregressive models, track detailed request reordering info - # to support logitsprocs - self.batch_update_builder.moved.append( - (i1, i2, MoveDirectionality.SWAP)) old_id_i1 = self._req_ids[i1] old_id_i2 = self._req_ids[i2] self._req_ids[i1], self._req_ids[i2] =\ @@ -479,18 +472,6 @@ class InputBatch: self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] - self.temperature_cpu[i1], self.temperature_cpu[i2] =\ - self.temperature_cpu[i2], self.temperature_cpu[i1] - self.top_p_cpu[i1], self.top_p_cpu[i2] =\ - self.top_p_cpu[i2], self.top_p_cpu[i1] - self.top_k_cpu[i1], self.top_k_cpu[i2] =\ - self.top_k_cpu[i2], self.top_k_cpu[i1] - self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\ - self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] - self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\ - self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] - self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ - self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ @@ -501,18 +482,41 @@ class InputBatch: self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp + self.block_table.swap_row(i1, i2) + + self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \ + self.request_lora_mapping[i2], self.request_lora_mapping[i1] + + if self.is_pooling_model: + # Sampling and logits parameters don't apply to pooling models. + return + + # For autoregressive models, track detailed request reordering info + # to support logitsprocs. + self.batch_update_builder.moved.append( + (i1, i2, MoveDirectionality.SWAP)) + + self.temperature_cpu[i1], self.temperature_cpu[i2] = \ + self.temperature_cpu[i2], self.temperature_cpu[i1] + self.top_p_cpu[i1], self.top_p_cpu[i2] = \ + self.top_p_cpu[i2], self.top_p_cpu[i1] + self.top_k_cpu[i1], self.top_k_cpu[i2] = \ + self.top_k_cpu[i2], self.top_k_cpu[i1] + self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \ + self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] + self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \ + self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] + self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \ + self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] + swap_dict_values(self.generators, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) - self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ - self.request_lora_mapping[i2], self.request_lora_mapping[i1] - if self.allowed_token_ids_mask_cpu_tensor is not None: self.allowed_token_ids_mask_cpu_tensor[i1], \ self.allowed_token_ids_mask_cpu_tensor[i2] =\ self.allowed_token_ids_mask_cpu_tensor[i2], \ self.allowed_token_ids_mask_cpu_tensor[i1] - self.block_table.swap_row(i1, i2) def condense(self) -> None: """Slide non-empty requests down into lower, empty indices. @@ -529,12 +533,6 @@ class InputBatch: """ num_reqs = self.num_reqs - if self.is_pooling_model: - # Will be contiguous in pooling case, just trim the lists. - del self._req_ids[num_reqs:] - del self.req_output_token_ids[num_reqs:] - return - if not (empty_req_indices := self.batch_update_builder.removed): # All removed requests were replaced by added requests, or else no # requests were removed at all. No condense() needed @@ -562,11 +560,6 @@ class InputBatch: # Move active request down into empty request # index. self.batch_update_builder.pop_removed() - # Autoregressive models require detailed tracking of condense - # operations to support logitsprocs - self.batch_update_builder.moved.append( - (last_req_index, empty_index, - MoveDirectionality.UNIDIRECTIONAL)) req_id = self._req_ids[last_req_index] output_token_ids = self.req_output_token_ids[last_req_index] assert req_id is not None @@ -587,6 +580,21 @@ class InputBatch: self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] self.block_table.move_row(last_req_index, empty_index) + + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ + last_req_index] + + if self.is_pooling_model: + last_req_index -= 1 + # Samping state not used by pooling models. + continue + + # Autoregressive models require detailed tracking of condense + # operations to support logitsprocs + self.batch_update_builder.moved.append( + (last_req_index, empty_index, + MoveDirectionality.UNIDIRECTIONAL)) + self.temperature_cpu[empty_index] = self.temperature_cpu[ last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] @@ -601,9 +609,6 @@ class InputBatch: if generator is not None: self.generators[empty_index] = generator - self.request_lora_mapping[empty_index] = self.request_lora_mapping[ - last_req_index] - # TODO convert these to LogitsProcessors if self.allowed_token_ids_mask_cpu_tensor is not None: self.allowed_token_ids_mask_cpu_tensor[ @@ -626,8 +631,9 @@ class InputBatch: """Apply any batch updates to sampling metadata.""" if self.is_pooling_model: - # Batch changes every step for pooling models. - self.sampling_metadata = self._make_sampling_metadata() + batch_changed = self.batch_update_builder.reset() + if batch_changed: + self.sampling_metadata = self._make_sampling_metadata() return # For non-pooling models - generate and apply logitsprocs update; @@ -720,7 +726,8 @@ class InputBatch: ) def _make_prompt_token_ids_tensor(self) -> torch.Tensor: - max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() + num_reqs = self.num_reqs + max_prompt_len = self.num_prompt_tokens[:num_reqs].max() prompt_token_ids_cpu_tensor = torch.empty( (self.num_reqs, max_prompt_len), device="cpu", @@ -728,11 +735,10 @@ class InputBatch: pin_memory=self.pin_memory, ) prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() - prompt_token_ids[:] = self.token_ids_cpu[:self. - num_reqs, :max_prompt_len] + prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len] # Use the value of vocab_size as a pad since we don't have a # token_id of this value. - for i in range(self.num_reqs): + for i in range(num_reqs): prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7160894b4acda..ed4a4e55f1212 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1489,10 +1489,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for raw_output, seq_len, prompt_len in zip( raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): - if seq_len == prompt_len: - pooler_output.append(raw_output.data) - else: - pooler_output.append(None) + output = raw_output.data if seq_len == prompt_len else None + pooler_output.append(output) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1522,7 +1520,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Prepare the decoder inputs. (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len) = (self._prepare_inputs(scheduler_output)) + max_query_len) = self._prepare_inputs(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE