mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 00:07:08 +08:00
manually manage stream
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
020269c4c5
commit
ffb740ae95
@ -34,8 +34,9 @@ def main():
|
||||
enforce_eager=False,
|
||||
compilation_config=2,
|
||||
enable_microbatching=True,
|
||||
enable_expert_parallel=True,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=4,
|
||||
tensor_parallel_size=2,
|
||||
max_model_len=1024,
|
||||
#load_format="dummy",
|
||||
)
|
||||
|
||||
@ -4328,7 +4328,13 @@ class VllmConfig:
|
||||
if self.parallel_config.enable_microbatching:
|
||||
# Microbatching is not supported with piecewise compilation yet.
|
||||
# More specifically piecewise cuda-graphs
|
||||
self.compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||
if self.compilation_config.level >= CompilationLevel.PIECEWISE:
|
||||
logger.warning_once(
|
||||
"Piecewise compilation is not supported with "
|
||||
"microbatching. Disabling piecewiseching compilation.")
|
||||
self.parallel_config.enable_microbatching = False
|
||||
self.compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||
|
||||
|
||||
if self.model_config and self.model_config.use_mla and \
|
||||
not (current_platform.is_cuda() or current_platform.is_rocm()):
|
||||
|
||||
@ -1294,10 +1294,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return num_tokens, *self._get_model_inputs(tokens_slice, scheduler_output)
|
||||
|
||||
@torch.inference_mode()
|
||||
def process_batch(i, is_dummy_ubatch, is_dummy_run, attn_metadata, vllm_config, model, num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors, results):
|
||||
with set_forward_context(attn_metadata[i] if attn_metadata is not None else None,
|
||||
def process_batch(save_results, attn_metadata, vllm_config, model, num_tokens, input_ids, positions, inputs_embeds, intermediate_tensors, results, stream):
|
||||
with set_forward_context(attn_metadata,
|
||||
vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
torch.cuda.set_stream(stream)
|
||||
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@ -1305,7 +1307,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
if not is_dummy_ubatch or is_dummy_run:
|
||||
if save_results:
|
||||
results.append(model_output.clone())
|
||||
|
||||
def threaded_processing(ubatch_slices, attn_metadata, vllm_config, model, is_dummy_run=False):
|
||||
@ -1320,10 +1322,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
model_inputs(tokens_slice, is_dummy_ubatch)
|
||||
|
||||
thread = threading.Thread(target=process_batch, args=(
|
||||
i,
|
||||
is_dummy_ubatch,
|
||||
is_dummy_run,
|
||||
attn_metadata,
|
||||
not is_dummy_ubatch or is_dummy_run,
|
||||
attn_metadata[i] if attn_metadata is not None else None,
|
||||
vllm_config,
|
||||
model,
|
||||
num_tokens,
|
||||
@ -1332,6 +1332,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
inputs_embeds,
|
||||
intermediate_tensors,
|
||||
results,
|
||||
torch.cuda.current_stream()
|
||||
))
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
@ -10,44 +10,58 @@ class UBatchContext:
|
||||
"""
|
||||
Context manager for micro-batching synchronization using threading events.
|
||||
"""
|
||||
def __init__(self, wait_event: threading.Event, signal_event: threading.Event):
|
||||
def __init__(self,
|
||||
stream: torch.cuda.Stream,
|
||||
wait_event: threading.Event,
|
||||
signal_event: threading.Event,
|
||||
schedule="default"):
|
||||
self.wait_event = wait_event
|
||||
self.signal_event = signal_event
|
||||
self.schedule = schedule
|
||||
self.stream = stream
|
||||
self.original_stream = torch.cuda.current_stream()
|
||||
|
||||
def __enter__(self):
|
||||
global _CURRENT_CONTEXT
|
||||
self.original_stream = torch.cuda.current_stream()
|
||||
_CURRENT_CONTEXT[threading.get_ident()] = self
|
||||
# Set micro-batch stream
|
||||
torch.cuda.set_stream(self.stream)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global _CURRENT_CONTEXT
|
||||
_CURRENT_CONTEXT[threading.get_ident()] = None
|
||||
# Restore the original stream
|
||||
torch.cuda.set_stream(self.original_stream)
|
||||
return False
|
||||
|
||||
def yield_(self):
|
||||
# Signal that this batch reached the barrier and wait for the other
|
||||
self.signal_event.set()
|
||||
# Wait for the next batch to signal back
|
||||
self.wait_event.wait()
|
||||
# Reset for reuse
|
||||
self.wait_event.clear()
|
||||
# When we resume switch back to the microbatch stream
|
||||
torch.cuda.set_stream(self.stream)
|
||||
|
||||
_CURRENT_CONTEXT: dict = {}
|
||||
|
||||
def yield_impl():
|
||||
def yield_impl(schedule="default"):
|
||||
# Perform the barrier if a context exists for this thread
|
||||
ctx = _CURRENT_CONTEXT.get(threading.get_ident(), None)
|
||||
if ctx is not None:
|
||||
if ctx is not None and ctx.schedule == schedule:
|
||||
ctx.yield_()
|
||||
|
||||
|
||||
# 2) Register kernel for CUDA
|
||||
@custom_op("custom::yield_", mutates_args=("x",))
|
||||
def yield_(x: torch.Tensor) -> None:
|
||||
yield_impl()
|
||||
@custom_op("vllm::yield_", mutates_args=("x",))
|
||||
def yield_(x: torch.Tensor, schedule="default") -> None:
|
||||
yield_impl(schedule)
|
||||
|
||||
# 3) Fake implementation for shape prop and FX tracing
|
||||
@yield_.register_fake
|
||||
def yield_(x: torch.Tensor):
|
||||
def yield_(x: torch.Tensor, schedule="default") -> None:
|
||||
pass
|
||||
|
||||
"""
|
||||
@ -63,10 +77,7 @@ def make_ubatch_context_chain(num_micro_batches: int) -> list[UBatchContext]:
|
||||
for i in range(num_micro_batches):
|
||||
wait_event = events[i]
|
||||
signal_event = events[(i + 1) % num_micro_batches]
|
||||
ctx = UBatchContext(wait_event, signal_event)
|
||||
ctx = UBatchContext(torch.Stream(), wait_event, signal_event)
|
||||
ctxs.append(ctx)
|
||||
|
||||
# Create the context manager
|
||||
ctx = UBatchContext(wait_event, signal_event)
|
||||
|
||||
return ctx
|
||||
return ctxs
|
||||
Loading…
x
Reference in New Issue
Block a user