[Hardware][Intel GPU] Upgrade to torch 2.7 (#17444)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Qiming Zhang <qiming1.zhang@intel.com>
This commit is contained in:
Kunshang Ji 2025-04-30 15:03:58 +08:00 committed by GitHub
parent 6ed9f6047e
commit ed6cfb90c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 18 additions and 35 deletions

View File

@ -40,12 +40,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=.git,target=.git \ --mount=type=bind,source=.git,target=.git \
python3 setup.py install python3 setup.py install
# Please refer xpu doc, we need manually install intel-extension-for-pytorch 2.6.10+xpu due to there are some conflict dependencies with torch 2.6.0+xpu
# FIXME: This will be fix in ipex 2.7. just leave this here for awareness.
RUN --mount=type=cache,target=/root/.cache/pip \
pip install intel-extension-for-pytorch==2.6.10+xpu \
--extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
CMD ["/bin/bash"] CMD ["/bin/bash"]
FROM vllm-base AS vllm-openai FROM vllm-base AS vllm-openai

View File

@ -35,13 +35,6 @@ pip install -v -r requirements/xpu.txt
VLLM_TARGET_DEVICE=xpu python setup.py install VLLM_TARGET_DEVICE=xpu python setup.py install
``` ```
- Finally, due to a known issue of conflict dependency(oneapi related) in torch-xpu 2.6 and ipex-xpu 2.6, we install ipex here. This will be fixed in the ipex-xpu 2.7.
```console
pip install intel-extension-for-pytorch==2.6.10+xpu \
--extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
```
:::{note} :::{note}
- FP16 is the default data type in the current XPU backend. The BF16 data - FP16 is the default data type in the current XPU backend. The BF16 data
type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet. type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet.
@ -81,5 +74,3 @@ python -m vllm.entrypoints.openai.api_server \
``` ```
By default, a ray instance will be launched automatically if no existing one is detected in the system, with `num-gpus` equals to `parallel_config.world_size`. We recommend properly starting a ray cluster before execution, referring to the <gh-file:examples/online_serving/run_cluster.sh> helper script. By default, a ray instance will be launched automatically if no existing one is detected in the system, with `num-gpus` equals to `parallel_config.world_size`. We recommend properly starting a ray cluster before execution, referring to the <gh-file:examples/online_serving/run_cluster.sh> helper script.
There are some new features coming with ipex-xpu 2.6, e.g. **chunked prefill**, **V1 engine support**, **lora**, **MoE**, etc.

View File

@ -10,7 +10,7 @@ wheel
jinja2>=3.1.6 jinja2>=3.1.6
datasets # for benchmark scripts datasets # for benchmark scripts
torch==2.6.0+xpu torch==2.7.0+xpu
torchaudio torchaudio
torchvision torchvision
pytorch-triton-xpu pytorch-triton-xpu
@ -18,6 +18,6 @@ pytorch-triton-xpu
# Please refer xpu doc, we need manually install intel-extension-for-pytorch 2.6.10+xpu due to there are some conflict dependencies with torch 2.6.0+xpu # Please refer xpu doc, we need manually install intel-extension-for-pytorch 2.6.10+xpu due to there are some conflict dependencies with torch 2.6.0+xpu
# FIXME: This will be fix in ipex 2.7. just leave this here for awareness. # FIXME: This will be fix in ipex 2.7. just leave this here for awareness.
# intel-extension-for-pytorch==2.6.10+xpu intel-extension-for-pytorch==2.7.10+xpu
oneccl_bind_pt==2.6.0+xpu oneccl_bind_pt==2.7.0+xpu
--extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/

View File

@ -177,6 +177,7 @@ class ipex_ops:
out: torch.Tensor, out: torch.Tensor,
seqlen_q: torch.Tensor, seqlen_q: torch.Tensor,
seqlen_k: torch.Tensor, seqlen_k: torch.Tensor,
alibi_slopes: torch.Tensor,
max_seqlen_q: int, max_seqlen_q: int,
max_seqlen_k: int, max_seqlen_k: int,
pdropout: float, pdropout: float,
@ -185,6 +186,8 @@ class ipex_ops:
is_causal: bool, is_causal: bool,
return_softmax: bool, return_softmax: bool,
gen_: torch.Generator, gen_: torch.Generator,
window_size_left: float,
window_size_right: float,
logits_soft_cap: float, logits_soft_cap: float,
) -> None: ) -> None:
if ipex.__version__.endswith("cpu"): if ipex.__version__.endswith("cpu"):
@ -200,15 +203,12 @@ class ipex_ops:
is_causal, return_softmax, is_causal, return_softmax,
gen_) gen_)
else: # XPU build else: # XPU build
ipex.llm.functional.varlen_attention(query.contiguous(), ipex.llm.functional.varlen_attention(
key.contiguous(), query.contiguous(), key.contiguous(), value.contiguous(), out,
value.contiguous(), out, seqlen_q.int(), seqlen_k.int(), alibi_slopes, max_seqlen_q,
seqlen_q.int(), max_seqlen_k, pdropout, softmax_scale, zero_tensors, is_causal,
seqlen_k.int(), max_seqlen_q, return_softmax, gen_, window_size_left, window_size_right,
max_seqlen_k, pdropout, logits_soft_cap)
softmax_scale, zero_tensors,
is_causal, return_softmax,
gen_, logits_soft_cap)
@staticmethod @staticmethod
def reshape_and_cache( def reshape_and_cache(

View File

@ -143,10 +143,9 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.alibi_slopes is not None self.need_mask = (self.sliding_window is not None)
or self.sliding_window is not None)
if logits_soft_cap is None: if logits_soft_cap is None:
logits_soft_cap = 0 logits_soft_cap = -1
self.logits_soft_cap = logits_soft_cap self.logits_soft_cap = logits_soft_cap
supported_head_sizes = PagedAttention.get_supported_head_sizes() supported_head_sizes = PagedAttention.get_supported_head_sizes()
@ -234,11 +233,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
dim=1) dim=1)
if attn_metadata.attn_bias is None: if attn_metadata.attn_bias is None:
if self.alibi_slopes is not None: if self.sliding_window is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.seq_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias( att_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window, attn_metadata.seq_lens, self.sliding_window,
query.dtype) # type: ignore query.dtype) # type: ignore
@ -258,6 +253,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
output, output,
attn_metadata.seqlen_q, attn_metadata.seqlen_q,
attn_metadata.seqlen_q, attn_metadata.seqlen_q,
self.alibi_slopes,
attn_metadata.max_seqlen, attn_metadata.max_seqlen,
attn_metadata.max_seqlen, attn_metadata.max_seqlen,
pdropout=0.0, pdropout=0.0,
@ -266,6 +262,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
is_causal=True, is_causal=True,
return_softmax=False, return_softmax=False,
gen_=None, gen_=None,
window_size_left=-1,
window_size_right=-1,
logits_soft_cap=self.logits_soft_cap, logits_soft_cap=self.logits_soft_cap,
) )
else: else: