[MM] Optimize memory profiling for scattered multimodal embeddings (#25810)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang 2025-09-27 19:17:58 -07:00 committed by GitHub
parent da63274d9f
commit 69311446ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3429,6 +3429,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
expected_num_items=max_mm_items_per_batch,
)
# NOTE: This happens when encoder cache needs to store
# the embeddings that encoder outputs are scattered onto.
# In this case we create dummy embeddings of size
# (encode_budget, hidden_size) and scatter encoder
# output into it.
encoder_output_shape = dummy_encoder_outputs[0].shape
if encoder_output_shape[0] < encoder_budget:
expanded_outputs = []
for output in dummy_encoder_outputs:
expanded = output.new_zeros(
(encoder_budget, encoder_output_shape[-1]))
num_tokens = output.shape[0]
expanded[:num_tokens].copy_(output)
expanded_outputs.append(expanded)
dummy_encoder_outputs = expanded_outputs
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(
enumerate(dummy_encoder_outputs))