mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-19 19:54:28 +08:00
[V1] Simplify M-RoPE (#12352)
Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: imkero <kerorek@outlook.com>
This commit is contained in:
parent
d07efb31c5
commit
99d01a5e3d
@ -144,28 +144,24 @@ class GPUModelRunner:
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.model_config.uses_mrope:
|
||||
# NOTE: `mrope_positions` is implemented as a permuted tensor to
|
||||
# satisfy the following properties to allow `torch.compile` to work
|
||||
# properly:
|
||||
# - shape: (3, <variable>)
|
||||
# - stride: (1, 3)
|
||||
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1921022256
|
||||
# NOTE: `mrope_positions` is implemented with one additional dummy
|
||||
# position on purpose to make it non-contiguous so that it can work
|
||||
# with torch compile.
|
||||
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
||||
|
||||
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
||||
# the modality of inputs. For text-only inputs, each dimension has
|
||||
# identical position IDs, making M-RoPE functionally equivalent to
|
||||
# 1D-RoPE.
|
||||
# See page 5 of https://arxiv.org/abs/2409.12191
|
||||
self.mrope_positions = torch.zeros((self.max_num_tokens, 3),
|
||||
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.mrope_positions_cpu = torch.zeros((self.max_num_tokens, 3),
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
self.mrope_positions = self.mrope_positions.permute((1, 0))
|
||||
self.mrope_positions_cpu = self.mrope_positions_cpu.permute((1, 0))
|
||||
self.mrope_positions_cpu = torch.zeros(
|
||||
(3, self.max_num_tokens + 1),
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user