diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 10b5d42ebaa30..160a276426ff4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -75,6 +75,7 @@ from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +import dataclasses if TYPE_CHECKING: import xgrammar as xgr @@ -98,7 +99,14 @@ PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], UbatchSlice: TypeAlias = tuple[slice, slice] UBatchSlices: TypeAlias = list[UbatchSlice] -import dataclasses +@dataclasses.dataclass +class UbatchMetadata: + context: UBatchContext + input_ids: torch.Tensor + positions: torch.Tensor + inputs_embeds: Optional[torch.Tensor] + intermediate_tensors: Optional[IntermediateTensors] + class GPUModelRunner(LoRAModelRunnerMixin): @@ -1498,85 +1506,58 @@ class GPUModelRunner(LoRAModelRunnerMixin): tokens_slice, intermediate_tensors, True) return input_ids, positions, inputs_embeds, intermediate_tensors - def _run_model(self, - attn_metadata: Optional[PerLayerAttnMetadata], - num_scheduled_tokens: Optional[int], - ubatch_slices: Optional[UBatchSlices] = None, - scheduler_output: Optional["SchedulerOutput"] = None, - is_dummy_run: bool = False, - num_tokens_across_dp: Optional[torch.Tensor] = None, - skip_cuda_graphs: bool = False): - - @dataclasses.dataclass - class UbatchMetadata: - context: UBatchContext - input_ids: torch.Tensor - positions: torch.Tensor - inputs_embeds: Optional[torch.Tensor] - intermediate_tensors: Optional[IntermediateTensors] + def model_inputs(self, tokens_slice: slice, use_dummy_input: bool, + scheduler_output: Optional["SchedulerOutput"]) -> tuple: + if use_dummy_input: + return self._get_dummy_model_inputs(tokens_slice.stop - tokens_slice.start) + else: + assert scheduler_output is not None + return self._get_model_inputs(tokens_slice, scheduler_output) - num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1 - def _make_ubatch_contexts(ubatch_slices, - attn_metadata, - compute_stream, - num_tokens_across_dp, - skip_cuda_graphs) -> list[UBatchContext]: - ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices), - compute_stream=compute_stream, - device=self.device) + def _make_ubatch_metadata(self, + ubatch_slices, + attn_metadata, + compute_stream, + is_dummy_run, + num_tokens_across_dp, + skip_cuda_graphs, + scheduler_output) -> list[UbatchMetadata]: - for i, (_, tokens_slice) in enumerate(ubatch_slices): - num_tokens = (tokens_slice.stop - tokens_slice.start) - # TODO (Sage) Instead of using this setter we should be able - # to just create the forward context in advance and pass it - # to the UBatchContext's __init__ method - ubatch_ctxs[i].forward_context = create_forward_context( - attn_metadata[i] - if attn_metadata is not None else None, - self.vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs) - return ubatch_ctxs - - def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple: - if use_dummy_input: - assert num_dummy_tokens == tokens_slice.stop - tokens_slice.start - return self._get_dummy_model_inputs(num_dummy_tokens) - else: - assert scheduler_output is not None - return self._get_model_inputs(tokens_slice, scheduler_output) - - def _make_ubatch_metadata(ubatch_slices, - attn_metadata, - compute_stream, - is_dummy_run, - num_tokens_across_dp, - skip_cuda_graphs) -> list[UbatchMetadata]: - ubatch_ctxs = _make_ubatch_contexts( - ubatch_slices=ubatch_slices, - attn_metadata=attn_metadata, - compute_stream=compute_stream, + # Create one forward context per ubatch + forward_contexts = [] + for i, (_, tokens_slice) in enumerate(ubatch_slices): + num_tokens = (tokens_slice.stop - tokens_slice.start) + forward_contexts.append(create_forward_context( + attn_metadata[i] + if attn_metadata is not None else None, + self.vllm_config, + num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs - ) + skip_cuda_graphs=skip_cuda_graphs)) - ubatch_metadata: list[UbatchMetadata] = [] - for i, (_, tokens_slice) in enumerate(ubatch_slices): - input_ids, positions, inputs_embeds, intermediate_tensors = \ - model_inputs(tokens_slice, is_dummy_run) - ubatch_metadata.append(UbatchMetadata( - context=ubatch_ctxs[i], - input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors - )) - - return ubatch_metadata + ubatch_ctxs = make_ubatch_contexts(num_micro_batches=len(ubatch_slices), + compute_stream=compute_stream, + forward_contexts=forward_contexts, + device=self.device) + ubatch_metadata: list[UbatchMetadata] = [] + for i, (_, tokens_slice) in enumerate(ubatch_slices): + input_ids, positions, inputs_embeds, intermediate_tensors = \ + self.model_inputs(tokens_slice, is_dummy_run, scheduler_output) + ubatch_metadata.append(UbatchMetadata( + context=ubatch_ctxs[i], + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors + )) + + return ubatch_metadata + + + def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor: @torch.inference_mode() def _ubatch_thread(results, model, ubatch_metadata): with ubatch_metadata.context: @@ -1588,49 +1569,59 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) results.append((ubatch_metadata.context.id, model_output)) - def _run_ubatches(ubatch_metadata, model) -> torch.Tensor: - results: list[tuple[int, torch.Tensor]] = [] + results: list[tuple[int, torch.Tensor]] = [] - # Ubatch threads will manually manage the forward context, so we - # override it to None here so we can have it restored correctly - # after both threads have finished - with override_forward_context(None): - ubatch_threads = [] - for metadata in ubatch_metadata: - thread = threading.Thread(target=_ubatch_thread, - args=( - results, - model, - metadata, - )) - ubatch_threads.append(thread) - thread.start() + # Ubatch threads will manually manage the forward context, so we + # override it to None here so we can have it restored correctly + # after both threads have finished + with override_forward_context(None): + ubatch_threads = [] + for metadata in ubatch_metadata: + thread = threading.Thread(target=_ubatch_thread, + args=( + results, + model, + metadata, + )) + ubatch_threads.append(thread) + thread.start() + + ubatch_metadata[0].context.cpu_wait_event.set() + for thread in ubatch_threads: + thread.join() + sorted_results = [value for position, value in sorted(results)] + result = torch.cat(sorted_results, dim=0) + return result + + def _run_model(self, + attn_metadata: Optional[PerLayerAttnMetadata], + num_scheduled_tokens: Optional[int], + ubatch_slices: Optional[UBatchSlices] = None, + scheduler_output: Optional["SchedulerOutput"] = None, + is_dummy_run: bool = False, + num_tokens_across_dp: Optional[torch.Tensor] = None, + skip_cuda_graphs: bool = False): - ubatch_metadata[0].context.cpu_wait_event.set() - for thread in ubatch_threads: - thread.join() - sorted_results = [value for position, value in sorted(results)] - result = torch.cat(sorted_results, dim=0) - return result # run micro-batched if ubatch_slices is not None: assert len(ubatch_slices) == 2, "Only two ubatches has been tested" compute_stream = torch.cuda.current_stream() - ubatch_metadata = _make_ubatch_metadata( + ubatch_metadata = self._make_ubatch_metadata( ubatch_slices=ubatch_slices, attn_metadata=attn_metadata, compute_stream=compute_stream, is_dummy_run=is_dummy_run, num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs + skip_cuda_graphs=skip_cuda_graphs, + scheduler_output=scheduler_output ) - return _run_ubatches(ubatch_metadata, self.model) + return self._run_ubatches(ubatch_metadata, self.model) # run normal batch else: input_ids, positions, inputs_embeds, intermediate_tensors = \ - model_inputs(slice(0, num_scheduled_tokens), is_dummy_run) + self.model_inputs(slice(0, num_scheduled_tokens), is_dummy_run, scheduler_output) with set_forward_context(attn_metadata, vllm_config=self.vllm_config, num_tokens=num_scheduled_tokens or 1, diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 33f33d7f78f85..160fbe75411c5 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -5,6 +5,7 @@ from typing import Optional import torch from vllm import forward_context +from vllm.forward_context import ForwardContext from vllm.utils import current_stream @@ -18,7 +19,7 @@ class UBatchContext: id: int, comm_stream: torch.cuda.Stream, compute_stream: torch.cuda.Stream, - #fwd_ctx: forward_context.ForwardContext, + forward_context: ForwardContext, cpu_wait_event: threading.Event, cpu_signal_event: threading.Event, gpu_comm_done_event: torch.cuda.Event, @@ -27,7 +28,7 @@ class UBatchContext: self.id = id self.comm_stream = comm_stream self.compute_stream = compute_stream - self.forward_context = None #fwd_ctx + self.forward_context = forward_context self.cpu_wait_event = cpu_wait_event self.cpu_signal_event = cpu_signal_event self.current_stream = compute_stream @@ -150,6 +151,7 @@ def yield_and_switch_from_comm_to_compute(schedule="default"): def make_ubatch_contexts( num_micro_batches: int, compute_stream: torch.cuda.Stream, + forward_contexts: list[ForwardContext], device: Optional[torch.device] = None, schedule: str = "default", ) -> list[UBatchContext]: @@ -167,11 +169,14 @@ def make_ubatch_contexts( device = device or torch.cuda.current_device() comm_stream = torch.cuda.Stream(device) + assert len(forward_contexts) == 2 + ctxs = [] for i in range(num_micro_batches): ctx = UBatchContext(id=i, compute_stream=compute_stream, comm_stream=comm_stream, + forward_context=forward_contexts[i], cpu_wait_event=cpu_events[i], cpu_signal_event=cpu_events[(i + 1) % num_micro_batches],