From 020269c4c5c50aeee18c2e9d8dcd479b6bea4dd8 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 20 May 2025 21:27:33 +0000 Subject: [PATCH] added multhreading support Signed-off-by: Sage Moore --- examples/basic-ub.py | 10 ++-- vllm/v1/worker/gpu_model_runner.py | 94 ++++++++++++++++++++++-------- 2 files changed, 76 insertions(+), 28 deletions(-) diff --git a/examples/basic-ub.py b/examples/basic-ub.py index 9e0fb2fd60df3..8f1fbc2a25420 100644 --- a/examples/basic-ub.py +++ b/examples/basic-ub.py @@ -30,17 +30,19 @@ sampling_params = SamplingParams(**param_kwargs) def main(): # Create an LLM. - llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite", - enforce_eager=False, + model = "deepseek-ai/DeepSeek-V2-Lite" + # model = "facebook/opt-125m" + llm = LLM(model=model, + enforce_eager=True, compilation_config=2, ############### trust_remote_code=True, max_model_len=1024, #load_format="dummy", ############### - tensor_parallel_size=2, + tensor_parallel_size=4, #data_parallel_size=2, - enable_expert_parallel=True, + enable_expert_parallel=False, ############### enable_microbatching=True, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 768be3d145b3e..ca8b62faac70c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -61,6 +61,9 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed + if TYPE_CHECKING: import xgrammar as xgr @@ -1289,39 +1292,82 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert scheduler_output is not None num_tokens = tokens_slice.stop - tokens_slice.start return num_tokens, *self._get_model_inputs(tokens_slice, scheduler_output) + + @torch.inference_mode() + def process_batch(i, is_dummy_ubatch, is_dummy_run, attn_metadata, vllm_config, model, num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors, results): + with set_forward_context(attn_metadata[i] if attn_metadata is not None else None, + vllm_config, + num_tokens=num_tokens): + model_output = model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + if not is_dummy_ubatch or is_dummy_run: + results.append(model_output.clone()) - # run micro-batched - if ubatch_slices is not None: - model_outputs = [] + def threaded_processing(ubatch_slices, attn_metadata, vllm_config, model, is_dummy_run=False): + results = [] + # print(f"UBATCH SLICES: {len(ubatch_slices)}") for i, (_, tokens_slice) in enumerate(ubatch_slices): + # print("ITERATION") is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start - # only support the last ubatch being a dummy ubatch, or all batches, - # i.e. dummy_run for other DP workers assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \ model_inputs(tokens_slice, is_dummy_ubatch) + thread = threading.Thread(target=process_batch, args=( + i, + is_dummy_ubatch, + is_dummy_run, + attn_metadata, + vllm_config, + model, + num_tokens, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + results, + )) + thread.start() + thread.join() + # for i, (_, tokens_slice) in enumerate(ubatch_slices): + # is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start + # assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run + # num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \ + # model_inputs(tokens_slice, is_dummy_ubatch) + # process_batch( + # i, + # is_dummy_ubatch, + # is_dummy_run, + # attn_metadata, + # vllm_config, + # model, + # num_tokens, + # input_ids, + # positions, + # inputs_embeds, + # intermediate_tensors, + # results, + # ) - with set_forward_context(attn_metadata[i] if attn_metadata is not None else None, - self.vllm_config, - num_tokens=num_tokens): - model_output = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - - # clone is important for eventually piecewise cuda-graphs - # drop the ouputs its a dummy ubatch but not a dummy run - # In a dummy run is all the ubatches are dummy so we to - # still pass some output, when its not a dummy run we - # want the output to match what it would be if we had run - # without the dummy ubatch. - if not is_dummy_ubatch or is_dummy_run: - model_outputs.append(model_output.clone()) - model_output = torch.cat(model_outputs, dim=0) + if results: + return torch.cat(results, dim=0) + else: + return None + + # run micro-batched + if ubatch_slices is not None: + model_output = threaded_processing(ubatch_slices, + attn_metadata, + self.vllm_config, + self.model, + is_dummy_run) + # print("FINISHED MODEL OUTPUT") # run single batch else: num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \