mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 23:27:12 +08:00
working but only on the same stream
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
ffb740ae95
commit
04f11d97a0
@ -40,7 +40,7 @@ def main():
|
||||
max_model_len=1024,
|
||||
#load_format="dummy",
|
||||
###############
|
||||
tensor_parallel_size=4,
|
||||
tensor_parallel_size=2,
|
||||
#data_parallel_size=2,
|
||||
enable_expert_parallel=False,
|
||||
###############
|
||||
|
||||
@ -57,6 +57,7 @@ from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.ubatching import make_ubatch_context_chain, UBatchContext
|
||||
|
||||
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders)
|
||||
@ -1284,104 +1285,108 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output: Optional["SchedulerOutput"] = None,
|
||||
is_dummy_run: bool = False):
|
||||
|
||||
def model_inputs(tokens_slice: slice, is_dummy: bool) -> tuple:
|
||||
if is_dummy:
|
||||
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
|
||||
if use_dummy_input:
|
||||
num_tokens = num_scheduled_tokens or 1
|
||||
return num_tokens, *self._get_dummy_model_inputs(num_tokens)
|
||||
else:
|
||||
assert scheduler_output is not None
|
||||
num_tokens = tokens_slice.stop - tokens_slice.start
|
||||
return num_tokens, *self._get_model_inputs(tokens_slice, scheduler_output)
|
||||
|
||||
@torch.inference_mode()
|
||||
def process_batch(save_results, attn_metadata, vllm_config, model, num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors, results, stream):
|
||||
with set_forward_context(attn_metadata,
|
||||
vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
torch.cuda.set_stream(stream)
|
||||
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
if save_results:
|
||||
results.append(model_output.clone())
|
||||
|
||||
def threaded_processing(ubatch_slices, attn_metadata, vllm_config, model, is_dummy_run=False):
|
||||
results = []
|
||||
# print(f"UBATCH SLICES: {len(ubatch_slices)}")
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
# print("ITERATION")
|
||||
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
|
||||
assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run
|
||||
|
||||
num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(tokens_slice, is_dummy_ubatch)
|
||||
|
||||
thread = threading.Thread(target=process_batch, args=(
|
||||
not is_dummy_ubatch or is_dummy_run,
|
||||
attn_metadata[i] if attn_metadata is not None else None,
|
||||
vllm_config,
|
||||
model,
|
||||
num_tokens,
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
intermediate_tensors,
|
||||
results,
|
||||
torch.cuda.current_stream()
|
||||
))
|
||||
thread.start()
|
||||
thread.join()
|
||||
# for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
# is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
|
||||
# assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run
|
||||
# num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
# model_inputs(tokens_slice, is_dummy_ubatch)
|
||||
# process_batch(
|
||||
# i,
|
||||
# is_dummy_ubatch,
|
||||
# is_dummy_run,
|
||||
# attn_metadata,
|
||||
# vllm_config,
|
||||
# model,
|
||||
# num_tokens,
|
||||
# input_ids,
|
||||
# positions,
|
||||
# inputs_embeds,
|
||||
# intermediate_tensors,
|
||||
# results,
|
||||
# )
|
||||
|
||||
if results:
|
||||
return torch.cat(results, dim=0)
|
||||
else:
|
||||
return None
|
||||
|
||||
# run micro-batched
|
||||
if ubatch_slices is not None:
|
||||
model_output = threaded_processing(ubatch_slices,
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
self.model,
|
||||
is_dummy_run)
|
||||
# print("FINISHED MODEL OUTPUT")
|
||||
# run single batch
|
||||
else:
|
||||
|
||||
|
||||
def _run(token_slice: slice, attn_metadata, use_dummy_input: bool = False,
|
||||
ubatch_context: Optional[UBatchContext]=None,
|
||||
setup_done_evt: Optional[threading.Event]=None):
|
||||
num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
|
||||
model_inputs(token_slice, use_dummy_input)
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
if ubatch_context:
|
||||
# Update the forward context now that its available
|
||||
ubatch_context.update_forward_context()
|
||||
|
||||
if setup_done_evt is not None:
|
||||
# Wait for the setup to be done
|
||||
setup_done_evt.set()
|
||||
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return model_output
|
||||
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(ubatch_ctx, root_stream, token_slice, attn_metadata, results, save_results, use_dummy_input, setup_done_evt):
|
||||
with ubatch_ctx:
|
||||
# an event to enable the start of the ubatch execution on the GPU
|
||||
# since different ubatches may be on different streams than this one
|
||||
# they all need to wait on the gpu kernels launched in
|
||||
# _prepare_inputs before continuing
|
||||
if torch.cuda.current_stream() != root_stream:
|
||||
start_evt = torch.cuda.Event()
|
||||
# Make sure we wait then record so we don't miss the event
|
||||
torch.cuda.current_stream().wait_event(start_evt)
|
||||
root_stream.record_event(start_evt)
|
||||
|
||||
model_output = _run(token_slice, attn_metadata, use_dummy_input, ubatch_ctx, setup_done_evt)
|
||||
|
||||
if save_results:
|
||||
results.append(model_output.clone())
|
||||
|
||||
if torch.cuda.current_stream() != root_stream:
|
||||
# Make the root stream for the ubatch to finish
|
||||
# Make sure we wait then record so we don't miss the event
|
||||
root_stream.wait_event(ubatch_ctx.done_evt)
|
||||
torch.cuda.current_stream().record_event(ubatch_ctx.done_evt)
|
||||
|
||||
def _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run):
|
||||
results = []
|
||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||
root_stream = torch.cuda.current_stream()
|
||||
ubatch_ctxs = make_ubatch_context_chain(len(ubatch_slices),
|
||||
stream=root_stream, # Only works currently if everything is run on the same stream
|
||||
device=self.device)
|
||||
setup_done = threading.Event()
|
||||
ubatch_threads = []
|
||||
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
|
||||
assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run
|
||||
|
||||
thread = threading.Thread(target=_ubatch_thread, args=(
|
||||
ubatch_ctxs[i],
|
||||
root_stream,
|
||||
tokens_slice,
|
||||
attn_metadata[i] if attn_metadata is not None else None,
|
||||
results,
|
||||
not is_dummy_ubatch or is_dummy_run,
|
||||
is_dummy_ubatch or is_dummy_run,
|
||||
setup_done,
|
||||
))
|
||||
#ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
setup_done.wait()
|
||||
thread.join()
|
||||
|
||||
# for thread in ubatch_threads:
|
||||
# thread.join()
|
||||
|
||||
torch.cuda.set_stream(root_stream)
|
||||
return torch.cat(results, dim=0)
|
||||
|
||||
# run micro-batched
|
||||
if ubatch_slices is not None:
|
||||
model_output = _run_ubatches(
|
||||
ubatch_slices, attn_metadata, is_dummy_run)
|
||||
# run single batch
|
||||
else:
|
||||
model_output = _run(
|
||||
slice(0, num_scheduled_tokens), attn_metadata, is_dummy_run)
|
||||
|
||||
return model_output
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@ -3,8 +3,10 @@ import threading
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch.profiler as profiler
|
||||
from typing import Optional
|
||||
from torch.library import Library
|
||||
from torch.library import custom_op, register_kernel
|
||||
from vllm import forward_context
|
||||
|
||||
class UBatchContext:
|
||||
"""
|
||||
@ -20,11 +22,20 @@ class UBatchContext:
|
||||
self.schedule = schedule
|
||||
self.stream = stream
|
||||
self.original_stream = torch.cuda.current_stream()
|
||||
self.done_evt = torch.cuda.Event()
|
||||
self.forward_context = None
|
||||
|
||||
def update_forward_context(self):
|
||||
self.forward_context = forward_context._forward_context
|
||||
|
||||
def __enter__(self):
|
||||
global _CURRENT_CONTEXT
|
||||
self.original_stream = torch.cuda.current_stream()
|
||||
_CURRENT_CONTEXT[threading.get_ident()] = self
|
||||
|
||||
self.original_stream = torch.cuda.current_stream()
|
||||
self.original_forward_context = forward_context._forward_context
|
||||
self.forward_context = self.original_forward_context
|
||||
|
||||
# Set micro-batch stream
|
||||
torch.cuda.set_stream(self.stream)
|
||||
return self
|
||||
@ -32,8 +43,9 @@ class UBatchContext:
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global _CURRENT_CONTEXT
|
||||
_CURRENT_CONTEXT[threading.get_ident()] = None
|
||||
# Restore the original stream
|
||||
|
||||
torch.cuda.set_stream(self.original_stream)
|
||||
forward_context._forward_context = self.original_forward_context
|
||||
return False
|
||||
|
||||
def yield_(self):
|
||||
@ -42,8 +54,14 @@ class UBatchContext:
|
||||
# Wait for the next batch to signal back
|
||||
self.wait_event.wait()
|
||||
self.wait_event.clear()
|
||||
# When we resume switch back to the microbatch stream
|
||||
# When we resume i.e. switch back to this micro-batch, we make sure
|
||||
# we have the correct stream and forward context
|
||||
torch.cuda.set_stream(self.stream)
|
||||
forward_context._forward_context = self.forward_context
|
||||
|
||||
def wait(self):
|
||||
self.wait_event.wait()
|
||||
self.wait_event.clear()
|
||||
|
||||
_CURRENT_CONTEXT: dict = {}
|
||||
|
||||
@ -54,30 +72,37 @@ def yield_impl(schedule="default"):
|
||||
ctx.yield_()
|
||||
|
||||
|
||||
# 2) Register kernel for CUDA
|
||||
# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from
|
||||
# optimizing it away (TODO: see if this is actually needed)
|
||||
@custom_op("vllm::yield_", mutates_args=("x",))
|
||||
def yield_(x: torch.Tensor, schedule="default") -> None:
|
||||
def yield_(x: torch.Tensor, schedule: str="default") -> None:
|
||||
yield_impl(schedule)
|
||||
|
||||
# 3) Fake implementation for shape prop and FX tracing
|
||||
@yield_.register_fake
|
||||
def yield_(x: torch.Tensor, schedule="default") -> None:
|
||||
def yield_(x: torch.Tensor, schedule: str="default") -> None:
|
||||
pass
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
def make_ubatch_context_chain(num_micro_batches: int) -> list[UBatchContext]:
|
||||
def make_ubatch_context_chain(
|
||||
num_micro_batches: int,
|
||||
stream: Optional[torch.Stream] = None,
|
||||
device: Optional[torch.device] = None
|
||||
) -> list[UBatchContext]:
|
||||
"""
|
||||
Create a context manager for micro-batching synchronization.
|
||||
"""
|
||||
events = [threading.Event() for _ in range(num_micro_batches)]
|
||||
device = device or torch.cuda.current_device()
|
||||
|
||||
ctxs = []
|
||||
for i in range(num_micro_batches):
|
||||
wait_event = events[i]
|
||||
signal_event = events[(i + 1) % num_micro_batches]
|
||||
ctx = UBatchContext(torch.Stream(), wait_event, signal_event)
|
||||
ctx = UBatchContext(stream or torch.cuda.Stream(device),
|
||||
wait_event, signal_event)
|
||||
ctxs.append(ctx)
|
||||
|
||||
return ctxs
|
||||
Loading…
x
Reference in New Issue
Block a user