diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 041687ae28b2..85976fc1c825 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -36,7 +36,8 @@ from vllm.model_executor.models.interfaces import (is_mixture_of_experts, from vllm.model_executor.models.interfaces_base import ( VllmModelForPooling, is_pooling_model, is_text_generation_model) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, + PlaceholderRange) from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType @@ -51,7 +52,6 @@ from vllm.v1.attention.backends.utils import ( make_kv_sharing_fast_prefill_attention_metadata, make_local_attention_virtual_batches, reorder_batch_to_split_decodes_and_prefills) -from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, FullAttentionSpec, KVCacheConfig, @@ -73,7 +73,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import ( from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from ..sample.logits_processor import LogitsProcessorManager -from .utils import (bind_kv_cache, gather_mm_placeholders, +from .utils import (MultiModalBudget, bind_kv_cache, gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) @@ -148,14 +148,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope - encoder_compute_budget, encoder_cache_size = compute_encoder_budget( - model_config=model_config, - scheduler_config=scheduler_config, - mm_registry=self.mm_registry, - ) - self.max_num_encoder_input_tokens = encoder_compute_budget - self.encoder_cache_size = encoder_cache_size - # Sampler self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) @@ -330,6 +322,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_sharing_fast_prefill_logits_indices = torch.zeros( self.max_num_tokens, dtype=torch.int32, device=self.device) + self.mm_budget = (MultiModalBudget( + self.model_config, + self.scheduler_config, + self.mm_registry, + max_model_len=self.max_model_len, + max_num_reqs=self.max_num_reqs, + ) if self.is_multimodal_model else None) + self.reorder_batch_threshold: Optional[int] = None def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: @@ -578,37 +578,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() - def _init_model_kwargs_for_multimodal_model( + def _extract_mm_kwargs( self, - scheduler_output: Optional["SchedulerOutput"] = None, - num_reqs: int = -1, - ) -> dict[str, Any]: - - model_kwargs: dict[str, Any] = {} - if self.is_multimodal_raw_input_supported: - # This model requires the raw multimodal data in input. + scheduler_output: "SchedulerOutput", + ) -> BatchedTensorInputs: + if self.is_multimodal_raw_input_supported: # noqa: SIM102 if scheduler_output: - multi_modal_kwargs_list = [] + multi_modal_kwargs_list = list[MultiModalKwargs]() for req in scheduler_output.scheduled_new_reqs: req_mm_inputs = req.mm_inputs if not isinstance(req_mm_inputs, list): req_mm_inputs = list(req_mm_inputs) multi_modal_kwargs_list.extend(req_mm_inputs) - multi_modal_kwargs = MultiModalKwargs.batch( - multi_modal_kwargs_list) - else: - # The only case where SchedulerOutput is None is for - # a dummy run let's get some dummy data. - dummy_data = [ - self.mm_registry.get_decoder_dummy_data( - model_config=self.model_config, - seq_len=1).multi_modal_data for i in range(num_reqs) - ] - multi_modal_kwargs = MultiModalKwargs.batch(dummy_data) - model_kwargs.update(multi_modal_kwargs) + return MultiModalKwargs.batch(multi_modal_kwargs_list) - return model_kwargs + return {} + + def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: + if self.is_multimodal_raw_input_supported: + mm_budget = self.mm_budget + assert mm_budget is not None + + dummy_modality, _ = mm_budget.get_modality_with_max_tokens() + + return self._get_mm_dummy_batch(dummy_modality, num_seqs) + + return {} def _get_cumsum_and_arange( self, @@ -1517,19 +1513,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. - input_ids = self.input_ids[:num_scheduled_tokens] - - model_kwargs = self._init_model_kwargs_for_multimodal_model( - scheduler_output=scheduler_output) - inputs_embeds = self.model.get_input_embeddings( - input_ids=input_ids, + inputs_embeds_scheduled = self.model.get_input_embeddings( + input_ids=self.input_ids[:num_scheduled_tokens], multimodal_embeddings=mm_embeds or None, ) # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) - inputs_embeds = self.inputs_embeds[:num_input_tokens] + self.inputs_embeds[:num_scheduled_tokens].copy_( + inputs_embeds_scheduled) + input_ids = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] + model_mm_kwargs = self._extract_mm_kwargs(scheduler_output) else: # For text-only models, we use token ids as input. # While it is possible to use embeddings as input just like the @@ -1537,7 +1532,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None - model_kwargs = {} + model_mm_kwargs = {} if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] else: @@ -1571,7 +1566,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, **MultiModalKwargs.as_kwargs( - model_kwargs, + model_mm_kwargs, device=self.device, ), ) @@ -2149,6 +2144,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): yield input_ids.fill_(0) + def _get_mm_dummy_batch( + self, + modality: str, + max_items_per_batch: int, + ) -> BatchedTensorInputs: + """Dummy data for profiling and precompiling multimodal models.""" + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, + seq_len=self.max_num_tokens, + mm_counts={modality: 1}, + ) + dummy_mm_data = dummy_decoder_data.multi_modal_data + + # Result in the maximum GPU consumption of the model + dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) + dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) + + batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * + max_items_per_batch) + return MultiModalKwargs.as_kwargs( + batched_dummy_mm_inputs, + device=self.device, + ) + @torch.inference_mode() def _dummy_run( self, @@ -2213,16 +2232,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): - model = self.model if self.is_multimodal_model: - model_kwargs = self._init_model_kwargs_for_multimodal_model( - num_reqs=num_reqs) input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] + model_mm_kwargs = self._dummy_mm_kwargs(num_reqs) else: input_ids = self.input_ids[:num_tokens] inputs_embeds = None - model_kwargs = {} + model_mm_kwargs = {} if self.uses_mrope: positions = self.mrope_positions[:, :num_tokens] @@ -2247,13 +2264,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.vllm_config, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp): - outputs = model( + outputs = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, **MultiModalKwargs.as_kwargs( - model_kwargs, + model_mm_kwargs, device=self.device, ), ) @@ -2423,75 +2440,51 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. - # TODO: handle encoder-decoder models once we support them. - if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 - and self.encoder_cache_size > 0): + if self.is_multimodal_model: + mm_budget = self.mm_budget + assert mm_budget is not None - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - max_tokens_by_modality_dict = self.mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(self.model_config) - dummy_data_modality, max_tokens_per_mm_item = max( - max_tokens_by_modality_dict.items(), key=lambda item: item[1]) + # TODO: handle encoder-decoder models once we support them. + if (encoder_budget := mm_budget.get_encoder_budget()) > 0: + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + ( + dummy_modality, + max_tokens, + ) = mm_budget.get_modality_with_max_tokens() + ( + max_mm_items_per_prompt, + max_mm_items_per_batch, + ) = mm_budget.get_max_items(dummy_modality, max_tokens) - # Check how many items of this modality can be supported by - # the encoder budget. - encoder_budget = min(self.max_num_encoder_input_tokens, - self.encoder_cache_size) + logger.info( + "Encoder cache will be initialized with a budget of " + "%s tokens, and profiled with %s %s items of the maximum " + "feature size.", + encoder_budget, + max_mm_items_per_batch, + dummy_modality, + ) - max_num_mm_items_encoder_budget = encoder_budget // \ - max_tokens_per_mm_item + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) - # Check how many items of this modality can be supported by - # the decoder budget. - max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt( - self.model_config)[dummy_data_modality] + # Run multimodal encoder. + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) - # NOTE: We do not consider max_num_batched_tokens on purpose - # because the multimodal embeddings can be generated in advance - # and chunked prefilled. - max_num_mm_items_decoder_budget = self.max_num_reqs * \ - max_mm_items_per_req + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) - max_num_mm_items = max( - 1, - min(max_num_mm_items_encoder_budget, - max_num_mm_items_decoder_budget)) - - logger.info( - "Encoder cache will be initialized with a budget of %s tokens," - " and profiled with %s %s items of the maximum feature size.", - encoder_budget, max_num_mm_items, dummy_data_modality) - - # Create dummy batch of multimodal inputs. - dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data( - model_config=self.model_config, - seq_len=max_tokens_per_mm_item, - mm_counts={ - dummy_data_modality: 1 - }, - ).multi_modal_data - - batched_dummy_mm_inputs = MultiModalKwargs.batch( - [dummy_mm_kwargs] * max_num_mm_items, - pin_memory=self.pin_memory) - batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( - batched_dummy_mm_inputs, - device=self.device, - ) - - # Run multimodal encoder. - dummy_encoder_outputs = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) - - sanity_check_mm_encoder_outputs( - dummy_encoder_outputs, - expected_num_items=max_num_mm_items, - ) - - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict( + enumerate(dummy_encoder_outputs)) # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states \ diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 67cb2f9dd810..5f3188efdb24 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -42,7 +42,6 @@ from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE, PallasAttentionBackend, PallasMetadata, get_page_size_bytes) -from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, SlidingWindowSpec) @@ -55,7 +54,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import ( from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch -from .utils import (bind_kv_cache, initialize_kv_cache_for_kv_sharing, +from .utils import (MultiModalBudget, bind_kv_cache, + initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs) if TYPE_CHECKING: @@ -195,14 +195,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # TODO: Support M-RoPE (e.g, Qwen2-VL) assert not self.uses_mrope, "TPU does not support M-RoPE yet." - encoder_compute_budget, encoder_cache_size = compute_encoder_budget( - model_config=model_config, - scheduler_config=scheduler_config, - mm_registry=self.mm_registry, - ) - self.max_num_encoder_input_tokens = encoder_compute_budget - self.encoder_cache_size = encoder_cache_size - self._num_slices_per_kv_cache_update_block = \ _get_num_slices_per_kv_cache_update_block(get_page_size_bytes( block_size=self.block_size, @@ -294,36 +286,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.structured_decode_arange = torch.arange( 0, 32, device="cpu", pin_memory=self.pin_memory) - # Get maximum number of mm items per modality (batch size). - self.max_num_mm_items_by_modality = dict() - if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 - and self.encoder_cache_size > 0): - max_tokens_by_modality_dict = ( - MULTIMODAL_REGISTRY. - get_max_tokens_per_item_by_nonzero_modality(self.model_config)) - for modality, max_tokens in max_tokens_by_modality_dict.items(): - # Check how many items of this modality can be supported by - # the encoder budget. - encoder_budget = min(self.max_num_encoder_input_tokens, - self.encoder_cache_size) - - max_num_mm_items_encoder_budget = cdiv(encoder_budget, - max_tokens) - - # Check how many items of this modality can be supported by - # the decoder budget. - max_mm_items_per_req = self.mm_registry.\ - get_mm_limits_per_prompt(self.model_config)[modality] - - # NOTE: We do not consider max_num_batched_tokens on purpose - # because the multimodal embeddings can be generated in advance - # and chunked prefilled. - max_num_mm_items_decoder_budget = self.max_num_reqs * \ - max_mm_items_per_req - - max_num_mm_items = min(max_num_mm_items_encoder_budget, - max_num_mm_items_decoder_budget) - self.max_num_mm_items_by_modality[modality] = max_num_mm_items + self.mm_budget = (MultiModalBudget( + self.model_config, + self.scheduler_config, + self.mm_registry, + max_model_len=self.max_model_len, + max_num_reqs=self.max_num_reqs, + ) if self.is_multimodal_model else None) if not self.use_spmd: self.sample_from_logits_func = torch.compile( @@ -1335,23 +1304,33 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): xm.mark_step() # Captures metadata updates def _precompile_mm_encoder(self) -> None: + if not self.is_multimodal_model: + return + # Pre-compile MM encoder for all supported data modalities. hf_config = self.vllm_config.model_config.hf_config - for mode, max_items_by_mode in \ - self.max_num_mm_items_by_modality.items(): + + mm_budget = self.mm_budget + assert mm_budget is not None + + max_items_per_seq_by_modality = mm_budget.max_items_per_batch_by_modality # noqa: E501 + + for mode, max_items_per_seq in max_items_per_seq_by_modality.items(): logger.info( "Compiling Multimodal %s Encoder with different input" " shapes.", mode) start = time.perf_counter() # No padding for MM encoder just yet. - for num_items in range(1, max_items_by_mode + 1): + for num_items in range(1, max_items_per_seq + 1): logger.info(" -- mode: %s items: %d", mode, num_items) batched_dummy_mm_inputs = self._get_mm_dummy_batch( - mode, num_items) + mode, + num_items, + ) # Run multimodal encoder. xm.mark_step() - mm_embeds = self.model.\ - get_multimodal_embeddings(**batched_dummy_mm_inputs) + mm_embeds = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) xm.mark_step() num_patches = mm_embeds[0].shape[0] items_size = num_patches * num_items @@ -1547,51 +1526,61 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_tokens: int, ) -> None: # Profile with multimodal encoder & encoder cache. - # TODO: handle encoder-decoder models once we support them. - if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 - and self.encoder_cache_size > 0): + if self.is_multimodal_model: + mm_budget = self.mm_budget + assert mm_budget is not None - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - dummy_data_modality, max_num_mm_items = max( - self.max_num_mm_items_by_modality.items(), key=lambda t: t[1]) + # TODO: handle encoder-decoder models once we support them. + if (encoder_budget := mm_budget.get_encoder_budget()) > 0: + # NOTE: Currently model is profiled with a single non-text + # modality with the max possible input tokens even when + # it supports multiple. + ( + dummy_modality, + max_tokens, + ) = mm_budget.get_modality_with_max_tokens() + ( + max_mm_items_per_prompt, + max_mm_items_per_batch, + ) = mm_budget.get_max_items(dummy_modality, max_tokens) - encoder_budget = min(self.max_num_encoder_input_tokens, - self.encoder_cache_size) + logger.info( + "Encoder cache will be initialized with a budget of " + "%s tokens, and profiled with %s %s items of the maximum " + "feature size.", + encoder_budget, + max_mm_items_per_batch, + dummy_modality, + ) - logger.info( - "Encoder cache will be initialized with a budget of %d tokens," - " and profiled with %s %s items of the maximum feature size.", - encoder_budget, max_num_mm_items, dummy_data_modality) + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) - # Create dummy batch of multimodal inputs. - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - dummy_data_modality, max_num_mm_items) + # Run multimodal encoder. + # Isolate encoder graph from post-processing to minimize + # impact of recompilation until it's fixed. + start = time.perf_counter() + xm.mark_step() + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs) + xm.mark_step() + xm.wait_device_ops() + end = time.perf_counter() + logger.info( + "Multimodal Encoder profiling finished in in %.2f [secs].", + end - start) - # Run multimodal encoder. - # Isolate encoder graph from post-processing to minimize - # impact of recompilation until it's fixed. - start = time.perf_counter() - xm.mark_step() - dummy_encoder_outputs = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) - xm.mark_step() - xm.wait_device_ops() - end = time.perf_counter() - logger.info( - "Multimodal Encoder profiling finished in in %.2f [secs].", - end - start) + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) - assert len(dummy_encoder_outputs) == max_num_mm_items, ( - "Expected dimension 0 of encoder outputs to match the number " - f"of multimodal data items: {max_num_mm_items}, got " - f"{len(dummy_encoder_outputs)=} instead. This is most likely " - "due to the 'get_multimodal_embeddings' method of the model " - "not implemented correctly.") - - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + # Cache the dummy encoder outputs. + self.encoder_cache["tmp"] = dict( + enumerate(dummy_encoder_outputs)) # Trigger compilation for general shape. self._dummy_run(num_tokens, self.num_reqs_max_model_len, @@ -1809,33 +1798,25 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \ self.structured_decode_arange.to(logits.device) - def _get_mm_dummy_batch(self, modality: str, - batch_size: int) -> BatchedTensorInputs: - # Dummy data for pre-compiling multimodal models. - dummy_request_data = self.mm_registry.get_decoder_dummy_data( + def _get_mm_dummy_batch( + self, + modality: str, + max_items_per_batch: int, + ) -> BatchedTensorInputs: + """Dummy data for profiling and precompiling multimodal models.""" + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, seq_len=self.max_num_tokens, + mm_counts={modality: 1}, ) - dummy_mm_data = dummy_request_data.multi_modal_data + dummy_mm_data = dummy_decoder_data.multi_modal_data - # Dummy data definition in V0 may contain multiple multimodal items - # (e.g, multiple images) for a single request, therefore here we - # always replicate first item by max_num_mm_items times since in V1 - # they are scheduled to be processed separately. - assert isinstance(dummy_mm_data, MultiModalKwargs), ( - "Expected dummy multimodal data to be of type " - f"MultiModalKwargs, got {type(dummy_mm_data)=} instead. " - "This is most likely due to the model not having a merged " - "processor.") - - # When models have a merged processor, their dummy data is - # already batched `MultiModalKwargs`, therefore we take the first - # `MultiModalKwargsItem` from the desired modality to profile on. + # Result in the maximum GPU consumption of the model dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * - batch_size) + max_items_per_batch) return MultiModalKwargs.as_kwargs( batched_dummy_mm_inputs, device=self.device, diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 3ecb1d7dd656..6761b3c5e41d 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -5,14 +5,123 @@ from typing import TYPE_CHECKING, Optional import torch +from vllm.config import ModelConfig, SchedulerConfig from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index +from vllm.multimodal.registry import MultiModalRegistry +from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec if TYPE_CHECKING: from vllm.attention.layer import Attention +class MultiModalBudget: + """Helper class to calculate budget information for multi-modal models.""" + + def __init__( + self, + model_config: ModelConfig, + scheduler_config: SchedulerConfig, + mm_registry: MultiModalRegistry, + *, + max_model_len: int, + max_num_reqs: int, + ) -> None: + super().__init__() + + self.model_config = model_config + self.scheduler_config = scheduler_config + self.mm_registry = mm_registry + + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( + model_config=model_config, + scheduler_config=scheduler_config, + mm_registry=mm_registry, + ) + + self.max_num_encoder_input_tokens = encoder_compute_budget + self.encoder_cache_size = encoder_cache_size + self.max_model_len = max_model_len + self.max_num_reqs = max_num_reqs + + self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config) + + max_items_per_prompt_by_modality = dict[str, int]() + max_items_per_batch_by_modality = dict[str, int]() + + max_tokens_by_modality = mm_registry \ + .get_max_tokens_per_item_by_nonzero_modality(model_config) + + for modality, max_tokens in max_tokens_by_modality.items(): + ( + max_items_per_prompt, + max_items_per_batch, + ) = self.get_max_items(modality, max_tokens) + + max_items_per_prompt_by_modality[modality] = max_items_per_prompt + max_items_per_batch_by_modality[modality] = max_items_per_batch + + self.max_tokens_by_modality = max_tokens_by_modality + self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality + self.max_items_per_batch_by_modality = max_items_per_batch_by_modality + + def get_modality_with_max_tokens(self) -> tuple[str, int]: + max_tokens_by_modality = self.max_tokens_by_modality + modality, max_tokens = max(max_tokens_by_modality.items(), + key=lambda item: item[1]) + + return modality, max_tokens + + def get_encoder_budget(self) -> int: + return min(self.max_num_encoder_input_tokens, self.encoder_cache_size) + + def get_max_items( + self, + modality: str, + max_tokens_per_item: int, + ) -> tuple[int, int]: + if max_tokens_per_item == 0: + return 0, 0 + + # Check how many items of this modality can be supported by + # the encoder budget. + encoder_budget = self.get_encoder_budget() + + # TODO: handle encoder-decoder models once we support them. + if encoder_budget == 0: + return 0, 0 + + max_encoder_items_per_batch = encoder_budget // max_tokens_per_item + + # Check how many items of this modality can be supported by + # the decoder budget. + mm_limit = self.mm_limits[modality] + + max_items_per_prompt = max( + 1, + min(mm_limit, self.max_model_len // max_tokens_per_item), + ) + + scheduler_config = self.scheduler_config + max_num_reqs = self.max_num_reqs + + if not scheduler_config.enable_chunked_prefill: + max_num_reqs = min( + max_num_reqs, + scheduler_config.max_num_batched_tokens // max_tokens_per_item, + ) + + max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt + + max_items_per_batch = max( + 1, + min(max_encoder_items_per_batch, max_decoder_items_per_batch), + ) + + return max_items_per_prompt, max_items_per_batch + + def sanity_check_mm_encoder_outputs( mm_embeddings: MultiModalEmbeddings, expected_num_items: int,