[PERF] Speedup of MRoPE prepare inputs (#19939)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai>
This commit is contained in:
Vadim Gimpelson 2025-06-24 10:01:26 +04:00 committed by GitHub
parent 3014c920da
commit 9a3b88328f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 18 deletions

View File

@ -26,6 +26,7 @@
import math import math
from typing import Any, Optional, Union from typing import Any, Optional, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
@ -1458,15 +1459,14 @@ class MRotaryEmbedding(RotaryEmbedding):
] ]
@staticmethod @staticmethod
def get_next_input_positions_tensor( def get_next_input_positions_tensor(out: np.ndarray, out_offset: int,
mrope_position_delta: int, mrope_position_delta: int,
context_len: int, context_len: int, num_new_tokens: int):
seq_len: int,
) -> torch.Tensor: values = np.arange(mrope_position_delta + context_len,
return torch.arange( mrope_position_delta + context_len + num_new_tokens,
mrope_position_delta + context_len, dtype=out.dtype)
mrope_position_delta + seq_len, out[:, out_offset:out_offset + num_new_tokens] = values
).expand(3, -1)
@classmethod @classmethod
def omni_get_updates_use_audio_in_video( def omni_get_updates_use_audio_in_video(

View File

@ -262,6 +262,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int64, dtype=torch.int64,
device="cpu", device="cpu",
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
# Only relevant for models using ALiBi (e.g, MPT) # Only relevant for models using ALiBi (e.g, MPT)
self.use_alibi = check_use_alibi(model_config) self.use_alibi = check_use_alibi(model_config)
@ -889,15 +890,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dst_start = mrope_pos_ptr dst_start = mrope_pos_ptr
dst_end = mrope_pos_ptr + completion_part_len dst_end = mrope_pos_ptr + completion_part_len
self.mrope_positions_cpu[:, dst_start:dst_end] = \ MRotaryEmbedding.get_next_input_positions_tensor(
MRotaryEmbedding.get_next_input_positions_tensor( out=self.mrope_positions_np,
req.mrope_position_delta, out_offset=dst_start,
context_len=num_computed_tokens + mrope_position_delta=req.mrope_position_delta,
prompt_part_len, context_len=num_computed_tokens + prompt_part_len,
seq_len=num_computed_tokens + num_new_tokens=completion_part_len,
prompt_part_len + )
completion_part_len,
)
mrope_pos_ptr += completion_part_len mrope_pos_ptr += completion_part_len