fix eager mode

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-08-05 18:01:25 +00:00
parent e283eff060
commit 4819bb8715

View File

@ -250,12 +250,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
is_spec_decode=bool(self.vllm_config.speculative_config), is_spec_decode=bool(self.vllm_config.speculative_config),
) )
can_use_cudagraphs = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
or self.compilation_config.full_cuda_graph)
self.use_cuda_graph = ( self.use_cuda_graph = (
self.vllm_config.compilation_config.level can_use_cudagraphs
== CompilationLevel.PIECEWISE
and self.vllm_config.compilation_config.use_cudagraph and self.vllm_config.compilation_config.use_cudagraph
and not self.model_config.enforce_eager) and not self.model_config.enforce_eager)
self.use_cuda_graph = True # self.use_cuda_graph = True
# TODO(woosuk): Provide an option to tune the max cudagraph batch size. # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different. # The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order. # self.cudagraph_batch_sizes sorts in ascending order.
@ -654,7 +656,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens_after_padding = None num_tokens_after_padding = None
ubatch_abort = False ubatch_abort = False
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch( num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(
ubatch_slices, True) ubatch_slices)
if num_pad_tokens > 0: if num_pad_tokens > 0:
# Check if the padding would result in an empty second ubatch. # Check if the padding would result in an empty second ubatch.
# If so abort ubatching # If so abort ubatching
@ -1546,9 +1548,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens_padded = num_tokens_unpadded num_tokens_padded = num_tokens_unpadded
# if (self.use_cuda_graph if (self.use_cuda_graph and not self.parallel_config.enable_microbatching
# and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
if False: # if False:
# Use piecewise CUDA graphs. # Use piecewise CUDA graphs.
# Add padding to the batch size. # Add padding to the batch size.
num_tokens_padded = self.vllm_config.pad_for_cudagraph( num_tokens_padded = self.vllm_config.pad_for_cudagraph(
@ -1571,8 +1573,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def get_dp_padding_ubatch( def get_dp_padding_ubatch(
self, self,
ubatch_slices: UBatchSlices, ubatch_slices: UBatchSlices) -> tuple[int, Optional[torch.Tensor]]:
include_cudagraphs: bool) -> tuple[int, Optional[torch.Tensor]]:
dp_size = self.vllm_config.parallel_config.data_parallel_size dp_size = self.vllm_config.parallel_config.data_parallel_size
if dp_size == 1: if dp_size == 1:
@ -1592,7 +1593,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens_unpadded = first_ubatch_num_tokens + second_ubatch_num_tokens num_tokens_unpadded = first_ubatch_num_tokens + second_ubatch_num_tokens
num_tokens_padded = round_up(num_tokens_unpadded, 2) num_tokens_padded = round_up(num_tokens_unpadded, 2)
if (include_cudagraphs and self.use_cuda_graph if (self.full_cuda_graph
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
# Add padding to the batch size. # Add padding to the batch size.
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded) num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded)
@ -3056,7 +3057,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self._dummy_run(num_tokens, self._dummy_run(num_tokens,
capture_attn_cudagraph=full_cg, capture_attn_cudagraph=full_cg,
allow_microbatching=allow_microbatching, allow_microbatching=allow_microbatching,
build_cuda_graph=True, build_cuda_graph=full_cg,
skip_eplb=True) skip_eplb=True)
end_time = time.perf_counter() end_time = time.perf_counter()