From 9a3b88328f7e434cac35b90ee463de6689f9a833 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Date: Tue, 24 Jun 2025 10:01:26 +0400 Subject: [PATCH] [PERF] Speedup of MRoPE prepare inputs (#19939) Signed-off-by: Vadim Gimpelson --- vllm/model_executor/layers/rotary_embedding.py | 18 +++++++++--------- vllm/v1/worker/gpu_model_runner.py | 17 ++++++++--------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 9de2338968a1c..b7bb2affc4fab 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -26,6 +26,7 @@ import math from typing import Any, Optional, Union +import numpy as np import torch import torch.nn as nn from transformers import PretrainedConfig @@ -1458,15 +1459,14 @@ class MRotaryEmbedding(RotaryEmbedding): ] @staticmethod - def get_next_input_positions_tensor( - mrope_position_delta: int, - context_len: int, - seq_len: int, - ) -> torch.Tensor: - return torch.arange( - mrope_position_delta + context_len, - mrope_position_delta + seq_len, - ).expand(3, -1) + def get_next_input_positions_tensor(out: np.ndarray, out_offset: int, + mrope_position_delta: int, + context_len: int, num_new_tokens: int): + + values = np.arange(mrope_position_delta + context_len, + mrope_position_delta + context_len + num_new_tokens, + dtype=out.dtype) + out[:, out_offset:out_offset + num_new_tokens] = values @classmethod def omni_get_updates_use_audio_in_video( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 520d8fb186f4f..40639fdf24338 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -262,6 +262,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): dtype=torch.int64, device="cpu", pin_memory=self.pin_memory) + self.mrope_positions_np = self.mrope_positions_cpu.numpy() # Only relevant for models using ALiBi (e.g, MPT) self.use_alibi = check_use_alibi(model_config) @@ -889,15 +890,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): dst_start = mrope_pos_ptr dst_end = mrope_pos_ptr + completion_part_len - self.mrope_positions_cpu[:, dst_start:dst_end] = \ - MRotaryEmbedding.get_next_input_positions_tensor( - req.mrope_position_delta, - context_len=num_computed_tokens + - prompt_part_len, - seq_len=num_computed_tokens + - prompt_part_len + - completion_part_len, - ) + MRotaryEmbedding.get_next_input_positions_tensor( + out=self.mrope_positions_np, + out_offset=dst_start, + mrope_position_delta=req.mrope_position_delta, + context_len=num_computed_tokens + prompt_part_len, + num_new_tokens=completion_part_len, + ) mrope_pos_ptr += completion_part_len