[crashfix] Eagle + multimodal can crash on mm cache miss (#29750)

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Mickaël Seznec 2025-12-01 10:29:33 +01:00 committed by GitHub
parent 014ece97c7
commit 86e178f7c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -263,6 +263,7 @@ class Scheduler(SchedulerInterface):
request.num_computed_tokens,
num_new_tokens,
encoder_compute_budget,
shift_computed_tokens=1 if self.use_eagle else 0,
)
if num_new_tokens == 0:
@ -532,6 +533,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens,
num_new_tokens,
encoder_compute_budget,
shift_computed_tokens=1 if self.use_eagle else 0,
)
if num_new_tokens == 0:
# The request cannot be scheduled.
@ -829,6 +831,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens: int,
num_new_tokens: int,
encoder_compute_budget: int,
shift_computed_tokens: int = 0,
) -> tuple[list[int], int, int, list[int]]:
"""
Determine which encoder inputs need to be scheduled in the current step,
@ -873,7 +876,10 @@ class Scheduler(SchedulerInterface):
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
# [start_pos, start_pos + num_encoder_tokens)
if start_pos >= num_computed_tokens + num_new_tokens:
if (
start_pos
>= num_computed_tokens + num_new_tokens + shift_computed_tokens
):
# The encoder input is not needed in this step.
break
@ -929,10 +935,12 @@ class Scheduler(SchedulerInterface):
# NOTE(woosuk): We assume that the encoder input tokens should
# be processed altogether, as the encoder usually uses
# bidirectional attention.
if num_computed_tokens < start_pos:
if num_computed_tokens + shift_computed_tokens < start_pos:
# We only schedule the decoder tokens just before the
# encoder input.
num_new_tokens = start_pos - num_computed_tokens
num_new_tokens = start_pos - (
num_computed_tokens + shift_computed_tokens
)
else:
# Because of prefix caching, num_computed_tokens is greater
# than start_pos even though its encoder input is not