From ffb740ae95518450c533dda4d614b6d24701a96e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 20 May 2025 23:55:31 +0000 Subject: [PATCH] manually manage stream Signed-off-by: Lucas Wilkinson --- examples/offline_inference/basic/basic.py | 3 +- vllm/config.py | 8 ++++- vllm/v1/worker/gpu_model_runner.py | 15 ++++----- vllm/v1/worker/ubatching.py | 37 +++++++++++++++-------- 4 files changed, 41 insertions(+), 22 deletions(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index 15aa8dd17e64b..0aeaae9f6e351 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -34,8 +34,9 @@ def main(): enforce_eager=False, compilation_config=2, enable_microbatching=True, + enable_expert_parallel=True, trust_remote_code=True, - tensor_parallel_size=4, + tensor_parallel_size=2, max_model_len=1024, #load_format="dummy", ) diff --git a/vllm/config.py b/vllm/config.py index 591a03654bc7d..161d16533c554 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4328,7 +4328,13 @@ class VllmConfig: if self.parallel_config.enable_microbatching: # Microbatching is not supported with piecewise compilation yet. # More specifically piecewise cuda-graphs - self.compilation_config.level = CompilationLevel.DYNAMO_ONCE + if self.compilation_config.level >= CompilationLevel.PIECEWISE: + logger.warning_once( + "Piecewise compilation is not supported with " + "microbatching. Disabling piecewiseching compilation.") + self.parallel_config.enable_microbatching = False + self.compilation_config.level = CompilationLevel.DYNAMO_ONCE + if self.model_config and self.model_config.use_mla and \ not (current_platform.is_cuda() or current_platform.is_rocm()): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ca8b62faac70c..0c3a9084f3301 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1294,10 +1294,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): return num_tokens, *self._get_model_inputs(tokens_slice, scheduler_output) @torch.inference_mode() - def process_batch(i, is_dummy_ubatch, is_dummy_run, attn_metadata, vllm_config, model, num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors, results): - with set_forward_context(attn_metadata[i] if attn_metadata is not None else None, + def process_batch(save_results, attn_metadata, vllm_config, model, num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors, results, stream): + with set_forward_context(attn_metadata, vllm_config, num_tokens=num_tokens): + torch.cuda.set_stream(stream) + model_output = model( input_ids=input_ids, positions=positions, @@ -1305,7 +1307,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): inputs_embeds=inputs_embeds, ) - if not is_dummy_ubatch or is_dummy_run: + if save_results: results.append(model_output.clone()) def threaded_processing(ubatch_slices, attn_metadata, vllm_config, model, is_dummy_run=False): @@ -1320,10 +1322,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): model_inputs(tokens_slice, is_dummy_ubatch) thread = threading.Thread(target=process_batch, args=( - i, - is_dummy_ubatch, - is_dummy_run, - attn_metadata, + not is_dummy_ubatch or is_dummy_run, + attn_metadata[i] if attn_metadata is not None else None, vllm_config, model, num_tokens, @@ -1332,6 +1332,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): inputs_embeds, intermediate_tensors, results, + torch.cuda.current_stream() )) thread.start() thread.join() diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 5634fd54c415e..82d82ac68b23e 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -10,44 +10,58 @@ class UBatchContext: """ Context manager for micro-batching synchronization using threading events. """ - def __init__(self, wait_event: threading.Event, signal_event: threading.Event): + def __init__(self, + stream: torch.cuda.Stream, + wait_event: threading.Event, + signal_event: threading.Event, + schedule="default"): self.wait_event = wait_event self.signal_event = signal_event + self.schedule = schedule + self.stream = stream + self.original_stream = torch.cuda.current_stream() def __enter__(self): global _CURRENT_CONTEXT + self.original_stream = torch.cuda.current_stream() _CURRENT_CONTEXT[threading.get_ident()] = self + # Set micro-batch stream + torch.cuda.set_stream(self.stream) return self def __exit__(self, exc_type, exc_val, exc_tb): global _CURRENT_CONTEXT _CURRENT_CONTEXT[threading.get_ident()] = None + # Restore the original stream + torch.cuda.set_stream(self.original_stream) return False def yield_(self): # Signal that this batch reached the barrier and wait for the other self.signal_event.set() + # Wait for the next batch to signal back self.wait_event.wait() - # Reset for reuse self.wait_event.clear() + # When we resume switch back to the microbatch stream + torch.cuda.set_stream(self.stream) _CURRENT_CONTEXT: dict = {} -def yield_impl(): +def yield_impl(schedule="default"): # Perform the barrier if a context exists for this thread ctx = _CURRENT_CONTEXT.get(threading.get_ident(), None) - if ctx is not None: + if ctx is not None and ctx.schedule == schedule: ctx.yield_() # 2) Register kernel for CUDA -@custom_op("custom::yield_", mutates_args=("x",)) -def yield_(x: torch.Tensor) -> None: - yield_impl() +@custom_op("vllm::yield_", mutates_args=("x",)) +def yield_(x: torch.Tensor, schedule="default") -> None: + yield_impl(schedule) # 3) Fake implementation for shape prop and FX tracing @yield_.register_fake -def yield_(x: torch.Tensor): +def yield_(x: torch.Tensor, schedule="default") -> None: pass """ @@ -63,10 +77,7 @@ def make_ubatch_context_chain(num_micro_batches: int) -> list[UBatchContext]: for i in range(num_micro_batches): wait_event = events[i] signal_event = events[(i + 1) % num_micro_batches] - ctx = UBatchContext(wait_event, signal_event) + ctx = UBatchContext(torch.Stream(), wait_event, signal_event) ctxs.append(ctx) - # Create the context manager - ctx = UBatchContext(wait_event, signal_event) - - return ctx \ No newline at end of file + return ctxs \ No newline at end of file