mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:15:01 +08:00
[Bugfix][Kernel] Support partial rotary embedding for MRoPE triton kernel (#22593)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
b81fe83b2c
commit
b76753f0b5
@ -42,12 +42,13 @@ def unroll_model_tp_dict(model_tp_dict):
|
|||||||
model_tp_dict = {
|
model_tp_dict = {
|
||||||
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
|
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
|
||||||
"Qwen/Qwen2-VL-72B-Instruct": [1, 2],
|
"Qwen/Qwen2-VL-72B-Instruct": [1, 2],
|
||||||
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2]
|
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2],
|
||||||
|
"zai-org/GLM-4.1V-9B-Thinking": [1, 2],
|
||||||
}
|
}
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
|
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
|
||||||
dtype_atol_rtol_list = [
|
dtype_atol_rtol_list = [
|
||||||
[torch.bfloat16, 1e-5, 1.6e-2],
|
[torch.bfloat16, 1e-2, 1.6e-2],
|
||||||
]
|
]
|
||||||
|
|
||||||
num_tokens_list = [11, 8192]
|
num_tokens_list = [11, 8192]
|
||||||
@ -73,10 +74,12 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
|
|||||||
|
|
||||||
rope_theta = config.rope_theta
|
rope_theta = config.rope_theta
|
||||||
max_position = config.max_position_embeddings
|
max_position = config.max_position_embeddings
|
||||||
|
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
|
||||||
|
rotary_dim = int(head_dim * partial_rotary_factor)
|
||||||
|
|
||||||
mrope_helper_class = get_rope(
|
mrope_helper_class = get_rope(
|
||||||
head_size=head_dim,
|
head_size=head_dim,
|
||||||
rotary_dim=head_dim,
|
rotary_dim=rotary_dim,
|
||||||
max_position=max_position,
|
max_position=max_position,
|
||||||
base=rope_theta,
|
base=rope_theta,
|
||||||
is_neox_style=is_neox_style,
|
is_neox_style=is_neox_style,
|
||||||
@ -110,7 +113,10 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
|
|||||||
reason="Skipping CUDA/ROCm only tests.")
|
reason="Skipping CUDA/ROCm only tests.")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name, tp_size",
|
"model_name, tp_size",
|
||||||
unroll_model_tp_dict({"Qwen/Qwen2-VL-7B-Instruct": [1, 2]}))
|
unroll_model_tp_dict({
|
||||||
|
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
|
||||||
|
"zai-org/GLM-4.1V-9B-Thinking": [1, 2]
|
||||||
|
}))
|
||||||
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
|
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
|
||||||
@pytest.mark.parametrize("num_tokens", [4])
|
@pytest.mark.parametrize("num_tokens", [4])
|
||||||
def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
|
def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
|
||||||
@ -126,10 +132,12 @@ def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
|
|||||||
is_neox_style = True
|
is_neox_style = True
|
||||||
rope_theta = config.rope_theta
|
rope_theta = config.rope_theta
|
||||||
max_position = config.max_position_embeddings
|
max_position = config.max_position_embeddings
|
||||||
|
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
|
||||||
|
rotary_dim = int(head_dim * partial_rotary_factor)
|
||||||
|
|
||||||
mrope_helper_class = get_rope(
|
mrope_helper_class = get_rope(
|
||||||
head_size=head_dim,
|
head_size=head_dim,
|
||||||
rotary_dim=head_dim,
|
rotary_dim=rotary_dim,
|
||||||
max_position=max_position,
|
max_position=max_position,
|
||||||
base=rope_theta,
|
base=rope_theta,
|
||||||
is_neox_style=is_neox_style,
|
is_neox_style=is_neox_style,
|
||||||
@ -25,6 +25,7 @@ def _triton_qwen2vl_mrope_forward(
|
|||||||
n_qh: tl.constexpr,
|
n_qh: tl.constexpr,
|
||||||
n_kh: tl.constexpr,
|
n_kh: tl.constexpr,
|
||||||
hd: tl.constexpr,
|
hd: tl.constexpr,
|
||||||
|
rd: tl.constexpr,
|
||||||
pad_n_qh: tl.constexpr,
|
pad_n_qh: tl.constexpr,
|
||||||
pad_n_kh: tl.constexpr,
|
pad_n_kh: tl.constexpr,
|
||||||
pad_hd: tl.constexpr,
|
pad_hd: tl.constexpr,
|
||||||
@ -51,19 +52,19 @@ def _triton_qwen2vl_mrope_forward(
|
|||||||
h_end = t_end + mrope_section_h
|
h_end = t_end + mrope_section_h
|
||||||
|
|
||||||
# Updated stride calculation for half head_dim
|
# Updated stride calculation for half head_dim
|
||||||
half_hd = hd // 2
|
half_rd = rd // 2
|
||||||
t_cos = cos + pid * half_hd
|
t_cos = cos + pid * half_rd
|
||||||
h_cos = t_cos + num_tokens * half_hd
|
h_cos = t_cos + num_tokens * half_rd
|
||||||
w_cos = h_cos + num_tokens * half_hd
|
w_cos = h_cos + num_tokens * half_rd
|
||||||
t_sin = sin + pid * half_hd
|
t_sin = sin + pid * half_rd
|
||||||
h_sin = t_sin + num_tokens * half_hd
|
h_sin = t_sin + num_tokens * half_rd
|
||||||
w_sin = h_sin + num_tokens * half_hd
|
w_sin = h_sin + num_tokens * half_rd
|
||||||
|
|
||||||
# Updated offsets for half head_dim
|
# Updated offsets for half head_dim
|
||||||
cos_offsets = tl.arange(0, pad_hd // 2)
|
cos_offsets = tl.arange(0, pad_hd // 2)
|
||||||
t_mask = cos_offsets < t_end
|
t_mask = cos_offsets < t_end
|
||||||
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
|
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
|
||||||
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_hd)
|
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
|
||||||
|
|
||||||
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
|
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
|
||||||
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
|
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
|
||||||
@ -85,9 +86,9 @@ def _triton_qwen2vl_mrope_forward(
|
|||||||
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(
|
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(
|
||||||
0, pad_hd // 2)[None, :]
|
0, pad_hd // 2)[None, :]
|
||||||
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(
|
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(
|
||||||
0, pad_hd // 2)[None, :] < hd // 2)
|
0, pad_hd // 2)[None, :] < rd // 2)
|
||||||
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(
|
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(
|
||||||
0, pad_hd // 2)[None, :] < hd // 2)
|
0, pad_hd // 2)[None, :] < rd // 2)
|
||||||
|
|
||||||
q_tile_1 = tl.load(q_ptr + first_half_q_offsets,
|
q_tile_1 = tl.load(q_ptr + first_half_q_offsets,
|
||||||
mask=first_q_mask,
|
mask=first_q_mask,
|
||||||
@ -97,8 +98,8 @@ def _triton_qwen2vl_mrope_forward(
|
|||||||
other=0).to(sin_row.dtype)
|
other=0).to(sin_row.dtype)
|
||||||
|
|
||||||
# right half of the head
|
# right half of the head
|
||||||
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
second_half_q_offsets = first_half_q_offsets + (rd // 2)
|
||||||
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
second_half_k_offsets = first_half_k_offsets + (rd // 2)
|
||||||
second_q_mask = first_q_mask
|
second_q_mask = first_q_mask
|
||||||
second_k_mask = first_k_mask
|
second_k_mask = first_k_mask
|
||||||
|
|
||||||
@ -130,6 +131,7 @@ def triton_mrope(
|
|||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
mrope_section: list[int],
|
mrope_section: list[int],
|
||||||
head_size: int,
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Qwen2VL mrope kernel.
|
"""Qwen2VL mrope kernel.
|
||||||
|
|
||||||
@ -166,6 +168,7 @@ def triton_mrope(
|
|||||||
n_q_head,
|
n_q_head,
|
||||||
n_kv_head,
|
n_kv_head,
|
||||||
head_size,
|
head_size,
|
||||||
|
rotary_dim,
|
||||||
pad_n_q_head,
|
pad_n_q_head,
|
||||||
pad_n_kv_head,
|
pad_n_kv_head,
|
||||||
pad_hd,
|
pad_hd,
|
||||||
@ -300,6 +303,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
sin,
|
sin,
|
||||||
self.mrope_section,
|
self.mrope_section,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
|
self.rotary_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
return q.reshape(query_shape), k.reshape(key_shape)
|
return q.reshape(query_shape), k.reshape(key_shape)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user