From 92e0cc79a8c8bc08bde41aaa45a71a255b1d4083 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 2 Jun 2025 19:04:26 +0000 Subject: [PATCH] format Signed-off-by: Sage Moore --- vllm/v1/worker/gpu_model_runner.py | 177 +++++++++++++++-------------- vllm/v1/worker/ubatching.py | 117 ++++++++++--------- 2 files changed, 153 insertions(+), 141 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c728c17e873f0..15edc0d67fba9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2,18 +2,16 @@ import copy import gc -import os +import threading import time import weakref from typing import TYPE_CHECKING, Optional, TypeAlias, Union -import contextlib import numpy as np import torch import torch.distributed import torch.nn as nn -from vllm.utils import current_stream from vllm.attention import AttentionType, get_attn_backend from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadataBuilder) @@ -27,7 +25,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, prepare_communication_buffer_for_model) -from vllm.forward_context import get_forward_context, set_forward_context, create_forward_context, override_forward_context +from vllm.forward_context import (create_forward_context, get_forward_context, + override_forward_context, + set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model @@ -38,7 +38,7 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, cdiv, check_use_alibi, - is_pin_memory_available) + current_stream, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget @@ -59,14 +59,11 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.ubatching import make_ubatch_contexts, UBatchContext +from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) -import threading -from concurrent.futures import ThreadPoolExecutor, as_completed - if TYPE_CHECKING: import xgrammar as xgr @@ -504,40 +501,41 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.input_batch.refresh_sampling_metadata() def _ubatch_split( - self, - query_start_loc_np: torch.Tensor, - max_num_scheduled_tokens: int, - scheduler_output: "SchedulerOutput" - ) -> Optional[UBatchSlices]: + self, query_start_loc_np: torch.Tensor, + max_num_scheduled_tokens: int, + scheduler_output: "SchedulerOutput") -> Optional[UBatchSlices]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_reqs = self.input_batch.num_reqs - + if self.parallel_config.enable_microbatching and \ - total_num_scheduled_tokens >= self.parallel_config.microbatching_token_threshold \ - and max_num_scheduled_tokens == 1: + total_num_scheduled_tokens >= \ + self.parallel_config.microbatching_token_threshold \ + and max_num_scheduled_tokens == 1: # For pure decode we can just create ubatchs by cutting the request # in half b0_reqs_end = num_reqs // 2 b0_tokens_end = total_num_scheduled_tokens // 2 - assert b0_reqs_end < num_reqs and b0_tokens_end < total_num_scheduled_tokens + assert b0_reqs_end < num_reqs and \ + b0_tokens_end < total_num_scheduled_tokens return [ (slice(0, b0_reqs_end), slice(0, b0_tokens_end)), (slice(b0_reqs_end, num_reqs), slice(b0_tokens_end, total_num_scheduled_tokens)), ] - + if self.parallel_config.enable_microbatching and \ self.parallel_config.always_microbatch_if_enabled: # TODO we can do something more advanced here to try to balance, # i.e. split to the left of `total_num_scheduled_tokens // 2` if it # is more balanced - req_split_id = np.argmax(query_start_loc_np > (total_num_scheduled_tokens // 2)) - return [(slice(0, req_split_id), slice(0, query_start_loc_np[req_split_id])), - (slice(req_split_id, num_reqs), slice(query_start_loc_np[req_split_id], total_num_scheduled_tokens))] + req_split_id = np.argmax( + query_start_loc_np > (total_num_scheduled_tokens // 2)) + return [(slice(0, req_split_id), + slice(0, query_start_loc_np[req_split_id])), + (slice(req_split_id, num_reqs), + slice(query_start_loc_np[req_split_id], + total_num_scheduled_tokens))] return None - - def _is_dummy_ubatch(self, ubatch_slice: UBatchSlices) -> bool: - return ubatch_slice[1].start >= ubatch_slice[1].stop def _prepare_inputs( self, scheduler_output: "SchedulerOutput" @@ -630,7 +628,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens ubatch_slices: Optional[UBatchSlices] = self._ubatch_split( - self.query_start_loc_np, max_num_scheduled_tokens, scheduler_output) + self.query_start_loc_np, max_num_scheduled_tokens, + scheduler_output) self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + @@ -708,6 +707,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): )) for layer_name in kv_cache_group_spec.layer_names: assert type(attn_metadata) is list + assert attn_metadata_i is not None + # What if it's None? Do we still add it to the list? attn_metadata[ubid][layer_name] = attn_metadata_i else: attn_metadata_i = ( @@ -749,7 +750,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return attn_metadata, logits_indices, spec_decode_metadata, ubatch_slices + return (attn_metadata, logits_indices, spec_decode_metadata, + ubatch_slices) def _compute_cascade_attn_prefix_len( self, @@ -1167,8 +1169,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): if sync_self: assert intermediate_tensors is not None for k, v in intermediate_tensors.items(): - is_scattered = "residual" and is_residual_scattered - _copy_slice = copy_slice(is_scattered) + _copy_slice = copy_slice(is_residual_scattered) self.intermediate_tensors[k][_copy_slice].copy_( v[_copy_slice], non_blocking=True) @@ -1181,7 +1182,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _get_dummy_model_inputs(self, num_tokens: int) -> tuple: # Dummy batch. (hopefully we are the last one so we can just # update this to a one token batch and return) - + if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -1203,9 +1204,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): batch_size=self.max_num_tokens, dtype=self.model_config.dtype, device=self.device)) - - return input_ids, positions, inputs_embeds, intermediate_tensors + return input_ids, positions, inputs_embeds, intermediate_tensors def _get_model_inputs(self, tokens_slice: slice, scheduler_output: "SchedulerOutput"): @@ -1215,7 +1215,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # update this to a one token batch and return) tokens_slice = slice(tokens_slice.start, tokens_slice.start + 1) num_tokens = 1 - + if (self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. @@ -1281,85 +1281,87 @@ class GPUModelRunner(LoRAModelRunnerMixin): return input_ids, positions, inputs_embeds, intermediate_tensors def _run_model(self, - attn_metadata: Optional[PerLayerAttnMetadata], - num_scheduled_tokens: Optional[int], + attn_metadata: Optional[PerLayerAttnMetadata], + num_scheduled_tokens: Optional[int], ubatch_slices: Optional[UBatchSlices] = None, scheduler_output: Optional["SchedulerOutput"] = None, is_dummy_run: bool = False): - + num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1 - + def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple: if use_dummy_input: + assert num_dummy_tokens == 1 return self._get_dummy_model_inputs(num_dummy_tokens) else: assert scheduler_output is not None return self._get_model_inputs(tokens_slice, scheduler_output) - - + def _run(token_slice: slice, context, use_dummy_input: bool = False): input_ids, positions, inputs_embeds, intermediate_tensors = \ model_inputs(token_slice, use_dummy_input) with context: - # if isinstance(context, UBatchContext): - # print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape} use_dummy_input {use_dummy_input} token_slice {token_slice}") model_output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) - # if isinstance(context, UBatchContext): - # print(f"Ran ubatch {context.id}putput {model_output.shape}") if isinstance(context, UBatchContext): # Clone before we leave the ubatch context model_output = model_output.clone() - + return model_output - + @torch.inference_mode() - def _ubatch_thread(ubatch_ctx, token_slice, results, save_results, use_dummy_input): + def _ubatch_thread(ubatch_ctx, token_slice, results, save_results, + use_dummy_input): model_output = _run(token_slice, ubatch_ctx, use_dummy_input) if save_results: results.append(model_output) - def _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run): - results = [] + def _run_ubatches(ubatch_slices, attn_metadata, + is_dummy_run) -> torch.Tensor: + results: list[torch.Tensor] = [] assert len(ubatch_slices) == 2, "Only two ubatches has been tested" root_stream = current_stream() - ubatch_ctxs = make_ubatch_contexts( - len(ubatch_slices), - compute_stream=root_stream, - device=self.device) - + ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices), + compute_stream=root_stream, + device=self.device) + # Ubatches will manually manage the forward context, so we override # it to None here so we can have it restored correctly later with override_forward_context(None): ubatch_threads = [] for i, (_, tokens_slice) in enumerate(ubatch_slices): is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start - assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run - - # print("ubatch", i, "tokens slice", tokens_slice, "is dummy ubatch", is_dummy_ubatch, "is dummy run", is_dummy_run) - - num_tokens = num_dummy_tokens if is_dummy_ubatch or is_dummy_run else (tokens_slice.stop - tokens_slice.start) + assert not is_dummy_ubatch or i == len( + ubatch_slices) - 1 or is_dummy_run + + num_tokens = num_dummy_tokens if is_dummy_ubatch or \ + is_dummy_run else (tokens_slice.stop - tokens_slice.start) ubatch_ctxs[i].forward_context = create_forward_context( - attn_metadata[i] if attn_metadata is not None else None, - self.vllm_config, num_tokens=num_tokens) - - thread = threading.Thread(target=_ubatch_thread, args=( - ubatch_ctxs[i], - tokens_slice, - results, - not is_dummy_ubatch or is_dummy_run, - is_dummy_ubatch or is_dummy_run, - )) + attn_metadata[i] + if attn_metadata is not None else None, + self.vllm_config, + num_tokens=num_tokens) + + thread = threading.Thread(target=_ubatch_thread, + args=( + ubatch_ctxs[i], + tokens_slice, + results, + not is_dummy_ubatch + or is_dummy_run, + is_dummy_ubatch + or is_dummy_run, + )) ubatch_threads.append(thread) thread.start() ubatch_ctxs[0].cpu_wait_event.set() - + for thread in ubatch_threads: thread.join() @@ -1369,18 +1371,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): # run micro-batched if ubatch_slices is not None: - model_output = _run_ubatches( - ubatch_slices, attn_metadata, is_dummy_run) + model_output = _run_ubatches(ubatch_slices, attn_metadata, + is_dummy_run) # run single batch else: model_output = _run( - slice(0, num_scheduled_tokens), - set_forward_context( - attn_metadata, - vllm_config=self.vllm_config, - num_tokens=num_scheduled_tokens or 1), + slice(0, num_scheduled_tokens), + set_forward_context(attn_metadata, + vllm_config=self.vllm_config, + num_tokens=num_scheduled_tokens or 1), is_dummy_run) - + return model_output @torch.inference_mode() @@ -1583,7 +1584,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): if spec_decode_metadata is None: # input_ids can be None for multimodal models. target_token_ids = self.input_ids[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] + #TODO(sage) make sure this works with mrope + target_positions = self.positions[:num_scheduled_tokens] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[:num_scheduled_tokens] for h in aux_hidden_states], @@ -1609,7 +1611,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_rejected_tokens, ) target_token_ids = self.input_ids[token_indices] - target_positions = positions[token_indices] + #TODO(sage) make sure this works with mrope + target_positions = self.positions[token_indices] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) @@ -1647,6 +1650,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): def kv_connector_no_forward( self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: + # KV send/recv even if no work to do. with set_forward_context(None, self.vllm_config): self.maybe_setup_kv_connector(scheduler_output) @@ -1894,20 +1898,21 @@ class GPUModelRunner(LoRAModelRunnerMixin): )) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i - + should_microbatch = ( - allow_microbatching - and self.vllm_config.parallel_config.enable_microbatching - and self.vllm_config.parallel_config.always_microbatch_if_enabled - ) - dummy_microbatches = [(slice(0, 0), slice(0, 0)), (slice(0, 0), slice(0, 0))] + allow_microbatching + and self.vllm_config.parallel_config.enable_microbatching + and self.vllm_config.parallel_config.always_microbatch_if_enabled) + dummy_microbatches = [(slice(0, 0), slice(0, 0)), + (slice(0, 0), slice(0, 0))] with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): outputs = self._run_model( attn_metadata, num_tokens, - ubatch_slices=None if not should_microbatch else dummy_microbatches, + ubatch_slices=None + if not should_microbatch else dummy_microbatches, is_dummy_run=True, ) if self.use_aux_hidden_state_outputs: diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 1b08d12bd9f08..9a8b546819e43 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -1,36 +1,36 @@ # SPDX-License-Identifier: Apache-2.0 import threading +from typing import Optional + import torch import torch._dynamo -import torch.profiler as profiler -import os -from typing import Optional -from torch.library import Library -from torch.library import custom_op, register_kernel -from vllm.distributed import (get_dp_group) +from torch.library import custom_op -from vllm.utils import current_stream from vllm import forward_context +from vllm.utils import current_stream + class UBatchContext: """ Context manager for micro-batching synchronization using threading events. """ - def __init__(self, - id: int, - comm_stream: torch.cuda.Stream, - compute_stream: torch.cuda.Stream, - #fwd_ctx: forward_context.ForwardContext, - cpu_wait_event: threading.Event, - cpu_signal_event: threading.Event, - gpu_comm_done_event: torch.cuda.Event, - gpu_compute_done_event: torch.cuda.Event, - schedule: str = "default"): + + def __init__( + self, + id: int, + comm_stream: torch.cuda.Stream, + compute_stream: torch.cuda.Stream, + #fwd_ctx: forward_context.ForwardContext, + cpu_wait_event: threading.Event, + cpu_signal_event: threading.Event, + gpu_comm_done_event: torch.cuda.Event, + gpu_compute_done_event: torch.cuda.Event, + schedule: str = "default"): self.id = id self.comm_stream = comm_stream self.compute_stream = compute_stream self.original_stream = current_stream() - self.forward_context = None #fwd_ctx + self.forward_context = None #fwd_ctx self.cpu_wait_event = cpu_wait_event self.cpu_signal_event = cpu_signal_event self.current_stream = compute_stream @@ -47,8 +47,7 @@ class UBatchContext: self.cpu_wait_event.clear() self._restore_context() # Assume we start on the compute stream - assert current_stream() == self.compute_stream, \ - "Expected to start on the compute stream, but found %s" % current_stream() + assert current_stream() == self.compute_stream return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -64,7 +63,7 @@ class UBatchContext: def _restore_context(self): forward_context._forward_context = self.forward_context torch.cuda.set_stream(self.current_stream) - + def update_stream(self, stream): self.current_stream = stream torch.cuda.set_stream(self.current_stream) @@ -74,10 +73,11 @@ class UBatchContext: # assert current_stream() == self.current_stream # assert not self.cpu_wait_event.is_set() pass + def _signal_comm_done(self): # self.ctx_valid_state() self.gpu_comm_done_event.record(self.comm_stream) - + def _signal_compute_done(self): # self.ctx_valid_state() self.gpu_compute_done_event.record(self.compute_stream) @@ -114,32 +114,38 @@ class UBatchContext: def yield_and_switch_from_compute_to_comm(self): assert current_stream() == self.compute_stream - # dp_rank = get_dp_group().rank_in_group - # print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True) + # dp_rank = get_dp_group().rank_in_group + # print(f"DP: {dp_rank} UB: {self.id} " + # f"Yield and switch from {self.stream_string()}", flush=True) # self.ctx_valid_state() self._signal_compute_done() self._cpu_yield() # self.ctx_valid_state() assert self.current_stream == self.compute_stream self.update_stream(self.comm_stream) - # print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True) + # print(f"DP: {dp_rank} UB: {self.id} " + # f"Resuming on stream {self.stream_string()}", flush=True) self._wait_compute_done() def yield_and_switch_from_comm_to_compute(self): assert current_stream() == self.comm_stream - # dp_rank = get_dp_group().rank_in_group - # print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True) + # dp_rank = get_dp_group().rank_in_group + # print(f"DP: {dp_rank} UB: {self.id} " + # f"Yield and switch from {self.stream_string()}", flush=True) # self.ctx_valid_state() self._signal_comm_done() self._cpu_yield() # self.ctx_valid_state() assert self.current_stream == self.comm_stream self.update_stream(self.compute_stream) - # print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True) + # print(f"DP: {dp_rank} UB: {self.id} " + # f"Resuming on stream {self.stream_string()}", flush=True) self._wait_comm_done() - + _CURRENT_CONTEXT: dict = {} + + def get_current_ubatch_context() -> Optional[UBatchContext]: global _CURRENT_CONTEXT """ @@ -147,57 +153,59 @@ def get_current_ubatch_context() -> Optional[UBatchContext]: """ return _CURRENT_CONTEXT.get(threading.get_ident(), None) + def yield_and_switch_from_compute_to_comm_impl(schedule="default"): # Perform the barrier if a context exists for this thread - ctx = get_current_ubatch_context() + ctx = get_current_ubatch_context() #print("you are in yield_impl", ctx) if ctx is not None and ctx.schedule == schedule: ctx.yield_and_switch_from_compute_to_comm() + def yield_and_switch_from_comm_to_compute_impl(schedule="default"): # Perform the barrier if a context exists for this thread ctx = get_current_ubatch_context() if ctx is not None and ctx.schedule == schedule: ctx.yield_and_switch_from_comm_to_compute() + # 2) Register kernel for CUDA, mark as mutating to prevent the compiler from # optimizing it away (TODO: see if this is actually needed) -@custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x",)) -def yield_and_switch_from_compute_to_comm(x: torch.Tensor, schedule: str="default") -> None: +@custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x", )) +def yield_and_switch_from_compute_to_comm(x: torch.Tensor, + schedule: str = "default") -> None: yield_and_switch_from_compute_to_comm_impl(schedule) + # 3) Fake implementation for shape prop and FX tracing @yield_and_switch_from_compute_to_comm.register_fake -def yield_and_switch_from_compute_to_comm(x: torch.Tensor, schedule: str="default") -> None: +def yield_and_switch_from_compute_to_comm_fake(x: torch.Tensor, + schedule: str = "default" + ) -> None: pass -@custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x",)) -def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="default") -> None: + +@custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x", )) +def yield_and_switch_from_comm_to_compute(x: torch.Tensor, + schedule: str = "default") -> None: yield_and_switch_from_comm_to_compute_impl(schedule) + @yield_and_switch_from_comm_to_compute.register_fake -def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="default") -> None: +def yield_and_switch_from_comm_to_compute_fake(x: torch.Tensor, + schedule: str = "default" + ) -> None: pass + def dump_ubatching_state(): pass - # """ - # Dump the current UBatchContext state for debugging. - # """ - - # dp_rank = os.getenv("VLLM_DP_RANK", None) - - # for ctx in _CURRENT_CONTEXT.values(): - # print(f"UBatchContext: {ctx.id} (dp_rank {dp_rank})\n" - # f" Stream: {ctx.stream}, ({ctx.stream.query()})\n" - # f" Original Stream: {ctx.original_stream}, ({ctx.original_stream.query()})\n" - # f" CPU Wait Event: {ctx.cpu_wait_event}\n" - # f" GPU Wait Event: {ctx.gpu_wait_event} ({ctx.gpu_wait_event.query()})\n" - # f" CPU Signal Event: {ctx.cpu_signal_event}\n" - # f" GPU Signal Event: {ctx.gpu_signal_event} ({ctx.gpu_signal_event.query()})\n") + """ """ + + def make_ubatch_contexts( num_micro_batches: int, compute_stream: torch.cuda.Stream, @@ -205,7 +213,6 @@ def make_ubatch_contexts( schedule: str = "default", ) -> list[UBatchContext]: assert num_micro_batches == 2, "only been tested with 2 micro-batches" - """ Create a context manager for micro-batching synchronization. """ @@ -225,11 +232,11 @@ def make_ubatch_contexts( compute_stream=compute_stream, comm_stream=comm_stream, cpu_wait_event=cpu_events[i], - cpu_signal_event=cpu_events[(i + 1) % num_micro_batches], + cpu_signal_event=cpu_events[(i + 1) % + num_micro_batches], gpu_comm_done_event=gpu_comm_done_events[i], gpu_compute_done_event=gpu_compute_done_events[i], - schedule=schedule - ) + schedule=schedule) ctxs.append(ctx) - return ctxs \ No newline at end of file + return ctxs