mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 16:46:13 +08:00
[Bugfix] fix rotary embedding test for _get_padded_tensor_shape (#18229)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
c7852a6d9b
commit
4e1c6a0264
@ -152,6 +152,10 @@ def test_batched_rotary_embedding(
|
|||||||
query = torch.randn(query_shape, dtype=dtype)
|
query = torch.randn(query_shape, dtype=dtype)
|
||||||
key = torch.randn_like(query) if use_key else None
|
key = torch.randn_like(query) if use_key else None
|
||||||
|
|
||||||
|
# slice tensor if required, noop otherwise
|
||||||
|
query = query[..., :head_size]
|
||||||
|
key = key[..., :head_size] if use_key else None
|
||||||
|
|
||||||
# NOTE(woosuk): The reference implementation should be executed first
|
# NOTE(woosuk): The reference implementation should be executed first
|
||||||
# because the custom kernel is in-place.
|
# because the custom kernel is in-place.
|
||||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user