mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:55:51 +08:00
[Bugfix][Kernel] Support partial rotary embedding for MRoPE triton kernel (#22593)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
b81fe83b2c
commit
b76753f0b5
@ -42,12 +42,13 @@ def unroll_model_tp_dict(model_tp_dict):
|
||||
model_tp_dict = {
|
||||
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
|
||||
"Qwen/Qwen2-VL-72B-Instruct": [1, 2],
|
||||
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2]
|
||||
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2],
|
||||
"zai-org/GLM-4.1V-9B-Thinking": [1, 2],
|
||||
}
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
|
||||
dtype_atol_rtol_list = [
|
||||
[torch.bfloat16, 1e-5, 1.6e-2],
|
||||
[torch.bfloat16, 1e-2, 1.6e-2],
|
||||
]
|
||||
|
||||
num_tokens_list = [11, 8192]
|
||||
@ -73,10 +74,12 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
|
||||
|
||||
rope_theta = config.rope_theta
|
||||
max_position = config.max_position_embeddings
|
||||
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
|
||||
rotary_dim = int(head_dim * partial_rotary_factor)
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position,
|
||||
base=rope_theta,
|
||||
is_neox_style=is_neox_style,
|
||||
@ -110,7 +113,10 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
|
||||
reason="Skipping CUDA/ROCm only tests.")
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, tp_size",
|
||||
unroll_model_tp_dict({"Qwen/Qwen2-VL-7B-Instruct": [1, 2]}))
|
||||
unroll_model_tp_dict({
|
||||
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
|
||||
"zai-org/GLM-4.1V-9B-Thinking": [1, 2]
|
||||
}))
|
||||
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
|
||||
@pytest.mark.parametrize("num_tokens", [4])
|
||||
def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
|
||||
@ -126,10 +132,12 @@ def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
|
||||
is_neox_style = True
|
||||
rope_theta = config.rope_theta
|
||||
max_position = config.max_position_embeddings
|
||||
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
|
||||
rotary_dim = int(head_dim * partial_rotary_factor)
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position,
|
||||
base=rope_theta,
|
||||
is_neox_style=is_neox_style,
|
||||
@ -145,7 +153,7 @@ def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
|
||||
# Create a wrapper that makes the in-place function appear functional
|
||||
def functional_forward_cuda(pos, q, k):
|
||||
"""Wrapper that converts in-place operation to functional style
|
||||
|
||||
|
||||
CUDA Graph does not support in-place operations.
|
||||
This wrapper creates working copies of the
|
||||
input tensors and modifies them.
|
||||
@ -25,6 +25,7 @@ def _triton_qwen2vl_mrope_forward(
|
||||
n_qh: tl.constexpr,
|
||||
n_kh: tl.constexpr,
|
||||
hd: tl.constexpr,
|
||||
rd: tl.constexpr,
|
||||
pad_n_qh: tl.constexpr,
|
||||
pad_n_kh: tl.constexpr,
|
||||
pad_hd: tl.constexpr,
|
||||
@ -51,19 +52,19 @@ def _triton_qwen2vl_mrope_forward(
|
||||
h_end = t_end + mrope_section_h
|
||||
|
||||
# Updated stride calculation for half head_dim
|
||||
half_hd = hd // 2
|
||||
t_cos = cos + pid * half_hd
|
||||
h_cos = t_cos + num_tokens * half_hd
|
||||
w_cos = h_cos + num_tokens * half_hd
|
||||
t_sin = sin + pid * half_hd
|
||||
h_sin = t_sin + num_tokens * half_hd
|
||||
w_sin = h_sin + num_tokens * half_hd
|
||||
half_rd = rd // 2
|
||||
t_cos = cos + pid * half_rd
|
||||
h_cos = t_cos + num_tokens * half_rd
|
||||
w_cos = h_cos + num_tokens * half_rd
|
||||
t_sin = sin + pid * half_rd
|
||||
h_sin = t_sin + num_tokens * half_rd
|
||||
w_sin = h_sin + num_tokens * half_rd
|
||||
|
||||
# Updated offsets for half head_dim
|
||||
cos_offsets = tl.arange(0, pad_hd // 2)
|
||||
t_mask = cos_offsets < t_end
|
||||
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
|
||||
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_hd)
|
||||
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
|
||||
|
||||
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
|
||||
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
|
||||
@ -85,9 +86,9 @@ def _triton_qwen2vl_mrope_forward(
|
||||
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(
|
||||
0, pad_hd // 2)[None, :]
|
||||
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(
|
||||
0, pad_hd // 2)[None, :] < hd // 2)
|
||||
0, pad_hd // 2)[None, :] < rd // 2)
|
||||
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(
|
||||
0, pad_hd // 2)[None, :] < hd // 2)
|
||||
0, pad_hd // 2)[None, :] < rd // 2)
|
||||
|
||||
q_tile_1 = tl.load(q_ptr + first_half_q_offsets,
|
||||
mask=first_q_mask,
|
||||
@ -97,8 +98,8 @@ def _triton_qwen2vl_mrope_forward(
|
||||
other=0).to(sin_row.dtype)
|
||||
|
||||
# right half of the head
|
||||
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
||||
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
||||
second_half_q_offsets = first_half_q_offsets + (rd // 2)
|
||||
second_half_k_offsets = first_half_k_offsets + (rd // 2)
|
||||
second_q_mask = first_q_mask
|
||||
second_k_mask = first_k_mask
|
||||
|
||||
@ -130,6 +131,7 @@ def triton_mrope(
|
||||
sin: torch.Tensor,
|
||||
mrope_section: list[int],
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Qwen2VL mrope kernel.
|
||||
|
||||
@ -166,6 +168,7 @@ def triton_mrope(
|
||||
n_q_head,
|
||||
n_kv_head,
|
||||
head_size,
|
||||
rotary_dim,
|
||||
pad_n_q_head,
|
||||
pad_n_kv_head,
|
||||
pad_hd,
|
||||
@ -300,6 +303,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
sin,
|
||||
self.mrope_section,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
)
|
||||
|
||||
return q.reshape(query_shape), k.reshape(key_shape)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user