Fix TensorSchema validation test for symbolic dims (#22366)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-08-10 10:16:44 -07:00 committed by GitHub
parent 8c50d62f5a
commit 68b254d673
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,8 +4,8 @@
import pytest
import torch
from vllm.model_executor.models.fuyu import FuyuImagePatchInputs
from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs
from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
@ -129,23 +129,27 @@ def test_tensor_schema_with_invalid_resolve_binding_dims():
def test_tensor_schema_with_list_of_symbolic_dim():
flat_data = torch.stack([torch.randn(768) for _ in range(3)]) # (bn=3, fn)
patches_per_image = [64, 64, 64] # len = bn = 3
input_features = torch.randn(3, 10, 160) # (b=3, fi=10, 160)
input_features_mask = torch.randn(3, 8) # (b=3, fo=8)
audio_embed_sizes = [8, 8, 8] # len = b = 3
FuyuImagePatchInputs(
flat_data=flat_data,
patches_per_image=patches_per_image,
GraniteSpeechAudioInputs(
input_features=input_features,
input_features_mask=input_features_mask,
audio_embed_sizes=audio_embed_sizes,
)
def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length():
flat_data = torch.stack([torch.randn(768) for _ in range(4)]) # (bn=4, fn)
patches_per_image = [64, 64, 64] # len = 3 ≠ bn
input_features = torch.randn(4, 10, 160) # (b=4, fi=10, 160)
input_features_mask = torch.randn(4, 8) # (b=4, fo=8)
audio_embed_sizes = [8, 8, 8] # len = 3 ≠ b
with pytest.raises(ValueError, match="expected 'bn'=4, got 3"):
FuyuImagePatchInputs(
flat_data=flat_data,
patches_per_image=patches_per_image,
with pytest.raises(ValueError, match="expected 'b'=4, got 3"):
GraniteSpeechAudioInputs(
input_features=input_features,
input_features_mask=input_features_mask,
audio_embed_sizes=audio_embed_sizes,
)