mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 21:37:10 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
2bb2cb13f4
commit
67d8c0c21b
@ -864,6 +864,7 @@ class Scheduler(SchedulerInterface):
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
) -> dict[int, EngineCoreOutputs]:
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
num_sampled_tokens = model_runner_output.num_sampled_tokens
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
@ -878,7 +879,8 @@ class Scheduler(SchedulerInterface):
|
||||
# to avoid expensive operations inside the loop.
|
||||
stopped_running_reqs: set[Request] = set()
|
||||
stopped_preempted_reqs: set[Request] = set()
|
||||
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
|
||||
for req_index, req_id in enumerate(model_runner_output.req_ids):
|
||||
num_tokens_scheduled = num_scheduled_tokens[req_id]
|
||||
assert num_tokens_scheduled > 0
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
@ -887,9 +889,13 @@ class Scheduler(SchedulerInterface):
|
||||
# in pipeline parallelism).
|
||||
continue
|
||||
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
generated_token_ids = sampled_token_ids[
|
||||
req_index] if sampled_token_ids else []
|
||||
generated_token_ids = []
|
||||
if sampled_token_ids is not None:
|
||||
assert num_sampled_tokens is not None
|
||||
n = num_sampled_tokens[req_index]
|
||||
if n > 0:
|
||||
generated_token_ids = sampled_token_ids[req_index, :n]
|
||||
generated_token_ids = generated_token_ids.tolist()
|
||||
|
||||
scheduled_spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||
|
||||
@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@ -80,20 +81,18 @@ class KVConnectorOutput:
|
||||
|
||||
|
||||
# ModelRunnerOutput is serialized and sent to the scheduler process.
|
||||
# This is expensive for torch.Tensor so prefer to use list instead.
|
||||
@dataclass
|
||||
class ModelRunnerOutput:
|
||||
|
||||
# [num_reqs]
|
||||
req_ids: list[str]
|
||||
# req_id -> index
|
||||
req_id_to_index: dict[str, int]
|
||||
|
||||
# num_reqs x num_generated_tokens
|
||||
# num_generated_tokens is the number of tokens
|
||||
# generated in the current step. It can be different for
|
||||
# each request due to speculative/jump decoding.
|
||||
sampled_token_ids: list[list[int]]
|
||||
sampled_token_ids: Optional[np.ndarray]
|
||||
num_sampled_tokens: Optional[np.ndarray]
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
@ -139,8 +138,8 @@ class DraftTokenIds:
|
||||
|
||||
|
||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
||||
req_id_to_index={},
|
||||
sampled_token_ids=[],
|
||||
sampled_token_ids=None,
|
||||
num_sampled_tokens=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
|
||||
@ -348,7 +348,7 @@ class GPUModelRunner:
|
||||
self.req_states.append_token_ids(
|
||||
input_batch.idx_mapping_np,
|
||||
sampled_token_ids_np,
|
||||
num_sampled_tokens=num_sampled_tokens,
|
||||
num_sampled_tokens,
|
||||
)
|
||||
return sampled_token_ids_np, num_sampled_tokens
|
||||
|
||||
@ -380,14 +380,10 @@ class GPUModelRunner:
|
||||
sampler_output = self.sample(logits, input_batch)
|
||||
sampled_token_ids_np, num_sampled_tokens = self.postprocess(
|
||||
sampler_output, input_batch)
|
||||
req_id_to_index = {
|
||||
req_id: i
|
||||
for i, req_id in enumerate(input_batch.req_ids)
|
||||
}
|
||||
return ModelRunnerOutput(
|
||||
req_ids=input_batch.req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids_np.tolist(),
|
||||
sampled_token_ids=sampled_token_ids_np,
|
||||
num_sampled_tokens=num_sampled_tokens,
|
||||
logprobs=sampler_output.logprobs_tensors,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user