mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 15:45:02 +08:00
Add comments on RoPE initialization (#1176)
This commit is contained in:
parent
a425bd9a9a
commit
03ffd0a022
@ -264,6 +264,15 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
self.is_neox_style = is_neox_style
|
self.is_neox_style = is_neox_style
|
||||||
|
|
||||||
# Create the cos and sin cache.
|
# Create the cos and sin cache.
|
||||||
|
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
||||||
|
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
||||||
|
# avoid numerical issues with large base values (e.g., 10000000).
|
||||||
|
# This may cause a slight numerical difference between the HF
|
||||||
|
# implementation and ours.
|
||||||
|
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||||
|
# use CPU to compute the cache and then move it to GPU. However, we
|
||||||
|
# create the cache on GPU for faster initialization. This may cause
|
||||||
|
# a slight numerical difference between the HF implementation and ours.
|
||||||
inv_freq = 1.0 / (base**(torch.arange(
|
inv_freq = 1.0 / (base**(torch.arange(
|
||||||
0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
|
0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
|
||||||
t = torch.arange(max_position, dtype=torch.float, device="cuda")
|
t = torch.arange(max_position, dtype=torch.float, device="cuda")
|
||||||
@ -274,7 +283,6 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
|
|
||||||
# FIXME(woosuk): This assumes that we configure the default dtype when
|
# FIXME(woosuk): This assumes that we configure the default dtype when
|
||||||
# initializing the model.
|
# initializing the model.
|
||||||
# TODO(woosuk): Make it more robust.
|
|
||||||
torch_dtype = torch.get_default_dtype()
|
torch_dtype = torch.get_default_dtype()
|
||||||
cache = cache.to(torch_dtype)
|
cache = cache.to(torch_dtype)
|
||||||
# Embedding size: [max_position, rotary_dim]
|
# Embedding size: [max_position, rotary_dim]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user