mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 16:05:53 +08:00
Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Roger Wang <hey@rogerw.io> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
84135b1489
commit
e45271b09c
@ -4,7 +4,7 @@
|
|||||||
import math
|
import math
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
from typing import Callable
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -72,7 +72,9 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]:
|
def dispatch_rotary_emb_function(
|
||||||
|
default: Optional[Callable[..., torch.Tensor]] = None
|
||||||
|
) -> Callable[..., torch.Tensor]:
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
return apply_rotary_emb
|
return apply_rotary_emb
|
||||||
|
|
||||||
@ -85,6 +87,9 @@ def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]:
|
|||||||
"flash_attn is not installed. Falling back to PyTorch "
|
"flash_attn is not installed. Falling back to PyTorch "
|
||||||
"implementation for rotary embeddings.")
|
"implementation for rotary embeddings.")
|
||||||
|
|
||||||
|
if default is not None:
|
||||||
|
return default
|
||||||
|
else:
|
||||||
return apply_rotary_emb_torch
|
return apply_rotary_emb_torch
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -276,7 +276,8 @@ def apply_rotary_emb_torch(x: torch.Tensor,
|
|||||||
|
|
||||||
def apply_rotary_pos_emb_vision(t: torch.Tensor,
|
def apply_rotary_pos_emb_vision(t: torch.Tensor,
|
||||||
freqs: torch.Tensor) -> torch.Tensor:
|
freqs: torch.Tensor) -> torch.Tensor:
|
||||||
rotary_emb_function = dispatch_rotary_emb_function()
|
rotary_emb_function = dispatch_rotary_emb_function(
|
||||||
|
default=apply_rotary_emb_torch)
|
||||||
t_ = t.float()
|
t_ = t.float()
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user