From aa3b3d76e0db63a4214b45805dc9bc3e5609c30e Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 11 Apr 2025 02:09:52 -0600 Subject: [PATCH] Enforce valid max_num_batched_tokens when disable_chunked_mm_input=True (#16447) Signed-off-by: mgoin --- tests/v1/core/test_scheduler.py | 9 +++++++++ vllm/engine/arg_utils.py | 2 +- vllm/v1/core/encoder_cache_manager.py | 8 ++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 75c507555559..bc17ca32e5b6 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -322,6 +322,15 @@ def test_no_mm_input_chunking(): assert len(output.finished_req_ids) == 0 assert output.num_scheduled_tokens[requests[0].request_id] == 800 + # Test that we fail if we disable chunked mm input and use too small + # of a max_num_batched_tokens for the mm input. + with pytest.raises(ValueError): + _ = create_scheduler( + model="llava-hf/llava-1.5-7b-hf", + max_num_batched_tokens=100, + disable_chunked_mm_input=True, + ) + @pytest.mark.parametrize("enable_prefix_caching", [True, False]) def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9cc6eca24b5c..3eafb6827d49 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1030,7 +1030,7 @@ class EngineArgs: action=StoreBoolean, default=EngineArgs.disable_chunked_mm_input, nargs="?", - const="False", + const="True", help="Disable multimodal input chunking attention for V1. " "If set to true and chunked prefill is enabled, we do not want to" " partially schedule a multimodal item. This ensures that if a " diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index dc76df268c58..05d70bb9b977 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -133,6 +133,14 @@ def _compute_encoder_budget_multimodal( _, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(), key=lambda item: item[1]) + if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item + > scheduler_config.max_num_batched_tokens): + raise ValueError( + "Chunked MM input disabled but max_tokens_per_mm_item " + f"({max_tokens_per_mm_item}) is larger than max_num_batched_tokens" + f" ({scheduler_config.max_num_batched_tokens}). Please increase " + "max_num_batched_tokens.") + encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens, max_tokens_per_mm_item) encoder_cache_size = max(scheduler_config.encoder_cache_size,