diff --git a/examples/basic-ub.py b/examples/basic-ub.py index 8f1fbc2a25420..21ddff5dddbe8 100644 --- a/examples/basic-ub.py +++ b/examples/basic-ub.py @@ -40,7 +40,7 @@ def main(): max_model_len=1024, #load_format="dummy", ############### - tensor_parallel_size=4, + tensor_parallel_size=2, #data_parallel_size=2, enable_expert_parallel=False, ############### diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0c3a9084f3301..2e9b63c4cb00e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -57,6 +57,7 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm.v1.worker.ubatching import make_ubatch_context_chain, UBatchContext from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) @@ -1284,104 +1285,108 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output: Optional["SchedulerOutput"] = None, is_dummy_run: bool = False): - def model_inputs(tokens_slice: slice, is_dummy: bool) -> tuple: - if is_dummy: + def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple: + if use_dummy_input: num_tokens = num_scheduled_tokens or 1 return num_tokens, *self._get_dummy_model_inputs(num_tokens) else: assert scheduler_output is not None num_tokens = tokens_slice.stop - tokens_slice.start return num_tokens, *self._get_model_inputs(tokens_slice, scheduler_output) - - @torch.inference_mode() - def process_batch(save_results, attn_metadata, vllm_config, model, num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors, results, stream): - with set_forward_context(attn_metadata, - vllm_config, - num_tokens=num_tokens): - torch.cuda.set_stream(stream) - - model_output = model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - - if save_results: - results.append(model_output.clone()) - - def threaded_processing(ubatch_slices, attn_metadata, vllm_config, model, is_dummy_run=False): - results = [] - # print(f"UBATCH SLICES: {len(ubatch_slices)}") - for i, (_, tokens_slice) in enumerate(ubatch_slices): - # print("ITERATION") - is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start - assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run - - num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \ - model_inputs(tokens_slice, is_dummy_ubatch) - - thread = threading.Thread(target=process_batch, args=( - not is_dummy_ubatch or is_dummy_run, - attn_metadata[i] if attn_metadata is not None else None, - vllm_config, - model, - num_tokens, - input_ids, - positions, - inputs_embeds, - intermediate_tensors, - results, - torch.cuda.current_stream() - )) - thread.start() - thread.join() - # for i, (_, tokens_slice) in enumerate(ubatch_slices): - # is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start - # assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run - # num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \ - # model_inputs(tokens_slice, is_dummy_ubatch) - # process_batch( - # i, - # is_dummy_ubatch, - # is_dummy_run, - # attn_metadata, - # vllm_config, - # model, - # num_tokens, - # input_ids, - # positions, - # inputs_embeds, - # intermediate_tensors, - # results, - # ) - - if results: - return torch.cat(results, dim=0) - else: - return None - - # run micro-batched - if ubatch_slices is not None: - model_output = threaded_processing(ubatch_slices, - attn_metadata, - self.vllm_config, - self.model, - is_dummy_run) - # print("FINISHED MODEL OUTPUT") - # run single batch - else: + + + def _run(token_slice: slice, attn_metadata, use_dummy_input: bool = False, + ubatch_context: Optional[UBatchContext]=None, + setup_done_evt: Optional[threading.Event]=None): num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \ - model_inputs(slice(0, num_scheduled_tokens), is_dummy_run) + model_inputs(token_slice, use_dummy_input) with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_tokens): + if ubatch_context: + # Update the forward context now that its available + ubatch_context.update_forward_context() + + if setup_done_evt is not None: + # Wait for the setup to be done + setup_done_evt.set() + model_output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + return model_output + + @torch.inference_mode() + def _ubatch_thread(ubatch_ctx, root_stream, token_slice, attn_metadata, results, save_results, use_dummy_input, setup_done_evt): + with ubatch_ctx: + # an event to enable the start of the ubatch execution on the GPU + # since different ubatches may be on different streams than this one + # they all need to wait on the gpu kernels launched in + # _prepare_inputs before continuing + if torch.cuda.current_stream() != root_stream: + start_evt = torch.cuda.Event() + # Make sure we wait then record so we don't miss the event + torch.cuda.current_stream().wait_event(start_evt) + root_stream.record_event(start_evt) + + model_output = _run(token_slice, attn_metadata, use_dummy_input, ubatch_ctx, setup_done_evt) + + if save_results: + results.append(model_output.clone()) + + if torch.cuda.current_stream() != root_stream: + # Make the root stream for the ubatch to finish + # Make sure we wait then record so we don't miss the event + root_stream.wait_event(ubatch_ctx.done_evt) + torch.cuda.current_stream().record_event(ubatch_ctx.done_evt) + + def _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run): + results = [] + assert len(ubatch_slices) == 2, "Only two ubatches has been tested" + root_stream = torch.cuda.current_stream() + ubatch_ctxs = make_ubatch_context_chain(len(ubatch_slices), + stream=root_stream, # Only works currently if everything is run on the same stream + device=self.device) + setup_done = threading.Event() + ubatch_threads = [] + + for i, (_, tokens_slice) in enumerate(ubatch_slices): + is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start + assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run + + thread = threading.Thread(target=_ubatch_thread, args=( + ubatch_ctxs[i], + root_stream, + tokens_slice, + attn_metadata[i] if attn_metadata is not None else None, + results, + not is_dummy_ubatch or is_dummy_run, + is_dummy_ubatch or is_dummy_run, + setup_done, + )) + #ubatch_threads.append(thread) + thread.start() + setup_done.wait() + thread.join() + + # for thread in ubatch_threads: + # thread.join() + + torch.cuda.set_stream(root_stream) + return torch.cat(results, dim=0) + + # run micro-batched + if ubatch_slices is not None: + model_output = _run_ubatches( + ubatch_slices, attn_metadata, is_dummy_run) + # run single batch + else: + model_output = _run( + slice(0, num_scheduled_tokens), attn_metadata, is_dummy_run) + return model_output @torch.inference_mode() diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 82d82ac68b23e..5ed4fb91458bc 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -3,8 +3,10 @@ import threading import torch import torch._dynamo import torch.profiler as profiler +from typing import Optional from torch.library import Library from torch.library import custom_op, register_kernel +from vllm import forward_context class UBatchContext: """ @@ -20,11 +22,20 @@ class UBatchContext: self.schedule = schedule self.stream = stream self.original_stream = torch.cuda.current_stream() + self.done_evt = torch.cuda.Event() + self.forward_context = None + + def update_forward_context(self): + self.forward_context = forward_context._forward_context def __enter__(self): global _CURRENT_CONTEXT - self.original_stream = torch.cuda.current_stream() _CURRENT_CONTEXT[threading.get_ident()] = self + + self.original_stream = torch.cuda.current_stream() + self.original_forward_context = forward_context._forward_context + self.forward_context = self.original_forward_context + # Set micro-batch stream torch.cuda.set_stream(self.stream) return self @@ -32,8 +43,9 @@ class UBatchContext: def __exit__(self, exc_type, exc_val, exc_tb): global _CURRENT_CONTEXT _CURRENT_CONTEXT[threading.get_ident()] = None - # Restore the original stream + torch.cuda.set_stream(self.original_stream) + forward_context._forward_context = self.original_forward_context return False def yield_(self): @@ -42,8 +54,14 @@ class UBatchContext: # Wait for the next batch to signal back self.wait_event.wait() self.wait_event.clear() - # When we resume switch back to the microbatch stream + # When we resume i.e. switch back to this micro-batch, we make sure + # we have the correct stream and forward context torch.cuda.set_stream(self.stream) + forward_context._forward_context = self.forward_context + + def wait(self): + self.wait_event.wait() + self.wait_event.clear() _CURRENT_CONTEXT: dict = {} @@ -54,30 +72,37 @@ def yield_impl(schedule="default"): ctx.yield_() -# 2) Register kernel for CUDA +# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from +# optimizing it away (TODO: see if this is actually needed) @custom_op("vllm::yield_", mutates_args=("x",)) -def yield_(x: torch.Tensor, schedule="default") -> None: +def yield_(x: torch.Tensor, schedule: str="default") -> None: yield_impl(schedule) # 3) Fake implementation for shape prop and FX tracing @yield_.register_fake -def yield_(x: torch.Tensor, schedule="default") -> None: +def yield_(x: torch.Tensor, schedule: str="default") -> None: pass """ """ -def make_ubatch_context_chain(num_micro_batches: int) -> list[UBatchContext]: +def make_ubatch_context_chain( + num_micro_batches: int, + stream: Optional[torch.Stream] = None, + device: Optional[torch.device] = None +) -> list[UBatchContext]: """ Create a context manager for micro-batching synchronization. """ events = [threading.Event() for _ in range(num_micro_batches)] + device = device or torch.cuda.current_device() ctxs = [] for i in range(num_micro_batches): wait_event = events[i] signal_event = events[(i + 1) % num_micro_batches] - ctx = UBatchContext(torch.Stream(), wait_event, signal_event) + ctx = UBatchContext(stream or torch.cuda.Stream(device), + wait_event, signal_event) ctxs.append(ctx) return ctxs \ No newline at end of file