Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-18 15:15:31 -07:00
parent 2bb2cb13f4
commit 67d8c0c21b
3 changed files with 18 additions and 17 deletions

View File

@ -864,6 +864,7 @@ class Scheduler(SchedulerInterface):
model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]:
sampled_token_ids = model_runner_output.sampled_token_ids
num_sampled_tokens = model_runner_output.num_sampled_tokens
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
@ -878,7 +879,8 @@ class Scheduler(SchedulerInterface):
# to avoid expensive operations inside the loop.
stopped_running_reqs: set[Request] = set()
stopped_preempted_reqs: set[Request] = set()
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
for req_index, req_id in enumerate(model_runner_output.req_ids):
num_tokens_scheduled = num_scheduled_tokens[req_id]
assert num_tokens_scheduled > 0
request = self.requests.get(req_id)
if request is None:
@ -887,9 +889,13 @@ class Scheduler(SchedulerInterface):
# in pipeline parallelism).
continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else []
generated_token_ids = []
if sampled_token_ids is not None:
assert num_sampled_tokens is not None
n = num_sampled_tokens[req_index]
if n > 0:
generated_token_ids = sampled_token_ids[req_index, :n]
generated_token_ids = generated_token_ids.tolist()
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))

View File

@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import NamedTuple, Optional
import numpy as np
import torch
@ -80,20 +81,18 @@ class KVConnectorOutput:
# ModelRunnerOutput is serialized and sent to the scheduler process.
# This is expensive for torch.Tensor so prefer to use list instead.
@dataclass
class ModelRunnerOutput:
# [num_reqs]
req_ids: list[str]
# req_id -> index
req_id_to_index: dict[str, int]
# num_reqs x num_generated_tokens
# num_generated_tokens is the number of tokens
# generated in the current step. It can be different for
# each request due to speculative/jump decoding.
sampled_token_ids: list[list[int]]
sampled_token_ids: Optional[np.ndarray]
num_sampled_tokens: Optional[np.ndarray]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
@ -139,8 +138,8 @@ class DraftTokenIds:
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
sampled_token_ids=None,
num_sampled_tokens=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],

View File

@ -348,7 +348,7 @@ class GPUModelRunner:
self.req_states.append_token_ids(
input_batch.idx_mapping_np,
sampled_token_ids_np,
num_sampled_tokens=num_sampled_tokens,
num_sampled_tokens,
)
return sampled_token_ids_np, num_sampled_tokens
@ -380,14 +380,10 @@ class GPUModelRunner:
sampler_output = self.sample(logits, input_batch)
sampled_token_ids_np, num_sampled_tokens = self.postprocess(
sampler_output, input_batch)
req_id_to_index = {
req_id: i
for i, req_id in enumerate(input_batch.req_ids)
}
return ModelRunnerOutput(
req_ids=input_batch.req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids_np.tolist(),
sampled_token_ids=sampled_token_ids_np,
num_sampled_tokens=num_sampled_tokens,
logprobs=sampler_output.logprobs_tensors,
prompt_logprobs_dict={},
pooler_output=[],