mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:06:03 +08:00
[BugFix] Enhance test_pos_encoding to support execution on multi-devices (#13187)
Signed-off-by: wchen61 <wchen61@foxmail.com>
This commit is contained in:
parent
d3d547e057
commit
dc0f7ccf8b
@ -70,7 +70,7 @@ def test_rotary_embedding(
|
|||||||
if rotary_dim is None:
|
if rotary_dim is None:
|
||||||
rotary_dim = head_size
|
rotary_dim = head_size
|
||||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
|
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
|
||||||
rope = rope.to(dtype=dtype)
|
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
||||||
|
|
||||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||||
@ -125,7 +125,7 @@ def test_batched_rotary_embedding(
|
|||||||
"rope_type": "linear",
|
"rope_type": "linear",
|
||||||
"factor": (1, )
|
"factor": (1, )
|
||||||
})
|
})
|
||||||
rope = rope.to(dtype=dtype)
|
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
||||||
|
|
||||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||||
@ -184,7 +184,7 @@ def test_batched_rotary_embedding_multi_lora(
|
|||||||
"rope_type": "linear",
|
"rope_type": "linear",
|
||||||
"factor": tuple(scaling_factors)
|
"factor": tuple(scaling_factors)
|
||||||
})
|
})
|
||||||
rope = rope.to(dtype=dtype)
|
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
||||||
|
|
||||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||||
query = torch.randn(batch_size,
|
query = torch.randn(batch_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user