mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 03:24:27 +08:00
Refactor: Move CUDA graph dispatch logic earlier (#27382)
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
parent
7df331c66b
commit
df78aeef08
@ -3740,6 +3740,31 @@ class GPUModelRunner(
|
|||||||
dp_rank = self.parallel_config.data_parallel_rank
|
dp_rank = self.parallel_config.data_parallel_rank
|
||||||
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
|
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
|
||||||
|
|
||||||
|
# filter out the valid batch descriptor
|
||||||
|
_cg_mode, batch_descriptor = (
|
||||||
|
self.cudagraph_dispatcher.dispatch(
|
||||||
|
BatchDescriptor(
|
||||||
|
num_tokens=num_tokens_after_padding,
|
||||||
|
uniform_decode=uniform_decode,
|
||||||
|
has_lora=activate_lora and self.lora_config is not None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not is_profile
|
||||||
|
else (CUDAGraphMode.NONE, None)
|
||||||
|
)
|
||||||
|
if cudagraph_runtime_mode is not None:
|
||||||
|
# we allow forcing NONE when the dispatcher disagrees to support
|
||||||
|
# warm ups for cudagraph capture
|
||||||
|
assert (
|
||||||
|
cudagraph_runtime_mode == CUDAGraphMode.NONE
|
||||||
|
or cudagraph_runtime_mode == _cg_mode
|
||||||
|
), (
|
||||||
|
f"Cudagraph runtime mode mismatch at dummy_run. "
|
||||||
|
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cudagraph_runtime_mode = _cg_mode
|
||||||
|
|
||||||
attn_metadata: PerLayerAttnMetadata | None = None
|
attn_metadata: PerLayerAttnMetadata | None = None
|
||||||
|
|
||||||
# If force_attention is True, we always capture attention. Otherwise,
|
# If force_attention is True, we always capture attention. Otherwise,
|
||||||
@ -3814,31 +3839,6 @@ class GPUModelRunner(
|
|||||||
num_tokens_after_padding, None, False
|
num_tokens_after_padding, None, False
|
||||||
)
|
)
|
||||||
|
|
||||||
# filter out the valid batch descriptor
|
|
||||||
_cg_mode, batch_descriptor = (
|
|
||||||
self.cudagraph_dispatcher.dispatch(
|
|
||||||
BatchDescriptor(
|
|
||||||
num_tokens=num_tokens_after_padding,
|
|
||||||
uniform_decode=uniform_decode,
|
|
||||||
has_lora=activate_lora and self.lora_config is not None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not is_profile
|
|
||||||
else (CUDAGraphMode.NONE, None)
|
|
||||||
)
|
|
||||||
if cudagraph_runtime_mode is not None:
|
|
||||||
# we allow forcing NONE when the dispatcher disagrees to support
|
|
||||||
# warm ups for cudagraph capture
|
|
||||||
assert (
|
|
||||||
cudagraph_runtime_mode == CUDAGraphMode.NONE
|
|
||||||
or cudagraph_runtime_mode == _cg_mode
|
|
||||||
), (
|
|
||||||
f"Cudagraph runtime mode mismatch at dummy_run. "
|
|
||||||
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cudagraph_runtime_mode = _cg_mode
|
|
||||||
|
|
||||||
if ubatch_slices is not None:
|
if ubatch_slices is not None:
|
||||||
# Adjust values to reflect a single ubatch.
|
# Adjust values to reflect a single ubatch.
|
||||||
# TODO(sage,lucas): this is cruft that should be addressed in
|
# TODO(sage,lucas): this is cruft that should be addressed in
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user