diff --git a/benchmarks/kernels/benchmark_mrope.py b/benchmarks/kernels/benchmark_mrope.py new file mode 100644 index 0000000000000..b9147361708fd --- /dev/null +++ b/benchmarks/kernels/benchmark_mrope.py @@ -0,0 +1,328 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# This script benchmarks the mrope kernel (mainly for Qwen2VL and Qwen2.5VL models). +# It generates test data, runs benchmarks, and saves results to a CSV file. +# +# The CSV file (named with current date/time) contains these columns: +# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position, +# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99, +# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max, +# speedup +# +# == Usage Examples == +# +# Single model benchmark: +# python3 benchmark_mrope.py --model-name Qwen/Qwen2-VL-7B-Instruct --tp-size 1 \ +# --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 +# +# All models benchmark: +# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \ +# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 +# +# All models with different TP sizes: +# python3 benchmark_mrope.py --model-name "" --tp-size 1 2 4 8 --warmup-iter 10 \ +# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 +# +# All models with different token counts: +# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \ +# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 4096 16384 +import csv +import os +import time +from datetime import datetime +from typing import Any + +import numpy as np +import torch + +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.platforms import current_platform +from vllm.transformers_utils.config import get_config +from vllm.utils import FlexibleArgumentParser + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def generate_test_data( + num_tokens: int, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + max_position_embeddings: int, + dtype: torch.dtype, + device: torch.device, +): + """Generate test data for given configuration.""" + # Create 2D positions (3, num_tokens) for multimodal case + positions = torch.randint( + 0, max_position_embeddings // 4, (3, num_tokens), device=device + ) + + # Create query and key tensors + query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device) + key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device) + + return positions, query, key + + +def calculate_stats(times: list[float]) -> dict[str, float]: + """Calculate statistics from a list of times.""" + times_array = np.array(times) + return { + "mean": np.mean(times_array), + "median": np.median(times_array), + "p99": np.percentile(times_array, 99), + "min": np.min(times_array), + "max": np.max(times_array), + } + + +def benchmark_mrope( + model_name: str, + num_tokens: int, + head_dim: int, + tp_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 8192, + rope_theta: float = 10000, + is_neox_style: bool = True, + rope_scaling: dict[str, Any] = None, + dtype: torch.dtype = torch.bfloat16, + seed: int = 0, + warmup_iter: int = 10, + benchmark_iter: int = 100, + csv_writer=None, +): + current_platform.seed_everything(seed) + torch.set_default_device(device) + # the parameters to compute the q k v size based on tp_size + mrope_helper_class = get_rope( + head_size=head_dim, + rotary_dim=head_dim, + max_position=max_position, + base=rope_theta, + is_neox_style=is_neox_style, + rope_scaling=rope_scaling, + dtype=dtype, + ).to(device=device) + + print(80 * "=") + print( + f"Evaluating model: {model_name} " + f"with tp_size: {tp_size} " + f"and num_tokens: {num_tokens}, " + f"dtype: {dtype}" + ) + + # create q k v input tensors + # create rotary pos emb input tensors + positions, query, key = generate_test_data( + num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device + ) + + # Warm up + for _ in range(warmup_iter): + mrope_helper_class.forward_native( + positions, + query.clone(), + key.clone(), + ) + + mrope_helper_class.forward_cuda( + positions, + query.clone(), + key.clone(), + ) + + torch.cuda.synchronize() + + # Time reference implementation + torch_times = [] + for _ in range(benchmark_iter): + query_clone = query.clone() + key_clone = key.clone() + torch.cuda.synchronize() + start_time = time.time() + + mrope_helper_class.forward_native( + positions, + query_clone, + key_clone, + ) + + torch.cuda.synchronize() + torch_times.append(time.time() - start_time) + + # Time triton kernel implementation + triton_times = [] + for _ in range(benchmark_iter): + query_clone = query.clone() + key_clone = key.clone() + torch.cuda.synchronize() + start_time = time.time() + mrope_helper_class.forward_cuda( + positions, + query_clone, + key_clone, + ) + torch.cuda.synchronize() + triton_times.append(time.time() - start_time) + + # Calculate statistics + torch_stats = calculate_stats(torch_times) + triton_stats = calculate_stats(triton_times) + print(f"\nPerformance for config ({num_tokens}, {num_heads}, {num_kv_heads}):") + + print( + f"Torch implementation: " + f"mean={torch_stats['mean']:.8f}s, " + f"median={torch_stats['median']:.8f}s, " + f"p99={torch_stats['p99']:.8f}s" + ) + + print( + f"Triton implementation: " + f"mean={triton_stats['mean']:.8f}s, " + f"median={triton_stats['median']:.8f}s, " + f"p99={triton_stats['p99']:.8f}s" + ) + + print( + f"Triton Speedup over Torch: {torch_stats['mean'] / triton_stats['mean']:.8f}x" + ) + + # Write to CSV + if csv_writer: + row = [ + model_name, + tp_size, + num_tokens, + num_heads, + num_kv_heads, + head_dim, + max_position, + rope_theta, + is_neox_style, + str(rope_scaling), + str(dtype).split(".")[-1], + torch_stats["mean"], + torch_stats["median"], + torch_stats["p99"], + torch_stats["min"], + torch_stats["max"], + triton_stats["mean"], + triton_stats["median"], + triton_stats["p99"], + triton_stats["min"], + triton_stats["max"], + torch_stats["mean"] / triton_stats["mean"], # speedup + ] + csv_writer.writerow(row) + + return torch_stats, triton_stats + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark the rotary embedding kernels." + ) + parser.add_argument("--model-name", type=str, default="") + parser.add_argument("--tp-size", type=int, default=1) + parser.add_argument("--warmup-iter", type=int, default=10) + parser.add_argument("--benchmark-iter", type=int, default=100) + parser.add_argument("--dtype", type=str, choices=["bfloat16"], default="bfloat16") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--num-tokens", type=int, nargs="+", required=False) + parser.add_argument("--trust-remote-code", action="store_true") + parser.add_argument("--output-csv", type=str, default="mrope_benchmark_results.csv") + args = parser.parse_args() + print(args) + + # Create CSV file for results + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + csv_filename = f"{os.path.splitext(args.output_csv)[0]}_{timestamp}.csv" + + with open(csv_filename, "w", newline="") as csvfile: + csv_writer = csv.writer(csvfile) + # Write header + header = [ + "model_name", + "tp_size", + "num_tokens", + "num_heads", + "num_kv_heads", + "head_dim", + "max_position", + "rope_theta", + "is_neox_style", + "rope_scaling", + "dtype", + "torch_mean", + "torch_median", + "torch_p99", + "torch_min", + "torch_max", + "triton_mean", + "triton_median", + "triton_p99", + "triton_min", + "triton_max", + "speedup", + ] + csv_writer.writerow(header) + + model_tp_dict = {} + if args.model_name == "": + model_tp_dict = { + "Qwen/Qwen2-VL-2B-Instruct": [1], + "Qwen/Qwen2-VL-7B-Instruct": [1], + "Qwen/Qwen2-VL-72B-Instruct": [2, 4, 8], + "Qwen/Qwen2.5-VL-3B-Instruct": [1, 2, 4, 8], + "Qwen/Qwen2.5-VL-7B-Instruct": [1, 2, 4, 8], + "Qwen/Qwen2.5-VL-72B-Instruct": [2, 4, 8], + } + else: + model_tp_dict[args.model_name] = [args.tp_size] + + if args.num_tokens is None: + num_tokens_list = [2**i for i in range(0, 18)] + else: + num_tokens_list = args.num_tokens + + for model_name, tp_list in model_tp_dict.items(): + config = get_config(model_name, trust_remote_code=args.trust_remote_code) + for tp_size in tp_list: + # get the model config + total_num_kv_heads = config.num_key_value_heads + total_num_heads = config.num_attention_heads + num_heads = total_num_heads // tp_size + num_kv_heads = max(1, total_num_kv_heads // tp_size) + head_dim = config.hidden_size // total_num_heads + q_size = num_heads * head_dim + kv_size = num_kv_heads * head_dim + is_neox_style = True + rope_theta = config.rope_theta + max_position = config.max_position_embeddings + + for num_tokens in num_tokens_list: + benchmark_mrope( + model_name=model_name, + num_tokens=num_tokens, + head_dim=head_dim, + tp_size=tp_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_position=max_position, + rope_theta=rope_theta, + is_neox_style=is_neox_style, + rope_scaling=config.rope_scaling, + dtype=getattr(torch, args.dtype), + seed=args.seed, + warmup_iter=args.warmup_iter, + benchmark_iter=args.benchmark_iter, + csv_writer=csv_writer, + ) + + print(f"Benchmark results saved to {csv_filename}") diff --git a/tests/kernels/test_mrope.py b/tests/kernels/test_mrope.py new file mode 100644 index 0000000000000..5918b7a58b5c0 --- /dev/null +++ b/tests/kernels/test_mrope.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +from transformers import AutoConfig + +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.platforms import current_platform + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int, + head_size: int, max_position_embeddings: int, + dtype: torch.dtype, device: torch.device): + """Generate test data for given configuration.""" + # Create 2D positions (3, num_tokens) for multimodal case + positions = torch.randint(0, + max_position_embeddings // 4, (3, num_tokens), + device=device) + + # Create query and key tensors + query = torch.randn(num_tokens, + num_q_heads * head_size, + dtype=dtype, + device=device) + key = torch.randn(num_tokens, + num_kv_heads * head_size, + dtype=dtype, + device=device) + + return positions, query, key + + +def unroll_model_tp_dict(model_tp_dict): + return [(model_name, tp_size) + for model_name, tp_sizes in model_tp_dict.items() + for tp_size in tp_sizes] + + +model_tp_dict = { + "Qwen/Qwen2-VL-7B-Instruct": [1, 2], + "Qwen/Qwen2-VL-72B-Instruct": [1, 2], + "Qwen/Qwen2.5-VL-72B-Instruct": [1, 2] +} + +# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317 +dtype_atol_rtol_list = [ + [torch.bfloat16, 1e-5, 1.6e-2], +] + +num_tokens_list = [11, 8192] + + +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Skipping CUDA/ROCm only tests.") +@pytest.mark.parametrize("model_name, tp_size", + unroll_model_tp_dict(model_tp_dict)) +@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list) +@pytest.mark.parametrize("num_tokens", num_tokens_list) +def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): + + config = AutoConfig.from_pretrained(model_name) + + # get the model config + total_num_kv_heads = config.num_key_value_heads + total_num_heads = config.num_attention_heads + num_heads = total_num_heads // tp_size + num_kv_heads = max(1, total_num_kv_heads // tp_size) + head_dim = config.hidden_size // total_num_heads + is_neox_style = True + + rope_theta = config.rope_theta + max_position = config.max_position_embeddings + + mrope_helper_class = get_rope( + head_size=head_dim, + rotary_dim=head_dim, + max_position=max_position, + base=rope_theta, + is_neox_style=is_neox_style, + rope_scaling=config.rope_scaling, + dtype=dtype, + ).to(device=device) + + # create q k v input tensors + # create rotary pos emb input tensors + positions, query, key = generate_test_data(num_tokens, num_heads, + num_kv_heads, head_dim, + max_position, dtype, device) + + query_native, key_native = mrope_helper_class.forward_native( + positions, + query.clone(), + key.clone(), + ) + + query_cuda, key_cuda = mrope_helper_class.forward_cuda( + positions, + query.clone(), + key.clone(), + ) + + torch.testing.assert_close(query_native, query_cuda, atol=atol, rtol=rtol) + torch.testing.assert_close(key_native, key_cuda, atol=atol, rtol=rtol) + + +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Skipping CUDA/ROCm only tests.") +@pytest.mark.parametrize( + "model_name, tp_size", + unroll_model_tp_dict({"Qwen/Qwen2-VL-7B-Instruct": [1, 2]})) +@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list) +@pytest.mark.parametrize("num_tokens", [4]) +def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol, + num_tokens): + config = AutoConfig.from_pretrained(model_name) + + # get the model config + total_num_kv_heads = config.num_key_value_heads + total_num_heads = config.num_attention_heads + num_heads = total_num_heads // tp_size + num_kv_heads = max(1, total_num_kv_heads // tp_size) + head_dim = config.hidden_size // total_num_heads + is_neox_style = True + rope_theta = config.rope_theta + max_position = config.max_position_embeddings + + mrope_helper_class = get_rope( + head_size=head_dim, + rotary_dim=head_dim, + max_position=max_position, + base=rope_theta, + is_neox_style=is_neox_style, + rope_scaling=config.rope_scaling, + dtype=dtype, + ).to(device=device) + + # Generate test data + positions, query, key = generate_test_data(num_tokens, num_heads, + num_kv_heads, head_dim, + max_position, dtype, device) + + # Create a wrapper that makes the in-place function appear functional + def functional_forward_cuda(pos, q, k): + """Wrapper that converts in-place operation to functional style + + CUDA Graph does not support in-place operations. + This wrapper creates working copies of the + input tensors and modifies them. + """ + q_work = q.clone() # Create working copies + k_work = k.clone() + # Your in-place function modifies q_work and k_work + mrope_helper_class.forward_cuda(pos, q_work, k_work) + return q_work, k_work # Return the modified tensors + + # Get reference results + query_native, key_native = mrope_helper_class.forward_native( + positions, + query.clone(), + key.clone(), + ) + + try: + compiled_forward_cuda = torch.compile(functional_forward_cuda, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False) + + # Run compiled version + query_compiled_cuda, key_compiled_cuda = compiled_forward_cuda( + positions, + query, + key, + ) + + # Run original version for comparison + query_cuda = query.clone() + key_cuda = key.clone() + mrope_helper_class.forward_cuda(positions, query_cuda, key_cuda) + + # Verify results + torch.testing.assert_close(query_compiled_cuda, + query_cuda, + atol=atol, + rtol=rtol) + torch.testing.assert_close(key_compiled_cuda, + key_cuda, + atol=atol, + rtol=rtol) + torch.testing.assert_close(query_compiled_cuda, + query_native, + atol=atol, + rtol=rtol) + torch.testing.assert_close(key_compiled_cuda, + key_native, + atol=atol, + rtol=rtol) + + print("✓ forward_cuda successfully traced with torch.compile inductor") + + except Exception as e: + pytest.fail( + f"forward_cuda failed to trace with torch.compile inductor: {e}") diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index a75b9e5eb435c..d3b71930b6f17 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -8,10 +8,173 @@ import numpy as np import torch from transformers import PretrainedConfig +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton + from .base import RotaryEmbedding from .common import apply_rotary_emb_dispatch +@triton.jit +def _triton_qwen2vl_mrope_forward( + q_ptr, + k_ptr, + cos, + sin, + num_tokens, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + mrope_section_t: tl.constexpr, + mrope_section_h: tl.constexpr, +): + # Adapted from + # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py + # This version supports flatten input tensors from vllm + # and supports cos and sin cache with shape (3, num_tokens, head_dim // 2) + # instead of (3, bsz, seq_len, head_dim) + pid = tl.program_id(0) + # locate start address + q_ptr = q_ptr + pid * (n_qh * hd) + k_ptr = k_ptr + pid * (n_kh * hd) + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + # Note: cos and sin now have shape (3, num_tokens, head_dim // 2) + + t_end = mrope_section_t + h_end = t_end + mrope_section_h + + # Updated stride calculation for half head_dim + half_hd = hd // 2 + t_cos = cos + pid * half_hd + h_cos = t_cos + num_tokens * half_hd + w_cos = h_cos + num_tokens * half_hd + t_sin = sin + pid * half_hd + h_sin = t_sin + num_tokens * half_hd + w_sin = h_sin + num_tokens * half_hd + + # Updated offsets for half head_dim + cos_offsets = tl.arange(0, pad_hd // 2) + t_mask = cos_offsets < t_end + h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) + w_mask = (h_end <= cos_offsets) & (cos_offsets < half_hd) + + t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) + h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) + w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0) + t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0) + h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0) + w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0) + + cos_row = t_cos_row + h_cos_row + w_cos_row + sin_row = t_sin_row + h_sin_row + w_sin_row + + # #################################################################### + # Load the left and right half of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # left half of the head + first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange( + 0, pad_hd // 2)[None, :] + first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange( + 0, pad_hd // 2)[None, :] + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange( + 0, pad_hd // 2)[None, :] < hd // 2) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange( + 0, pad_hd // 2)[None, :] < hd // 2) + + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, + mask=first_q_mask, + other=0).to(sin_row.dtype) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, + mask=first_k_mask, + other=0).to(sin_row.dtype) + + # right half of the head + second_half_q_offsets = first_half_q_offsets + (hd // 2) + second_half_k_offsets = first_half_k_offsets + (hd // 2) + second_q_mask = first_q_mask + second_k_mask = first_k_mask + + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, + mask=second_q_mask, + other=0).to(sin_row.dtype) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, + mask=second_k_mask, + other=0).to(sin_row.dtype) + + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + # Since cos and sin are now half-size, + # we use the same cos_row and sin_row for both halves + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) + + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) + + +def triton_mrope( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mrope_section: list[int], + head_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Qwen2VL mrope kernel. + + Args: + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + cos: [3, num_tokens, head_size //2 ] + (T/H/W positions with multimodal inputs) + sin: [3, num_tokens, head_size //2 ] + (T/H/W positions with multimodal inputs) + mrope_section: [t, h, w] + head_size: int + """ + n_row, n_q_head_head_dim = q.shape + n_q_head = n_q_head_head_dim // head_size + n_kv_head = k.shape[1] // head_size + pad_hd = triton.next_power_of_2(head_size) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + + # ensure tensors passed into the kernel are contiguous. + # It will be no-op if they are already contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + + _triton_qwen2vl_mrope_forward[(n_row, )]( + q, + k, + cos, + sin, + n_row, + n_q_head, + n_kv_head, + head_size, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + mrope_section[0], + mrope_section[1], + ) + return q, k + + class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" @@ -36,11 +199,34 @@ class MRotaryEmbedding(RotaryEmbedding): if self.mrope_section: assert sum(self.mrope_section) == rotary_dim // 2 + self.use_triton = current_platform.is_cuda_alike() + def forward( self, positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """MRope forward. + + Args: + positions: + [num_tokens,] (text only) or + [3, num_tokens] (T/H/W positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + if self.use_triton: + return self.forward_cuda(positions, query, key) + else: + return self.forward_native(positions, query, key) + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward(). @@ -88,6 +274,51 @@ class MRotaryEmbedding(RotaryEmbedding): key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + + assert positions.ndim == 1 or positions.ndim == 2 + assert key is not None + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + query_shape = query.shape + key_shape = key.shape + if positions.ndim == 2: + assert self.mrope_section + + q, k = triton_mrope( + query, + key, + cos, + sin, + self.mrope_section, + self.head_size, + ) + + return q.reshape(query_shape), k.reshape(key_shape) + + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, + self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + @classmethod def get_input_positions( cls,