diff --git a/tests/kernels/test_rotary_embedding.py b/tests/kernels/test_rotary_embedding.py index 362bcb35ceabf..c497dd90edda8 100644 --- a/tests/kernels/test_rotary_embedding.py +++ b/tests/kernels/test_rotary_embedding.py @@ -41,7 +41,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, is_neox_style, rotary_dim, head_size, seq_len): batch_size = 1 - base = 0 + base = 10000 num_heads = 7 rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, torch.float32)