mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:26:15 +08:00
[BugFix] Temporary fix for IMA with MTP = 2 and full-cg (#28315)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
1b82fb0ad3
commit
64e39d667c
@ -18,6 +18,7 @@ from vllm.config.utils import config
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||||
|
from vllm.utils.math_utils import round_up
|
||||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -773,19 +774,8 @@ class CompilationConfig:
|
|||||||
if self.cudagraph_capture_sizes:
|
if self.cudagraph_capture_sizes:
|
||||||
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size
|
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size
|
||||||
|
|
||||||
# pre-compute the mapping from batch size to padded graph size
|
# May get recomputed in the model runner if adjustment is needed for spec-decode
|
||||||
self.bs_to_padded_graph_size = [
|
self.compute_bs_to_padded_graph_size()
|
||||||
0 for i in range(self.max_cudagraph_capture_size + 1)
|
|
||||||
]
|
|
||||||
for end, start in zip(
|
|
||||||
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
|
|
||||||
[0] + self.cudagraph_capture_sizes,
|
|
||||||
):
|
|
||||||
for bs in range(start, end):
|
|
||||||
if bs == start:
|
|
||||||
self.bs_to_padded_graph_size[bs] = start
|
|
||||||
else:
|
|
||||||
self.bs_to_padded_graph_size[bs] = end
|
|
||||||
|
|
||||||
def set_splitting_ops_for_v1(self):
|
def set_splitting_ops_for_v1(self):
|
||||||
# NOTE: this function needs to be called only when mode is
|
# NOTE: this function needs to be called only when mode is
|
||||||
@ -922,3 +912,64 @@ class CompilationConfig:
|
|||||||
enable_str,
|
enable_str,
|
||||||
op,
|
op,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def adjust_cudagraph_sizes_for_spec_decode(
|
||||||
|
self, uniform_decode_query_len: int, tensor_parallel_size: int
|
||||||
|
):
|
||||||
|
multiple_of = uniform_decode_query_len
|
||||||
|
if tensor_parallel_size > 1:
|
||||||
|
multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
|
||||||
|
if (
|
||||||
|
multiple_of % uniform_decode_query_len != 0
|
||||||
|
or multiple_of % tensor_parallel_size != 0
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Can't determine cudagraph shapes that are both a "
|
||||||
|
f"multiple of {uniform_decode_query_len} "
|
||||||
|
f"(num_speculative_tokens + 1) required by spec-decode "
|
||||||
|
f"and {tensor_parallel_size} (tensor_parallel_size) "
|
||||||
|
f"required by sequence parallelism please adjust "
|
||||||
|
f"num_speculative_tokens or disable sequence parallelism"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.cudagraph_capture_sizes or multiple_of <= 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert self.max_cudagraph_capture_size is not None
|
||||||
|
rounded_sizes = sorted(
|
||||||
|
set(
|
||||||
|
round_up(size, multiple_of)
|
||||||
|
for size in self.cudagraph_capture_sizes
|
||||||
|
if round_up(size, multiple_of) <= self.max_cudagraph_capture_size
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(rounded_sizes) == 0:
|
||||||
|
logger.warning(
|
||||||
|
"No valid cudagraph sizes after rounding to multiple of "
|
||||||
|
" num_speculative_tokens + 1 (%d); please adjust num_speculative_tokens"
|
||||||
|
" or max_cudagraph_capture_size (or cudagraph_capture_sizes)",
|
||||||
|
multiple_of,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.max_cudagraph_capture_size = rounded_sizes[-1]
|
||||||
|
self.cudagraph_capture_sizes = rounded_sizes
|
||||||
|
|
||||||
|
# Recompute after adjusting the cudagraph sizes
|
||||||
|
self.compute_bs_to_padded_graph_size()
|
||||||
|
|
||||||
|
def compute_bs_to_padded_graph_size(self):
|
||||||
|
# pre-compute the mapping from batch size to padded graph size
|
||||||
|
self.bs_to_padded_graph_size = [
|
||||||
|
0 for i in range(self.max_cudagraph_capture_size + 1)
|
||||||
|
]
|
||||||
|
for end, start in zip(
|
||||||
|
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
|
||||||
|
[0] + self.cudagraph_capture_sizes,
|
||||||
|
):
|
||||||
|
for bs in range(start, end):
|
||||||
|
if bs == start:
|
||||||
|
self.bs_to_padded_graph_size[bs] = start
|
||||||
|
else:
|
||||||
|
self.bs_to_padded_graph_size[bs] = end
|
||||||
|
|||||||
@ -4332,6 +4332,22 @@ class GPUModelRunner(
|
|||||||
"and make sure compilation mode is VLLM_COMPILE"
|
"and make sure compilation mode is VLLM_COMPILE"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# if we have dedicated decode cudagraphs, and spec-decode is enabled,
|
||||||
|
# we need to adjust the cudagraph sizes to be a multiple of the uniform
|
||||||
|
# decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207
|
||||||
|
# temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536
|
||||||
|
# Will be removed in the near future when we have seperate cudagraph capture
|
||||||
|
# sizes for decode and mixed prefill-decode.
|
||||||
|
if (
|
||||||
|
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||||||
|
and cudagraph_mode.separate_routine()
|
||||||
|
and self.uniform_decode_query_len > 1
|
||||||
|
):
|
||||||
|
self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
|
||||||
|
self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
|
||||||
|
)
|
||||||
|
self.cudagraph_batch_sizes = self.compilation_config.cudagraph_capture_sizes
|
||||||
|
|
||||||
# Trigger cudagraph dispatching keys initialization after
|
# Trigger cudagraph dispatching keys initialization after
|
||||||
# resolved cudagraph mode.
|
# resolved cudagraph mode.
|
||||||
self.cudagraph_dispatcher.initialize_cudagraph_keys(
|
self.cudagraph_dispatcher.initialize_cudagraph_keys(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user