From 68b254d67300a1740db900a3d0ff4252424715d7 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sun, 10 Aug 2025 10:16:44 -0700 Subject: [PATCH] Fix TensorSchema validation test for symbolic dims (#22366) Signed-off-by: Benji Beck --- tests/standalone_tests/test_tensor_schema.py | 28 +++++++++++--------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/standalone_tests/test_tensor_schema.py b/tests/standalone_tests/test_tensor_schema.py index e98aa3f53fb5..69744921b16a 100644 --- a/tests/standalone_tests/test_tensor_schema.py +++ b/tests/standalone_tests/test_tensor_schema.py @@ -4,8 +4,8 @@ import pytest import torch -from vllm.model_executor.models.fuyu import FuyuImagePatchInputs from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs +from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs @@ -129,23 +129,27 @@ def test_tensor_schema_with_invalid_resolve_binding_dims(): def test_tensor_schema_with_list_of_symbolic_dim(): - flat_data = torch.stack([torch.randn(768) for _ in range(3)]) # (bn=3, fn) - patches_per_image = [64, 64, 64] # len = bn = 3 + input_features = torch.randn(3, 10, 160) # (b=3, fi=10, 160) + input_features_mask = torch.randn(3, 8) # (b=3, fo=8) + audio_embed_sizes = [8, 8, 8] # len = b = 3 - FuyuImagePatchInputs( - flat_data=flat_data, - patches_per_image=patches_per_image, + GraniteSpeechAudioInputs( + input_features=input_features, + input_features_mask=input_features_mask, + audio_embed_sizes=audio_embed_sizes, ) def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length(): - flat_data = torch.stack([torch.randn(768) for _ in range(4)]) # (bn=4, fn) - patches_per_image = [64, 64, 64] # len = 3 ≠ bn + input_features = torch.randn(4, 10, 160) # (b=4, fi=10, 160) + input_features_mask = torch.randn(4, 8) # (b=4, fo=8) + audio_embed_sizes = [8, 8, 8] # len = 3 ≠ b - with pytest.raises(ValueError, match="expected 'bn'=4, got 3"): - FuyuImagePatchInputs( - flat_data=flat_data, - patches_per_image=patches_per_image, + with pytest.raises(ValueError, match="expected 'b'=4, got 3"): + GraniteSpeechAudioInputs( + input_features=input_features, + input_features_mask=input_features_mask, + audio_embed_sizes=audio_embed_sizes, )