Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-02 19:04:26 +00:00
parent 8ea80fca4a
commit 92e0cc79a8
2 changed files with 153 additions and 141 deletions

View File

@ -2,18 +2,16 @@
import copy
import gc
import os
import threading
import time
import weakref
from typing import TYPE_CHECKING, Optional, TypeAlias, Union
import contextlib
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
from vllm.utils import current_stream
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadataBuilder)
@ -27,7 +25,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture,
prepare_communication_buffer_for_model)
from vllm.forward_context import get_forward_context, set_forward_context, create_forward_context, override_forward_context
from vllm.forward_context import (create_forward_context, get_forward_context,
override_forward_context,
set_forward_context)
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
@ -38,7 +38,7 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
is_pin_memory_available)
current_stream, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@ -59,14 +59,11 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.ubatching import make_ubatch_contexts, UBatchContext
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
if TYPE_CHECKING:
import xgrammar as xgr
@ -504,40 +501,41 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.refresh_sampling_metadata()
def _ubatch_split(
self,
query_start_loc_np: torch.Tensor,
max_num_scheduled_tokens: int,
scheduler_output: "SchedulerOutput"
) -> Optional[UBatchSlices]:
self, query_start_loc_np: torch.Tensor,
max_num_scheduled_tokens: int,
scheduler_output: "SchedulerOutput") -> Optional[UBatchSlices]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_reqs = self.input_batch.num_reqs
if self.parallel_config.enable_microbatching and \
total_num_scheduled_tokens >= self.parallel_config.microbatching_token_threshold \
and max_num_scheduled_tokens == 1:
total_num_scheduled_tokens >= \
self.parallel_config.microbatching_token_threshold \
and max_num_scheduled_tokens == 1:
# For pure decode we can just create ubatchs by cutting the request
# in half
b0_reqs_end = num_reqs // 2
b0_tokens_end = total_num_scheduled_tokens // 2
assert b0_reqs_end < num_reqs and b0_tokens_end < total_num_scheduled_tokens
assert b0_reqs_end < num_reqs and \
b0_tokens_end < total_num_scheduled_tokens
return [
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
(slice(b0_reqs_end, num_reqs),
slice(b0_tokens_end, total_num_scheduled_tokens)),
]
if self.parallel_config.enable_microbatching and \
self.parallel_config.always_microbatch_if_enabled:
# TODO we can do something more advanced here to try to balance,
# i.e. split to the left of `total_num_scheduled_tokens // 2` if it
# is more balanced
req_split_id = np.argmax(query_start_loc_np > (total_num_scheduled_tokens // 2))
return [(slice(0, req_split_id), slice(0, query_start_loc_np[req_split_id])),
(slice(req_split_id, num_reqs), slice(query_start_loc_np[req_split_id], total_num_scheduled_tokens))]
req_split_id = np.argmax(
query_start_loc_np > (total_num_scheduled_tokens // 2))
return [(slice(0, req_split_id),
slice(0, query_start_loc_np[req_split_id])),
(slice(req_split_id, num_reqs),
slice(query_start_loc_np[req_split_id],
total_num_scheduled_tokens))]
return None
def _is_dummy_ubatch(self, ubatch_slice: UBatchSlices) -> bool:
return ubatch_slice[1].start >= ubatch_slice[1].stop
def _prepare_inputs(
self, scheduler_output: "SchedulerOutput"
@ -630,7 +628,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
ubatch_slices: Optional[UBatchSlices] = self._ubatch_split(
self.query_start_loc_np, max_num_scheduled_tokens, scheduler_output)
self.query_start_loc_np, max_num_scheduled_tokens,
scheduler_output)
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
@ -708,6 +707,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
))
for layer_name in kv_cache_group_spec.layer_names:
assert type(attn_metadata) is list
assert attn_metadata_i is not None
# What if it's None? Do we still add it to the list?
attn_metadata[ubid][layer_name] = attn_metadata_i
else:
attn_metadata_i = (
@ -749,7 +750,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
return attn_metadata, logits_indices, spec_decode_metadata, ubatch_slices
return (attn_metadata, logits_indices, spec_decode_metadata,
ubatch_slices)
def _compute_cascade_attn_prefix_len(
self,
@ -1167,8 +1169,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if sync_self:
assert intermediate_tensors is not None
for k, v in intermediate_tensors.items():
is_scattered = "residual" and is_residual_scattered
_copy_slice = copy_slice(is_scattered)
_copy_slice = copy_slice(is_residual_scattered)
self.intermediate_tensors[k][_copy_slice].copy_(
v[_copy_slice], non_blocking=True)
@ -1181,7 +1182,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _get_dummy_model_inputs(self, num_tokens: int) -> tuple:
# Dummy batch. (hopefully we are the last one so we can just
# update this to a one token batch and return)
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
@ -1203,9 +1204,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device))
return input_ids, positions, inputs_embeds, intermediate_tensors
return input_ids, positions, inputs_embeds, intermediate_tensors
def _get_model_inputs(self, tokens_slice: slice,
scheduler_output: "SchedulerOutput"):
@ -1215,7 +1215,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# update this to a one token batch and return)
tokens_slice = slice(tokens_slice.start, tokens_slice.start + 1)
num_tokens = 1
if (self.use_cuda_graph
and num_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
@ -1281,85 +1281,87 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return input_ids, positions, inputs_embeds, intermediate_tensors
def _run_model(self,
attn_metadata: Optional[PerLayerAttnMetadata],
num_scheduled_tokens: Optional[int],
attn_metadata: Optional[PerLayerAttnMetadata],
num_scheduled_tokens: Optional[int],
ubatch_slices: Optional[UBatchSlices] = None,
scheduler_output: Optional["SchedulerOutput"] = None,
is_dummy_run: bool = False):
num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
if use_dummy_input:
assert num_dummy_tokens == 1
return self._get_dummy_model_inputs(num_dummy_tokens)
else:
assert scheduler_output is not None
return self._get_model_inputs(tokens_slice, scheduler_output)
def _run(token_slice: slice, context, use_dummy_input: bool = False):
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(token_slice, use_dummy_input)
with context:
# if isinstance(context, UBatchContext):
# print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape} use_dummy_input {use_dummy_input} token_slice {token_slice}")
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
# if isinstance(context, UBatchContext):
# print(f"Ran ubatch {context.id}putput {model_output.shape}")
if isinstance(context, UBatchContext):
# Clone before we leave the ubatch context
model_output = model_output.clone()
return model_output
@torch.inference_mode()
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results, use_dummy_input):
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
use_dummy_input):
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
if save_results:
results.append(model_output)
def _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run):
results = []
def _run_ubatches(ubatch_slices, attn_metadata,
is_dummy_run) -> torch.Tensor:
results: list[torch.Tensor] = []
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
root_stream = current_stream()
ubatch_ctxs = make_ubatch_contexts(
len(ubatch_slices),
compute_stream=root_stream,
device=self.device)
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
compute_stream=root_stream,
device=self.device)
# Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later
with override_forward_context(None):
ubatch_threads = []
for i, (_, tokens_slice) in enumerate(ubatch_slices):
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run
# print("ubatch", i, "tokens slice", tokens_slice, "is dummy ubatch", is_dummy_ubatch, "is dummy run", is_dummy_run)
num_tokens = num_dummy_tokens if is_dummy_ubatch or is_dummy_run else (tokens_slice.stop - tokens_slice.start)
assert not is_dummy_ubatch or i == len(
ubatch_slices) - 1 or is_dummy_run
num_tokens = num_dummy_tokens if is_dummy_ubatch or \
is_dummy_run else (tokens_slice.stop - tokens_slice.start)
ubatch_ctxs[i].forward_context = create_forward_context(
attn_metadata[i] if attn_metadata is not None else None,
self.vllm_config, num_tokens=num_tokens)
thread = threading.Thread(target=_ubatch_thread, args=(
ubatch_ctxs[i],
tokens_slice,
results,
not is_dummy_ubatch or is_dummy_run,
is_dummy_ubatch or is_dummy_run,
))
attn_metadata[i]
if attn_metadata is not None else None,
self.vllm_config,
num_tokens=num_tokens)
thread = threading.Thread(target=_ubatch_thread,
args=(
ubatch_ctxs[i],
tokens_slice,
results,
not is_dummy_ubatch
or is_dummy_run,
is_dummy_ubatch
or is_dummy_run,
))
ubatch_threads.append(thread)
thread.start()
ubatch_ctxs[0].cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
@ -1369,18 +1371,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# run micro-batched
if ubatch_slices is not None:
model_output = _run_ubatches(
ubatch_slices, attn_metadata, is_dummy_run)
model_output = _run_ubatches(ubatch_slices, attn_metadata,
is_dummy_run)
# run single batch
else:
model_output = _run(
slice(0, num_scheduled_tokens),
set_forward_context(
attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_scheduled_tokens or 1),
slice(0, num_scheduled_tokens),
set_forward_context(attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_scheduled_tokens or 1),
is_dummy_run)
return model_output
@torch.inference_mode()
@ -1583,7 +1584,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
#TODO(sage) make sure this works with mrope
target_positions = self.positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
@ -1609,7 +1611,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_rejected_tokens,
)
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
#TODO(sage) make sure this works with mrope
target_positions = self.positions[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
@ -1647,6 +1650,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
@ -1894,20 +1898,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
))
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
should_microbatch = (
allow_microbatching
and self.vllm_config.parallel_config.enable_microbatching
and self.vllm_config.parallel_config.always_microbatch_if_enabled
)
dummy_microbatches = [(slice(0, 0), slice(0, 0)), (slice(0, 0), slice(0, 0))]
allow_microbatching
and self.vllm_config.parallel_config.enable_microbatching
and self.vllm_config.parallel_config.always_microbatch_if_enabled)
dummy_microbatches = [(slice(0, 0), slice(0, 0)),
(slice(0, 0), slice(0, 0))]
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
outputs = self._run_model(
attn_metadata,
num_tokens,
ubatch_slices=None if not should_microbatch else dummy_microbatches,
ubatch_slices=None
if not should_microbatch else dummy_microbatches,
is_dummy_run=True,
)
if self.use_aux_hidden_state_outputs:

View File

@ -1,36 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
import threading
from typing import Optional
import torch
import torch._dynamo
import torch.profiler as profiler
import os
from typing import Optional
from torch.library import Library
from torch.library import custom_op, register_kernel
from vllm.distributed import (get_dp_group)
from torch.library import custom_op
from vllm.utils import current_stream
from vllm import forward_context
from vllm.utils import current_stream
class UBatchContext:
"""
Context manager for micro-batching synchronization using threading events.
"""
def __init__(self,
id: int,
comm_stream: torch.cuda.Stream,
compute_stream: torch.cuda.Stream,
#fwd_ctx: forward_context.ForwardContext,
cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event,
gpu_comm_done_event: torch.cuda.Event,
gpu_compute_done_event: torch.cuda.Event,
schedule: str = "default"):
def __init__(
self,
id: int,
comm_stream: torch.cuda.Stream,
compute_stream: torch.cuda.Stream,
#fwd_ctx: forward_context.ForwardContext,
cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event,
gpu_comm_done_event: torch.cuda.Event,
gpu_compute_done_event: torch.cuda.Event,
schedule: str = "default"):
self.id = id
self.comm_stream = comm_stream
self.compute_stream = compute_stream
self.original_stream = current_stream()
self.forward_context = None #fwd_ctx
self.forward_context = None #fwd_ctx
self.cpu_wait_event = cpu_wait_event
self.cpu_signal_event = cpu_signal_event
self.current_stream = compute_stream
@ -47,8 +47,7 @@ class UBatchContext:
self.cpu_wait_event.clear()
self._restore_context()
# Assume we start on the compute stream
assert current_stream() == self.compute_stream, \
"Expected to start on the compute stream, but found %s" % current_stream()
assert current_stream() == self.compute_stream
return self
def __exit__(self, exc_type, exc_val, exc_tb):
@ -64,7 +63,7 @@ class UBatchContext:
def _restore_context(self):
forward_context._forward_context = self.forward_context
torch.cuda.set_stream(self.current_stream)
def update_stream(self, stream):
self.current_stream = stream
torch.cuda.set_stream(self.current_stream)
@ -74,10 +73,11 @@ class UBatchContext:
# assert current_stream() == self.current_stream
# assert not self.cpu_wait_event.is_set()
pass
def _signal_comm_done(self):
# self.ctx_valid_state()
self.gpu_comm_done_event.record(self.comm_stream)
def _signal_compute_done(self):
# self.ctx_valid_state()
self.gpu_compute_done_event.record(self.compute_stream)
@ -114,32 +114,38 @@ class UBatchContext:
def yield_and_switch_from_compute_to_comm(self):
assert current_stream() == self.compute_stream
# dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
# dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} "
# f"Yield and switch from {self.stream_string()}", flush=True)
# self.ctx_valid_state()
self._signal_compute_done()
self._cpu_yield()
# self.ctx_valid_state()
assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream)
# print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
# print(f"DP: {dp_rank} UB: {self.id} "
# f"Resuming on stream {self.stream_string()}", flush=True)
self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream
# dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
# dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} "
# f"Yield and switch from {self.stream_string()}", flush=True)
# self.ctx_valid_state()
self._signal_comm_done()
self._cpu_yield()
# self.ctx_valid_state()
assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream)
# print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
# print(f"DP: {dp_rank} UB: {self.id} "
# f"Resuming on stream {self.stream_string()}", flush=True)
self._wait_comm_done()
_CURRENT_CONTEXT: dict = {}
def get_current_ubatch_context() -> Optional[UBatchContext]:
global _CURRENT_CONTEXT
"""
@ -147,57 +153,59 @@ def get_current_ubatch_context() -> Optional[UBatchContext]:
"""
return _CURRENT_CONTEXT.get(threading.get_ident(), None)
def yield_and_switch_from_compute_to_comm_impl(schedule="default"):
# Perform the barrier if a context exists for this thread
ctx = get_current_ubatch_context()
ctx = get_current_ubatch_context()
#print("you are in yield_impl", ctx)
if ctx is not None and ctx.schedule == schedule:
ctx.yield_and_switch_from_compute_to_comm()
def yield_and_switch_from_comm_to_compute_impl(schedule="default"):
# Perform the barrier if a context exists for this thread
ctx = get_current_ubatch_context()
if ctx is not None and ctx.schedule == schedule:
ctx.yield_and_switch_from_comm_to_compute()
# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from
# optimizing it away (TODO: see if this is actually needed)
@custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x",))
def yield_and_switch_from_compute_to_comm(x: torch.Tensor, schedule: str="default") -> None:
@custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x", ))
def yield_and_switch_from_compute_to_comm(x: torch.Tensor,
schedule: str = "default") -> None:
yield_and_switch_from_compute_to_comm_impl(schedule)
# 3) Fake implementation for shape prop and FX tracing
@yield_and_switch_from_compute_to_comm.register_fake
def yield_and_switch_from_compute_to_comm(x: torch.Tensor, schedule: str="default") -> None:
def yield_and_switch_from_compute_to_comm_fake(x: torch.Tensor,
schedule: str = "default"
) -> None:
pass
@custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x",))
def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="default") -> None:
@custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x", ))
def yield_and_switch_from_comm_to_compute(x: torch.Tensor,
schedule: str = "default") -> None:
yield_and_switch_from_comm_to_compute_impl(schedule)
@yield_and_switch_from_comm_to_compute.register_fake
def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="default") -> None:
def yield_and_switch_from_comm_to_compute_fake(x: torch.Tensor,
schedule: str = "default"
) -> None:
pass
def dump_ubatching_state():
pass
# """
# Dump the current UBatchContext state for debugging.
# """
# dp_rank = os.getenv("VLLM_DP_RANK", None)
# for ctx in _CURRENT_CONTEXT.values():
# print(f"UBatchContext: {ctx.id} (dp_rank {dp_rank})\n"
# f" Stream: {ctx.stream}, ({ctx.stream.query()})\n"
# f" Original Stream: {ctx.original_stream}, ({ctx.original_stream.query()})\n"
# f" CPU Wait Event: {ctx.cpu_wait_event}\n"
# f" GPU Wait Event: {ctx.gpu_wait_event} ({ctx.gpu_wait_event.query()})\n"
# f" CPU Signal Event: {ctx.cpu_signal_event}\n"
# f" GPU Signal Event: {ctx.gpu_signal_event} ({ctx.gpu_signal_event.query()})\n")
"""
"""
def make_ubatch_contexts(
num_micro_batches: int,
compute_stream: torch.cuda.Stream,
@ -205,7 +213,6 @@ def make_ubatch_contexts(
schedule: str = "default",
) -> list[UBatchContext]:
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
"""
Create a context manager for micro-batching synchronization.
"""
@ -225,11 +232,11 @@ def make_ubatch_contexts(
compute_stream=compute_stream,
comm_stream=comm_stream,
cpu_wait_event=cpu_events[i],
cpu_signal_event=cpu_events[(i + 1) % num_micro_batches],
cpu_signal_event=cpu_events[(i + 1) %
num_micro_batches],
gpu_comm_done_event=gpu_comm_done_events[i],
gpu_compute_done_event=gpu_compute_done_events[i],
schedule=schedule
)
schedule=schedule)
ctxs.append(ctx)
return ctxs
return ctxs