From 75eb302a2e4000470d1ad6bfc3a009379554b648 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Tue, 16 Dec 2025 15:20:19 +0100 Subject: [PATCH] [Bugfix] Whisper fix number of allocated CrossAttn blocks per-request (#30772) Signed-off-by: NickLucche --- vllm/v1/core/sched/scheduler.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 278970ae7ee88..754e0b9d08316 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -187,6 +187,12 @@ class Scheduler(SchedulerInterface): if self.is_encoder_decoder else EncoderCacheManager(cache_size=encoder_cache_size) ) + # For encoder-decoder models, allocate the maximum number of tokens for Cross + # Attn blocks, as for Whisper its input is always padded to the maximum length. + # TODO (NickLucche): Generalize to models with variable-length encoder inputs. + self._num_encoder_max_input_tokens = ( + MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(vllm_config.model_config) + ) speculative_config = vllm_config.speculative_config self.use_eagle = False @@ -568,17 +574,11 @@ class Scheduler(SchedulerInterface): 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens ) - # Determine if we need to allocate cross-attention blocks. - if self.is_encoder_decoder and request.has_encoder_inputs: - # TODO(russellb): For Whisper, we know that the input is - # always padded to the maximum length. If we support other - # encoder-decoder models, this will need to be updated if we - # want to only allocate what is needed. - num_encoder_tokens = ( - self.scheduler_config.max_num_encoder_input_tokens - ) - else: - num_encoder_tokens = 0 + num_encoder_tokens = ( + self._num_encoder_max_input_tokens + if self.is_encoder_decoder and request.has_encoder_inputs + else 0 + ) new_blocks = self.kv_cache_manager.allocate_slots( request,