working but only on the same stream

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-21 04:45:41 +00:00
parent ffb740ae95
commit 04f11d97a0
3 changed files with 121 additions and 91 deletions

View File

@ -40,7 +40,7 @@ def main():
max_model_len=1024,
#load_format="dummy",
###############
tensor_parallel_size=4,
tensor_parallel_size=2,
#data_parallel_size=2,
enable_expert_parallel=False,
###############

View File

@ -57,6 +57,7 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.ubatching import make_ubatch_context_chain, UBatchContext
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
@ -1284,104 +1285,108 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: Optional["SchedulerOutput"] = None,
is_dummy_run: bool = False):
def model_inputs(tokens_slice: slice, is_dummy: bool) -> tuple:
if is_dummy:
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
if use_dummy_input:
num_tokens = num_scheduled_tokens or 1
return num_tokens, *self._get_dummy_model_inputs(num_tokens)
else:
assert scheduler_output is not None
num_tokens = tokens_slice.stop - tokens_slice.start
return num_tokens, *self._get_model_inputs(tokens_slice, scheduler_output)
@torch.inference_mode()
def process_batch(save_results, attn_metadata, vllm_config, model, num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors, results, stream):
with set_forward_context(attn_metadata,
vllm_config,
num_tokens=num_tokens):
torch.cuda.set_stream(stream)
model_output = model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
if save_results:
results.append(model_output.clone())
def threaded_processing(ubatch_slices, attn_metadata, vllm_config, model, is_dummy_run=False):
results = []
# print(f"UBATCH SLICES: {len(ubatch_slices)}")
for i, (_, tokens_slice) in enumerate(ubatch_slices):
# print("ITERATION")
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run
num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(tokens_slice, is_dummy_ubatch)
thread = threading.Thread(target=process_batch, args=(
not is_dummy_ubatch or is_dummy_run,
attn_metadata[i] if attn_metadata is not None else None,
vllm_config,
model,
num_tokens,
input_ids,
positions,
inputs_embeds,
intermediate_tensors,
results,
torch.cuda.current_stream()
))
thread.start()
thread.join()
# for i, (_, tokens_slice) in enumerate(ubatch_slices):
# is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
# assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run
# num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \
# model_inputs(tokens_slice, is_dummy_ubatch)
# process_batch(
# i,
# is_dummy_ubatch,
# is_dummy_run,
# attn_metadata,
# vllm_config,
# model,
# num_tokens,
# input_ids,
# positions,
# inputs_embeds,
# intermediate_tensors,
# results,
# )
if results:
return torch.cat(results, dim=0)
else:
return None
# run micro-batched
if ubatch_slices is not None:
model_output = threaded_processing(ubatch_slices,
attn_metadata,
self.vllm_config,
self.model,
is_dummy_run)
# print("FINISHED MODEL OUTPUT")
# run single batch
else:
def _run(token_slice: slice, attn_metadata, use_dummy_input: bool = False,
ubatch_context: Optional[UBatchContext]=None,
setup_done_evt: Optional[threading.Event]=None):
num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
model_inputs(token_slice, use_dummy_input)
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens):
if ubatch_context:
# Update the forward context now that its available
ubatch_context.update_forward_context()
if setup_done_evt is not None:
# Wait for the setup to be done
setup_done_evt.set()
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return model_output
@torch.inference_mode()
def _ubatch_thread(ubatch_ctx, root_stream, token_slice, attn_metadata, results, save_results, use_dummy_input, setup_done_evt):
with ubatch_ctx:
# an event to enable the start of the ubatch execution on the GPU
# since different ubatches may be on different streams than this one
# they all need to wait on the gpu kernels launched in
# _prepare_inputs before continuing
if torch.cuda.current_stream() != root_stream:
start_evt = torch.cuda.Event()
# Make sure we wait then record so we don't miss the event
torch.cuda.current_stream().wait_event(start_evt)
root_stream.record_event(start_evt)
model_output = _run(token_slice, attn_metadata, use_dummy_input, ubatch_ctx, setup_done_evt)
if save_results:
results.append(model_output.clone())
if torch.cuda.current_stream() != root_stream:
# Make the root stream for the ubatch to finish
# Make sure we wait then record so we don't miss the event
root_stream.wait_event(ubatch_ctx.done_evt)
torch.cuda.current_stream().record_event(ubatch_ctx.done_evt)
def _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run):
results = []
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
root_stream = torch.cuda.current_stream()
ubatch_ctxs = make_ubatch_context_chain(len(ubatch_slices),
stream=root_stream, # Only works currently if everything is run on the same stream
device=self.device)
setup_done = threading.Event()
ubatch_threads = []
for i, (_, tokens_slice) in enumerate(ubatch_slices):
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run
thread = threading.Thread(target=_ubatch_thread, args=(
ubatch_ctxs[i],
root_stream,
tokens_slice,
attn_metadata[i] if attn_metadata is not None else None,
results,
not is_dummy_ubatch or is_dummy_run,
is_dummy_ubatch or is_dummy_run,
setup_done,
))
#ubatch_threads.append(thread)
thread.start()
setup_done.wait()
thread.join()
# for thread in ubatch_threads:
# thread.join()
torch.cuda.set_stream(root_stream)
return torch.cat(results, dim=0)
# run micro-batched
if ubatch_slices is not None:
model_output = _run_ubatches(
ubatch_slices, attn_metadata, is_dummy_run)
# run single batch
else:
model_output = _run(
slice(0, num_scheduled_tokens), attn_metadata, is_dummy_run)
return model_output
@torch.inference_mode()

View File

@ -3,8 +3,10 @@ import threading
import torch
import torch._dynamo
import torch.profiler as profiler
from typing import Optional
from torch.library import Library
from torch.library import custom_op, register_kernel
from vllm import forward_context
class UBatchContext:
"""
@ -20,11 +22,20 @@ class UBatchContext:
self.schedule = schedule
self.stream = stream
self.original_stream = torch.cuda.current_stream()
self.done_evt = torch.cuda.Event()
self.forward_context = None
def update_forward_context(self):
self.forward_context = forward_context._forward_context
def __enter__(self):
global _CURRENT_CONTEXT
self.original_stream = torch.cuda.current_stream()
_CURRENT_CONTEXT[threading.get_ident()] = self
self.original_stream = torch.cuda.current_stream()
self.original_forward_context = forward_context._forward_context
self.forward_context = self.original_forward_context
# Set micro-batch stream
torch.cuda.set_stream(self.stream)
return self
@ -32,8 +43,9 @@ class UBatchContext:
def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_CONTEXT
_CURRENT_CONTEXT[threading.get_ident()] = None
# Restore the original stream
torch.cuda.set_stream(self.original_stream)
forward_context._forward_context = self.original_forward_context
return False
def yield_(self):
@ -42,8 +54,14 @@ class UBatchContext:
# Wait for the next batch to signal back
self.wait_event.wait()
self.wait_event.clear()
# When we resume switch back to the microbatch stream
# When we resume i.e. switch back to this micro-batch, we make sure
# we have the correct stream and forward context
torch.cuda.set_stream(self.stream)
forward_context._forward_context = self.forward_context
def wait(self):
self.wait_event.wait()
self.wait_event.clear()
_CURRENT_CONTEXT: dict = {}
@ -54,30 +72,37 @@ def yield_impl(schedule="default"):
ctx.yield_()
# 2) Register kernel for CUDA
# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from
# optimizing it away (TODO: see if this is actually needed)
@custom_op("vllm::yield_", mutates_args=("x",))
def yield_(x: torch.Tensor, schedule="default") -> None:
def yield_(x: torch.Tensor, schedule: str="default") -> None:
yield_impl(schedule)
# 3) Fake implementation for shape prop and FX tracing
@yield_.register_fake
def yield_(x: torch.Tensor, schedule="default") -> None:
def yield_(x: torch.Tensor, schedule: str="default") -> None:
pass
"""
"""
def make_ubatch_context_chain(num_micro_batches: int) -> list[UBatchContext]:
def make_ubatch_context_chain(
num_micro_batches: int,
stream: Optional[torch.Stream] = None,
device: Optional[torch.device] = None
) -> list[UBatchContext]:
"""
Create a context manager for micro-batching synchronization.
"""
events = [threading.Event() for _ in range(num_micro_batches)]
device = device or torch.cuda.current_device()
ctxs = []
for i in range(num_micro_batches):
wait_event = events[i]
signal_event = events[(i + 1) % num_micro_batches]
ctx = UBatchContext(torch.Stream(), wait_event, signal_event)
ctx = UBatchContext(stream or torch.cuda.Stream(device),
wait_event, signal_event)
ctxs.append(ctx)
return ctxs