diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index e07d53ff84d37..914597e7ae629 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -842,6 +842,7 @@ class Scheduler(SchedulerInterface): scheduler_output: SchedulerOutput, model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: + num_sampled_tokens = model_runner_output.num_sampled_tokens sampled_token_ids = model_runner_output.sampled_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict @@ -867,14 +868,17 @@ class Scheduler(SchedulerInterface): continue req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = sampled_token_ids[ - req_index] if sampled_token_ids else [] + num_sampled = num_sampled_tokens[req_index] + if num_sampled > 0: + generated_token_ids = sampled_token_ids[:num_sampled].tolist() + else: + generated_token_ids = [] scheduled_spec_token_ids = ( scheduler_output.scheduled_spec_decode_tokens.get(req_id)) if scheduled_spec_token_ids: num_draft_tokens = len(scheduled_spec_token_ids) - num_accepted = len(generated_token_ids) - 1 + num_accepted = num_sampled - 1 num_rejected = num_draft_tokens - num_accepted # num_computed_tokens represents the number of tokens # processed in the current step, considering scheduled diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f8d6b24702f3c..2a2a498ec2453 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import NamedTuple, Optional +import numpy as np import torch @@ -88,11 +89,12 @@ class ModelRunnerOutput: # req_id -> index req_id_to_index: dict[str, int] - # num_reqs x num_generated_tokens - # num_generated_tokens is the number of tokens - # generated in the current step. It can be different for - # each request due to speculative/jump decoding. - sampled_token_ids: list[list[int]] + # [num_reqs] + # Number of tokens sampled in the current step. Each request may generate + # different number of tokens due to chunked prefilling and spec decoding. + num_sampled_tokens: np.ndarray + # [num_reqs, max_num_sampled_tokens] + sampled_token_ids: np.ndarray # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] @@ -123,10 +125,13 @@ class DraftTokenIds: draft_token_ids: list[list[int]] -EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], - req_id_to_index={}, - sampled_token_ids=[], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - num_nans_in_logits=None) +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + num_sampled_tokens=np.empty(0, dtype=np.int32), + sampled_token_ids=np.empty((0, 0), dtype=np.int32), + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + num_nans_in_logits=None, +)