mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 17:25:29 +08:00
[Core] Factor out common logic for MM budget calculation (#22228)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
e79a12fc3a
commit
811ac13d03
@ -36,7 +36,8 @@ from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
|
|||||||
from vllm.model_executor.models.interfaces_base import (
|
from vllm.model_executor.models.interfaces_base import (
|
||||||
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
||||||
|
PlaceholderRange)
|
||||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
@ -51,7 +52,6 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
make_kv_sharing_fast_prefill_attention_metadata,
|
make_kv_sharing_fast_prefill_attention_metadata,
|
||||||
make_local_attention_virtual_batches,
|
make_local_attention_virtual_batches,
|
||||||
reorder_batch_to_split_decodes_and_prefills)
|
reorder_batch_to_split_decodes_and_prefills)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
|
||||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||||
ChunkedLocalAttentionSpec,
|
ChunkedLocalAttentionSpec,
|
||||||
FullAttentionSpec, KVCacheConfig,
|
FullAttentionSpec, KVCacheConfig,
|
||||||
@ -73,7 +73,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
|||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
|
|
||||||
from ..sample.logits_processor import LogitsProcessorManager
|
from ..sample.logits_processor import LogitsProcessorManager
|
||||||
from .utils import (bind_kv_cache, gather_mm_placeholders,
|
from .utils import (MultiModalBudget, bind_kv_cache, gather_mm_placeholders,
|
||||||
initialize_kv_cache_for_kv_sharing,
|
initialize_kv_cache_for_kv_sharing,
|
||||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||||
|
|
||||||
@ -148,14 +148,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.mm_registry = MULTIMODAL_REGISTRY
|
self.mm_registry = MULTIMODAL_REGISTRY
|
||||||
self.uses_mrope = model_config.uses_mrope
|
self.uses_mrope = model_config.uses_mrope
|
||||||
|
|
||||||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
|
||||||
model_config=model_config,
|
|
||||||
scheduler_config=scheduler_config,
|
|
||||||
mm_registry=self.mm_registry,
|
|
||||||
)
|
|
||||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
|
||||||
self.encoder_cache_size = encoder_cache_size
|
|
||||||
|
|
||||||
# Sampler
|
# Sampler
|
||||||
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
||||||
|
|
||||||
@ -330,6 +322,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
|
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
|
||||||
self.max_num_tokens, dtype=torch.int32, device=self.device)
|
self.max_num_tokens, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
|
self.mm_budget = (MultiModalBudget(
|
||||||
|
self.model_config,
|
||||||
|
self.scheduler_config,
|
||||||
|
self.mm_registry,
|
||||||
|
max_model_len=self.max_model_len,
|
||||||
|
max_num_reqs=self.max_num_reqs,
|
||||||
|
) if self.is_multimodal_model else None)
|
||||||
|
|
||||||
self.reorder_batch_threshold: Optional[int] = None
|
self.reorder_batch_threshold: Optional[int] = None
|
||||||
|
|
||||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
@ -578,37 +578,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Refresh batch metadata with any pending updates.
|
# Refresh batch metadata with any pending updates.
|
||||||
self.input_batch.refresh_metadata()
|
self.input_batch.refresh_metadata()
|
||||||
|
|
||||||
def _init_model_kwargs_for_multimodal_model(
|
def _extract_mm_kwargs(
|
||||||
self,
|
self,
|
||||||
scheduler_output: Optional["SchedulerOutput"] = None,
|
scheduler_output: "SchedulerOutput",
|
||||||
num_reqs: int = -1,
|
) -> BatchedTensorInputs:
|
||||||
) -> dict[str, Any]:
|
if self.is_multimodal_raw_input_supported: # noqa: SIM102
|
||||||
|
|
||||||
model_kwargs: dict[str, Any] = {}
|
|
||||||
if self.is_multimodal_raw_input_supported:
|
|
||||||
# This model requires the raw multimodal data in input.
|
|
||||||
if scheduler_output:
|
if scheduler_output:
|
||||||
multi_modal_kwargs_list = []
|
multi_modal_kwargs_list = list[MultiModalKwargs]()
|
||||||
for req in scheduler_output.scheduled_new_reqs:
|
for req in scheduler_output.scheduled_new_reqs:
|
||||||
req_mm_inputs = req.mm_inputs
|
req_mm_inputs = req.mm_inputs
|
||||||
if not isinstance(req_mm_inputs, list):
|
if not isinstance(req_mm_inputs, list):
|
||||||
req_mm_inputs = list(req_mm_inputs)
|
req_mm_inputs = list(req_mm_inputs)
|
||||||
multi_modal_kwargs_list.extend(req_mm_inputs)
|
multi_modal_kwargs_list.extend(req_mm_inputs)
|
||||||
multi_modal_kwargs = MultiModalKwargs.batch(
|
|
||||||
multi_modal_kwargs_list)
|
|
||||||
else:
|
|
||||||
# The only case where SchedulerOutput is None is for
|
|
||||||
# a dummy run let's get some dummy data.
|
|
||||||
dummy_data = [
|
|
||||||
self.mm_registry.get_decoder_dummy_data(
|
|
||||||
model_config=self.model_config,
|
|
||||||
seq_len=1).multi_modal_data for i in range(num_reqs)
|
|
||||||
]
|
|
||||||
multi_modal_kwargs = MultiModalKwargs.batch(dummy_data)
|
|
||||||
|
|
||||||
model_kwargs.update(multi_modal_kwargs)
|
return MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||||
|
|
||||||
return model_kwargs
|
return {}
|
||||||
|
|
||||||
|
def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
|
||||||
|
if self.is_multimodal_raw_input_supported:
|
||||||
|
mm_budget = self.mm_budget
|
||||||
|
assert mm_budget is not None
|
||||||
|
|
||||||
|
dummy_modality, _ = mm_budget.get_modality_with_max_tokens()
|
||||||
|
|
||||||
|
return self._get_mm_dummy_batch(dummy_modality, num_seqs)
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
def _get_cumsum_and_arange(
|
def _get_cumsum_and_arange(
|
||||||
self,
|
self,
|
||||||
@ -1517,19 +1513,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||||
# embeddings), we always use embeddings (rather than token ids)
|
# embeddings), we always use embeddings (rather than token ids)
|
||||||
# as input to the multimodal model, even when the input is text.
|
# as input to the multimodal model, even when the input is text.
|
||||||
input_ids = self.input_ids[:num_scheduled_tokens]
|
inputs_embeds_scheduled = self.model.get_input_embeddings(
|
||||||
|
input_ids=self.input_ids[:num_scheduled_tokens],
|
||||||
model_kwargs = self._init_model_kwargs_for_multimodal_model(
|
|
||||||
scheduler_output=scheduler_output)
|
|
||||||
inputs_embeds = self.model.get_input_embeddings(
|
|
||||||
input_ids=input_ids,
|
|
||||||
multimodal_embeddings=mm_embeds or None,
|
multimodal_embeddings=mm_embeds or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(woosuk): Avoid the copy. Optimize.
|
# TODO(woosuk): Avoid the copy. Optimize.
|
||||||
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
|
self.inputs_embeds[:num_scheduled_tokens].copy_(
|
||||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
inputs_embeds_scheduled)
|
||||||
|
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||||
|
model_mm_kwargs = self._extract_mm_kwargs(scheduler_output)
|
||||||
else:
|
else:
|
||||||
# For text-only models, we use token ids as input.
|
# For text-only models, we use token ids as input.
|
||||||
# While it is possible to use embeddings as input just like the
|
# While it is possible to use embeddings as input just like the
|
||||||
@ -1537,7 +1532,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# then the embedding layer is not included in the CUDA graph.
|
# then the embedding layer is not included in the CUDA graph.
|
||||||
input_ids = self.input_ids[:num_input_tokens]
|
input_ids = self.input_ids[:num_input_tokens]
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
model_kwargs = {}
|
model_mm_kwargs = {}
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
positions = self.mrope_positions[:, :num_input_tokens]
|
positions = self.mrope_positions[:, :num_input_tokens]
|
||||||
else:
|
else:
|
||||||
@ -1571,7 +1566,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
**MultiModalKwargs.as_kwargs(
|
**MultiModalKwargs.as_kwargs(
|
||||||
model_kwargs,
|
model_mm_kwargs,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -2149,6 +2144,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
yield
|
yield
|
||||||
input_ids.fill_(0)
|
input_ids.fill_(0)
|
||||||
|
|
||||||
|
def _get_mm_dummy_batch(
|
||||||
|
self,
|
||||||
|
modality: str,
|
||||||
|
max_items_per_batch: int,
|
||||||
|
) -> BatchedTensorInputs:
|
||||||
|
"""Dummy data for profiling and precompiling multimodal models."""
|
||||||
|
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
|
||||||
|
model_config=self.model_config,
|
||||||
|
seq_len=self.max_num_tokens,
|
||||||
|
mm_counts={modality: 1},
|
||||||
|
)
|
||||||
|
dummy_mm_data = dummy_decoder_data.multi_modal_data
|
||||||
|
|
||||||
|
# Result in the maximum GPU consumption of the model
|
||||||
|
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
|
||||||
|
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
|
||||||
|
|
||||||
|
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
|
||||||
|
max_items_per_batch)
|
||||||
|
return MultiModalKwargs.as_kwargs(
|
||||||
|
batched_dummy_mm_inputs,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _dummy_run(
|
def _dummy_run(
|
||||||
self,
|
self,
|
||||||
@ -2213,16 +2232,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||||
num_scheduled_tokens):
|
num_scheduled_tokens):
|
||||||
model = self.model
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
model_kwargs = self._init_model_kwargs_for_multimodal_model(
|
|
||||||
num_reqs=num_reqs)
|
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||||
|
model_mm_kwargs = self._dummy_mm_kwargs(num_reqs)
|
||||||
else:
|
else:
|
||||||
input_ids = self.input_ids[:num_tokens]
|
input_ids = self.input_ids[:num_tokens]
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
model_kwargs = {}
|
model_mm_kwargs = {}
|
||||||
|
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
positions = self.mrope_positions[:, :num_tokens]
|
positions = self.mrope_positions[:, :num_tokens]
|
||||||
@ -2247,13 +2264,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
num_tokens_across_dp=num_tokens_across_dp):
|
num_tokens_across_dp=num_tokens_across_dp):
|
||||||
outputs = model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
**MultiModalKwargs.as_kwargs(
|
**MultiModalKwargs.as_kwargs(
|
||||||
model_kwargs,
|
model_mm_kwargs,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -2423,75 +2440,51 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
# Profile with multimodal encoder & encoder cache.
|
# Profile with multimodal encoder & encoder cache.
|
||||||
# TODO: handle encoder-decoder models once we support them.
|
if self.is_multimodal_model:
|
||||||
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
|
mm_budget = self.mm_budget
|
||||||
and self.encoder_cache_size > 0):
|
assert mm_budget is not None
|
||||||
|
|
||||||
# NOTE: Currently model is profiled with a single non-text
|
# TODO: handle encoder-decoder models once we support them.
|
||||||
# modality with the max possible input tokens even when
|
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
|
||||||
# it supports multiple.
|
# NOTE: Currently model is profiled with a single non-text
|
||||||
max_tokens_by_modality_dict = self.mm_registry \
|
# modality with the max possible input tokens even when
|
||||||
.get_max_tokens_per_item_by_nonzero_modality(self.model_config)
|
# it supports multiple.
|
||||||
dummy_data_modality, max_tokens_per_mm_item = max(
|
(
|
||||||
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
|
dummy_modality,
|
||||||
|
max_tokens,
|
||||||
|
) = mm_budget.get_modality_with_max_tokens()
|
||||||
|
(
|
||||||
|
max_mm_items_per_prompt,
|
||||||
|
max_mm_items_per_batch,
|
||||||
|
) = mm_budget.get_max_items(dummy_modality, max_tokens)
|
||||||
|
|
||||||
# Check how many items of this modality can be supported by
|
logger.info(
|
||||||
# the encoder budget.
|
"Encoder cache will be initialized with a budget of "
|
||||||
encoder_budget = min(self.max_num_encoder_input_tokens,
|
"%s tokens, and profiled with %s %s items of the maximum "
|
||||||
self.encoder_cache_size)
|
"feature size.",
|
||||||
|
encoder_budget,
|
||||||
|
max_mm_items_per_batch,
|
||||||
|
dummy_modality,
|
||||||
|
)
|
||||||
|
|
||||||
max_num_mm_items_encoder_budget = encoder_budget // \
|
# Create dummy batch of multimodal inputs.
|
||||||
max_tokens_per_mm_item
|
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
|
||||||
|
dummy_modality,
|
||||||
|
max_mm_items_per_batch,
|
||||||
|
)
|
||||||
|
|
||||||
# Check how many items of this modality can be supported by
|
# Run multimodal encoder.
|
||||||
# the decoder budget.
|
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
||||||
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
|
**batched_dummy_mm_inputs)
|
||||||
self.model_config)[dummy_data_modality]
|
|
||||||
|
|
||||||
# NOTE: We do not consider max_num_batched_tokens on purpose
|
sanity_check_mm_encoder_outputs(
|
||||||
# because the multimodal embeddings can be generated in advance
|
dummy_encoder_outputs,
|
||||||
# and chunked prefilled.
|
expected_num_items=max_mm_items_per_batch,
|
||||||
max_num_mm_items_decoder_budget = self.max_num_reqs * \
|
)
|
||||||
max_mm_items_per_req
|
|
||||||
|
|
||||||
max_num_mm_items = max(
|
# Cache the dummy encoder outputs.
|
||||||
1,
|
self.encoder_cache["tmp"] = dict(
|
||||||
min(max_num_mm_items_encoder_budget,
|
enumerate(dummy_encoder_outputs))
|
||||||
max_num_mm_items_decoder_budget))
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Encoder cache will be initialized with a budget of %s tokens,"
|
|
||||||
" and profiled with %s %s items of the maximum feature size.",
|
|
||||||
encoder_budget, max_num_mm_items, dummy_data_modality)
|
|
||||||
|
|
||||||
# Create dummy batch of multimodal inputs.
|
|
||||||
dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
|
|
||||||
model_config=self.model_config,
|
|
||||||
seq_len=max_tokens_per_mm_item,
|
|
||||||
mm_counts={
|
|
||||||
dummy_data_modality: 1
|
|
||||||
},
|
|
||||||
).multi_modal_data
|
|
||||||
|
|
||||||
batched_dummy_mm_inputs = MultiModalKwargs.batch(
|
|
||||||
[dummy_mm_kwargs] * max_num_mm_items,
|
|
||||||
pin_memory=self.pin_memory)
|
|
||||||
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
|
||||||
batched_dummy_mm_inputs,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run multimodal encoder.
|
|
||||||
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
|
||||||
**batched_dummy_mm_inputs)
|
|
||||||
|
|
||||||
sanity_check_mm_encoder_outputs(
|
|
||||||
dummy_encoder_outputs,
|
|
||||||
expected_num_items=max_num_mm_items,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Cache the dummy encoder outputs.
|
|
||||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
|
||||||
|
|
||||||
# Add `is_profile` here to pre-allocate communication buffers
|
# Add `is_profile` here to pre-allocate communication buffers
|
||||||
hidden_states, last_hidden_states \
|
hidden_states, last_hidden_states \
|
||||||
|
|||||||
@ -42,7 +42,6 @@ from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE,
|
|||||||
PallasAttentionBackend,
|
PallasAttentionBackend,
|
||||||
PallasMetadata,
|
PallasMetadata,
|
||||||
get_page_size_bytes)
|
get_page_size_bytes)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
|
||||||
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
||||||
KVCacheConfig, KVCacheSpec,
|
KVCacheConfig, KVCacheSpec,
|
||||||
SlidingWindowSpec)
|
SlidingWindowSpec)
|
||||||
@ -55,7 +54,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
|||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
from .utils import (bind_kv_cache, initialize_kv_cache_for_kv_sharing,
|
from .utils import (MultiModalBudget, bind_kv_cache,
|
||||||
|
initialize_kv_cache_for_kv_sharing,
|
||||||
sanity_check_mm_encoder_outputs)
|
sanity_check_mm_encoder_outputs)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -195,14 +195,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# TODO: Support M-RoPE (e.g, Qwen2-VL)
|
# TODO: Support M-RoPE (e.g, Qwen2-VL)
|
||||||
assert not self.uses_mrope, "TPU does not support M-RoPE yet."
|
assert not self.uses_mrope, "TPU does not support M-RoPE yet."
|
||||||
|
|
||||||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
|
||||||
model_config=model_config,
|
|
||||||
scheduler_config=scheduler_config,
|
|
||||||
mm_registry=self.mm_registry,
|
|
||||||
)
|
|
||||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
|
||||||
self.encoder_cache_size = encoder_cache_size
|
|
||||||
|
|
||||||
self._num_slices_per_kv_cache_update_block = \
|
self._num_slices_per_kv_cache_update_block = \
|
||||||
_get_num_slices_per_kv_cache_update_block(get_page_size_bytes(
|
_get_num_slices_per_kv_cache_update_block(get_page_size_bytes(
|
||||||
block_size=self.block_size,
|
block_size=self.block_size,
|
||||||
@ -294,36 +286,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.structured_decode_arange = torch.arange(
|
self.structured_decode_arange = torch.arange(
|
||||||
0, 32, device="cpu", pin_memory=self.pin_memory)
|
0, 32, device="cpu", pin_memory=self.pin_memory)
|
||||||
|
|
||||||
# Get maximum number of mm items per modality (batch size).
|
self.mm_budget = (MultiModalBudget(
|
||||||
self.max_num_mm_items_by_modality = dict()
|
self.model_config,
|
||||||
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
|
self.scheduler_config,
|
||||||
and self.encoder_cache_size > 0):
|
self.mm_registry,
|
||||||
max_tokens_by_modality_dict = (
|
max_model_len=self.max_model_len,
|
||||||
MULTIMODAL_REGISTRY.
|
max_num_reqs=self.max_num_reqs,
|
||||||
get_max_tokens_per_item_by_nonzero_modality(self.model_config))
|
) if self.is_multimodal_model else None)
|
||||||
for modality, max_tokens in max_tokens_by_modality_dict.items():
|
|
||||||
# Check how many items of this modality can be supported by
|
|
||||||
# the encoder budget.
|
|
||||||
encoder_budget = min(self.max_num_encoder_input_tokens,
|
|
||||||
self.encoder_cache_size)
|
|
||||||
|
|
||||||
max_num_mm_items_encoder_budget = cdiv(encoder_budget,
|
|
||||||
max_tokens)
|
|
||||||
|
|
||||||
# Check how many items of this modality can be supported by
|
|
||||||
# the decoder budget.
|
|
||||||
max_mm_items_per_req = self.mm_registry.\
|
|
||||||
get_mm_limits_per_prompt(self.model_config)[modality]
|
|
||||||
|
|
||||||
# NOTE: We do not consider max_num_batched_tokens on purpose
|
|
||||||
# because the multimodal embeddings can be generated in advance
|
|
||||||
# and chunked prefilled.
|
|
||||||
max_num_mm_items_decoder_budget = self.max_num_reqs * \
|
|
||||||
max_mm_items_per_req
|
|
||||||
|
|
||||||
max_num_mm_items = min(max_num_mm_items_encoder_budget,
|
|
||||||
max_num_mm_items_decoder_budget)
|
|
||||||
self.max_num_mm_items_by_modality[modality] = max_num_mm_items
|
|
||||||
|
|
||||||
if not self.use_spmd:
|
if not self.use_spmd:
|
||||||
self.sample_from_logits_func = torch.compile(
|
self.sample_from_logits_func = torch.compile(
|
||||||
@ -1335,23 +1304,33 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
xm.mark_step() # Captures metadata updates
|
xm.mark_step() # Captures metadata updates
|
||||||
|
|
||||||
def _precompile_mm_encoder(self) -> None:
|
def _precompile_mm_encoder(self) -> None:
|
||||||
|
if not self.is_multimodal_model:
|
||||||
|
return
|
||||||
|
|
||||||
# Pre-compile MM encoder for all supported data modalities.
|
# Pre-compile MM encoder for all supported data modalities.
|
||||||
hf_config = self.vllm_config.model_config.hf_config
|
hf_config = self.vllm_config.model_config.hf_config
|
||||||
for mode, max_items_by_mode in \
|
|
||||||
self.max_num_mm_items_by_modality.items():
|
mm_budget = self.mm_budget
|
||||||
|
assert mm_budget is not None
|
||||||
|
|
||||||
|
max_items_per_seq_by_modality = mm_budget.max_items_per_batch_by_modality # noqa: E501
|
||||||
|
|
||||||
|
for mode, max_items_per_seq in max_items_per_seq_by_modality.items():
|
||||||
logger.info(
|
logger.info(
|
||||||
"Compiling Multimodal %s Encoder with different input"
|
"Compiling Multimodal %s Encoder with different input"
|
||||||
" shapes.", mode)
|
" shapes.", mode)
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
# No padding for MM encoder just yet.
|
# No padding for MM encoder just yet.
|
||||||
for num_items in range(1, max_items_by_mode + 1):
|
for num_items in range(1, max_items_per_seq + 1):
|
||||||
logger.info(" -- mode: %s items: %d", mode, num_items)
|
logger.info(" -- mode: %s items: %d", mode, num_items)
|
||||||
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
|
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
|
||||||
mode, num_items)
|
mode,
|
||||||
|
num_items,
|
||||||
|
)
|
||||||
# Run multimodal encoder.
|
# Run multimodal encoder.
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
mm_embeds = self.model.\
|
mm_embeds = self.model.get_multimodal_embeddings(
|
||||||
get_multimodal_embeddings(**batched_dummy_mm_inputs)
|
**batched_dummy_mm_inputs)
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
num_patches = mm_embeds[0].shape[0]
|
num_patches = mm_embeds[0].shape[0]
|
||||||
items_size = num_patches * num_items
|
items_size = num_patches * num_items
|
||||||
@ -1547,51 +1526,61 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Profile with multimodal encoder & encoder cache.
|
# Profile with multimodal encoder & encoder cache.
|
||||||
# TODO: handle encoder-decoder models once we support them.
|
if self.is_multimodal_model:
|
||||||
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
|
mm_budget = self.mm_budget
|
||||||
and self.encoder_cache_size > 0):
|
assert mm_budget is not None
|
||||||
|
|
||||||
# NOTE: Currently model is profiled with a single non-text
|
# TODO: handle encoder-decoder models once we support them.
|
||||||
# modality with the max possible input tokens even when
|
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
|
||||||
# it supports multiple.
|
# NOTE: Currently model is profiled with a single non-text
|
||||||
dummy_data_modality, max_num_mm_items = max(
|
# modality with the max possible input tokens even when
|
||||||
self.max_num_mm_items_by_modality.items(), key=lambda t: t[1])
|
# it supports multiple.
|
||||||
|
(
|
||||||
|
dummy_modality,
|
||||||
|
max_tokens,
|
||||||
|
) = mm_budget.get_modality_with_max_tokens()
|
||||||
|
(
|
||||||
|
max_mm_items_per_prompt,
|
||||||
|
max_mm_items_per_batch,
|
||||||
|
) = mm_budget.get_max_items(dummy_modality, max_tokens)
|
||||||
|
|
||||||
encoder_budget = min(self.max_num_encoder_input_tokens,
|
logger.info(
|
||||||
self.encoder_cache_size)
|
"Encoder cache will be initialized with a budget of "
|
||||||
|
"%s tokens, and profiled with %s %s items of the maximum "
|
||||||
|
"feature size.",
|
||||||
|
encoder_budget,
|
||||||
|
max_mm_items_per_batch,
|
||||||
|
dummy_modality,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
# Create dummy batch of multimodal inputs.
|
||||||
"Encoder cache will be initialized with a budget of %d tokens,"
|
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
|
||||||
" and profiled with %s %s items of the maximum feature size.",
|
dummy_modality,
|
||||||
encoder_budget, max_num_mm_items, dummy_data_modality)
|
max_mm_items_per_batch,
|
||||||
|
)
|
||||||
|
|
||||||
# Create dummy batch of multimodal inputs.
|
# Run multimodal encoder.
|
||||||
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
|
# Isolate encoder graph from post-processing to minimize
|
||||||
dummy_data_modality, max_num_mm_items)
|
# impact of recompilation until it's fixed.
|
||||||
|
start = time.perf_counter()
|
||||||
|
xm.mark_step()
|
||||||
|
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
||||||
|
**batched_dummy_mm_inputs)
|
||||||
|
xm.mark_step()
|
||||||
|
xm.wait_device_ops()
|
||||||
|
end = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
"Multimodal Encoder profiling finished in in %.2f [secs].",
|
||||||
|
end - start)
|
||||||
|
|
||||||
# Run multimodal encoder.
|
sanity_check_mm_encoder_outputs(
|
||||||
# Isolate encoder graph from post-processing to minimize
|
dummy_encoder_outputs,
|
||||||
# impact of recompilation until it's fixed.
|
expected_num_items=max_mm_items_per_batch,
|
||||||
start = time.perf_counter()
|
)
|
||||||
xm.mark_step()
|
|
||||||
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
|
||||||
**batched_dummy_mm_inputs)
|
|
||||||
xm.mark_step()
|
|
||||||
xm.wait_device_ops()
|
|
||||||
end = time.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
"Multimodal Encoder profiling finished in in %.2f [secs].",
|
|
||||||
end - start)
|
|
||||||
|
|
||||||
assert len(dummy_encoder_outputs) == max_num_mm_items, (
|
# Cache the dummy encoder outputs.
|
||||||
"Expected dimension 0 of encoder outputs to match the number "
|
self.encoder_cache["tmp"] = dict(
|
||||||
f"of multimodal data items: {max_num_mm_items}, got "
|
enumerate(dummy_encoder_outputs))
|
||||||
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
|
|
||||||
"due to the 'get_multimodal_embeddings' method of the model "
|
|
||||||
"not implemented correctly.")
|
|
||||||
|
|
||||||
# Cache the dummy encoder outputs.
|
|
||||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
|
||||||
|
|
||||||
# Trigger compilation for general shape.
|
# Trigger compilation for general shape.
|
||||||
self._dummy_run(num_tokens, self.num_reqs_max_model_len,
|
self._dummy_run(num_tokens, self.num_reqs_max_model_len,
|
||||||
@ -1809,33 +1798,25 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \
|
self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \
|
||||||
self.structured_decode_arange.to(logits.device)
|
self.structured_decode_arange.to(logits.device)
|
||||||
|
|
||||||
def _get_mm_dummy_batch(self, modality: str,
|
def _get_mm_dummy_batch(
|
||||||
batch_size: int) -> BatchedTensorInputs:
|
self,
|
||||||
# Dummy data for pre-compiling multimodal models.
|
modality: str,
|
||||||
dummy_request_data = self.mm_registry.get_decoder_dummy_data(
|
max_items_per_batch: int,
|
||||||
|
) -> BatchedTensorInputs:
|
||||||
|
"""Dummy data for profiling and precompiling multimodal models."""
|
||||||
|
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
seq_len=self.max_num_tokens,
|
seq_len=self.max_num_tokens,
|
||||||
|
mm_counts={modality: 1},
|
||||||
)
|
)
|
||||||
dummy_mm_data = dummy_request_data.multi_modal_data
|
dummy_mm_data = dummy_decoder_data.multi_modal_data
|
||||||
|
|
||||||
# Dummy data definition in V0 may contain multiple multimodal items
|
# Result in the maximum GPU consumption of the model
|
||||||
# (e.g, multiple images) for a single request, therefore here we
|
|
||||||
# always replicate first item by max_num_mm_items times since in V1
|
|
||||||
# they are scheduled to be processed separately.
|
|
||||||
assert isinstance(dummy_mm_data, MultiModalKwargs), (
|
|
||||||
"Expected dummy multimodal data to be of type "
|
|
||||||
f"MultiModalKwargs, got {type(dummy_mm_data)=} instead. "
|
|
||||||
"This is most likely due to the model not having a merged "
|
|
||||||
"processor.")
|
|
||||||
|
|
||||||
# When models have a merged processor, their dummy data is
|
|
||||||
# already batched `MultiModalKwargs`, therefore we take the first
|
|
||||||
# `MultiModalKwargsItem` from the desired modality to profile on.
|
|
||||||
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
|
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
|
||||||
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
|
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
|
||||||
|
|
||||||
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
|
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
|
||||||
batch_size)
|
max_items_per_batch)
|
||||||
return MultiModalKwargs.as_kwargs(
|
return MultiModalKwargs.as_kwargs(
|
||||||
batched_dummy_mm_inputs,
|
batched_dummy_mm_inputs,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
|||||||
@ -5,14 +5,123 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig, SchedulerConfig
|
||||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
|
from vllm.multimodal.registry import MultiModalRegistry
|
||||||
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
|
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalBudget:
|
||||||
|
"""Helper class to calculate budget information for multi-modal models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
mm_registry: MultiModalRegistry,
|
||||||
|
*,
|
||||||
|
max_model_len: int,
|
||||||
|
max_num_reqs: int,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model_config = model_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.mm_registry = mm_registry
|
||||||
|
|
||||||
|
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
||||||
|
model_config=model_config,
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
mm_registry=mm_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.max_num_encoder_input_tokens = encoder_compute_budget
|
||||||
|
self.encoder_cache_size = encoder_cache_size
|
||||||
|
self.max_model_len = max_model_len
|
||||||
|
self.max_num_reqs = max_num_reqs
|
||||||
|
|
||||||
|
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||||
|
|
||||||
|
max_items_per_prompt_by_modality = dict[str, int]()
|
||||||
|
max_items_per_batch_by_modality = dict[str, int]()
|
||||||
|
|
||||||
|
max_tokens_by_modality = mm_registry \
|
||||||
|
.get_max_tokens_per_item_by_nonzero_modality(model_config)
|
||||||
|
|
||||||
|
for modality, max_tokens in max_tokens_by_modality.items():
|
||||||
|
(
|
||||||
|
max_items_per_prompt,
|
||||||
|
max_items_per_batch,
|
||||||
|
) = self.get_max_items(modality, max_tokens)
|
||||||
|
|
||||||
|
max_items_per_prompt_by_modality[modality] = max_items_per_prompt
|
||||||
|
max_items_per_batch_by_modality[modality] = max_items_per_batch
|
||||||
|
|
||||||
|
self.max_tokens_by_modality = max_tokens_by_modality
|
||||||
|
self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
|
||||||
|
self.max_items_per_batch_by_modality = max_items_per_batch_by_modality
|
||||||
|
|
||||||
|
def get_modality_with_max_tokens(self) -> tuple[str, int]:
|
||||||
|
max_tokens_by_modality = self.max_tokens_by_modality
|
||||||
|
modality, max_tokens = max(max_tokens_by_modality.items(),
|
||||||
|
key=lambda item: item[1])
|
||||||
|
|
||||||
|
return modality, max_tokens
|
||||||
|
|
||||||
|
def get_encoder_budget(self) -> int:
|
||||||
|
return min(self.max_num_encoder_input_tokens, self.encoder_cache_size)
|
||||||
|
|
||||||
|
def get_max_items(
|
||||||
|
self,
|
||||||
|
modality: str,
|
||||||
|
max_tokens_per_item: int,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
if max_tokens_per_item == 0:
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
# Check how many items of this modality can be supported by
|
||||||
|
# the encoder budget.
|
||||||
|
encoder_budget = self.get_encoder_budget()
|
||||||
|
|
||||||
|
# TODO: handle encoder-decoder models once we support them.
|
||||||
|
if encoder_budget == 0:
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
|
||||||
|
|
||||||
|
# Check how many items of this modality can be supported by
|
||||||
|
# the decoder budget.
|
||||||
|
mm_limit = self.mm_limits[modality]
|
||||||
|
|
||||||
|
max_items_per_prompt = max(
|
||||||
|
1,
|
||||||
|
min(mm_limit, self.max_model_len // max_tokens_per_item),
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler_config = self.scheduler_config
|
||||||
|
max_num_reqs = self.max_num_reqs
|
||||||
|
|
||||||
|
if not scheduler_config.enable_chunked_prefill:
|
||||||
|
max_num_reqs = min(
|
||||||
|
max_num_reqs,
|
||||||
|
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
|
||||||
|
|
||||||
|
max_items_per_batch = max(
|
||||||
|
1,
|
||||||
|
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
|
||||||
|
)
|
||||||
|
|
||||||
|
return max_items_per_prompt, max_items_per_batch
|
||||||
|
|
||||||
|
|
||||||
def sanity_check_mm_encoder_outputs(
|
def sanity_check_mm_encoder_outputs(
|
||||||
mm_embeddings: MultiModalEmbeddings,
|
mm_embeddings: MultiModalEmbeddings,
|
||||||
expected_num_items: int,
|
expected_num_items: int,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user