From dc0f7ccf8bc1b2b74113fab8ea69d8420a91274a Mon Sep 17 00:00:00 2001 From: wchen61 Date: Sun, 16 Feb 2025 16:59:49 +0800 Subject: [PATCH] [BugFix] Enhance test_pos_encoding to support execution on multi-devices (#13187) Signed-off-by: wchen61 --- tests/kernels/test_pos_encoding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index af9bfd2f0f521..bff7f8e57fbf0 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -70,7 +70,7 @@ def test_rotary_embedding( if rotary_dim is None: rotary_dim = head_size rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) - rope = rope.to(dtype=dtype) + rope = rope.to(dtype=dtype, device=torch.get_default_device()) positions = torch.randint(0, max_position, (batch_size, seq_len)) query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) @@ -125,7 +125,7 @@ def test_batched_rotary_embedding( "rope_type": "linear", "factor": (1, ) }) - rope = rope.to(dtype=dtype) + rope = rope.to(dtype=dtype, device=torch.get_default_device()) positions = torch.randint(0, max_position, (batch_size, seq_len)) query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) @@ -184,7 +184,7 @@ def test_batched_rotary_embedding_multi_lora( "rope_type": "linear", "factor": tuple(scaling_factors) }) - rope = rope.to(dtype=dtype) + rope = rope.to(dtype=dtype, device=torch.get_default_device()) positions = torch.randint(0, max_position, (batch_size, seq_len)) query = torch.randn(batch_size,