[Bugfix] Fix overallocation in MM profiling (#29386)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang 2025-11-25 04:38:36 -08:00 committed by GitHub
parent 798e87db5c
commit c2c661af9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4245,14 +4245,18 @@ class GPUModelRunner(
# NOTE: This happens when encoder cache needs to store
# the embeddings that encoder outputs are scattered onto.
# In this case we create dummy embeddings of size
# (encode_budget, hidden_size) and scatter encoder
# output into it.
# (max_tokens_for_modality, hidden_size) and scatter
# encoder output into it.
encoder_output_shape = dummy_encoder_outputs[0].shape
if encoder_output_shape[0] < encoder_budget:
max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[
dummy_modality
]
if encoder_output_shape[0] < max_mm_tokens_per_item:
encoder_hidden_size = encoder_output_shape[-1]
expanded_outputs = []
for output in dummy_encoder_outputs:
expanded = output.new_zeros(
(encoder_budget, encoder_output_shape[-1])
(max_mm_tokens_per_item, encoder_hidden_size)
)
num_tokens = output.shape[0]
expanded[:num_tokens].copy_(output)