mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-18 21:27:08 +08:00
initial full cudagraphs support. normal runs are working. ubatching does not
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
97dbafaad6
commit
144b148de2
@ -157,6 +157,7 @@ def _support_torch_compile(
|
||||
vllm_config.compilation_config.level in [
|
||||
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
||||
] or not supports_dynamo()
|
||||
self.do_not_compile = True
|
||||
if self.do_not_compile:
|
||||
return
|
||||
compilation_counter.num_models_seen += 1
|
||||
|
||||
@ -4518,6 +4518,11 @@ class VllmConfig:
|
||||
"Piecewise compilation is not supported with "
|
||||
"microbatching. Disabling piecewiseching compilation.")
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
if not self.model_config.enforce_eager:
|
||||
self.compilation_config.full_cuda_graph = True
|
||||
logger.warning_once(
|
||||
"Enabling fullcudagraphs for microbatching"
|
||||
)
|
||||
|
||||
if (self.kv_events_config is not None
|
||||
and self.kv_events_config.enable_kv_cache_events
|
||||
|
||||
@ -903,7 +903,7 @@ def fused_topk(
|
||||
|
||||
|
||||
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
# @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
|
||||
@ -75,7 +75,9 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
if self.runner.full_cuda_graph:
|
||||
n = num_splits.size(0)
|
||||
logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}")
|
||||
if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1]:
|
||||
# First time around (CUDAGraph capture), allocate the static buffer
|
||||
if self.cg_buf_tile_scheduler_metadata is None:
|
||||
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
|
||||
@ -92,6 +94,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
|
||||
@ -91,6 +91,12 @@ UbatchSlice: TypeAlias = tuple[slice, slice]
|
||||
UBatchSlices: TypeAlias = list[UbatchSlice]
|
||||
|
||||
|
||||
import dataclasses
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphMetaData:
|
||||
cudagraph: torch.cuda.CUDAGraph
|
||||
outputs: Optional[Any] = None
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def __init__(
|
||||
@ -132,6 +138,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
self.cudagraphs = {}
|
||||
# Model-related.
|
||||
self.num_query_heads = model_config.get_num_attention_heads(
|
||||
parallel_config)
|
||||
@ -213,6 +220,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.use_cuda_graph = (self.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not self.model_config.enforce_eager)
|
||||
self.use_cuda_graph = True
|
||||
logger.info(f"self.use_cuda_graph {self.use_cuda_graph}")
|
||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||
# The convention is different.
|
||||
@ -222,6 +230,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
reversed(self.compilation_config.cudagraph_capture_sizes))
|
||||
|
||||
self.full_cuda_graph = self.compilation_config.full_cuda_graph
|
||||
self.full_cuda_graph = True
|
||||
logger.info(f"full_cuda_graph {self.full_cuda_graph}")
|
||||
|
||||
# Cache the device properties.
|
||||
self._init_device_properties()
|
||||
@ -1553,7 +1563,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output: Optional["SchedulerOutput"] = None,
|
||||
is_dummy_run: bool = False,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
skip_cuda_graphs: bool = False):
|
||||
skip_cuda_graphs: bool = False,
|
||||
build_cuda_graph: bool = False):
|
||||
|
||||
num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1
|
||||
|
||||
@ -1566,16 +1577,43 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert scheduler_output is not None
|
||||
return self._get_model_inputs(tokens_slice, scheduler_output)
|
||||
|
||||
def _run(token_slice: slice, context, use_dummy_input: bool = False):
|
||||
def _run(token_slice: slice,
|
||||
context,
|
||||
use_dummy_input: bool = False,
|
||||
build_cuda_graph: bool = False):
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(token_slice, use_dummy_input)
|
||||
with context:
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
# model_output = self.model(
|
||||
# input_ids=input_ids,
|
||||
# positions=positions,
|
||||
# intermediate_tensors=intermediate_tensors,
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# )
|
||||
num_tokens = token_slice.stop - token_slice.start
|
||||
if build_cuda_graph and num_tokens not in self.cudagraphs:
|
||||
print(f"Capturing for {num_tokens}")
|
||||
assert use_dummy_input
|
||||
self.cudagraphs[num_tokens] = CUDAGraphMetaData(cudagraph=torch.cuda.CUDAGraph())
|
||||
with torch.cuda.graph(self.cudagraphs[num_tokens].cudagraph):
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
self.cudagraphs[num_tokens].outputs = model_output
|
||||
elif num_tokens in self.cudagraphs and not skip_cuda_graphs:
|
||||
logger.info("GRAPH REPLAY")
|
||||
self.cudagraphs[num_tokens].cudagraph.replay()
|
||||
model_output = self.cudagraphs[num_tokens].outputs
|
||||
else:
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if isinstance(context, UBatchContext):
|
||||
# Clone before we leave the ubatch context
|
||||
model_output = model_output.clone()
|
||||
@ -1584,16 +1622,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
|
||||
use_dummy_input):
|
||||
use_dummy_input, build_cuda_graph):
|
||||
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
|
||||
model_output = _run(token_slice, ubatch_ctx, use_dummy_input, build_cuda_graph)
|
||||
|
||||
if save_results:
|
||||
results.append((ubatch_ctx.id, model_output))
|
||||
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||
|
||||
def _run_ubatches(ubatch_slices, attn_metadata,
|
||||
is_dummy_run, num_tokens_across_dp) -> torch.Tensor:
|
||||
is_dummy_run, num_tokens_across_dp, build_cuda_graph) -> torch.Tensor:
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||
root_stream = current_stream()
|
||||
@ -1632,6 +1670,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
or is_dummy_run,
|
||||
is_dummy_ubatch
|
||||
or is_dummy_run,
|
||||
build_cuda_graph
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
@ -1650,8 +1689,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# num_tokens = ubatch_slices[1][1].stop
|
||||
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
|
||||
# assert not is_dummy_run
|
||||
model_output = _run_ubatches(ubatch_slices, attn_metadata,
|
||||
is_dummy_run, num_tokens_across_dp=num_tokens_across_dp)
|
||||
model_output = _run_ubatches(ubatch_slices,
|
||||
attn_metadata,
|
||||
is_dummy_run,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
build_cuda_graph=build_cuda_graph)
|
||||
# run single batch
|
||||
else:
|
||||
# print("RUN NORMAL")
|
||||
@ -1662,7 +1704,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens=num_scheduled_tokens or 1,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs),
|
||||
is_dummy_run)
|
||||
is_dummy_run,
|
||||
build_cuda_graph=build_cuda_graph)
|
||||
|
||||
return model_output
|
||||
|
||||
@ -2222,6 +2265,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# For profiling runs we dont want microbatching but for
|
||||
# dp dummy runs we do.
|
||||
allow_microbatching: bool = False,
|
||||
build_cuda_graph: bool = False
|
||||
) -> torch.Tensor:
|
||||
|
||||
if allow_microbatching:
|
||||
@ -2239,7 +2283,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# _dummy_run doesn't go through _prepare_inputs so
|
||||
# we synchronize with other DP ranks here
|
||||
should_ubatch = self.should_ubatch(allow_microbatching)
|
||||
assert not should_ubatch
|
||||
# Padding for DP
|
||||
# logger.info("PADDING DUMMY")
|
||||
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
||||
@ -2295,10 +2338,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# over a certain threshold.
|
||||
if should_ubatch:
|
||||
assert num_tokens % 2 == 0
|
||||
# TODO (Sage) Add actual slices here
|
||||
assert False
|
||||
dummy_microbatches = [(slice(0, 0), slice(0, 0)),
|
||||
(slice(0, 0), slice(0, 0))]
|
||||
dummy_microbatches = [(slice(0, num_tokens // 2),
|
||||
slice(0, num_tokens // 2)),
|
||||
(slice(num_tokens // 2, num_tokens),
|
||||
slice(num_tokens // 2, num_tokens))]
|
||||
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
@ -2307,7 +2350,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens,
|
||||
ubatch_slices=dummy_microbatches,
|
||||
is_dummy_run=True,
|
||||
num_tokens_across_dp=num_tokens_across_dp
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
build_cuda_graph=build_cuda_graph
|
||||
)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, _ = outputs
|
||||
@ -2505,10 +2549,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.compilation_config.cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens,
|
||||
capture_attn_cudagraph=full_cg,
|
||||
allow_microbatching=allow_microbatching)
|
||||
allow_microbatching=allow_microbatching,
|
||||
build_cuda_graph=True)
|
||||
self._dummy_run(num_tokens,
|
||||
capture_attn_cudagraph=full_cg,
|
||||
allow_microbatching=allow_microbatching)
|
||||
allow_microbatching=allow_microbatching,
|
||||
build_cuda_graph=True)
|
||||
|
||||
logger.info("CAPTURE MODEL END")
|
||||
end_time = time.perf_counter()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user