Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-02 19:04:26 +00:00
parent 8ea80fca4a
commit 92e0cc79a8
2 changed files with 153 additions and 141 deletions

View File

@ -2,18 +2,16 @@
import copy import copy
import gc import gc
import os import threading
import time import time
import weakref import weakref
from typing import TYPE_CHECKING, Optional, TypeAlias, Union from typing import TYPE_CHECKING, Optional, TypeAlias, Union
import contextlib
import numpy as np import numpy as np
import torch import torch
import torch.distributed import torch.distributed
import torch.nn as nn import torch.nn as nn
from vllm.utils import current_stream
from vllm.attention import AttentionType, get_attn_backend from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadataBuilder) AttentionMetadataBuilder)
@ -27,7 +25,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture, get_pp_group, get_tp_group, graph_capture,
prepare_communication_buffer_for_model) prepare_communication_buffer_for_model)
from vllm.forward_context import get_forward_context, set_forward_context, create_forward_context, override_forward_context from vllm.forward_context import (create_forward_context, get_forward_context,
override_forward_context,
set_forward_context)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
@ -38,7 +38,7 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, cdiv, check_use_alibi, GiB_bytes, LazyLoader, cdiv, check_use_alibi,
is_pin_memory_available) current_stream, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@ -59,14 +59,11 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.ubatching import make_ubatch_contexts, UBatchContext from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders) scatter_mm_placeholders)
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
if TYPE_CHECKING: if TYPE_CHECKING:
import xgrammar as xgr import xgrammar as xgr
@ -504,22 +501,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.refresh_sampling_metadata() self.input_batch.refresh_sampling_metadata()
def _ubatch_split( def _ubatch_split(
self, self, query_start_loc_np: torch.Tensor,
query_start_loc_np: torch.Tensor, max_num_scheduled_tokens: int,
max_num_scheduled_tokens: int, scheduler_output: "SchedulerOutput") -> Optional[UBatchSlices]:
scheduler_output: "SchedulerOutput"
) -> Optional[UBatchSlices]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
if self.parallel_config.enable_microbatching and \ if self.parallel_config.enable_microbatching and \
total_num_scheduled_tokens >= self.parallel_config.microbatching_token_threshold \ total_num_scheduled_tokens >= \
and max_num_scheduled_tokens == 1: self.parallel_config.microbatching_token_threshold \
and max_num_scheduled_tokens == 1:
# For pure decode we can just create ubatchs by cutting the request # For pure decode we can just create ubatchs by cutting the request
# in half # in half
b0_reqs_end = num_reqs // 2 b0_reqs_end = num_reqs // 2
b0_tokens_end = total_num_scheduled_tokens // 2 b0_tokens_end = total_num_scheduled_tokens // 2
assert b0_reqs_end < num_reqs and b0_tokens_end < total_num_scheduled_tokens assert b0_reqs_end < num_reqs and \
b0_tokens_end < total_num_scheduled_tokens
return [ return [
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)), (slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
(slice(b0_reqs_end, num_reqs), (slice(b0_reqs_end, num_reqs),
@ -531,14 +528,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO we can do something more advanced here to try to balance, # TODO we can do something more advanced here to try to balance,
# i.e. split to the left of `total_num_scheduled_tokens // 2` if it # i.e. split to the left of `total_num_scheduled_tokens // 2` if it
# is more balanced # is more balanced
req_split_id = np.argmax(query_start_loc_np > (total_num_scheduled_tokens // 2)) req_split_id = np.argmax(
return [(slice(0, req_split_id), slice(0, query_start_loc_np[req_split_id])), query_start_loc_np > (total_num_scheduled_tokens // 2))
(slice(req_split_id, num_reqs), slice(query_start_loc_np[req_split_id], total_num_scheduled_tokens))] return [(slice(0, req_split_id),
slice(0, query_start_loc_np[req_split_id])),
(slice(req_split_id, num_reqs),
slice(query_start_loc_np[req_split_id],
total_num_scheduled_tokens))]
return None return None
def _is_dummy_ubatch(self, ubatch_slice: UBatchSlices) -> bool:
return ubatch_slice[1].start >= ubatch_slice[1].stop
def _prepare_inputs( def _prepare_inputs(
self, scheduler_output: "SchedulerOutput" self, scheduler_output: "SchedulerOutput"
) -> tuple[PerLayerAttnMetadata, torch.Tensor, ) -> tuple[PerLayerAttnMetadata, torch.Tensor,
@ -630,7 +628,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
ubatch_slices: Optional[UBatchSlices] = self._ubatch_split( ubatch_slices: Optional[UBatchSlices] = self._ubatch_split(
self.query_start_loc_np, max_num_scheduled_tokens, scheduler_output) self.query_start_loc_np, max_num_scheduled_tokens,
scheduler_output)
self.seq_lens_np[:num_reqs] = ( self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] + self.input_batch.num_computed_tokens_cpu[:num_reqs] +
@ -708,6 +707,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)) ))
for layer_name in kv_cache_group_spec.layer_names: for layer_name in kv_cache_group_spec.layer_names:
assert type(attn_metadata) is list assert type(attn_metadata) is list
assert attn_metadata_i is not None
# What if it's None? Do we still add it to the list?
attn_metadata[ubid][layer_name] = attn_metadata_i attn_metadata[ubid][layer_name] = attn_metadata_i
else: else:
attn_metadata_i = ( attn_metadata_i = (
@ -749,7 +750,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.lora_config: if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens) self.set_active_loras(self.input_batch, num_scheduled_tokens)
return attn_metadata, logits_indices, spec_decode_metadata, ubatch_slices return (attn_metadata, logits_indices, spec_decode_metadata,
ubatch_slices)
def _compute_cascade_attn_prefix_len( def _compute_cascade_attn_prefix_len(
self, self,
@ -1167,8 +1169,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if sync_self: if sync_self:
assert intermediate_tensors is not None assert intermediate_tensors is not None
for k, v in intermediate_tensors.items(): for k, v in intermediate_tensors.items():
is_scattered = "residual" and is_residual_scattered _copy_slice = copy_slice(is_residual_scattered)
_copy_slice = copy_slice(is_scattered)
self.intermediate_tensors[k][_copy_slice].copy_( self.intermediate_tensors[k][_copy_slice].copy_(
v[_copy_slice], non_blocking=True) v[_copy_slice], non_blocking=True)
@ -1206,7 +1207,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return input_ids, positions, inputs_embeds, intermediate_tensors return input_ids, positions, inputs_embeds, intermediate_tensors
def _get_model_inputs(self, tokens_slice: slice, def _get_model_inputs(self, tokens_slice: slice,
scheduler_output: "SchedulerOutput"): scheduler_output: "SchedulerOutput"):
num_tokens = tokens_slice.stop - tokens_slice.start num_tokens = tokens_slice.stop - tokens_slice.start
@ -1291,26 +1291,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple: def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
if use_dummy_input: if use_dummy_input:
assert num_dummy_tokens == 1
return self._get_dummy_model_inputs(num_dummy_tokens) return self._get_dummy_model_inputs(num_dummy_tokens)
else: else:
assert scheduler_output is not None assert scheduler_output is not None
return self._get_model_inputs(tokens_slice, scheduler_output) return self._get_model_inputs(tokens_slice, scheduler_output)
def _run(token_slice: slice, context, use_dummy_input: bool = False): def _run(token_slice: slice, context, use_dummy_input: bool = False):
input_ids, positions, inputs_embeds, intermediate_tensors = \ input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(token_slice, use_dummy_input) model_inputs(token_slice, use_dummy_input)
with context: with context:
# if isinstance(context, UBatchContext):
# print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape} use_dummy_input {use_dummy_input} token_slice {token_slice}")
model_output = self.model( model_output = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
# if isinstance(context, UBatchContext):
# print(f"Ran ubatch {context.id}putput {model_output.shape}")
if isinstance(context, UBatchContext): if isinstance(context, UBatchContext):
# Clone before we leave the ubatch context # Clone before we leave the ubatch context
model_output = model_output.clone() model_output = model_output.clone()
@ -1318,21 +1314,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return model_output return model_output
@torch.inference_mode() @torch.inference_mode()
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results, use_dummy_input): def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
use_dummy_input):
model_output = _run(token_slice, ubatch_ctx, use_dummy_input) model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
if save_results: if save_results:
results.append(model_output) results.append(model_output)
def _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run): def _run_ubatches(ubatch_slices, attn_metadata,
results = [] is_dummy_run) -> torch.Tensor:
results: list[torch.Tensor] = []
assert len(ubatch_slices) == 2, "Only two ubatches has been tested" assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
root_stream = current_stream() root_stream = current_stream()
ubatch_ctxs = make_ubatch_contexts( ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
len(ubatch_slices), compute_stream=root_stream,
compute_stream=root_stream, device=self.device)
device=self.device)
# Ubatches will manually manage the forward context, so we override # Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later # it to None here so we can have it restored correctly later
@ -1340,22 +1337,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
ubatch_threads = [] ubatch_threads = []
for i, (_, tokens_slice) in enumerate(ubatch_slices): for i, (_, tokens_slice) in enumerate(ubatch_slices):
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run assert not is_dummy_ubatch or i == len(
ubatch_slices) - 1 or is_dummy_run
# print("ubatch", i, "tokens slice", tokens_slice, "is dummy ubatch", is_dummy_ubatch, "is dummy run", is_dummy_run) num_tokens = num_dummy_tokens if is_dummy_ubatch or \
is_dummy_run else (tokens_slice.stop - tokens_slice.start)
num_tokens = num_dummy_tokens if is_dummy_ubatch or is_dummy_run else (tokens_slice.stop - tokens_slice.start)
ubatch_ctxs[i].forward_context = create_forward_context( ubatch_ctxs[i].forward_context = create_forward_context(
attn_metadata[i] if attn_metadata is not None else None, attn_metadata[i]
self.vllm_config, num_tokens=num_tokens) if attn_metadata is not None else None,
self.vllm_config,
num_tokens=num_tokens)
thread = threading.Thread(target=_ubatch_thread, args=( thread = threading.Thread(target=_ubatch_thread,
ubatch_ctxs[i], args=(
tokens_slice, ubatch_ctxs[i],
results, tokens_slice,
not is_dummy_ubatch or is_dummy_run, results,
is_dummy_ubatch or is_dummy_run, not is_dummy_ubatch
)) or is_dummy_run,
is_dummy_ubatch
or is_dummy_run,
))
ubatch_threads.append(thread) ubatch_threads.append(thread)
thread.start() thread.start()
ubatch_ctxs[0].cpu_wait_event.set() ubatch_ctxs[0].cpu_wait_event.set()
@ -1369,16 +1371,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# run micro-batched # run micro-batched
if ubatch_slices is not None: if ubatch_slices is not None:
model_output = _run_ubatches( model_output = _run_ubatches(ubatch_slices, attn_metadata,
ubatch_slices, attn_metadata, is_dummy_run) is_dummy_run)
# run single batch # run single batch
else: else:
model_output = _run( model_output = _run(
slice(0, num_scheduled_tokens), slice(0, num_scheduled_tokens),
set_forward_context( set_forward_context(attn_metadata,
attn_metadata, vllm_config=self.vllm_config,
vllm_config=self.vllm_config, num_tokens=num_scheduled_tokens or 1),
num_tokens=num_scheduled_tokens or 1),
is_dummy_run) is_dummy_run)
return model_output return model_output
@ -1583,7 +1584,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if spec_decode_metadata is None: if spec_decode_metadata is None:
# input_ids can be None for multimodal models. # input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens] target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens] #TODO(sage) make sure this works with mrope
target_positions = self.positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat( target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states], [h[:num_scheduled_tokens] for h in aux_hidden_states],
@ -1609,7 +1611,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_rejected_tokens, num_rejected_tokens,
) )
target_token_ids = self.input_ids[token_indices] target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices] #TODO(sage) make sure this works with mrope
target_positions = self.positions[token_indices]
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat( target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1) [h[token_indices] for h in aux_hidden_states], dim=-1)
@ -1647,6 +1650,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def kv_connector_no_forward( def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
# KV send/recv even if no work to do. # KV send/recv even if no work to do.
with set_forward_context(None, self.vllm_config): with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output) self.maybe_setup_kv_connector(scheduler_output)
@ -1898,16 +1902,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
should_microbatch = ( should_microbatch = (
allow_microbatching allow_microbatching
and self.vllm_config.parallel_config.enable_microbatching and self.vllm_config.parallel_config.enable_microbatching
and self.vllm_config.parallel_config.always_microbatch_if_enabled and self.vllm_config.parallel_config.always_microbatch_if_enabled)
) dummy_microbatches = [(slice(0, 0), slice(0, 0)),
dummy_microbatches = [(slice(0, 0), slice(0, 0)), (slice(0, 0), slice(0, 0))] (slice(0, 0), slice(0, 0))]
with self.maybe_dummy_run_with_lora(self.lora_config, with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens): num_scheduled_tokens):
outputs = self._run_model( outputs = self._run_model(
attn_metadata, attn_metadata,
num_tokens, num_tokens,
ubatch_slices=None if not should_microbatch else dummy_microbatches, ubatch_slices=None
if not should_microbatch else dummy_microbatches,
is_dummy_run=True, is_dummy_run=True,
) )
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:

View File

@ -1,36 +1,36 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import threading import threading
from typing import Optional
import torch import torch
import torch._dynamo import torch._dynamo
import torch.profiler as profiler from torch.library import custom_op
import os
from typing import Optional
from torch.library import Library
from torch.library import custom_op, register_kernel
from vllm.distributed import (get_dp_group)
from vllm.utils import current_stream
from vllm import forward_context from vllm import forward_context
from vllm.utils import current_stream
class UBatchContext: class UBatchContext:
""" """
Context manager for micro-batching synchronization using threading events. Context manager for micro-batching synchronization using threading events.
""" """
def __init__(self,
id: int, def __init__(
comm_stream: torch.cuda.Stream, self,
compute_stream: torch.cuda.Stream, id: int,
#fwd_ctx: forward_context.ForwardContext, comm_stream: torch.cuda.Stream,
cpu_wait_event: threading.Event, compute_stream: torch.cuda.Stream,
cpu_signal_event: threading.Event, #fwd_ctx: forward_context.ForwardContext,
gpu_comm_done_event: torch.cuda.Event, cpu_wait_event: threading.Event,
gpu_compute_done_event: torch.cuda.Event, cpu_signal_event: threading.Event,
schedule: str = "default"): gpu_comm_done_event: torch.cuda.Event,
gpu_compute_done_event: torch.cuda.Event,
schedule: str = "default"):
self.id = id self.id = id
self.comm_stream = comm_stream self.comm_stream = comm_stream
self.compute_stream = compute_stream self.compute_stream = compute_stream
self.original_stream = current_stream() self.original_stream = current_stream()
self.forward_context = None #fwd_ctx self.forward_context = None #fwd_ctx
self.cpu_wait_event = cpu_wait_event self.cpu_wait_event = cpu_wait_event
self.cpu_signal_event = cpu_signal_event self.cpu_signal_event = cpu_signal_event
self.current_stream = compute_stream self.current_stream = compute_stream
@ -47,8 +47,7 @@ class UBatchContext:
self.cpu_wait_event.clear() self.cpu_wait_event.clear()
self._restore_context() self._restore_context()
# Assume we start on the compute stream # Assume we start on the compute stream
assert current_stream() == self.compute_stream, \ assert current_stream() == self.compute_stream
"Expected to start on the compute stream, but found %s" % current_stream()
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
@ -74,6 +73,7 @@ class UBatchContext:
# assert current_stream() == self.current_stream # assert current_stream() == self.current_stream
# assert not self.cpu_wait_event.is_set() # assert not self.cpu_wait_event.is_set()
pass pass
def _signal_comm_done(self): def _signal_comm_done(self):
# self.ctx_valid_state() # self.ctx_valid_state()
self.gpu_comm_done_event.record(self.comm_stream) self.gpu_comm_done_event.record(self.comm_stream)
@ -115,31 +115,37 @@ class UBatchContext:
def yield_and_switch_from_compute_to_comm(self): def yield_and_switch_from_compute_to_comm(self):
assert current_stream() == self.compute_stream assert current_stream() == self.compute_stream
# dp_rank = get_dp_group().rank_in_group # dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True) # print(f"DP: {dp_rank} UB: {self.id} "
# f"Yield and switch from {self.stream_string()}", flush=True)
# self.ctx_valid_state() # self.ctx_valid_state()
self._signal_compute_done() self._signal_compute_done()
self._cpu_yield() self._cpu_yield()
# self.ctx_valid_state() # self.ctx_valid_state()
assert self.current_stream == self.compute_stream assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream) self.update_stream(self.comm_stream)
# print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True) # print(f"DP: {dp_rank} UB: {self.id} "
# f"Resuming on stream {self.stream_string()}", flush=True)
self._wait_compute_done() self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self): def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream assert current_stream() == self.comm_stream
# dp_rank = get_dp_group().rank_in_group # dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True) # print(f"DP: {dp_rank} UB: {self.id} "
# f"Yield and switch from {self.stream_string()}", flush=True)
# self.ctx_valid_state() # self.ctx_valid_state()
self._signal_comm_done() self._signal_comm_done()
self._cpu_yield() self._cpu_yield()
# self.ctx_valid_state() # self.ctx_valid_state()
assert self.current_stream == self.comm_stream assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream) self.update_stream(self.compute_stream)
# print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True) # print(f"DP: {dp_rank} UB: {self.id} "
# f"Resuming on stream {self.stream_string()}", flush=True)
self._wait_comm_done() self._wait_comm_done()
_CURRENT_CONTEXT: dict = {} _CURRENT_CONTEXT: dict = {}
def get_current_ubatch_context() -> Optional[UBatchContext]: def get_current_ubatch_context() -> Optional[UBatchContext]:
global _CURRENT_CONTEXT global _CURRENT_CONTEXT
""" """
@ -147,6 +153,7 @@ def get_current_ubatch_context() -> Optional[UBatchContext]:
""" """
return _CURRENT_CONTEXT.get(threading.get_ident(), None) return _CURRENT_CONTEXT.get(threading.get_ident(), None)
def yield_and_switch_from_compute_to_comm_impl(schedule="default"): def yield_and_switch_from_compute_to_comm_impl(schedule="default"):
# Perform the barrier if a context exists for this thread # Perform the barrier if a context exists for this thread
ctx = get_current_ubatch_context() ctx = get_current_ubatch_context()
@ -154,50 +161,51 @@ def yield_and_switch_from_compute_to_comm_impl(schedule="default"):
if ctx is not None and ctx.schedule == schedule: if ctx is not None and ctx.schedule == schedule:
ctx.yield_and_switch_from_compute_to_comm() ctx.yield_and_switch_from_compute_to_comm()
def yield_and_switch_from_comm_to_compute_impl(schedule="default"): def yield_and_switch_from_comm_to_compute_impl(schedule="default"):
# Perform the barrier if a context exists for this thread # Perform the barrier if a context exists for this thread
ctx = get_current_ubatch_context() ctx = get_current_ubatch_context()
if ctx is not None and ctx.schedule == schedule: if ctx is not None and ctx.schedule == schedule:
ctx.yield_and_switch_from_comm_to_compute() ctx.yield_and_switch_from_comm_to_compute()
# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from # 2) Register kernel for CUDA, mark as mutating to prevent the compiler from
# optimizing it away (TODO: see if this is actually needed) # optimizing it away (TODO: see if this is actually needed)
@custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x",)) @custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x", ))
def yield_and_switch_from_compute_to_comm(x: torch.Tensor, schedule: str="default") -> None: def yield_and_switch_from_compute_to_comm(x: torch.Tensor,
schedule: str = "default") -> None:
yield_and_switch_from_compute_to_comm_impl(schedule) yield_and_switch_from_compute_to_comm_impl(schedule)
# 3) Fake implementation for shape prop and FX tracing # 3) Fake implementation for shape prop and FX tracing
@yield_and_switch_from_compute_to_comm.register_fake @yield_and_switch_from_compute_to_comm.register_fake
def yield_and_switch_from_compute_to_comm(x: torch.Tensor, schedule: str="default") -> None: def yield_and_switch_from_compute_to_comm_fake(x: torch.Tensor,
schedule: str = "default"
) -> None:
pass pass
@custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x",))
def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="default") -> None: @custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x", ))
def yield_and_switch_from_comm_to_compute(x: torch.Tensor,
schedule: str = "default") -> None:
yield_and_switch_from_comm_to_compute_impl(schedule) yield_and_switch_from_comm_to_compute_impl(schedule)
@yield_and_switch_from_comm_to_compute.register_fake @yield_and_switch_from_comm_to_compute.register_fake
def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="default") -> None: def yield_and_switch_from_comm_to_compute_fake(x: torch.Tensor,
schedule: str = "default"
) -> None:
pass pass
def dump_ubatching_state(): def dump_ubatching_state():
pass pass
# """
# Dump the current UBatchContext state for debugging.
# """
# dp_rank = os.getenv("VLLM_DP_RANK", None)
# for ctx in _CURRENT_CONTEXT.values():
# print(f"UBatchContext: {ctx.id} (dp_rank {dp_rank})\n"
# f" Stream: {ctx.stream}, ({ctx.stream.query()})\n"
# f" Original Stream: {ctx.original_stream}, ({ctx.original_stream.query()})\n"
# f" CPU Wait Event: {ctx.cpu_wait_event}\n"
# f" GPU Wait Event: {ctx.gpu_wait_event} ({ctx.gpu_wait_event.query()})\n"
# f" CPU Signal Event: {ctx.cpu_signal_event}\n"
# f" GPU Signal Event: {ctx.gpu_signal_event} ({ctx.gpu_signal_event.query()})\n")
""" """
""" """
def make_ubatch_contexts( def make_ubatch_contexts(
num_micro_batches: int, num_micro_batches: int,
compute_stream: torch.cuda.Stream, compute_stream: torch.cuda.Stream,
@ -205,7 +213,6 @@ def make_ubatch_contexts(
schedule: str = "default", schedule: str = "default",
) -> list[UBatchContext]: ) -> list[UBatchContext]:
assert num_micro_batches == 2, "only been tested with 2 micro-batches" assert num_micro_batches == 2, "only been tested with 2 micro-batches"
""" """
Create a context manager for micro-batching synchronization. Create a context manager for micro-batching synchronization.
""" """
@ -225,11 +232,11 @@ def make_ubatch_contexts(
compute_stream=compute_stream, compute_stream=compute_stream,
comm_stream=comm_stream, comm_stream=comm_stream,
cpu_wait_event=cpu_events[i], cpu_wait_event=cpu_events[i],
cpu_signal_event=cpu_events[(i + 1) % num_micro_batches], cpu_signal_event=cpu_events[(i + 1) %
num_micro_batches],
gpu_comm_done_event=gpu_comm_done_events[i], gpu_comm_done_event=gpu_comm_done_events[i],
gpu_compute_done_event=gpu_compute_done_events[i], gpu_compute_done_event=gpu_compute_done_events[i],
schedule=schedule schedule=schedule)
)
ctxs.append(ctx) ctxs.append(ctx)
return ctxs return ctxs