diff --git a/vllm/v1/worker/gpu/async_utils.py b/vllm/v1/worker/gpu/async_utils.py new file mode 100644 index 0000000000000..ed11701739986 --- /dev/null +++ b/vllm/v1/worker/gpu/async_utils.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.v1.outputs import (AsyncModelRunnerOutput, LogprobsTensors, + ModelRunnerOutput, SamplerOutput) + + +class AsyncOutput(AsyncModelRunnerOutput): + + def __init__( + self, + model_runner_output: ModelRunnerOutput, + sampler_output: SamplerOutput, + copy_stream: torch.cuda.Stream, + ): + self.model_runner_output = model_runner_output + self.sampler_output = sampler_output + self.copy_stream = copy_stream + self.copy_event = torch.cuda.Event() + + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(self.copy_stream): + self.copy_stream.wait_stream(default_stream) + + self.sampled_token_ids = sampler_output.sampled_token_ids.to( + "cpu", non_blocking=True) + x = sampler_output.logprobs_tensors + if x is not None: + self.logprobs_tensors = LogprobsTensors( + logprob_token_ids=x.logprob_token_ids.to( + "cpu", non_blocking=True), + logprobs=x.logprobs.to("cpu", non_blocking=True), + selected_token_ranks=x.selected_token_ranks.to( + "cpu", non_blocking=True), + ) + else: + self.logprobs_tensors = None + self.copy_event.record() + + def get_output(self) -> ModelRunnerOutput: + self.copy_event.synchronize() + self.model_runner_output.sampled_token_ids = ( + self.sampled_token_ids.numpy()) + if self.logprobs_tensors is not None: + self.model_runner_output.logprobs = ( + self.logprobs_tensors.tolists()) + return self.model_runner_output diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 153f4af34eb8d..53072ac9d5ecc 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -20,6 +20,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.sample.sampler import SamplerOutput +from vllm.v1.worker.gpu.async_utils import AsyncOutput from vllm.v1.worker.gpu.attn_utils import (get_kv_cache_spec, init_attn_backend, init_kv_cache) from vllm.v1.worker.gpu.block_table import BlockTables @@ -65,6 +66,10 @@ class GPUModelRunner: self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_reqs = self.scheduler_config.max_num_seqs + self.use_async_scheduling = self.scheduler_config.async_scheduling + assert self.use_async_scheduling + self.output_copy_stream = torch.cuda.Stream() + self.req_states = RequestState( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, @@ -412,29 +417,32 @@ class GPUModelRunner: self, sampler_output: SamplerOutput, input_batch: InputBatch, - ) -> tuple[np.ndarray, np.ndarray]: + ) -> AsyncOutput: # Get the number of sampled tokens. # 0 if chunked-prefilling, 1 if not. is_chunked_prefilling = input_batch.is_chunked_prefilling num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32) - # Update the token IDs and the number of tokens. - sampled_token_ids_cpu = sampler_output.sampled_token_ids.cpu() - sampled_token_ids_np = sampled_token_ids_cpu.numpy() - self.req_states.append_token_ids( - input_batch.idx_mapping_np, - sampled_token_ids_np, - num_sampled_tokens, + model_runner_output = ModelRunnerOutput( + req_ids=input_batch.req_ids, + sampled_token_ids=None, + num_sampled_tokens=num_sampled_tokens, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + kv_connector_output=None, + num_nans_in_logits=None, + ) + return AsyncOutput( + model_runner_output=model_runner_output, + sampler_output=sampler_output, + copy_stream=self.output_copy_stream, ) - - # self.req_states.last_token_ids[input_batch.idx_mapping] = ( - # sampler_output.sampled_token_ids) - return sampled_token_ids_np, num_sampled_tokens def execute_model( self, scheduler_output: SchedulerOutput, - ): + ) -> AsyncOutput: self.update_states(scheduler_output) if scheduler_output.total_num_scheduled_tokens == 0: return EMPTY_MODEL_RUNNER_OUTPUT @@ -453,19 +461,4 @@ class GPUModelRunner: ) sampler_output = self.sample(hidden_states, input_batch) - - sampled_token_ids_np, num_sampled_tokens = self.postprocess( - sampler_output, input_batch) - logprobs = None - if sampler_output.logprobs_tensors is not None: - logprobs = sampler_output.logprobs_tensors.tolists() - return ModelRunnerOutput( - req_ids=input_batch.req_ids, - sampled_token_ids=sampled_token_ids_np, - num_sampled_tokens=num_sampled_tokens, - logprobs=logprobs, - prompt_logprobs_dict={}, - pooler_output=[], - kv_connector_output=None, - num_nans_in_logits=None, - ) + return self.postprocess(sampler_output, input_batch)