mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[PERF] Speedup of MRoPE prepare inputs (#19939)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai>
This commit is contained in:
parent
3014c920da
commit
9a3b88328f
@ -26,6 +26,7 @@
|
||||
import math
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
@ -1458,15 +1459,14 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_next_input_positions_tensor(
|
||||
mrope_position_delta: int,
|
||||
context_len: int,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.arange(
|
||||
mrope_position_delta + context_len,
|
||||
mrope_position_delta + seq_len,
|
||||
).expand(3, -1)
|
||||
def get_next_input_positions_tensor(out: np.ndarray, out_offset: int,
|
||||
mrope_position_delta: int,
|
||||
context_len: int, num_new_tokens: int):
|
||||
|
||||
values = np.arange(mrope_position_delta + context_len,
|
||||
mrope_position_delta + context_len + num_new_tokens,
|
||||
dtype=out.dtype)
|
||||
out[:, out_offset:out_offset + num_new_tokens] = values
|
||||
|
||||
@classmethod
|
||||
def omni_get_updates_use_audio_in_video(
|
||||
|
||||
@ -262,6 +262,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
|
||||
|
||||
# Only relevant for models using ALiBi (e.g, MPT)
|
||||
self.use_alibi = check_use_alibi(model_config)
|
||||
@ -889,15 +890,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
dst_start = mrope_pos_ptr
|
||||
dst_end = mrope_pos_ptr + completion_part_len
|
||||
|
||||
self.mrope_positions_cpu[:, dst_start:dst_end] = \
|
||||
MRotaryEmbedding.get_next_input_positions_tensor(
|
||||
req.mrope_position_delta,
|
||||
context_len=num_computed_tokens +
|
||||
prompt_part_len,
|
||||
seq_len=num_computed_tokens +
|
||||
prompt_part_len +
|
||||
completion_part_len,
|
||||
)
|
||||
MRotaryEmbedding.get_next_input_positions_tensor(
|
||||
out=self.mrope_positions_np,
|
||||
out_offset=dst_start,
|
||||
mrope_position_delta=req.mrope_position_delta,
|
||||
context_len=num_computed_tokens + prompt_part_len,
|
||||
num_new_tokens=completion_part_len,
|
||||
)
|
||||
|
||||
mrope_pos_ptr += completion_part_len
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user