diff --git a/tests/kernels/test_mrope.py b/tests/kernels/core/test_mrope.py similarity index 92% rename from tests/kernels/test_mrope.py rename to tests/kernels/core/test_mrope.py index 5918b7a58b5c0..3f2f330f6dc3b 100644 --- a/tests/kernels/test_mrope.py +++ b/tests/kernels/core/test_mrope.py @@ -42,12 +42,13 @@ def unroll_model_tp_dict(model_tp_dict): model_tp_dict = { "Qwen/Qwen2-VL-7B-Instruct": [1, 2], "Qwen/Qwen2-VL-72B-Instruct": [1, 2], - "Qwen/Qwen2.5-VL-72B-Instruct": [1, 2] + "Qwen/Qwen2.5-VL-72B-Instruct": [1, 2], + "zai-org/GLM-4.1V-9B-Thinking": [1, 2], } # https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317 dtype_atol_rtol_list = [ - [torch.bfloat16, 1e-5, 1.6e-2], + [torch.bfloat16, 1e-2, 1.6e-2], ] num_tokens_list = [11, 8192] @@ -73,10 +74,12 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): rope_theta = config.rope_theta max_position = config.max_position_embeddings + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + rotary_dim = int(head_dim * partial_rotary_factor) mrope_helper_class = get_rope( head_size=head_dim, - rotary_dim=head_dim, + rotary_dim=rotary_dim, max_position=max_position, base=rope_theta, is_neox_style=is_neox_style, @@ -110,7 +113,10 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): reason="Skipping CUDA/ROCm only tests.") @pytest.mark.parametrize( "model_name, tp_size", - unroll_model_tp_dict({"Qwen/Qwen2-VL-7B-Instruct": [1, 2]})) + unroll_model_tp_dict({ + "Qwen/Qwen2-VL-7B-Instruct": [1, 2], + "zai-org/GLM-4.1V-9B-Thinking": [1, 2] + })) @pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list) @pytest.mark.parametrize("num_tokens", [4]) def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol, @@ -126,10 +132,12 @@ def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol, is_neox_style = True rope_theta = config.rope_theta max_position = config.max_position_embeddings + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + rotary_dim = int(head_dim * partial_rotary_factor) mrope_helper_class = get_rope( head_size=head_dim, - rotary_dim=head_dim, + rotary_dim=rotary_dim, max_position=max_position, base=rope_theta, is_neox_style=is_neox_style, @@ -145,7 +153,7 @@ def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol, # Create a wrapper that makes the in-place function appear functional def functional_forward_cuda(pos, q, k): """Wrapper that converts in-place operation to functional style - + CUDA Graph does not support in-place operations. This wrapper creates working copies of the input tensors and modifies them. diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index d3b71930b6f17..a091cfb743291 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -25,6 +25,7 @@ def _triton_qwen2vl_mrope_forward( n_qh: tl.constexpr, n_kh: tl.constexpr, hd: tl.constexpr, + rd: tl.constexpr, pad_n_qh: tl.constexpr, pad_n_kh: tl.constexpr, pad_hd: tl.constexpr, @@ -51,19 +52,19 @@ def _triton_qwen2vl_mrope_forward( h_end = t_end + mrope_section_h # Updated stride calculation for half head_dim - half_hd = hd // 2 - t_cos = cos + pid * half_hd - h_cos = t_cos + num_tokens * half_hd - w_cos = h_cos + num_tokens * half_hd - t_sin = sin + pid * half_hd - h_sin = t_sin + num_tokens * half_hd - w_sin = h_sin + num_tokens * half_hd + half_rd = rd // 2 + t_cos = cos + pid * half_rd + h_cos = t_cos + num_tokens * half_rd + w_cos = h_cos + num_tokens * half_rd + t_sin = sin + pid * half_rd + h_sin = t_sin + num_tokens * half_rd + w_sin = h_sin + num_tokens * half_rd # Updated offsets for half head_dim cos_offsets = tl.arange(0, pad_hd // 2) t_mask = cos_offsets < t_end h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) - w_mask = (h_end <= cos_offsets) & (cos_offsets < half_hd) + w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd) t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) @@ -85,9 +86,9 @@ def _triton_qwen2vl_mrope_forward( first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange( 0, pad_hd // 2)[None, :] first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange( - 0, pad_hd // 2)[None, :] < hd // 2) + 0, pad_hd // 2)[None, :] < rd // 2) first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange( - 0, pad_hd // 2)[None, :] < hd // 2) + 0, pad_hd // 2)[None, :] < rd // 2) q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, @@ -97,8 +98,8 @@ def _triton_qwen2vl_mrope_forward( other=0).to(sin_row.dtype) # right half of the head - second_half_q_offsets = first_half_q_offsets + (hd // 2) - second_half_k_offsets = first_half_k_offsets + (hd // 2) + second_half_q_offsets = first_half_q_offsets + (rd // 2) + second_half_k_offsets = first_half_k_offsets + (rd // 2) second_q_mask = first_q_mask second_k_mask = first_k_mask @@ -130,6 +131,7 @@ def triton_mrope( sin: torch.Tensor, mrope_section: list[int], head_size: int, + rotary_dim: int, ) -> tuple[torch.Tensor, torch.Tensor]: """Qwen2VL mrope kernel. @@ -166,6 +168,7 @@ def triton_mrope( n_q_head, n_kv_head, head_size, + rotary_dim, pad_n_q_head, pad_n_kv_head, pad_hd, @@ -300,6 +303,7 @@ class MRotaryEmbedding(RotaryEmbedding): sin, self.mrope_section, self.head_size, + self.rotary_dim, ) return q.reshape(query_shape), k.reshape(key_shape)