[PERF] Speedup of MRoPE prepare inputs (#19939)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai>
This commit is contained in:
Vadim Gimpelson 2025-06-24 10:01:26 +04:00 committed by GitHub
parent 3014c920da
commit 9a3b88328f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 18 deletions

View File

@ -26,6 +26,7 @@
import math
from typing import Any, Optional, Union
import numpy as np
import torch
import torch.nn as nn
from transformers import PretrainedConfig
@ -1458,15 +1459,14 @@ class MRotaryEmbedding(RotaryEmbedding):
]
@staticmethod
def get_next_input_positions_tensor(
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> torch.Tensor:
return torch.arange(
mrope_position_delta + context_len,
mrope_position_delta + seq_len,
).expand(3, -1)
def get_next_input_positions_tensor(out: np.ndarray, out_offset: int,
mrope_position_delta: int,
context_len: int, num_new_tokens: int):
values = np.arange(mrope_position_delta + context_len,
mrope_position_delta + context_len + num_new_tokens,
dtype=out.dtype)
out[:, out_offset:out_offset + num_new_tokens] = values
@classmethod
def omni_get_updates_use_audio_in_video(

View File

@ -262,6 +262,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
# Only relevant for models using ALiBi (e.g, MPT)
self.use_alibi = check_use_alibi(model_config)
@ -889,15 +890,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dst_start = mrope_pos_ptr
dst_end = mrope_pos_ptr + completion_part_len
self.mrope_positions_cpu[:, dst_start:dst_end] = \
MRotaryEmbedding.get_next_input_positions_tensor(
req.mrope_position_delta,
context_len=num_computed_tokens +
prompt_part_len,
seq_len=num_computed_tokens +
prompt_part_len +
completion_part_len,
)
MRotaryEmbedding.get_next_input_positions_tensor(
out=self.mrope_positions_np,
out_offset=dst_start,
mrope_position_delta=req.mrope_position_delta,
context_len=num_computed_tokens + prompt_part_len,
num_new_tokens=completion_part_len,
)
mrope_pos_ptr += completion_part_len