async output

Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
Woosuk Kwon 2025-09-19 07:10:42 +00:00
parent 33672774f5
commit 37478c18cf
2 changed files with 70 additions and 29 deletions

View File

@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.v1.outputs import (AsyncModelRunnerOutput, LogprobsTensors,
ModelRunnerOutput, SamplerOutput)
class AsyncOutput(AsyncModelRunnerOutput):
def __init__(
self,
model_runner_output: ModelRunnerOutput,
sampler_output: SamplerOutput,
copy_stream: torch.cuda.Stream,
):
self.model_runner_output = model_runner_output
self.sampler_output = sampler_output
self.copy_stream = copy_stream
self.copy_event = torch.cuda.Event()
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.copy_stream):
self.copy_stream.wait_stream(default_stream)
self.sampled_token_ids = sampler_output.sampled_token_ids.to(
"cpu", non_blocking=True)
x = sampler_output.logprobs_tensors
if x is not None:
self.logprobs_tensors = LogprobsTensors(
logprob_token_ids=x.logprob_token_ids.to(
"cpu", non_blocking=True),
logprobs=x.logprobs.to("cpu", non_blocking=True),
selected_token_ranks=x.selected_token_ranks.to(
"cpu", non_blocking=True),
)
else:
self.logprobs_tensors = None
self.copy_event.record()
def get_output(self) -> ModelRunnerOutput:
self.copy_event.synchronize()
self.model_runner_output.sampled_token_ids = (
self.sampled_token_ids.numpy())
if self.logprobs_tensors is not None:
self.model_runner_output.logprobs = (
self.logprobs_tensors.tolists())
return self.model_runner_output

View File

@ -20,6 +20,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.sample.sampler import SamplerOutput
from vllm.v1.worker.gpu.async_utils import AsyncOutput
from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec,
init_attn_backend, init_kv_cache)
from vllm.v1.worker.gpu.block_table import BlockTables
@ -65,6 +66,10 @@ class GPUModelRunner:
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.use_async_scheduling = self.scheduler_config.async_scheduling
assert self.use_async_scheduling
self.output_copy_stream = torch.cuda.Stream()
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
@ -412,29 +417,32 @@ class GPUModelRunner:
self,
sampler_output: SamplerOutput,
input_batch: InputBatch,
) -> tuple[np.ndarray, np.ndarray]:
) -> AsyncOutput:
# Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
is_chunked_prefilling = input_batch.is_chunked_prefilling
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
# Update the token IDs and the number of tokens.
sampled_token_ids_cpu = sampler_output.sampled_token_ids.cpu()
sampled_token_ids_np = sampled_token_ids_cpu.numpy()
self.req_states.append_token_ids(
input_batch.idx_mapping_np,
sampled_token_ids_np,
num_sampled_tokens,
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
sampled_token_ids=None,
num_sampled_tokens=num_sampled_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
kv_connector_output=None,
num_nans_in_logits=None,
)
return AsyncOutput(
model_runner_output=model_runner_output,
sampler_output=sampler_output,
copy_stream=self.output_copy_stream,
)
# self.req_states.last_token_ids[input_batch.idx_mapping] = (
# sampler_output.sampled_token_ids)
return sampled_token_ids_np, num_sampled_tokens
def execute_model(
self,
scheduler_output: SchedulerOutput,
):
) -> AsyncOutput:
self.update_states(scheduler_output)
if scheduler_output.total_num_scheduled_tokens == 0:
return EMPTY_MODEL_RUNNER_OUTPUT
@ -453,19 +461,4 @@ class GPUModelRunner:
)
sampler_output = self.sample(hidden_states, input_batch)
sampled_token_ids_np, num_sampled_tokens = self.postprocess(
sampler_output, input_batch)
logprobs = None
if sampler_output.logprobs_tensors is not None:
logprobs = sampler_output.logprobs_tensors.tolists()
return ModelRunnerOutput(
req_ids=input_batch.req_ids,
sampled_token_ids=sampled_token_ids_np,
num_sampled_tokens=num_sampled_tokens,
logprobs=logprobs,
prompt_logprobs_dict={},
pooler_output=[],
kv_connector_output=None,
num_nans_in_logits=None,
)
return self.postprocess(sampler_output, input_batch)