manually manage stream

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-20 23:55:31 +00:00
parent 020269c4c5
commit ffb740ae95
4 changed files with 41 additions and 22 deletions

View File

@ -34,8 +34,9 @@ def main():
enforce_eager=False,
compilation_config=2,
enable_microbatching=True,
enable_expert_parallel=True,
trust_remote_code=True,
tensor_parallel_size=4,
tensor_parallel_size=2,
max_model_len=1024,
#load_format="dummy",
)

View File

@ -4328,7 +4328,13 @@ class VllmConfig:
if self.parallel_config.enable_microbatching:
# Microbatching is not supported with piecewise compilation yet.
# More specifically piecewise cuda-graphs
self.compilation_config.level = CompilationLevel.DYNAMO_ONCE
if self.compilation_config.level >= CompilationLevel.PIECEWISE:
logger.warning_once(
"Piecewise compilation is not supported with "
"microbatching. Disabling piecewiseching compilation.")
self.parallel_config.enable_microbatching = False
self.compilation_config.level = CompilationLevel.DYNAMO_ONCE
if self.model_config and self.model_config.use_mla and \
not (current_platform.is_cuda() or current_platform.is_rocm()):

View File

@ -1294,10 +1294,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return num_tokens, *self._get_model_inputs(tokens_slice, scheduler_output)
@torch.inference_mode()
def process_batch(i, is_dummy_ubatch, is_dummy_run, attn_metadata, vllm_config, model, num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors, results):
with set_forward_context(attn_metadata[i] if attn_metadata is not None else None,
def process_batch(save_results, attn_metadata, vllm_config, model, num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors, results, stream):
with set_forward_context(attn_metadata,
vllm_config,
num_tokens=num_tokens):
torch.cuda.set_stream(stream)
model_output = model(
input_ids=input_ids,
positions=positions,
@ -1305,7 +1307,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds=inputs_embeds,
)
if not is_dummy_ubatch or is_dummy_run:
if save_results:
results.append(model_output.clone())
def threaded_processing(ubatch_slices, attn_metadata, vllm_config, model, is_dummy_run=False):
@ -1320,10 +1322,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
model_inputs(tokens_slice, is_dummy_ubatch)
thread = threading.Thread(target=process_batch, args=(
i,
is_dummy_ubatch,
is_dummy_run,
attn_metadata,
not is_dummy_ubatch or is_dummy_run,
attn_metadata[i] if attn_metadata is not None else None,
vllm_config,
model,
num_tokens,
@ -1332,6 +1332,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds,
intermediate_tensors,
results,
torch.cuda.current_stream()
))
thread.start()
thread.join()

View File

@ -10,44 +10,58 @@ class UBatchContext:
"""
Context manager for micro-batching synchronization using threading events.
"""
def __init__(self, wait_event: threading.Event, signal_event: threading.Event):
def __init__(self,
stream: torch.cuda.Stream,
wait_event: threading.Event,
signal_event: threading.Event,
schedule="default"):
self.wait_event = wait_event
self.signal_event = signal_event
self.schedule = schedule
self.stream = stream
self.original_stream = torch.cuda.current_stream()
def __enter__(self):
global _CURRENT_CONTEXT
self.original_stream = torch.cuda.current_stream()
_CURRENT_CONTEXT[threading.get_ident()] = self
# Set micro-batch stream
torch.cuda.set_stream(self.stream)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_CONTEXT
_CURRENT_CONTEXT[threading.get_ident()] = None
# Restore the original stream
torch.cuda.set_stream(self.original_stream)
return False
def yield_(self):
# Signal that this batch reached the barrier and wait for the other
self.signal_event.set()
# Wait for the next batch to signal back
self.wait_event.wait()
# Reset for reuse
self.wait_event.clear()
# When we resume switch back to the microbatch stream
torch.cuda.set_stream(self.stream)
_CURRENT_CONTEXT: dict = {}
def yield_impl():
def yield_impl(schedule="default"):
# Perform the barrier if a context exists for this thread
ctx = _CURRENT_CONTEXT.get(threading.get_ident(), None)
if ctx is not None:
if ctx is not None and ctx.schedule == schedule:
ctx.yield_()
# 2) Register kernel for CUDA
@custom_op("custom::yield_", mutates_args=("x",))
def yield_(x: torch.Tensor) -> None:
yield_impl()
@custom_op("vllm::yield_", mutates_args=("x",))
def yield_(x: torch.Tensor, schedule="default") -> None:
yield_impl(schedule)
# 3) Fake implementation for shape prop and FX tracing
@yield_.register_fake
def yield_(x: torch.Tensor):
def yield_(x: torch.Tensor, schedule="default") -> None:
pass
"""
@ -63,10 +77,7 @@ def make_ubatch_context_chain(num_micro_batches: int) -> list[UBatchContext]:
for i in range(num_micro_batches):
wait_event = events[i]
signal_event = events[(i + 1) % num_micro_batches]
ctx = UBatchContext(wait_event, signal_event)
ctx = UBatchContext(torch.Stream(), wait_event, signal_event)
ctxs.append(ctx)
# Create the context manager
ctx = UBatchContext(wait_event, signal_event)
return ctx
return ctxs