mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 23:05:01 +08:00
[Misc] Update type annotation for rotary embedding base (#18914)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
d54af615d5
commit
1aa2f81b43
@ -22,7 +22,7 @@ def benchmark_rope_kernels_multi_lora(
|
|||||||
seed: int,
|
seed: int,
|
||||||
device: str,
|
device: str,
|
||||||
max_position: int = 8192,
|
max_position: int = 8192,
|
||||||
base: int = 10000,
|
base: float = 10000,
|
||||||
) -> None:
|
) -> None:
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
|
|||||||
@ -70,7 +70,7 @@ def test_rotary_embedding(
|
|||||||
device: str,
|
device: str,
|
||||||
use_key: bool,
|
use_key: bool,
|
||||||
max_position: int = 8192,
|
max_position: int = 8192,
|
||||||
base: int = 10000,
|
base: float = 10000,
|
||||||
) -> None:
|
) -> None:
|
||||||
if rotary_dim is None:
|
if rotary_dim is None:
|
||||||
rotary_dim = head_size
|
rotary_dim = head_size
|
||||||
@ -135,7 +135,7 @@ def test_batched_rotary_embedding(
|
|||||||
device: str,
|
device: str,
|
||||||
use_key: bool,
|
use_key: bool,
|
||||||
max_position: int = 8192,
|
max_position: int = 8192,
|
||||||
base: int = 10000,
|
base: float = 10000,
|
||||||
) -> None:
|
) -> None:
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
@ -203,7 +203,7 @@ def test_batched_rotary_embedding_multi_lora(
|
|||||||
device: str,
|
device: str,
|
||||||
use_key: bool,
|
use_key: bool,
|
||||||
max_position: int = 8192,
|
max_position: int = 8192,
|
||||||
base: int = 10000,
|
base: float = 10000,
|
||||||
) -> None:
|
) -> None:
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
|
|||||||
@ -96,7 +96,7 @@ class RotaryEmbedding(CustomOp):
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
base: int,
|
base: float,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -113,7 +113,7 @@ class RotaryEmbedding(CustomOp):
|
|||||||
self.cos_sin_cache: torch.Tensor
|
self.cos_sin_cache: torch.Tensor
|
||||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||||
"""Compute the inverse frequency."""
|
"""Compute the inverse frequency."""
|
||||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||||
# use CPU to compute the cache and then move it to GPU. However, we
|
# use CPU to compute the cache and then move it to GPU. However, we
|
||||||
@ -404,7 +404,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
base: int,
|
base: float,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
scaling_factors: Union[list[float], float],
|
scaling_factors: Union[list[float], float],
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
@ -464,7 +464,7 @@ class NTKScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
base: int,
|
base: float,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
scaling_factor: float,
|
scaling_factor: float,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
@ -474,7 +474,7 @@ class NTKScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
is_neox_style, dtype)
|
is_neox_style, dtype)
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||||
base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
|
base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
|
||||||
inv_freq = super()._compute_inv_freq(base)
|
inv_freq = super()._compute_inv_freq(base)
|
||||||
|
|
||||||
@ -501,7 +501,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
base: int,
|
base: float,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
scaling_factor: float,
|
scaling_factor: float,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
@ -582,7 +582,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
base: int,
|
base: float,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
scaling_factor: float,
|
scaling_factor: float,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
@ -644,7 +644,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
|||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
original_max_position_embeddings: int,
|
original_max_position_embeddings: int,
|
||||||
base: int,
|
base: float,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
short_factor: list[float],
|
short_factor: list[float],
|
||||||
@ -769,7 +769,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
base: int,
|
base: float,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
scaling_factor: float,
|
scaling_factor: float,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
@ -877,7 +877,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
base: int,
|
base: float,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
scaling_factor: float,
|
scaling_factor: float,
|
||||||
@ -892,7 +892,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
|
|||||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
is_neox_style, dtype)
|
is_neox_style, dtype)
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||||
inv_freqs = super()._compute_inv_freq(base)
|
inv_freqs = super()._compute_inv_freq(base)
|
||||||
low_freq_wavelen = self.orig_max_position / self.low_freq_factor
|
low_freq_wavelen = self.orig_max_position / self.low_freq_factor
|
||||||
high_freq_wavelen = self.orig_max_position / self.high_freq_factor
|
high_freq_wavelen = self.orig_max_position / self.high_freq_factor
|
||||||
@ -923,14 +923,14 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
base: int,
|
base: float,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
is_neox_style, dtype)
|
is_neox_style, dtype)
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||||
inv_freqs = super()._compute_inv_freq(base)
|
inv_freqs = super()._compute_inv_freq(base)
|
||||||
inv_freqs = inv_freqs[:(self.rotary_dim // 2)]
|
inv_freqs = inv_freqs[:(self.rotary_dim // 2)]
|
||||||
return inv_freqs
|
return inv_freqs
|
||||||
@ -989,7 +989,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
base: int,
|
base: float,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
mrope_section: Optional[list[int]] = None,
|
mrope_section: Optional[list[int]] = None,
|
||||||
@ -1529,7 +1529,7 @@ class DualChunkRotaryEmbedding(CustomOp):
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
base: int,
|
base: float,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
@ -1558,7 +1558,7 @@ class DualChunkRotaryEmbedding(CustomOp):
|
|||||||
q_inter_cache,
|
q_inter_cache,
|
||||||
persistent=False)
|
persistent=False)
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||||
"""Compute the inverse frequency."""
|
"""Compute the inverse frequency."""
|
||||||
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
||||||
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
||||||
@ -1705,7 +1705,7 @@ def get_rope(
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position: int,
|
max_position: int,
|
||||||
base: int,
|
base: float,
|
||||||
is_neox_style: bool = True,
|
is_neox_style: bool = True,
|
||||||
rope_scaling: Optional[dict[str, Any]] = None,
|
rope_scaling: Optional[dict[str, Any]] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
|||||||
@ -141,7 +141,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp):
|
|||||||
head_size: int,
|
head_size: int,
|
||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
max_position: int,
|
max_position: int,
|
||||||
base: int,
|
base: float,
|
||||||
is_neox_style: bool,
|
is_neox_style: bool,
|
||||||
cache_dtype: torch.dtype,
|
cache_dtype: torch.dtype,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -155,10 +155,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp):
|
|||||||
cache = self._compute_cos_sin_cache().to(cache_dtype)
|
cache = self._compute_cos_sin_cache().to(cache_dtype)
|
||||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
|
||||||
def _compute_inv_freq(
|
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||||
self,
|
|
||||||
base: Union[int, float],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Compute the inverse frequency."""
|
"""Compute the inverse frequency."""
|
||||||
inv_freq = 1.0 / (base**(torch.arange(
|
inv_freq = 1.0 / (base**(torch.arange(
|
||||||
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user