mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 11:15:01 +08:00
[PERF] Speedup of MRoPE prepare inputs (#19939)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai>
This commit is contained in:
parent
3014c920da
commit
9a3b88328f
@ -26,6 +26,7 @@
|
|||||||
import math
|
import math
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@ -1458,15 +1459,14 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_next_input_positions_tensor(
|
def get_next_input_positions_tensor(out: np.ndarray, out_offset: int,
|
||||||
mrope_position_delta: int,
|
mrope_position_delta: int,
|
||||||
context_len: int,
|
context_len: int, num_new_tokens: int):
|
||||||
seq_len: int,
|
|
||||||
) -> torch.Tensor:
|
values = np.arange(mrope_position_delta + context_len,
|
||||||
return torch.arange(
|
mrope_position_delta + context_len + num_new_tokens,
|
||||||
mrope_position_delta + context_len,
|
dtype=out.dtype)
|
||||||
mrope_position_delta + seq_len,
|
out[:, out_offset:out_offset + num_new_tokens] = values
|
||||||
).expand(3, -1)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def omni_get_updates_use_audio_in_video(
|
def omni_get_updates_use_audio_in_video(
|
||||||
|
|||||||
@ -262,6 +262,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=self.pin_memory)
|
pin_memory=self.pin_memory)
|
||||||
|
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
|
||||||
|
|
||||||
# Only relevant for models using ALiBi (e.g, MPT)
|
# Only relevant for models using ALiBi (e.g, MPT)
|
||||||
self.use_alibi = check_use_alibi(model_config)
|
self.use_alibi = check_use_alibi(model_config)
|
||||||
@ -889,15 +890,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dst_start = mrope_pos_ptr
|
dst_start = mrope_pos_ptr
|
||||||
dst_end = mrope_pos_ptr + completion_part_len
|
dst_end = mrope_pos_ptr + completion_part_len
|
||||||
|
|
||||||
self.mrope_positions_cpu[:, dst_start:dst_end] = \
|
MRotaryEmbedding.get_next_input_positions_tensor(
|
||||||
MRotaryEmbedding.get_next_input_positions_tensor(
|
out=self.mrope_positions_np,
|
||||||
req.mrope_position_delta,
|
out_offset=dst_start,
|
||||||
context_len=num_computed_tokens +
|
mrope_position_delta=req.mrope_position_delta,
|
||||||
prompt_part_len,
|
context_len=num_computed_tokens + prompt_part_len,
|
||||||
seq_len=num_computed_tokens +
|
num_new_tokens=completion_part_len,
|
||||||
prompt_part_len +
|
)
|
||||||
completion_part_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
mrope_pos_ptr += completion_part_len
|
mrope_pos_ptr += completion_part_len
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user