[Bugfix] Fix overallocation in MM profiling (#29386)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang 2025-11-25 04:38:36 -08:00 committed by GitHub
parent 798e87db5c
commit c2c661af9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4245,14 +4245,18 @@ class GPUModelRunner(
# NOTE: This happens when encoder cache needs to store # NOTE: This happens when encoder cache needs to store
# the embeddings that encoder outputs are scattered onto. # the embeddings that encoder outputs are scattered onto.
# In this case we create dummy embeddings of size # In this case we create dummy embeddings of size
# (encode_budget, hidden_size) and scatter encoder # (max_tokens_for_modality, hidden_size) and scatter
# output into it. # encoder output into it.
encoder_output_shape = dummy_encoder_outputs[0].shape encoder_output_shape = dummy_encoder_outputs[0].shape
if encoder_output_shape[0] < encoder_budget: max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[
dummy_modality
]
if encoder_output_shape[0] < max_mm_tokens_per_item:
encoder_hidden_size = encoder_output_shape[-1]
expanded_outputs = [] expanded_outputs = []
for output in dummy_encoder_outputs: for output in dummy_encoder_outputs:
expanded = output.new_zeros( expanded = output.new_zeros(
(encoder_budget, encoder_output_shape[-1]) (max_mm_tokens_per_item, encoder_hidden_size)
) )
num_tokens = output.shape[0] num_tokens = output.shape[0]
expanded[:num_tokens].copy_(output) expanded[:num_tokens].copy_(output)