diff --git a/examples/basic-ub.py b/examples/basic-ub.py new file mode 100644 index 0000000000000..9e0fb2fd60df3 --- /dev/null +++ b/examples/basic-ub.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os + +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Configure logging level for vllm (optional, uses VLLM_LOGGING_LEVEL env var). +logging_level = os.getenv("VLLM_LOGGING_LEVEL", "").upper() +if logging_level: + logging.basicConfig(level=getattr(logging, logging_level, logging.INFO)) + +# Create a sampling params object, optionally limiting output tokens via MAX_TOKENS env var. +param_kwargs = {"temperature": 0.8, "top_p": 0.95} +max_tokens_env = os.getenv("MAX_TOKENS") +if max_tokens_env is not None: + try: + param_kwargs["max_tokens"] = int(max_tokens_env) + except ValueError: + raise ValueError(f"Invalid MAX_TOKENS value: {max_tokens_env}") +sampling_params = SamplingParams(**param_kwargs) + + +def main(): + # Create an LLM. + llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite", + enforce_eager=False, + compilation_config=2, + ############### + trust_remote_code=True, + max_model_len=1024, + #load_format="dummy", + ############### + tensor_parallel_size=2, + #data_parallel_size=2, + enable_expert_parallel=True, + ############### + enable_microbatching=True, + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/vllm/config.py b/vllm/config.py index aaf419a61a2d4..591a03654bc7d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1743,6 +1743,15 @@ class ParallelConfig: enable_microbatching: bool = False """Enable microbatching for the model executor.""" + + always_microbatch_if_enabled: bool = True + """Always microbatch if microbatching is enabled. Easier to sync bewteen + dp workers.""" + + microbatching_token_threshold: int = 4 + """The threshold for microbatching. If the number of tokens in the + request is greater than this threshold, microbatching will be used. + Otherwise, the request will be processed in a single batch.""" @property def world_size_across_dp(self) -> int: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f1cb77f64eae7..553d468018081 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -74,7 +74,7 @@ class FusedMoEParallelConfig: @property def use_pplx_kernels(self): - return self.dp_size > 1 and self.use_ep and has_pplx + return self.dp_size > 1 and self.use_ep and has_pplx and False @staticmethod def make(tp_size_: int, dp_size_: int, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 927968120a09e..768be3d145b3e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -74,7 +74,9 @@ AttnMetadataDict: TypeAlias = dict[str, FlashAttentionMetadata] # list when ubatching is enabled PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], AttnMetadataDict] -UBatchSlices: TypeAlias = Optional[list[tuple[slice, slice]]] + +UbatchSlice: TypeAlias = tuple[slice, slice] +UBatchSlices: TypeAlias = list[UbatchSlice] class GPUModelRunner(LoRAModelRunnerMixin): @@ -497,23 +499,39 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _ubatch_split( self, + query_start_loc_np: torch.Tensor, max_num_scheduled_tokens: int, scheduler_output: "SchedulerOutput" ) -> Optional[UBatchSlices]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_reqs = self.input_batch.num_reqs - if self.parallel_config.enable_microbatching and max_num_scheduled_tokens == 1: + if self.parallel_config.enable_microbatching and \ + total_num_scheduled_tokens >= self.parallel_config.microbatching_token_threshold \ + and max_num_scheduled_tokens == 1: # For pure decode we can just create ubatchs by cutting the request # in half b0_reqs_end = num_reqs // 2 b0_tokens_end = total_num_scheduled_tokens // 2 + assert b0_reqs_end < num_reqs and b0_tokens_end < total_num_scheduled_tokens return [ (slice(0, b0_reqs_end), slice(0, b0_tokens_end)), (slice(b0_reqs_end, num_reqs), slice(b0_tokens_end, total_num_scheduled_tokens)), ] + + if self.parallel_config.enable_microbatching and \ + self.parallel_config.always_microbatch_if_enabled: + # TODO we can do something more advanced here to try to balance, + # i.e. split to the left of `total_num_scheduled_tokens // 2` if it + # is more balanced + req_split_id = np.argmax(query_start_loc_np > (total_num_scheduled_tokens // 2)) + return [(slice(0, req_split_id), slice(0, query_start_loc_np[req_split_id])), + (slice(req_split_id, num_reqs), slice(query_start_loc_np[req_split_id], total_num_scheduled_tokens))] return None + + def _is_dummy_ubatch(self, ubatch_slice: UBatchSlices) -> bool: + return ubatch_slice[1].start >= ubatch_slice[1].stop def _prepare_inputs( self, scheduler_output: "SchedulerOutput" @@ -534,9 +552,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens = np.array(tokens, dtype=np.int32) max_num_scheduled_tokens = max(tokens) - ubatch_slices: Optional[UBatchSlices] = self._ubatch_split( - max_num_scheduled_tokens, scheduler_output) - # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] req_indices = np.repeat(self.arange_np[:num_reqs], @@ -608,6 +623,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + ubatch_slices: Optional[UBatchSlices] = self._ubatch_split( + self.query_start_loc_np, max_num_scheduled_tokens, scheduler_output) + self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) @@ -669,15 +687,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): if ubatch_slices is not None: for ubid, (req_slice, token_slice) in enumerate(ubatch_slices): - attn_metadata_i = ( - self.attn_metadata_builders[kv_cache_group_id]. - build_slice( - req_slice=req_slice, - token_slice=token_slice, - max_query_len=max(tokens[req_slice]), - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - )) + # Run a dummy batch if its a empty ubatch + if token_slice.stop <= token_slice.start: + attn_metadata_i = None + else: + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id]. + build_slice( + req_slice=req_slice, + token_slice=token_slice, + max_query_len=max(tokens[req_slice]), + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + )) for layer_name in kv_cache_group_spec.layer_names: assert type(attn_metadata) is list attn_metadata[ubid][layer_name] = attn_metadata_i @@ -1150,9 +1172,44 @@ class GPUModelRunner(LoRAModelRunnerMixin): for k, v in self.intermediate_tensors.items() }) + def _get_dummy_model_inputs(self, num_tokens: int) -> tuple: + # Dummy batch. (hopefully we are the last one so we can just + # update this to a one token batch and return) + + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device)) + + return input_ids, positions, inputs_embeds, intermediate_tensors + + def _get_model_inputs(self, tokens_slice: slice, scheduler_output: "SchedulerOutput"): num_tokens = tokens_slice.stop - tokens_slice.start + if num_tokens == 0: + # Dummy batch. (hopefully we are the last one so we can just + # update this to a one token batch and return) + tokens_slice = slice(tokens_slice.start, tokens_slice.start + 1) + num_tokens = 1 + if (self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]): # Use piecewise CUDA graphs. @@ -1217,6 +1274,69 @@ class GPUModelRunner(LoRAModelRunnerMixin): tokens_slice, intermediate_tensors, True) return input_ids, positions, inputs_embeds, intermediate_tensors + def _run_model(self, + attn_metadata: Optional[PerLayerAttnMetadata], + num_scheduled_tokens: Optional[int], + ubatch_slices: Optional[UBatchSlices] = None, + scheduler_output: Optional["SchedulerOutput"] = None, + is_dummy_run: bool = False): + + def model_inputs(tokens_slice: slice, is_dummy: bool) -> tuple: + if is_dummy: + num_tokens = num_scheduled_tokens or 1 + return num_tokens, *self._get_dummy_model_inputs(num_tokens) + else: + assert scheduler_output is not None + num_tokens = tokens_slice.stop - tokens_slice.start + return num_tokens, *self._get_model_inputs(tokens_slice, scheduler_output) + + # run micro-batched + if ubatch_slices is not None: + model_outputs = [] + for i, (_, tokens_slice) in enumerate(ubatch_slices): + is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start + # only support the last ubatch being a dummy ubatch, or all batches, + # i.e. dummy_run for other DP workers + assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run + + num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \ + model_inputs(tokens_slice, is_dummy_ubatch) + + + with set_forward_context(attn_metadata[i] if attn_metadata is not None else None, + self.vllm_config, + num_tokens=num_tokens): + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + # clone is important for eventually piecewise cuda-graphs + # drop the ouputs its a dummy ubatch but not a dummy run + # In a dummy run is all the ubatches are dummy so we to + # still pass some output, when its not a dummy run we + # want the output to match what it would be if we had run + # without the dummy ubatch. + if not is_dummy_ubatch or is_dummy_run: + model_outputs.append(model_output.clone()) + model_output = torch.cat(model_outputs, dim=0) + # run single batch + else: + num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \ + model_inputs(slice(0, num_scheduled_tokens), is_dummy_run) + with set_forward_context(attn_metadata, + self.vllm_config, + num_tokens=num_tokens): + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return model_output + @torch.inference_mode() def execute_model( self, @@ -1240,42 +1360,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Run the decoder. # Use persistent buffers for CUDA graphs. self.maybe_setup_kv_connector(scheduler_output) - - if ubatch_slices is not None: - model_outputs = [] - for i, (_, tokens_slice) in enumerate(ubatch_slices): - input_ids, positions, inputs_embeds, intermediate_tensors = \ - self._get_model_inputs(tokens_slice, scheduler_output) - num_input_token = tokens_slice.stop - tokens_slice.start - - with set_forward_context(attn_metadata[i], - self.vllm_config, - num_tokens=num_input_token): - model_output = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - - # clone is important for eventually piecewise cuda-graphs - model_outputs.append(model_output.clone()) - model_output = torch.cat(model_outputs, dim=0) - else: - input_ids, positions, inputs_embeds, intermediate_tensors = \ - self._get_model_inputs(slice(0, num_scheduled_tokens), - scheduler_output) - - with set_forward_context(attn_metadata, - self.vllm_config, - num_tokens=num_scheduled_tokens): - model_output = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - + model_output = self._run_model( + attn_metadata, + num_scheduled_tokens, + ubatch_slices, + scheduler_output, + ) self.maybe_wait_for_kv_save() finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) @@ -1717,6 +1807,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): self, num_tokens: int, skip_attn: bool = True, + # For profiling runs we dont want microbatching but for + # dp dummy runs we do. + allow_microbatching: bool = False, ) -> torch.Tensor: # Set num_scheduled_tokens based on num_tokens and max_num_seqs @@ -1755,43 +1848,22 @@ class GPUModelRunner(LoRAModelRunnerMixin): )) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i + + should_microbatch = ( + allow_microbatching + and self.vllm_config.parallel_config.enable_microbatching + and self.vllm_config.parallel_config.always_microbatch_if_enabled + ) + dummy_microbatches = [(slice(0, 0), slice(0, 0)), (slice(0, 0), slice(0, 0))] with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): - model = self.model - if self.is_multimodal_model: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] - else: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = None - if self.uses_mrope: - positions = self.mrope_positions[:, :num_tokens] - else: - positions = self.positions[:num_tokens] - - if get_pp_group().is_first_rank: - intermediate_tensors = None - else: - if self.intermediate_tensors is None: - self.intermediate_tensors = ( - self.model.make_empty_intermediate_tensors( - batch_size=self.max_num_tokens, - dtype=self.model_config.dtype, - device=self.device)) - - intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_tokens, None, False) - - with set_forward_context(attn_metadata, - self.vllm_config, - num_tokens=num_tokens): - outputs = model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) + outputs = self._run_model( + attn_metadata, + num_tokens, + ubatch_slices=None if not should_microbatch else dummy_microbatches, + is_dummy_run=True, + ) if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 93129d9879401..2ce07acbb8938 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -294,7 +294,8 @@ class Worker(WorkerBase): self.profiler.stop() def execute_dummy_batch(self) -> None: - self.model_runner._dummy_run(1) + # TODO: adding allow_microbatching will break non-gpu backends + self.model_runner._dummy_run(1, allow_microbatching=True) def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index a08bf4b41d587..5634fd54c415e 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -1,5 +1,72 @@ # SPDX-License-Identifier: Apache-2.0 -class UBatchContext: +import threading +import torch +import torch._dynamo +import torch.profiler as profiler +from torch.library import Library +from torch.library import custom_op, register_kernel - def __init__(self, ubatch_id: int): - self.ubatch_id = ubatch_id +class UBatchContext: + """ + Context manager for micro-batching synchronization using threading events. + """ + def __init__(self, wait_event: threading.Event, signal_event: threading.Event): + self.wait_event = wait_event + self.signal_event = signal_event + + def __enter__(self): + global _CURRENT_CONTEXT + _CURRENT_CONTEXT[threading.get_ident()] = self + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + global _CURRENT_CONTEXT + _CURRENT_CONTEXT[threading.get_ident()] = None + return False + + def yield_(self): + # Signal that this batch reached the barrier and wait for the other + self.signal_event.set() + self.wait_event.wait() + # Reset for reuse + self.wait_event.clear() + +_CURRENT_CONTEXT: dict = {} + +def yield_impl(): + # Perform the barrier if a context exists for this thread + ctx = _CURRENT_CONTEXT.get(threading.get_ident(), None) + if ctx is not None: + ctx.yield_() + + +# 2) Register kernel for CUDA +@custom_op("custom::yield_", mutates_args=("x",)) +def yield_(x: torch.Tensor) -> None: + yield_impl() + +# 3) Fake implementation for shape prop and FX tracing +@yield_.register_fake +def yield_(x: torch.Tensor): + pass + +""" + +""" +def make_ubatch_context_chain(num_micro_batches: int) -> list[UBatchContext]: + """ + Create a context manager for micro-batching synchronization. + """ + events = [threading.Event() for _ in range(num_micro_batches)] + + ctxs = [] + for i in range(num_micro_batches): + wait_event = events[i] + signal_event = events[(i + 1) % num_micro_batches] + ctx = UBatchContext(wait_event, signal_event) + ctxs.append(ctx) + + # Create the context manager + ctx = UBatchContext(wait_event, signal_event) + + return ctx \ No newline at end of file