[BugFix] Fix using dbo_decode_token_threshold always (and ignoring dbo_prefill_token_threshold) (#25622)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-09-26 12:22:49 -04:00 committed by GitHub
parent d4d9899860
commit 984d18498a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 3 deletions

View File

@ -1045,11 +1045,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
num_tokens_padded = num_tokens_unpadded + self.get_local_padding(
num_tokens_unpadded)
uniform_decode = \
(max_num_scheduled_tokens == self.uniform_decode_query_len) and \
(total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
ubatch_slices, num_tokens_after_padding = \
ubatch_split(num_scheduled_tokens,
num_tokens_unpadded,
num_tokens_padded,
self.vllm_config)
uniform_decode=uniform_decode,
vllm_config=self.vllm_config)
self.seq_lens.np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
@ -2989,7 +2993,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens,
total_num_scheduled_tokens,
total_num_scheduled_tokens,
self.vllm_config,
uniform_decode=uniform_decode,
vllm_config=self.vllm_config,
)
# If we failed to microbatch, currently need to resynchronize

View File

@ -139,6 +139,7 @@ def ubatch_split(
num_scheduled_tokens_per_request: np.ndarray,
num_tokens_unpadded: int,
num_tokens_padded: int,
uniform_decode: bool,
vllm_config: VllmConfig,
) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]:
"""
@ -164,7 +165,7 @@ def ubatch_split(
should_attempt_ubatching = check_ubatch_thresholds(
parallel_config,
num_tokens_unpadded,
vllm_config,
uniform_decode=uniform_decode,
)
# Don't microbatch unless every other DP worker is also microbatching