Add more test scenario for tensor schema (#22733)

Signed-off-by: teekenl <teekenlau@gmail.com>
This commit is contained in:
TeeKen Lau 2025-08-13 02:34:41 +10:00 committed by GitHub
parent 5a4b4b3729
commit c42fe0b63a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -33,6 +33,31 @@ def test_tensor_schema_constant_dim_failure():
) )
def test_tensor_schema_invalid_types_in_list():
with pytest.raises(ValueError, match="is not a torch.Tensor"):
Phi3VImagePixelInputs(
data=[
torch.randn(64, 3, 32, 32),
"not_a_tensor",
torch.randn(64, 3, 32, 32),
],
image_sizes=torch.randint(0, 256, (3, 2)),
)
def test_tensor_schema_rank_mismatch():
with pytest.raises(ValueError, match="has rank 3 but expected 5"):
Phi3VImagePixelInputs(
data=torch.randn(16, 64, 3),
image_sizes=torch.randint(0, 256, (16, 2)),
)
def test_tensor_schema_missing_required_field():
with pytest.raises(ValueError, match="Required field 'data' is missing"):
Phi3VImagePixelInputs(image_sizes=torch.randint(0, 256, (16, 2)), )
def test_tensor_schema_symbolic_dim_mismatch(): def test_tensor_schema_symbolic_dim_mismatch():
with pytest.raises(ValueError, match="expected 'bn'=12, got 16"): with pytest.raises(ValueError, match="expected 'bn'=12, got 16"):
Phi3VImagePixelInputs( Phi3VImagePixelInputs(