diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index 383a3c83b84a..f327deb0e549 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -152,6 +152,10 @@ def test_batched_rotary_embedding( query = torch.randn(query_shape, dtype=dtype) key = torch.randn_like(query) if use_key else None + # slice tensor if required, noop otherwise + query = query[..., :head_size] + key = key[..., :head_size] if use_key else None + # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. ref_query, ref_key = rope.forward_native(positions, query, key)