mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
[ROCm] support Radeon™ 7900 series (gfx1100) without using flash-attention (#2768)
This commit is contained in:
parent
3711811b1d
commit
0580aab02f
@ -17,6 +17,12 @@ RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
|
||||
ARG FA_BRANCH="3d2b6f5"
|
||||
RUN echo "FA_BRANCH is $FA_BRANCH"
|
||||
|
||||
# whether to build flash-attention
|
||||
# if 0, will not build flash attention
|
||||
# this is useful for gfx target where flash-attention is not supported
|
||||
# In that case, we need to use the python reference attention implementation in vllm
|
||||
ARG BUILD_FA="1"
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||
|
||||
@ -50,7 +56,8 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
||||
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
||||
|
||||
# Install ROCm flash-attention
|
||||
RUN mkdir libs \
|
||||
RUN if [ "$BUILD_FA" == "1" ]; then \
|
||||
mkdir libs \
|
||||
&& cd libs \
|
||||
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
|
||||
&& cd flash-attention \
|
||||
@ -60,7 +67,8 @@ RUN mkdir libs \
|
||||
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
|
||||
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
|
||||
&& python3 setup.py install \
|
||||
&& cd ..
|
||||
&& cd ..; \
|
||||
fi
|
||||
|
||||
COPY ./ /app/vllm
|
||||
|
||||
@ -75,7 +83,8 @@ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||
RUN cd /app \
|
||||
&& cd vllm \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& bash patch_xformers.rocm.sh \
|
||||
&& if [ "$BUILD_FA" == "1" ]; then \
|
||||
bash patch_xformers.rocm.sh; fi \
|
||||
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
|
||||
&& python3 setup.py install \
|
||||
&& cd ..
|
||||
|
||||
@ -12,7 +12,7 @@ Requirements
|
||||
|
||||
* OS: Linux
|
||||
* Python: 3.8 -- 3.11
|
||||
* GPU: MI200s (gfx90a), MI300 (gfx942)
|
||||
* GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
|
||||
* Pytorch 2.0.1/2.1.1/2.2
|
||||
* ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9)
|
||||
|
||||
@ -105,6 +105,7 @@ The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later
|
||||
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
|
||||
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
|
||||
* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5`
|
||||
* `BUILD_FA`: specifies whether to build flash-attention. For `Radeon RX 7900 series (gfx1100) <https://rocm.docs.amd.com/projects/radeon/en/latest/index.html>`_, this should be set to 0 before flash-attention supports this target.
|
||||
|
||||
Their values can be passed in when running ``docker build`` with ``--build-arg`` options.
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@ -24,7 +24,7 @@ MAIN_CUDA_VERSION = "12.1"
|
||||
|
||||
# Supported NVIDIA GPU architectures.
|
||||
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
|
||||
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942"}
|
||||
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942", "gfx1100"}
|
||||
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""Multi-head attention."""
|
||||
from typing import List, Optional
|
||||
|
||||
import importlib
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from xformers import ops as xops
|
||||
@ -58,6 +59,40 @@ class PagedAttention(nn.Module):
|
||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
||||
|
||||
self.use_ref_attention = self.check_use_ref_attention()
|
||||
|
||||
def check_use_ref_attention(self) -> bool:
|
||||
if not is_hip():
|
||||
return False
|
||||
# For ROCm, check whether flash attention is installed or not.
|
||||
# if not, use_ref_attention needs to be True
|
||||
return importlib.util.find_spec("flash_attn") is None
|
||||
|
||||
def ref_masked_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
seq_len, _, _ = query.shape
|
||||
attn_mask = torch.triu(torch.ones(seq_len,
|
||||
seq_len,
|
||||
dtype=query.dtype,
|
||||
device=query.device),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(query.dtype).min
|
||||
|
||||
attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query,
|
||||
key).float()
|
||||
attn_weights = attn_weights + attn_mask.float()
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
||||
return out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@ -137,6 +172,16 @@ class PagedAttention(nn.Module):
|
||||
self.alibi_slopes, self.num_kv_heads, batch_size,
|
||||
seq_len, query.dtype)
|
||||
|
||||
if self.use_ref_attention:
|
||||
output = self.ref_masked_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
)
|
||||
# Using view got RuntimeError: view size is not compatible with input tensor's size and stride
|
||||
# (at least one dimension spans across two contiguous subspaces). Use reshape instead
|
||||
return output.reshape(batch_size, seq_len, hidden_size)
|
||||
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||
# them in the future for code readability.
|
||||
if self.alibi_slopes is None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user