[Bugfix][Kernel] Support partial rotary embedding for MRoPE triton kernel (#22593)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-08-11 00:00:36 +08:00 committed by GitHub
parent b81fe83b2c
commit b76753f0b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 18 deletions

View File

@ -42,12 +42,13 @@ def unroll_model_tp_dict(model_tp_dict):
model_tp_dict = {
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
"Qwen/Qwen2-VL-72B-Instruct": [1, 2],
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2]
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2],
"zai-org/GLM-4.1V-9B-Thinking": [1, 2],
}
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
dtype_atol_rtol_list = [
[torch.bfloat16, 1e-5, 1.6e-2],
[torch.bfloat16, 1e-2, 1.6e-2],
]
num_tokens_list = [11, 8192]
@ -73,10 +74,12 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
rope_theta = config.rope_theta
max_position = config.max_position_embeddings
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
rotary_dim = int(head_dim * partial_rotary_factor)
mrope_helper_class = get_rope(
head_size=head_dim,
rotary_dim=head_dim,
rotary_dim=rotary_dim,
max_position=max_position,
base=rope_theta,
is_neox_style=is_neox_style,
@ -110,7 +113,10 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize(
"model_name, tp_size",
unroll_model_tp_dict({"Qwen/Qwen2-VL-7B-Instruct": [1, 2]}))
unroll_model_tp_dict({
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
"zai-org/GLM-4.1V-9B-Thinking": [1, 2]
}))
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
@pytest.mark.parametrize("num_tokens", [4])
def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
@ -126,10 +132,12 @@ def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
is_neox_style = True
rope_theta = config.rope_theta
max_position = config.max_position_embeddings
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
rotary_dim = int(head_dim * partial_rotary_factor)
mrope_helper_class = get_rope(
head_size=head_dim,
rotary_dim=head_dim,
rotary_dim=rotary_dim,
max_position=max_position,
base=rope_theta,
is_neox_style=is_neox_style,
@ -145,7 +153,7 @@ def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
# Create a wrapper that makes the in-place function appear functional
def functional_forward_cuda(pos, q, k):
"""Wrapper that converts in-place operation to functional style
CUDA Graph does not support in-place operations.
This wrapper creates working copies of the
input tensors and modifies them.

View File

@ -25,6 +25,7 @@ def _triton_qwen2vl_mrope_forward(
n_qh: tl.constexpr,
n_kh: tl.constexpr,
hd: tl.constexpr,
rd: tl.constexpr,
pad_n_qh: tl.constexpr,
pad_n_kh: tl.constexpr,
pad_hd: tl.constexpr,
@ -51,19 +52,19 @@ def _triton_qwen2vl_mrope_forward(
h_end = t_end + mrope_section_h
# Updated stride calculation for half head_dim
half_hd = hd // 2
t_cos = cos + pid * half_hd
h_cos = t_cos + num_tokens * half_hd
w_cos = h_cos + num_tokens * half_hd
t_sin = sin + pid * half_hd
h_sin = t_sin + num_tokens * half_hd
w_sin = h_sin + num_tokens * half_hd
half_rd = rd // 2
t_cos = cos + pid * half_rd
h_cos = t_cos + num_tokens * half_rd
w_cos = h_cos + num_tokens * half_rd
t_sin = sin + pid * half_rd
h_sin = t_sin + num_tokens * half_rd
w_sin = h_sin + num_tokens * half_rd
# Updated offsets for half head_dim
cos_offsets = tl.arange(0, pad_hd // 2)
t_mask = cos_offsets < t_end
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_hd)
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
@ -85,9 +86,9 @@ def _triton_qwen2vl_mrope_forward(
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(
0, pad_hd // 2)[None, :]
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(
0, pad_hd // 2)[None, :] < hd // 2)
0, pad_hd // 2)[None, :] < rd // 2)
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(
0, pad_hd // 2)[None, :] < hd // 2)
0, pad_hd // 2)[None, :] < rd // 2)
q_tile_1 = tl.load(q_ptr + first_half_q_offsets,
mask=first_q_mask,
@ -97,8 +98,8 @@ def _triton_qwen2vl_mrope_forward(
other=0).to(sin_row.dtype)
# right half of the head
second_half_q_offsets = first_half_q_offsets + (hd // 2)
second_half_k_offsets = first_half_k_offsets + (hd // 2)
second_half_q_offsets = first_half_q_offsets + (rd // 2)
second_half_k_offsets = first_half_k_offsets + (rd // 2)
second_q_mask = first_q_mask
second_k_mask = first_k_mask
@ -130,6 +131,7 @@ def triton_mrope(
sin: torch.Tensor,
mrope_section: list[int],
head_size: int,
rotary_dim: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Qwen2VL mrope kernel.
@ -166,6 +168,7 @@ def triton_mrope(
n_q_head,
n_kv_head,
head_size,
rotary_dim,
pad_n_q_head,
pad_n_kv_head,
pad_hd,
@ -300,6 +303,7 @@ class MRotaryEmbedding(RotaryEmbedding):
sin,
self.mrope_section,
self.head_size,
self.rotary_dim,
)
return q.reshape(query_shape), k.reshape(key_shape)