mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 20:47:06 +08:00
format
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
8ea80fca4a
commit
92e0cc79a8
@ -2,18 +2,16 @@
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Optional, TypeAlias, Union
|
||||
import contextlib
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.utils import current_stream
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadataBuilder)
|
||||
@ -27,7 +25,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tp_group, graph_capture,
|
||||
prepare_communication_buffer_for_model)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context, create_forward_context, override_forward_context
|
||||
from vllm.forward_context import (create_forward_context, get_forward_context,
|
||||
override_forward_context,
|
||||
set_forward_context)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@ -38,7 +38,7 @@ from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
||||
is_pin_memory_available)
|
||||
current_stream, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
@ -59,14 +59,11 @@ from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.ubatching import make_ubatch_contexts, UBatchContext
|
||||
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||
|
||||
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders)
|
||||
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
|
||||
@ -504,40 +501,41 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
|
||||
def _ubatch_split(
|
||||
self,
|
||||
query_start_loc_np: torch.Tensor,
|
||||
max_num_scheduled_tokens: int,
|
||||
scheduler_output: "SchedulerOutput"
|
||||
) -> Optional[UBatchSlices]:
|
||||
self, query_start_loc_np: torch.Tensor,
|
||||
max_num_scheduled_tokens: int,
|
||||
scheduler_output: "SchedulerOutput") -> Optional[UBatchSlices]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
|
||||
|
||||
if self.parallel_config.enable_microbatching and \
|
||||
total_num_scheduled_tokens >= self.parallel_config.microbatching_token_threshold \
|
||||
and max_num_scheduled_tokens == 1:
|
||||
total_num_scheduled_tokens >= \
|
||||
self.parallel_config.microbatching_token_threshold \
|
||||
and max_num_scheduled_tokens == 1:
|
||||
# For pure decode we can just create ubatchs by cutting the request
|
||||
# in half
|
||||
b0_reqs_end = num_reqs // 2
|
||||
b0_tokens_end = total_num_scheduled_tokens // 2
|
||||
assert b0_reqs_end < num_reqs and b0_tokens_end < total_num_scheduled_tokens
|
||||
assert b0_reqs_end < num_reqs and \
|
||||
b0_tokens_end < total_num_scheduled_tokens
|
||||
return [
|
||||
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
|
||||
(slice(b0_reqs_end, num_reqs),
|
||||
slice(b0_tokens_end, total_num_scheduled_tokens)),
|
||||
]
|
||||
|
||||
|
||||
if self.parallel_config.enable_microbatching and \
|
||||
self.parallel_config.always_microbatch_if_enabled:
|
||||
# TODO we can do something more advanced here to try to balance,
|
||||
# i.e. split to the left of `total_num_scheduled_tokens // 2` if it
|
||||
# is more balanced
|
||||
req_split_id = np.argmax(query_start_loc_np > (total_num_scheduled_tokens // 2))
|
||||
return [(slice(0, req_split_id), slice(0, query_start_loc_np[req_split_id])),
|
||||
(slice(req_split_id, num_reqs), slice(query_start_loc_np[req_split_id], total_num_scheduled_tokens))]
|
||||
req_split_id = np.argmax(
|
||||
query_start_loc_np > (total_num_scheduled_tokens // 2))
|
||||
return [(slice(0, req_split_id),
|
||||
slice(0, query_start_loc_np[req_split_id])),
|
||||
(slice(req_split_id, num_reqs),
|
||||
slice(query_start_loc_np[req_split_id],
|
||||
total_num_scheduled_tokens))]
|
||||
return None
|
||||
|
||||
def _is_dummy_ubatch(self, ubatch_slice: UBatchSlices) -> bool:
|
||||
return ubatch_slice[1].start >= ubatch_slice[1].stop
|
||||
|
||||
def _prepare_inputs(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
@ -630,7 +628,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
||||
|
||||
ubatch_slices: Optional[UBatchSlices] = self._ubatch_split(
|
||||
self.query_start_loc_np, max_num_scheduled_tokens, scheduler_output)
|
||||
self.query_start_loc_np, max_num_scheduled_tokens,
|
||||
scheduler_output)
|
||||
|
||||
self.seq_lens_np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
@ -708,6 +707,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
))
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
assert type(attn_metadata) is list
|
||||
assert attn_metadata_i is not None
|
||||
# What if it's None? Do we still add it to the list?
|
||||
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||||
else:
|
||||
attn_metadata_i = (
|
||||
@ -749,7 +750,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
|
||||
return attn_metadata, logits_indices, spec_decode_metadata, ubatch_slices
|
||||
return (attn_metadata, logits_indices, spec_decode_metadata,
|
||||
ubatch_slices)
|
||||
|
||||
def _compute_cascade_attn_prefix_len(
|
||||
self,
|
||||
@ -1167,8 +1169,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if sync_self:
|
||||
assert intermediate_tensors is not None
|
||||
for k, v in intermediate_tensors.items():
|
||||
is_scattered = "residual" and is_residual_scattered
|
||||
_copy_slice = copy_slice(is_scattered)
|
||||
_copy_slice = copy_slice(is_residual_scattered)
|
||||
self.intermediate_tensors[k][_copy_slice].copy_(
|
||||
v[_copy_slice], non_blocking=True)
|
||||
|
||||
@ -1181,7 +1182,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _get_dummy_model_inputs(self, num_tokens: int) -> tuple:
|
||||
# Dummy batch. (hopefully we are the last one so we can just
|
||||
# update this to a one token batch and return)
|
||||
|
||||
|
||||
if self.is_multimodal_model:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||
@ -1203,9 +1204,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
batch_size=self.max_num_tokens,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device))
|
||||
|
||||
return input_ids, positions, inputs_embeds, intermediate_tensors
|
||||
|
||||
return input_ids, positions, inputs_embeds, intermediate_tensors
|
||||
|
||||
def _get_model_inputs(self, tokens_slice: slice,
|
||||
scheduler_output: "SchedulerOutput"):
|
||||
@ -1215,7 +1215,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# update this to a one token batch and return)
|
||||
tokens_slice = slice(tokens_slice.start, tokens_slice.start + 1)
|
||||
num_tokens = 1
|
||||
|
||||
|
||||
if (self.use_cuda_graph
|
||||
and num_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
@ -1281,85 +1281,87 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return input_ids, positions, inputs_embeds, intermediate_tensors
|
||||
|
||||
def _run_model(self,
|
||||
attn_metadata: Optional[PerLayerAttnMetadata],
|
||||
num_scheduled_tokens: Optional[int],
|
||||
attn_metadata: Optional[PerLayerAttnMetadata],
|
||||
num_scheduled_tokens: Optional[int],
|
||||
ubatch_slices: Optional[UBatchSlices] = None,
|
||||
scheduler_output: Optional["SchedulerOutput"] = None,
|
||||
is_dummy_run: bool = False):
|
||||
|
||||
|
||||
num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1
|
||||
|
||||
|
||||
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
|
||||
if use_dummy_input:
|
||||
assert num_dummy_tokens == 1
|
||||
return self._get_dummy_model_inputs(num_dummy_tokens)
|
||||
else:
|
||||
assert scheduler_output is not None
|
||||
return self._get_model_inputs(tokens_slice, scheduler_output)
|
||||
|
||||
|
||||
|
||||
def _run(token_slice: slice, context, use_dummy_input: bool = False):
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(token_slice, use_dummy_input)
|
||||
with context:
|
||||
# if isinstance(context, UBatchContext):
|
||||
# print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape} use_dummy_input {use_dummy_input} token_slice {token_slice}")
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
# if isinstance(context, UBatchContext):
|
||||
# print(f"Ran ubatch {context.id}putput {model_output.shape}")
|
||||
if isinstance(context, UBatchContext):
|
||||
# Clone before we leave the ubatch context
|
||||
model_output = model_output.clone()
|
||||
|
||||
|
||||
return model_output
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results, use_dummy_input):
|
||||
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
|
||||
use_dummy_input):
|
||||
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
|
||||
|
||||
if save_results:
|
||||
results.append(model_output)
|
||||
|
||||
def _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run):
|
||||
results = []
|
||||
def _run_ubatches(ubatch_slices, attn_metadata,
|
||||
is_dummy_run) -> torch.Tensor:
|
||||
results: list[torch.Tensor] = []
|
||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||
root_stream = current_stream()
|
||||
|
||||
ubatch_ctxs = make_ubatch_contexts(
|
||||
len(ubatch_slices),
|
||||
compute_stream=root_stream,
|
||||
device=self.device)
|
||||
|
||||
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
|
||||
compute_stream=root_stream,
|
||||
device=self.device)
|
||||
|
||||
# Ubatches will manually manage the forward context, so we override
|
||||
# it to None here so we can have it restored correctly later
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
|
||||
assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run
|
||||
|
||||
# print("ubatch", i, "tokens slice", tokens_slice, "is dummy ubatch", is_dummy_ubatch, "is dummy run", is_dummy_run)
|
||||
|
||||
num_tokens = num_dummy_tokens if is_dummy_ubatch or is_dummy_run else (tokens_slice.stop - tokens_slice.start)
|
||||
assert not is_dummy_ubatch or i == len(
|
||||
ubatch_slices) - 1 or is_dummy_run
|
||||
|
||||
num_tokens = num_dummy_tokens if is_dummy_ubatch or \
|
||||
is_dummy_run else (tokens_slice.stop - tokens_slice.start)
|
||||
ubatch_ctxs[i].forward_context = create_forward_context(
|
||||
attn_metadata[i] if attn_metadata is not None else None,
|
||||
self.vllm_config, num_tokens=num_tokens)
|
||||
|
||||
thread = threading.Thread(target=_ubatch_thread, args=(
|
||||
ubatch_ctxs[i],
|
||||
tokens_slice,
|
||||
results,
|
||||
not is_dummy_ubatch or is_dummy_run,
|
||||
is_dummy_ubatch or is_dummy_run,
|
||||
))
|
||||
attn_metadata[i]
|
||||
if attn_metadata is not None else None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens)
|
||||
|
||||
thread = threading.Thread(target=_ubatch_thread,
|
||||
args=(
|
||||
ubatch_ctxs[i],
|
||||
tokens_slice,
|
||||
results,
|
||||
not is_dummy_ubatch
|
||||
or is_dummy_run,
|
||||
is_dummy_ubatch
|
||||
or is_dummy_run,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
ubatch_ctxs[0].cpu_wait_event.set()
|
||||
|
||||
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
|
||||
@ -1369,18 +1371,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# run micro-batched
|
||||
if ubatch_slices is not None:
|
||||
model_output = _run_ubatches(
|
||||
ubatch_slices, attn_metadata, is_dummy_run)
|
||||
model_output = _run_ubatches(ubatch_slices, attn_metadata,
|
||||
is_dummy_run)
|
||||
# run single batch
|
||||
else:
|
||||
model_output = _run(
|
||||
slice(0, num_scheduled_tokens),
|
||||
set_forward_context(
|
||||
attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_scheduled_tokens or 1),
|
||||
slice(0, num_scheduled_tokens),
|
||||
set_forward_context(attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_scheduled_tokens or 1),
|
||||
is_dummy_run)
|
||||
|
||||
|
||||
return model_output
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -1583,7 +1584,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
#TODO(sage) make sure this works with mrope
|
||||
target_positions = self.positions[:num_scheduled_tokens]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
||||
@ -1609,7 +1611,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_rejected_tokens,
|
||||
)
|
||||
target_token_ids = self.input_ids[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
#TODO(sage) make sure this works with mrope
|
||||
target_positions = self.positions[token_indices]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
@ -1647,6 +1650,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def kv_connector_no_forward(
|
||||
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
|
||||
# KV send/recv even if no work to do.
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
@ -1894,20 +1898,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
))
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
|
||||
should_microbatch = (
|
||||
allow_microbatching
|
||||
and self.vllm_config.parallel_config.enable_microbatching
|
||||
and self.vllm_config.parallel_config.always_microbatch_if_enabled
|
||||
)
|
||||
dummy_microbatches = [(slice(0, 0), slice(0, 0)), (slice(0, 0), slice(0, 0))]
|
||||
allow_microbatching
|
||||
and self.vllm_config.parallel_config.enable_microbatching
|
||||
and self.vllm_config.parallel_config.always_microbatch_if_enabled)
|
||||
dummy_microbatches = [(slice(0, 0), slice(0, 0)),
|
||||
(slice(0, 0), slice(0, 0))]
|
||||
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
outputs = self._run_model(
|
||||
attn_metadata,
|
||||
num_tokens,
|
||||
ubatch_slices=None if not should_microbatch else dummy_microbatches,
|
||||
ubatch_slices=None
|
||||
if not should_microbatch else dummy_microbatches,
|
||||
is_dummy_run=True,
|
||||
)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
|
||||
@ -1,36 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch.profiler as profiler
|
||||
import os
|
||||
from typing import Optional
|
||||
from torch.library import Library
|
||||
from torch.library import custom_op, register_kernel
|
||||
from vllm.distributed import (get_dp_group)
|
||||
from torch.library import custom_op
|
||||
|
||||
from vllm.utils import current_stream
|
||||
from vllm import forward_context
|
||||
from vllm.utils import current_stream
|
||||
|
||||
|
||||
class UBatchContext:
|
||||
"""
|
||||
Context manager for micro-batching synchronization using threading events.
|
||||
"""
|
||||
def __init__(self,
|
||||
id: int,
|
||||
comm_stream: torch.cuda.Stream,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
#fwd_ctx: forward_context.ForwardContext,
|
||||
cpu_wait_event: threading.Event,
|
||||
cpu_signal_event: threading.Event,
|
||||
gpu_comm_done_event: torch.cuda.Event,
|
||||
gpu_compute_done_event: torch.cuda.Event,
|
||||
schedule: str = "default"):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
comm_stream: torch.cuda.Stream,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
#fwd_ctx: forward_context.ForwardContext,
|
||||
cpu_wait_event: threading.Event,
|
||||
cpu_signal_event: threading.Event,
|
||||
gpu_comm_done_event: torch.cuda.Event,
|
||||
gpu_compute_done_event: torch.cuda.Event,
|
||||
schedule: str = "default"):
|
||||
self.id = id
|
||||
self.comm_stream = comm_stream
|
||||
self.compute_stream = compute_stream
|
||||
self.original_stream = current_stream()
|
||||
self.forward_context = None #fwd_ctx
|
||||
self.forward_context = None #fwd_ctx
|
||||
self.cpu_wait_event = cpu_wait_event
|
||||
self.cpu_signal_event = cpu_signal_event
|
||||
self.current_stream = compute_stream
|
||||
@ -47,8 +47,7 @@ class UBatchContext:
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
# Assume we start on the compute stream
|
||||
assert current_stream() == self.compute_stream, \
|
||||
"Expected to start on the compute stream, but found %s" % current_stream()
|
||||
assert current_stream() == self.compute_stream
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
@ -64,7 +63,7 @@ class UBatchContext:
|
||||
def _restore_context(self):
|
||||
forward_context._forward_context = self.forward_context
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
|
||||
|
||||
def update_stream(self, stream):
|
||||
self.current_stream = stream
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
@ -74,10 +73,11 @@ class UBatchContext:
|
||||
# assert current_stream() == self.current_stream
|
||||
# assert not self.cpu_wait_event.is_set()
|
||||
pass
|
||||
|
||||
def _signal_comm_done(self):
|
||||
# self.ctx_valid_state()
|
||||
self.gpu_comm_done_event.record(self.comm_stream)
|
||||
|
||||
|
||||
def _signal_compute_done(self):
|
||||
# self.ctx_valid_state()
|
||||
self.gpu_compute_done_event.record(self.compute_stream)
|
||||
@ -114,32 +114,38 @@ class UBatchContext:
|
||||
|
||||
def yield_and_switch_from_compute_to_comm(self):
|
||||
assert current_stream() == self.compute_stream
|
||||
# dp_rank = get_dp_group().rank_in_group
|
||||
# print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
||||
# dp_rank = get_dp_group().rank_in_group
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
# f"Yield and switch from {self.stream_string()}", flush=True)
|
||||
# self.ctx_valid_state()
|
||||
self._signal_compute_done()
|
||||
self._cpu_yield()
|
||||
# self.ctx_valid_state()
|
||||
assert self.current_stream == self.compute_stream
|
||||
self.update_stream(self.comm_stream)
|
||||
# print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
# f"Resuming on stream {self.stream_string()}", flush=True)
|
||||
self._wait_compute_done()
|
||||
|
||||
def yield_and_switch_from_comm_to_compute(self):
|
||||
assert current_stream() == self.comm_stream
|
||||
# dp_rank = get_dp_group().rank_in_group
|
||||
# print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
||||
# dp_rank = get_dp_group().rank_in_group
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
# f"Yield and switch from {self.stream_string()}", flush=True)
|
||||
# self.ctx_valid_state()
|
||||
self._signal_comm_done()
|
||||
self._cpu_yield()
|
||||
# self.ctx_valid_state()
|
||||
assert self.current_stream == self.comm_stream
|
||||
self.update_stream(self.compute_stream)
|
||||
# print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
# f"Resuming on stream {self.stream_string()}", flush=True)
|
||||
self._wait_comm_done()
|
||||
|
||||
|
||||
|
||||
_CURRENT_CONTEXT: dict = {}
|
||||
|
||||
|
||||
def get_current_ubatch_context() -> Optional[UBatchContext]:
|
||||
global _CURRENT_CONTEXT
|
||||
"""
|
||||
@ -147,57 +153,59 @@ def get_current_ubatch_context() -> Optional[UBatchContext]:
|
||||
"""
|
||||
return _CURRENT_CONTEXT.get(threading.get_ident(), None)
|
||||
|
||||
|
||||
def yield_and_switch_from_compute_to_comm_impl(schedule="default"):
|
||||
# Perform the barrier if a context exists for this thread
|
||||
ctx = get_current_ubatch_context()
|
||||
ctx = get_current_ubatch_context()
|
||||
#print("you are in yield_impl", ctx)
|
||||
if ctx is not None and ctx.schedule == schedule:
|
||||
ctx.yield_and_switch_from_compute_to_comm()
|
||||
|
||||
|
||||
def yield_and_switch_from_comm_to_compute_impl(schedule="default"):
|
||||
# Perform the barrier if a context exists for this thread
|
||||
ctx = get_current_ubatch_context()
|
||||
if ctx is not None and ctx.schedule == schedule:
|
||||
ctx.yield_and_switch_from_comm_to_compute()
|
||||
|
||||
|
||||
# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from
|
||||
# optimizing it away (TODO: see if this is actually needed)
|
||||
@custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x",))
|
||||
def yield_and_switch_from_compute_to_comm(x: torch.Tensor, schedule: str="default") -> None:
|
||||
@custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x", ))
|
||||
def yield_and_switch_from_compute_to_comm(x: torch.Tensor,
|
||||
schedule: str = "default") -> None:
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule)
|
||||
|
||||
|
||||
# 3) Fake implementation for shape prop and FX tracing
|
||||
@yield_and_switch_from_compute_to_comm.register_fake
|
||||
def yield_and_switch_from_compute_to_comm(x: torch.Tensor, schedule: str="default") -> None:
|
||||
def yield_and_switch_from_compute_to_comm_fake(x: torch.Tensor,
|
||||
schedule: str = "default"
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x",))
|
||||
def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="default") -> None:
|
||||
|
||||
@custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x", ))
|
||||
def yield_and_switch_from_comm_to_compute(x: torch.Tensor,
|
||||
schedule: str = "default") -> None:
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule)
|
||||
|
||||
|
||||
@yield_and_switch_from_comm_to_compute.register_fake
|
||||
def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="default") -> None:
|
||||
def yield_and_switch_from_comm_to_compute_fake(x: torch.Tensor,
|
||||
schedule: str = "default"
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def dump_ubatching_state():
|
||||
pass
|
||||
# """
|
||||
# Dump the current UBatchContext state for debugging.
|
||||
# """
|
||||
|
||||
# dp_rank = os.getenv("VLLM_DP_RANK", None)
|
||||
|
||||
# for ctx in _CURRENT_CONTEXT.values():
|
||||
# print(f"UBatchContext: {ctx.id} (dp_rank {dp_rank})\n"
|
||||
# f" Stream: {ctx.stream}, ({ctx.stream.query()})\n"
|
||||
# f" Original Stream: {ctx.original_stream}, ({ctx.original_stream.query()})\n"
|
||||
# f" CPU Wait Event: {ctx.cpu_wait_event}\n"
|
||||
# f" GPU Wait Event: {ctx.gpu_wait_event} ({ctx.gpu_wait_event.query()})\n"
|
||||
# f" CPU Signal Event: {ctx.cpu_signal_event}\n"
|
||||
# f" GPU Signal Event: {ctx.gpu_signal_event} ({ctx.gpu_signal_event.query()})\n")
|
||||
|
||||
|
||||
"""
|
||||
"""
|
||||
|
||||
|
||||
def make_ubatch_contexts(
|
||||
num_micro_batches: int,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
@ -205,7 +213,6 @@ def make_ubatch_contexts(
|
||||
schedule: str = "default",
|
||||
) -> list[UBatchContext]:
|
||||
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
||||
|
||||
"""
|
||||
Create a context manager for micro-batching synchronization.
|
||||
"""
|
||||
@ -225,11 +232,11 @@ def make_ubatch_contexts(
|
||||
compute_stream=compute_stream,
|
||||
comm_stream=comm_stream,
|
||||
cpu_wait_event=cpu_events[i],
|
||||
cpu_signal_event=cpu_events[(i + 1) % num_micro_batches],
|
||||
cpu_signal_event=cpu_events[(i + 1) %
|
||||
num_micro_batches],
|
||||
gpu_comm_done_event=gpu_comm_done_events[i],
|
||||
gpu_compute_done_event=gpu_compute_done_events[i],
|
||||
schedule=schedule
|
||||
)
|
||||
schedule=schedule)
|
||||
ctxs.append(ctx)
|
||||
|
||||
return ctxs
|
||||
return ctxs
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user