mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 06:41:47 +08:00
parent
ce26b16268
commit
538fab93cd
@ -509,15 +509,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if rotary_dim != head_size:
|
||||
raise ValueError(
|
||||
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
|
||||
rotary_dim != head_size ({rotary_dim}!={head_size}).")
|
||||
if is_neox_style is False:
|
||||
raise ValueError(
|
||||
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
|
||||
)
|
||||
|
||||
self.rotary_dim = rotary_dim
|
||||
self.head_size = head_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
@ -557,7 +554,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
|
||||
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
|
||||
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
|
||||
0, self.head_size, 2, dtype=torch.float) / self.head_size)))
|
||||
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)))
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(
|
||||
@ -596,8 +593,15 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
cos = cos.repeat(1, 2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 2).unsqueeze(-2)
|
||||
|
||||
query = query * cos + _rotate_neox(query) * sin
|
||||
key = key * cos + _rotate_neox(key) * sin
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
|
||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
|
||||
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||
|
||||
return query.flatten(-2), key.flatten(-2)
|
||||
|
||||
|
||||
@ -128,6 +128,9 @@ class LlamaAttention(nn.Module):
|
||||
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||
self.head_dim = getattr(config, "head_dim",
|
||||
self.hidden_size // self.total_num_heads)
|
||||
# Phi models introduced a partial_rotary_factor parameter in the config
|
||||
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
|
||||
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
@ -159,7 +162,7 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
rotary_dim=self.rotary_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user