mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 09:07:04 +08:00
tp1 working multistream tp > 1 broken
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
2259b47951
commit
9c60a6299d
@ -40,7 +40,7 @@ def main():
|
||||
max_model_len=1024,
|
||||
#load_format="dummy",
|
||||
###############
|
||||
tensor_parallel_size=2,
|
||||
tensor_parallel_size=1,
|
||||
#data_parallel_size=2,
|
||||
enable_expert_parallel=False,
|
||||
###############
|
||||
|
||||
@ -58,20 +58,11 @@ def get_forward_context() -> ForwardContext:
|
||||
"Please use `set_forward_context` to set the forward context.")
|
||||
return _forward_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_forward_context(attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: int = 0):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
Here we can inject common logic for every model forward pass.
|
||||
"""
|
||||
global forward_start_time
|
||||
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
||||
if need_to_track_batchsize:
|
||||
forward_start_time = time.perf_counter()
|
||||
def create_forward_context(attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: int = 0
|
||||
):
|
||||
dp_metadata: Optional[DPMetadata] = None
|
||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
@ -96,17 +87,49 @@ def set_forward_context(attn_metadata: Any,
|
||||
dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
|
||||
cu_tokens_across_dp_cpu)
|
||||
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
_forward_context = ForwardContext(
|
||||
return ForwardContext(
|
||||
no_compile_layers=vllm_config.compilation_config.
|
||||
static_forward_context,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata,
|
||||
dp_metadata=dp_metadata)
|
||||
|
||||
@contextmanager
|
||||
def override_forward_context(forward_context: Optional[ForwardContext]):
|
||||
"""A context manager that overrides the current forward context.
|
||||
This is used to override the forward context for a specific
|
||||
forward pass.
|
||||
"""
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
print("overriding forward context with", forward_context)
|
||||
_forward_context = forward_context
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_forward_context = prev_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_forward_context(attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: int = 0):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
Here we can inject common logic for every model forward pass.
|
||||
"""
|
||||
global forward_start_time
|
||||
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
||||
if need_to_track_batchsize:
|
||||
forward_start_time = time.perf_counter()
|
||||
|
||||
forward_context = create_forward_context(
|
||||
attn_metadata, vllm_config, virtual_engine, num_tokens)
|
||||
|
||||
try:
|
||||
with override_forward_context(forward_context):
|
||||
yield
|
||||
finally:
|
||||
global last_logging_time, batchsize_logging_interval
|
||||
if need_to_track_batchsize:
|
||||
@ -140,5 +163,3 @@ def set_forward_context(attn_metadata: Any,
|
||||
logger.info(("Batchsize forward time stats "
|
||||
"(batchsize, count, median_time(ms)): %s"),
|
||||
forward_stats)
|
||||
|
||||
_forward_context = prev_context
|
||||
|
||||
@ -6,6 +6,7 @@ import os
|
||||
import time
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Optional, TypeAlias, Union
|
||||
import contextlib
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -26,7 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tp_group, graph_capture,
|
||||
prepare_communication_buffer_for_model)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.forward_context import get_forward_context, set_forward_context, create_forward_context, override_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@ -1289,93 +1290,93 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
|
||||
if use_dummy_input:
|
||||
num_tokens = num_scheduled_tokens or 1
|
||||
return num_tokens, *self._get_dummy_model_inputs(num_tokens)
|
||||
return self._get_dummy_model_inputs(num_tokens)
|
||||
else:
|
||||
assert scheduler_output is not None
|
||||
num_tokens = tokens_slice.stop - tokens_slice.start
|
||||
return num_tokens, *self._get_model_inputs(tokens_slice, scheduler_output)
|
||||
return self._get_model_inputs(tokens_slice, scheduler_output)
|
||||
|
||||
|
||||
def _run(token_slice: slice, attn_metadata, use_dummy_input: bool = False,
|
||||
ubatch_context: Optional[UBatchContext]=None,
|
||||
setup_done_evt: Optional[threading.Event]=None):
|
||||
num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
def _run(token_slice: slice, context, use_dummy_input: bool = False):
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(token_slice, use_dummy_input)
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
if ubatch_context:
|
||||
# Update the forward context now that its available
|
||||
ubatch_context.update_forward_context()
|
||||
|
||||
if setup_done_evt is not None:
|
||||
# Wait for the setup to be done
|
||||
setup_done_evt.set()
|
||||
|
||||
with context:
|
||||
if isinstance(context, UBatchContext):
|
||||
print("running ubatch ctx", context.id)
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if isinstance(context, UBatchContext):
|
||||
print("done ubatch ctx", context.id)
|
||||
if isinstance(context, UBatchContext):
|
||||
# Clone before we leave the ubatch context
|
||||
model_output = model_output.clone()
|
||||
|
||||
return model_output
|
||||
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(ubatch_ctx, root_stream, token_slice, attn_metadata, results, save_results, use_dummy_input, setup_done_evt):
|
||||
with ubatch_ctx:
|
||||
# an event to enable the start of the ubatch execution on the GPU
|
||||
# since different ubatches may be on different streams than this one
|
||||
# they all need to wait on the gpu kernels launched in
|
||||
# _prepare_inputs before continuing
|
||||
if current_stream() != root_stream:
|
||||
start_evt = torch.cuda.Event()
|
||||
# Make sure we wait then record so we don't miss the event
|
||||
current_stream().wait_event(start_evt)
|
||||
root_stream.record_event(start_evt)
|
||||
ubatch_ctx.stream.wait_stream(root_stream)
|
||||
|
||||
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
|
||||
|
||||
model_output = _run(token_slice, attn_metadata, use_dummy_input, ubatch_ctx, setup_done_evt)
|
||||
|
||||
if save_results:
|
||||
results.append(model_output.clone())
|
||||
|
||||
if current_stream() != root_stream:
|
||||
# Make the root stream for the ubatch to finish
|
||||
# Make sure we wait then record so we don't miss the event
|
||||
root_stream.wait_event(ubatch_ctx.done_evt)
|
||||
current_stream().record_event(ubatch_ctx.done_evt)
|
||||
if save_results:
|
||||
results.append(model_output)
|
||||
|
||||
def _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run):
|
||||
results = []
|
||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||
root_stream = current_stream()
|
||||
ubatch_ctxs = make_ubatch_context_chain(len(ubatch_slices),
|
||||
stream=root_stream, # Only works currently if everything is run on the same stream
|
||||
device=self.device)
|
||||
|
||||
if not hasattr(self, "ubatch_streams"):
|
||||
# Create the ubatch streams
|
||||
self.ubatch_streams = [torch.cuda.Stream(self.device) for _ in range(len(ubatch_slices))]
|
||||
|
||||
ubatch_fwd_ctxs = [create_forward_context(
|
||||
attn_metadata[i], self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start)
|
||||
) for i, (_, tokens_slice) in enumerate(ubatch_slices)]
|
||||
ubatch_ctxs, start_hook = make_ubatch_context_chain(
|
||||
len(ubatch_slices),
|
||||
fwd_ctxs=ubatch_fwd_ctxs,
|
||||
streams=self.ubatch_streams, #stream=root_stream, # Only works currently if everything is run on the same stream
|
||||
device=self.device)
|
||||
setup_done = threading.Event()
|
||||
ubatch_threads = []
|
||||
|
||||
# Ubatches will manually manage the forward context, so we override
|
||||
# it to None here so we can have it restored correctly later
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
|
||||
assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run
|
||||
|
||||
thread = threading.Thread(target=_ubatch_thread, args=(
|
||||
ubatch_ctxs[i],
|
||||
root_stream,
|
||||
tokens_slice,
|
||||
attn_metadata[i] if attn_metadata is not None else None,
|
||||
results,
|
||||
not is_dummy_ubatch or is_dummy_run,
|
||||
is_dummy_ubatch or is_dummy_run,
|
||||
setup_done,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Single the first ubatch to start
|
||||
start_hook(root_stream)
|
||||
print("started first ubatch")
|
||||
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
|
||||
for ubatch_ctx in ubatch_ctxs:
|
||||
root_stream.wait_stream(ubatch_ctx.stream)
|
||||
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
|
||||
assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run
|
||||
|
||||
thread = threading.Thread(target=_ubatch_thread, args=(
|
||||
ubatch_ctxs[i],
|
||||
root_stream,
|
||||
tokens_slice,
|
||||
attn_metadata[i] if attn_metadata is not None else None,
|
||||
results,
|
||||
not is_dummy_ubatch or is_dummy_run,
|
||||
is_dummy_ubatch or is_dummy_run,
|
||||
setup_done,
|
||||
))
|
||||
#ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
setup_done.wait()
|
||||
thread.join()
|
||||
|
||||
# for thread in ubatch_threads:
|
||||
# thread.join()
|
||||
|
||||
print("torch cat")
|
||||
torch.cuda.set_stream(root_stream)
|
||||
return torch.cat(results, dim=0)
|
||||
|
||||
@ -1386,7 +1387,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# run single batch
|
||||
else:
|
||||
model_output = _run(
|
||||
slice(0, num_scheduled_tokens), attn_metadata, is_dummy_run)
|
||||
slice(0, num_scheduled_tokens),
|
||||
set_forward_context(
|
||||
attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_scheduled_tokens or 1),
|
||||
is_dummy_run)
|
||||
|
||||
return model_output
|
||||
|
||||
|
||||
@ -15,55 +15,59 @@ class UBatchContext:
|
||||
Context manager for micro-batching synchronization using threading events.
|
||||
"""
|
||||
def __init__(self,
|
||||
id: int,
|
||||
stream: torch.cuda.Stream,
|
||||
wait_event: threading.Event,
|
||||
signal_event: threading.Event,
|
||||
fwd_ctx: forward_context.ForwardContext,
|
||||
cpu_wait_event: threading.Event,
|
||||
cpu_signal_event: threading.Event,
|
||||
gpu_wait_event: torch.cuda.Event,
|
||||
gpu_signal_event: torch.cuda.Event,
|
||||
schedule="default"):
|
||||
self.wait_event = wait_event
|
||||
self.signal_event = signal_event
|
||||
self.schedule = schedule
|
||||
self.id = id
|
||||
self.stream = stream
|
||||
self.original_stream = current_stream()
|
||||
self.done_evt = torch.cuda.Event()
|
||||
self.forward_context = None
|
||||
|
||||
def update_forward_context(self):
|
||||
self.forward_context = forward_context._forward_context
|
||||
self.forward_context = fwd_ctx
|
||||
self.cpu_wait_event = cpu_wait_event
|
||||
self.cpu_signal_event = cpu_signal_event
|
||||
self.gpu_wait_event = gpu_wait_event
|
||||
self.gpu_signal_event = gpu_signal_event
|
||||
self.schedule = schedule
|
||||
self.done_event = torch.cuda.Event()
|
||||
|
||||
def __enter__(self):
|
||||
global _CURRENT_CONTEXT
|
||||
_CURRENT_CONTEXT[threading.get_ident()] = self
|
||||
|
||||
self.original_stream = current_stream()
|
||||
self.original_forward_context = forward_context._forward_context
|
||||
self.forward_context = self.original_forward_context
|
||||
|
||||
# Set micro-batch stream
|
||||
torch.cuda.set_stream(self.stream)
|
||||
self._wait()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global _CURRENT_CONTEXT
|
||||
_CURRENT_CONTEXT[threading.get_ident()] = None
|
||||
|
||||
torch.cuda.set_stream(self.original_stream)
|
||||
forward_context._forward_context = self.original_forward_context
|
||||
self._signal()
|
||||
return False
|
||||
|
||||
def yield_(self):
|
||||
# Signal that this batch reached the barrier and wait for the other
|
||||
self.signal_event.set()
|
||||
# Wait for the next batch to signal back
|
||||
self.wait_event.wait()
|
||||
self.wait_event.clear()
|
||||
def _restore_context(self):
|
||||
# When we resume i.e. switch back to this micro-batch, we make sure
|
||||
# we have the correct stream and forward context
|
||||
torch.cuda.set_stream(self.stream)
|
||||
forward_context._forward_context = self.forward_context
|
||||
|
||||
def wait(self):
|
||||
self.wait_event.wait()
|
||||
self.wait_event.clear()
|
||||
|
||||
def yield_(self):
|
||||
self._signal()
|
||||
self._wait()
|
||||
|
||||
def _signal(self):
|
||||
# Wait for the next batch to signal back
|
||||
self.gpu_signal_event.record(self.stream)
|
||||
# Signal that this batch reached the barrier
|
||||
self.cpu_signal_event.set()
|
||||
|
||||
def _wait(self):
|
||||
self.stream.wait_event(self.gpu_wait_event)
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
|
||||
_CURRENT_CONTEXT: dict = {}
|
||||
|
||||
@ -90,21 +94,34 @@ def yield_(x: torch.Tensor, schedule: str="default") -> None:
|
||||
"""
|
||||
def make_ubatch_context_chain(
|
||||
num_micro_batches: int,
|
||||
stream: Optional[torch.Stream] = None,
|
||||
fwd_ctxs: forward_context.ForwardContext,
|
||||
streams: Optional[list[torch.Stream]] = None,
|
||||
device: Optional[torch.device] = None
|
||||
) -> list[UBatchContext]:
|
||||
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
||||
|
||||
"""
|
||||
Create a context manager for micro-batching synchronization.
|
||||
"""
|
||||
events = [threading.Event() for _ in range(num_micro_batches)]
|
||||
cpu_events = [threading.Event() for _ in range(num_micro_batches)]
|
||||
gpu_events = [torch.cuda.Event() for _ in range(num_micro_batches)]
|
||||
device = device or torch.cuda.current_device()
|
||||
|
||||
ctxs = []
|
||||
for i in range(num_micro_batches):
|
||||
wait_event = events[i]
|
||||
signal_event = events[(i + 1) % num_micro_batches]
|
||||
ctx = UBatchContext(stream or torch.cuda.Stream(device),
|
||||
wait_event, signal_event)
|
||||
stream = (streams[i] if streams else None) or torch.cuda.Stream(device)
|
||||
ctx = UBatchContext(id=i,
|
||||
stream=stream,
|
||||
fwd_ctx=fwd_ctxs[i],
|
||||
cpu_wait_event=cpu_events[i],
|
||||
cpu_signal_event=cpu_events[(i + 1) % num_micro_batches],
|
||||
gpu_wait_event=gpu_events[i],
|
||||
gpu_signal_event=gpu_events[(i + 1) % num_micro_batches],
|
||||
)
|
||||
ctxs.append(ctx)
|
||||
|
||||
def start_hook(from_stream: torch.cuda.Stream):
|
||||
ctxs[0].cpu_wait_event.set()
|
||||
ctxs[0].gpu_wait_event.record(from_stream)
|
||||
|
||||
return ctxs
|
||||
return ctxs, start_hook
|
||||
@ -1417,7 +1417,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
if model_input.attn_metadata is not None:
|
||||
model_input.attn_metadata.enable_kv_scales_calculation = False
|
||||
|
||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||
import nvtx
|
||||
with nvtx.annotate("execute_model"):
|
||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||
torch.cuda.synchronize()
|
||||
if self.lora_config:
|
||||
self._remove_dummy_loras()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user