diff --git a/examples/basic-ub.py b/examples/basic-ub.py index 21ddff5dddbe8..397f586b6598b 100644 --- a/examples/basic-ub.py +++ b/examples/basic-ub.py @@ -40,7 +40,7 @@ def main(): max_model_len=1024, #load_format="dummy", ############### - tensor_parallel_size=2, + tensor_parallel_size=1, #data_parallel_size=2, enable_expert_parallel=False, ############### diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 5d2d95f18d2fa..a2641759ac555 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -58,20 +58,11 @@ def get_forward_context() -> ForwardContext: "Please use `set_forward_context` to set the forward context.") return _forward_context - -@contextmanager -def set_forward_context(attn_metadata: Any, - vllm_config: VllmConfig, - virtual_engine: int = 0, - num_tokens: int = 0): - """A context manager that stores the current forward context, - can be attention metadata, etc. - Here we can inject common logic for every model forward pass. - """ - global forward_start_time - need_to_track_batchsize = track_batchsize and attn_metadata is not None - if need_to_track_batchsize: - forward_start_time = time.perf_counter() +def create_forward_context(attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: int = 0 +): dp_metadata: Optional[DPMetadata] = None if vllm_config.parallel_config.data_parallel_size > 1: dp_size = vllm_config.parallel_config.data_parallel_size @@ -96,17 +87,49 @@ def set_forward_context(attn_metadata: Any, dp_metadata = DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu) - global _forward_context - prev_context = _forward_context - _forward_context = ForwardContext( + return ForwardContext( no_compile_layers=vllm_config.compilation_config. static_forward_context, virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata) +@contextmanager +def override_forward_context(forward_context: Optional[ForwardContext]): + """A context manager that overrides the current forward context. + This is used to override the forward context for a specific + forward pass. + """ + global _forward_context + prev_context = _forward_context + print("overriding forward context with", forward_context) + _forward_context = forward_context try: yield + finally: + _forward_context = prev_context + + +@contextmanager +def set_forward_context(attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: int = 0): + """A context manager that stores the current forward context, + can be attention metadata, etc. + Here we can inject common logic for every model forward pass. + """ + global forward_start_time + need_to_track_batchsize = track_batchsize and attn_metadata is not None + if need_to_track_batchsize: + forward_start_time = time.perf_counter() + + forward_context = create_forward_context( + attn_metadata, vllm_config, virtual_engine, num_tokens) + + try: + with override_forward_context(forward_context): + yield finally: global last_logging_time, batchsize_logging_interval if need_to_track_batchsize: @@ -140,5 +163,3 @@ def set_forward_context(attn_metadata: Any, logger.info(("Batchsize forward time stats " "(batchsize, count, median_time(ms)): %s"), forward_stats) - - _forward_context = prev_context diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a8becaaf5f74c..9c01abc9c6998 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6,6 +6,7 @@ import os import time import weakref from typing import TYPE_CHECKING, Optional, TypeAlias, Union +import contextlib import numpy as np import torch @@ -26,7 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, prepare_communication_buffer_for_model) -from vllm.forward_context import get_forward_context, set_forward_context +from vllm.forward_context import get_forward_context, set_forward_context, create_forward_context, override_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model @@ -1289,93 +1290,93 @@ class GPUModelRunner(LoRAModelRunnerMixin): def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple: if use_dummy_input: num_tokens = num_scheduled_tokens or 1 - return num_tokens, *self._get_dummy_model_inputs(num_tokens) + return self._get_dummy_model_inputs(num_tokens) else: assert scheduler_output is not None - num_tokens = tokens_slice.stop - tokens_slice.start - return num_tokens, *self._get_model_inputs(tokens_slice, scheduler_output) + return self._get_model_inputs(tokens_slice, scheduler_output) - def _run(token_slice: slice, attn_metadata, use_dummy_input: bool = False, - ubatch_context: Optional[UBatchContext]=None, - setup_done_evt: Optional[threading.Event]=None): - num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \ + def _run(token_slice: slice, context, use_dummy_input: bool = False): + input_ids, positions, inputs_embeds, intermediate_tensors = \ model_inputs(token_slice, use_dummy_input) - with set_forward_context(attn_metadata, - self.vllm_config, - num_tokens=num_tokens): - if ubatch_context: - # Update the forward context now that its available - ubatch_context.update_forward_context() - - if setup_done_evt is not None: - # Wait for the setup to be done - setup_done_evt.set() - + with context: + if isinstance(context, UBatchContext): + print("running ubatch ctx", context.id) model_output = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + if isinstance(context, UBatchContext): + print("done ubatch ctx", context.id) + if isinstance(context, UBatchContext): + # Clone before we leave the ubatch context + model_output = model_output.clone() + return model_output @torch.inference_mode() def _ubatch_thread(ubatch_ctx, root_stream, token_slice, attn_metadata, results, save_results, use_dummy_input, setup_done_evt): - with ubatch_ctx: - # an event to enable the start of the ubatch execution on the GPU - # since different ubatches may be on different streams than this one - # they all need to wait on the gpu kernels launched in - # _prepare_inputs before continuing - if current_stream() != root_stream: - start_evt = torch.cuda.Event() - # Make sure we wait then record so we don't miss the event - current_stream().wait_event(start_evt) - root_stream.record_event(start_evt) + ubatch_ctx.stream.wait_stream(root_stream) + + model_output = _run(token_slice, ubatch_ctx, use_dummy_input) - model_output = _run(token_slice, attn_metadata, use_dummy_input, ubatch_ctx, setup_done_evt) - - if save_results: - results.append(model_output.clone()) - - if current_stream() != root_stream: - # Make the root stream for the ubatch to finish - # Make sure we wait then record so we don't miss the event - root_stream.wait_event(ubatch_ctx.done_evt) - current_stream().record_event(ubatch_ctx.done_evt) + if save_results: + results.append(model_output) def _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run): results = [] assert len(ubatch_slices) == 2, "Only two ubatches has been tested" root_stream = current_stream() - ubatch_ctxs = make_ubatch_context_chain(len(ubatch_slices), - stream=root_stream, # Only works currently if everything is run on the same stream - device=self.device) + + if not hasattr(self, "ubatch_streams"): + # Create the ubatch streams + self.ubatch_streams = [torch.cuda.Stream(self.device) for _ in range(len(ubatch_slices))] + + ubatch_fwd_ctxs = [create_forward_context( + attn_metadata[i], self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start) + ) for i, (_, tokens_slice) in enumerate(ubatch_slices)] + ubatch_ctxs, start_hook = make_ubatch_context_chain( + len(ubatch_slices), + fwd_ctxs=ubatch_fwd_ctxs, + streams=self.ubatch_streams, #stream=root_stream, # Only works currently if everything is run on the same stream + device=self.device) setup_done = threading.Event() ubatch_threads = [] + + # Ubatches will manually manage the forward context, so we override + # it to None here so we can have it restored correctly later + with override_forward_context(None): + ubatch_threads = [] + for i, (_, tokens_slice) in enumerate(ubatch_slices): + is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start + assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run + + thread = threading.Thread(target=_ubatch_thread, args=( + ubatch_ctxs[i], + root_stream, + tokens_slice, + attn_metadata[i] if attn_metadata is not None else None, + results, + not is_dummy_ubatch or is_dummy_run, + is_dummy_ubatch or is_dummy_run, + setup_done, + )) + ubatch_threads.append(thread) + thread.start() + + # Single the first ubatch to start + start_hook(root_stream) + print("started first ubatch") + + for thread in ubatch_threads: + thread.join() + + for ubatch_ctx in ubatch_ctxs: + root_stream.wait_stream(ubatch_ctx.stream) - for i, (_, tokens_slice) in enumerate(ubatch_slices): - is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start - assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run - - thread = threading.Thread(target=_ubatch_thread, args=( - ubatch_ctxs[i], - root_stream, - tokens_slice, - attn_metadata[i] if attn_metadata is not None else None, - results, - not is_dummy_ubatch or is_dummy_run, - is_dummy_ubatch or is_dummy_run, - setup_done, - )) - #ubatch_threads.append(thread) - thread.start() - setup_done.wait() - thread.join() - - # for thread in ubatch_threads: - # thread.join() - + print("torch cat") torch.cuda.set_stream(root_stream) return torch.cat(results, dim=0) @@ -1386,7 +1387,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): # run single batch else: model_output = _run( - slice(0, num_scheduled_tokens), attn_metadata, is_dummy_run) + slice(0, num_scheduled_tokens), + set_forward_context( + attn_metadata, + vllm_config=self.vllm_config, + num_tokens=num_scheduled_tokens or 1), + is_dummy_run) return model_output diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 024f64a07ff8c..7449f196636f0 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -15,55 +15,59 @@ class UBatchContext: Context manager for micro-batching synchronization using threading events. """ def __init__(self, + id: int, stream: torch.cuda.Stream, - wait_event: threading.Event, - signal_event: threading.Event, + fwd_ctx: forward_context.ForwardContext, + cpu_wait_event: threading.Event, + cpu_signal_event: threading.Event, + gpu_wait_event: torch.cuda.Event, + gpu_signal_event: torch.cuda.Event, schedule="default"): - self.wait_event = wait_event - self.signal_event = signal_event - self.schedule = schedule + self.id = id self.stream = stream self.original_stream = current_stream() - self.done_evt = torch.cuda.Event() - self.forward_context = None - - def update_forward_context(self): - self.forward_context = forward_context._forward_context + self.forward_context = fwd_ctx + self.cpu_wait_event = cpu_wait_event + self.cpu_signal_event = cpu_signal_event + self.gpu_wait_event = gpu_wait_event + self.gpu_signal_event = gpu_signal_event + self.schedule = schedule + self.done_event = torch.cuda.Event() def __enter__(self): global _CURRENT_CONTEXT _CURRENT_CONTEXT[threading.get_ident()] = self - - self.original_stream = current_stream() - self.original_forward_context = forward_context._forward_context - self.forward_context = self.original_forward_context - - # Set micro-batch stream - torch.cuda.set_stream(self.stream) + self._wait() return self def __exit__(self, exc_type, exc_val, exc_tb): global _CURRENT_CONTEXT _CURRENT_CONTEXT[threading.get_ident()] = None - torch.cuda.set_stream(self.original_stream) - forward_context._forward_context = self.original_forward_context + self._signal() return False - def yield_(self): - # Signal that this batch reached the barrier and wait for the other - self.signal_event.set() - # Wait for the next batch to signal back - self.wait_event.wait() - self.wait_event.clear() + def _restore_context(self): # When we resume i.e. switch back to this micro-batch, we make sure # we have the correct stream and forward context torch.cuda.set_stream(self.stream) forward_context._forward_context = self.forward_context - - def wait(self): - self.wait_event.wait() - self.wait_event.clear() + + def yield_(self): + self._signal() + self._wait() + + def _signal(self): + # Wait for the next batch to signal back + self.gpu_signal_event.record(self.stream) + # Signal that this batch reached the barrier + self.cpu_signal_event.set() + + def _wait(self): + self.stream.wait_event(self.gpu_wait_event) + self.cpu_wait_event.wait() + self.cpu_wait_event.clear() + self._restore_context() _CURRENT_CONTEXT: dict = {} @@ -90,21 +94,34 @@ def yield_(x: torch.Tensor, schedule: str="default") -> None: """ def make_ubatch_context_chain( num_micro_batches: int, - stream: Optional[torch.Stream] = None, + fwd_ctxs: forward_context.ForwardContext, + streams: Optional[list[torch.Stream]] = None, device: Optional[torch.device] = None ) -> list[UBatchContext]: + assert num_micro_batches == 2, "only been tested with 2 micro-batches" + """ Create a context manager for micro-batching synchronization. """ - events = [threading.Event() for _ in range(num_micro_batches)] + cpu_events = [threading.Event() for _ in range(num_micro_batches)] + gpu_events = [torch.cuda.Event() for _ in range(num_micro_batches)] device = device or torch.cuda.current_device() ctxs = [] for i in range(num_micro_batches): - wait_event = events[i] - signal_event = events[(i + 1) % num_micro_batches] - ctx = UBatchContext(stream or torch.cuda.Stream(device), - wait_event, signal_event) + stream = (streams[i] if streams else None) or torch.cuda.Stream(device) + ctx = UBatchContext(id=i, + stream=stream, + fwd_ctx=fwd_ctxs[i], + cpu_wait_event=cpu_events[i], + cpu_signal_event=cpu_events[(i + 1) % num_micro_batches], + gpu_wait_event=gpu_events[i], + gpu_signal_event=gpu_events[(i + 1) % num_micro_batches], + ) ctxs.append(ctx) + + def start_hook(from_stream: torch.cuda.Stream): + ctxs[0].cpu_wait_event.set() + ctxs[0].gpu_wait_event.record(from_stream) - return ctxs \ No newline at end of file + return ctxs, start_hook \ No newline at end of file diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6ecdfa6204f2a..fac7efaa6753b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1417,7 +1417,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): if model_input.attn_metadata is not None: model_input.attn_metadata.enable_kv_scales_calculation = False - self.execute_model(model_input, kv_caches, intermediate_tensors) + import nvtx + with nvtx.annotate("execute_model"): + self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() if self.lora_config: self._remove_dummy_loras()