From 4e1c6a02641e427a6140d33262f1467906817781 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 15 May 2025 21:32:45 -0400 Subject: [PATCH] [Bugfix] fix rotary embedding test for _get_padded_tensor_shape (#18229) Signed-off-by: Lucas Wilkinson --- tests/kernels/core/test_pos_encoding.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index 383a3c83b84ac..f327deb0e549e 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -152,6 +152,10 @@ def test_batched_rotary_embedding( query = torch.randn(query_shape, dtype=dtype) key = torch.randn_like(query) if use_key else None + # slice tensor if required, noop otherwise + query = query[..., :head_size] + key = key[..., :head_size] if use_key else None + # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. ref_query, ref_key = rope.forward_native(positions, query, key)