mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:06:10 +08:00
[Hardware][intel GPU] bump up ipex version to 2.3 (#8365)
Co-authored-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
parent
9ba0817ff1
commit
851725202a
@ -1,15 +1,23 @@
|
||||
FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu20.04
|
||||
FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04
|
||||
|
||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \
|
||||
echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \
|
||||
chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \
|
||||
rm /etc/apt/sources.list.d/intel-graphics.list && \
|
||||
wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \
|
||||
echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \
|
||||
chmod 644 /usr/share/keyrings/intel-graphics.gpg
|
||||
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1
|
||||
|
||||
RUN git clone https://github.com/intel/pti-gpu && \
|
||||
cd pti-gpu/sdk && \
|
||||
mkdir build && \
|
||||
cd build && \
|
||||
cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/icpx_toolchain.cmake -DBUILD_TESTING=OFF .. && \
|
||||
make -j && \
|
||||
cmake --install . --config Release --prefix "/usr/local"
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
@ -3,9 +3,10 @@
|
||||
|
||||
setuptools < 70.0.0 # IPEX's torch have some dependency. to be removed.
|
||||
|
||||
torch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl
|
||||
intel_extension_for_pytorch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl
|
||||
oneccl_bind_pt @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.1.200%2Bxpu-cp310-cp310-linux_x86_64.whl
|
||||
torch == 2.3.1+cxx11.abi
|
||||
intel-extension-for-pytorch == 2.3.110+xpu
|
||||
oneccl_bind_pt == 2.3.100+xpu
|
||||
|
||||
triton @ https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
triton-xpu == 3.0.0b2
|
||||
|
||||
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
|
||||
@ -27,29 +27,27 @@ class ipex_ops:
|
||||
|
||||
@staticmethod
|
||||
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
x1, x2 = ipex_ops._reshape_activation_tensor(x)
|
||||
ipex.llm.functional.silu_mul(x1, x2, out)
|
||||
ipex.llm.functional.silu_and_mul(x, out)
|
||||
|
||||
@staticmethod
|
||||
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
x1, x2 = ipex_ops._reshape_activation_tensor(x)
|
||||
ipex.llm.functional.gelu_mul(x1, x2, out, "none")
|
||||
ipex.llm.functional.gelu_and_mul(x, out)
|
||||
|
||||
@staticmethod
|
||||
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
x1, x2 = ipex_ops._reshape_activation_tensor(x)
|
||||
ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
|
||||
ipex.llm.functional.gelu_and_mul(x, out)
|
||||
|
||||
@staticmethod
|
||||
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
out.copy_(torch.nn.functional.gelu(x))
|
||||
def gelu_fast(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(x)
|
||||
|
||||
@staticmethod
|
||||
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
out.copy_(torch.nn.functional.gelu(x))
|
||||
def gelu_new(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(x)
|
||||
|
||||
# TODO add implementation of gelu_quick here
|
||||
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
@staticmethod
|
||||
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
ipex.llm.functional.gelu_quick(x, out)
|
||||
|
||||
@staticmethod
|
||||
def paged_attention_v1(
|
||||
@ -160,29 +158,10 @@ class ipex_ops:
|
||||
cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
if positions.dim() == 1:
|
||||
positions = positions.unsqueeze(0)
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
|
||||
rotary_dim = cos_sin_cache.size(1)
|
||||
query = query.view(*query.shape[:-1], -1, head_size)
|
||||
key = key.view(*key.shape[:-1], -1, head_size)
|
||||
|
||||
query_rot = query[..., :rotary_dim]
|
||||
key_rot = key[..., :rotary_dim]
|
||||
|
||||
cos_sin = cos_sin_cache[positions.long()]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
if is_neox:
|
||||
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||
else:
|
||||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
|
||||
rotary_dim, is_neox, positions)
|
||||
rot_dim = cos_sin_cache.size(1)
|
||||
ipex.llm.functional.rotary_embedding_batched(positions, query, key,
|
||||
head_size, cos_sin_cache,
|
||||
is_neox, rot_dim)
|
||||
|
||||
@staticmethod
|
||||
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
||||
@ -190,37 +169,15 @@ class ipex_ops:
|
||||
cos_sin_cache: torch.Tensor, is_neox: bool,
|
||||
rot_dim: int,
|
||||
cos_sin_cache_offsets: torch.Tensor) -> None:
|
||||
if positions.dim() == 1:
|
||||
positions = positions.unsqueeze(0)
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions)
|
||||
rotary_dim = cos_sin_cache.size(1)
|
||||
query = query.view(*query.shape[:-1], -1, head_size)
|
||||
key = key.view(*key.shape[:-1], -1, head_size)
|
||||
|
||||
query_rot = query[..., :rotary_dim]
|
||||
key_rot = key[..., :rotary_dim]
|
||||
|
||||
cos_sin = cos_sin_cache[torch.add(positions,
|
||||
cos_sin_cache_offsets).long()]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
if is_neox:
|
||||
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||
else:
|
||||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
|
||||
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
|
||||
rotary_dim, is_neox, positions)
|
||||
ipex.llm.functional.rotary_embedding_batched(positions, query, key,
|
||||
head_size, cos_sin_cache,
|
||||
is_neox, rot_dim,
|
||||
cos_sin_cache_offsets)
|
||||
|
||||
@staticmethod
|
||||
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
||||
epsilon: float) -> None:
|
||||
tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
|
||||
out.copy_(tmp)
|
||||
def rms_norm(input: torch.Tensor, weight: torch.Tensor,
|
||||
epsilon: float) -> torch.Tensor:
|
||||
return ipex.llm.functional.rms_norm(input, weight, epsilon)
|
||||
|
||||
@staticmethod
|
||||
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
||||
@ -246,11 +203,14 @@ class ipex_ops:
|
||||
return_softmax: bool,
|
||||
gen_: torch.Generator,
|
||||
) -> None:
|
||||
ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q,
|
||||
seqlen_k, max_seqlen_q,
|
||||
max_seqlen_k, pdropout,
|
||||
softmax_scale, zero_tensors,
|
||||
is_causal, return_softmax, gen_)
|
||||
ipex.llm.functional.varlen_attention(query.contiguous(),
|
||||
key.contiguous(),
|
||||
value.contiguous(), out,
|
||||
seqlen_q.int(), seqlen_k.int(),
|
||||
max_seqlen_q, max_seqlen_k,
|
||||
pdropout, softmax_scale,
|
||||
zero_tensors, is_causal,
|
||||
return_softmax, gen_)
|
||||
|
||||
@staticmethod
|
||||
def reshape_and_cache(
|
||||
|
||||
@ -49,14 +49,18 @@ class IpexAttnBackend(AttentionBackend):
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -114,9 +114,7 @@ class NewGELU(CustomOp):
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_new(out, x)
|
||||
return out
|
||||
return ops.gelu_new(x)
|
||||
|
||||
|
||||
class FastGELU(CustomOp):
|
||||
@ -136,9 +134,7 @@ class FastGELU(CustomOp):
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_fast(out, x)
|
||||
return out
|
||||
return ops.gelu_fast(x)
|
||||
|
||||
|
||||
class QuickGELU(CustomOp):
|
||||
@ -155,6 +151,13 @@ class QuickGELU(CustomOp):
|
||||
ops.gelu_quick(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_quick(out, x)
|
||||
return out
|
||||
|
||||
# TODO implement forward_xpu for QuickGELU
|
||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
@ -82,14 +82,11 @@ class RMSNorm(CustomOp):
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
out = torch.empty_like(x)
|
||||
ops.rms_norm(
|
||||
out,
|
||||
return ops.rms_norm(
|
||||
x,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"hidden_size={self.weight.data.size(0)}"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user