split some of the ubatching logic out of _run_model

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-03 20:26:56 +00:00
parent 908e9f8f54
commit 10ca263058
2 changed files with 97 additions and 101 deletions

View File

@ -75,6 +75,7 @@ from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
import dataclasses
if TYPE_CHECKING:
import xgrammar as xgr
@ -98,7 +99,14 @@ PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
UbatchSlice: TypeAlias = tuple[slice, slice]
UBatchSlices: TypeAlias = list[UbatchSlice]
import dataclasses
@dataclasses.dataclass
class UbatchMetadata:
context: UBatchContext
input_ids: torch.Tensor
positions: torch.Tensor
inputs_embeds: Optional[torch.Tensor]
intermediate_tensors: Optional[IntermediateTensors]
class GPUModelRunner(LoRAModelRunnerMixin):
@ -1498,85 +1506,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
tokens_slice, intermediate_tensors, True)
return input_ids, positions, inputs_embeds, intermediate_tensors
def _run_model(self,
attn_metadata: Optional[PerLayerAttnMetadata],
num_scheduled_tokens: Optional[int],
ubatch_slices: Optional[UBatchSlices] = None,
scheduler_output: Optional["SchedulerOutput"] = None,
is_dummy_run: bool = False,
num_tokens_across_dp: Optional[torch.Tensor] = None,
skip_cuda_graphs: bool = False):
@dataclasses.dataclass
class UbatchMetadata:
context: UBatchContext
input_ids: torch.Tensor
positions: torch.Tensor
inputs_embeds: Optional[torch.Tensor]
intermediate_tensors: Optional[IntermediateTensors]
def model_inputs(self, tokens_slice: slice, use_dummy_input: bool,
scheduler_output: Optional["SchedulerOutput"]) -> tuple:
if use_dummy_input:
return self._get_dummy_model_inputs(tokens_slice.stop - tokens_slice.start)
else:
assert scheduler_output is not None
return self._get_model_inputs(tokens_slice, scheduler_output)
num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1
def _make_ubatch_contexts(ubatch_slices,
attn_metadata,
compute_stream,
num_tokens_across_dp,
skip_cuda_graphs) -> list[UBatchContext]:
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
compute_stream=compute_stream,
device=self.device)
def _make_ubatch_metadata(self,
ubatch_slices,
attn_metadata,
compute_stream,
is_dummy_run,
num_tokens_across_dp,
skip_cuda_graphs,
scheduler_output) -> list[UbatchMetadata]:
for i, (_, tokens_slice) in enumerate(ubatch_slices):
num_tokens = (tokens_slice.stop - tokens_slice.start)
# TODO (Sage) Instead of using this setter we should be able
# to just create the forward context in advance and pass it
# to the UBatchContext's __init__ method
ubatch_ctxs[i].forward_context = create_forward_context(
attn_metadata[i]
if attn_metadata is not None else None,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs)
return ubatch_ctxs
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
if use_dummy_input:
assert num_dummy_tokens == tokens_slice.stop - tokens_slice.start
return self._get_dummy_model_inputs(num_dummy_tokens)
else:
assert scheduler_output is not None
return self._get_model_inputs(tokens_slice, scheduler_output)
def _make_ubatch_metadata(ubatch_slices,
attn_metadata,
compute_stream,
is_dummy_run,
num_tokens_across_dp,
skip_cuda_graphs) -> list[UbatchMetadata]:
ubatch_ctxs = _make_ubatch_contexts(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
compute_stream=compute_stream,
# Create one forward context per ubatch
forward_contexts = []
for i, (_, tokens_slice) in enumerate(ubatch_slices):
num_tokens = (tokens_slice.stop - tokens_slice.start)
forward_contexts.append(create_forward_context(
attn_metadata[i]
if attn_metadata is not None else None,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs
)
skip_cuda_graphs=skip_cuda_graphs))
ubatch_metadata: list[UbatchMetadata] = []
for i, (_, tokens_slice) in enumerate(ubatch_slices):
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(tokens_slice, is_dummy_run)
ubatch_metadata.append(UbatchMetadata(
context=ubatch_ctxs[i],
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors
))
return ubatch_metadata
ubatch_ctxs = make_ubatch_contexts(num_micro_batches=len(ubatch_slices),
compute_stream=compute_stream,
forward_contexts=forward_contexts,
device=self.device)
ubatch_metadata: list[UbatchMetadata] = []
for i, (_, tokens_slice) in enumerate(ubatch_slices):
input_ids, positions, inputs_embeds, intermediate_tensors = \
self.model_inputs(tokens_slice, is_dummy_run, scheduler_output)
ubatch_metadata.append(UbatchMetadata(
context=ubatch_ctxs[i],
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors
))
return ubatch_metadata
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
@torch.inference_mode()
def _ubatch_thread(results, model, ubatch_metadata):
with ubatch_metadata.context:
@ -1588,49 +1569,59 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
results.append((ubatch_metadata.context.id, model_output))
def _run_ubatches(ubatch_metadata, model) -> torch.Tensor:
results: list[tuple[int, torch.Tensor]] = []
results: list[tuple[int, torch.Tensor]] = []
# Ubatch threads will manually manage the forward context, so we
# override it to None here so we can have it restored correctly
# after both threads have finished
with override_forward_context(None):
ubatch_threads = []
for metadata in ubatch_metadata:
thread = threading.Thread(target=_ubatch_thread,
args=(
results,
model,
metadata,
))
ubatch_threads.append(thread)
thread.start()
# Ubatch threads will manually manage the forward context, so we
# override it to None here so we can have it restored correctly
# after both threads have finished
with override_forward_context(None):
ubatch_threads = []
for metadata in ubatch_metadata:
thread = threading.Thread(target=_ubatch_thread,
args=(
results,
model,
metadata,
))
ubatch_threads.append(thread)
thread.start()
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
return result
def _run_model(self,
attn_metadata: Optional[PerLayerAttnMetadata],
num_scheduled_tokens: Optional[int],
ubatch_slices: Optional[UBatchSlices] = None,
scheduler_output: Optional["SchedulerOutput"] = None,
is_dummy_run: bool = False,
num_tokens_across_dp: Optional[torch.Tensor] = None,
skip_cuda_graphs: bool = False):
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
return result
# run micro-batched
if ubatch_slices is not None:
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
compute_stream = torch.cuda.current_stream()
ubatch_metadata = _make_ubatch_metadata(
ubatch_metadata = self._make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
compute_stream=compute_stream,
is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs
skip_cuda_graphs=skip_cuda_graphs,
scheduler_output=scheduler_output
)
return _run_ubatches(ubatch_metadata, self.model)
return self._run_ubatches(ubatch_metadata, self.model)
# run normal batch
else:
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
self.model_inputs(slice(0, num_scheduled_tokens), is_dummy_run, scheduler_output)
with set_forward_context(attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_scheduled_tokens or 1,

View File

@ -5,6 +5,7 @@ from typing import Optional
import torch
from vllm import forward_context
from vllm.forward_context import ForwardContext
from vllm.utils import current_stream
@ -18,7 +19,7 @@ class UBatchContext:
id: int,
comm_stream: torch.cuda.Stream,
compute_stream: torch.cuda.Stream,
#fwd_ctx: forward_context.ForwardContext,
forward_context: ForwardContext,
cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event,
gpu_comm_done_event: torch.cuda.Event,
@ -27,7 +28,7 @@ class UBatchContext:
self.id = id
self.comm_stream = comm_stream
self.compute_stream = compute_stream
self.forward_context = None #fwd_ctx
self.forward_context = forward_context
self.cpu_wait_event = cpu_wait_event
self.cpu_signal_event = cpu_signal_event
self.current_stream = compute_stream
@ -150,6 +151,7 @@ def yield_and_switch_from_comm_to_compute(schedule="default"):
def make_ubatch_contexts(
num_micro_batches: int,
compute_stream: torch.cuda.Stream,
forward_contexts: list[ForwardContext],
device: Optional[torch.device] = None,
schedule: str = "default",
) -> list[UBatchContext]:
@ -167,11 +169,14 @@ def make_ubatch_contexts(
device = device or torch.cuda.current_device()
comm_stream = torch.cuda.Stream(device)
assert len(forward_contexts) == 2
ctxs = []
for i in range(num_micro_batches):
ctx = UBatchContext(id=i,
compute_stream=compute_stream,
comm_stream=comm_stream,
forward_context=forward_contexts[i],
cpu_wait_event=cpu_events[i],
cpu_signal_event=cpu_events[(i + 1) %
num_micro_batches],