[ROCm] support Radeon™ 7900 series (gfx1100) without using flash-attention (#2768)

This commit is contained in:
Hongxia Yang 2024-02-11 02:14:37 -05:00 committed by GitHub
parent 3711811b1d
commit 0580aab02f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 60 additions and 5 deletions

View File

@ -17,6 +17,12 @@ RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
ARG FA_BRANCH="3d2b6f5" ARG FA_BRANCH="3d2b6f5"
RUN echo "FA_BRANCH is $FA_BRANCH" RUN echo "FA_BRANCH is $FA_BRANCH"
# whether to build flash-attention
# if 0, will not build flash attention
# this is useful for gfx target where flash-attention is not supported
# In that case, we need to use the python reference attention implementation in vllm
ARG BUILD_FA="1"
# Install some basic utilities # Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y RUN apt-get update && apt-get install python3 python3-pip -y
@ -50,7 +56,8 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
# Install ROCm flash-attention # Install ROCm flash-attention
RUN mkdir libs \ RUN if [ "$BUILD_FA" == "1" ]; then \
mkdir libs \
&& cd libs \ && cd libs \
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \ && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
&& cd flash-attention \ && cd flash-attention \
@ -60,7 +67,8 @@ RUN mkdir libs \
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \ && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \ patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
&& python3 setup.py install \ && python3 setup.py install \
&& cd .. && cd ..; \
fi
COPY ./ /app/vllm COPY ./ /app/vllm
@ -75,7 +83,8 @@ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
RUN cd /app \ RUN cd /app \
&& cd vllm \ && cd vllm \
&& pip install -U -r requirements-rocm.txt \ && pip install -U -r requirements-rocm.txt \
&& bash patch_xformers.rocm.sh \ && if [ "$BUILD_FA" == "1" ]; then \
bash patch_xformers.rocm.sh; fi \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \ && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
&& python3 setup.py install \ && python3 setup.py install \
&& cd .. && cd ..

View File

@ -12,7 +12,7 @@ Requirements
* OS: Linux * OS: Linux
* Python: 3.8 -- 3.11 * Python: 3.8 -- 3.11
* GPU: MI200s (gfx90a), MI300 (gfx942) * GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
* Pytorch 2.0.1/2.1.1/2.2 * Pytorch 2.0.1/2.1.1/2.2
* ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9) * ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9)
@ -105,6 +105,7 @@ The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1` * `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942` * `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5` * `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5`
* `BUILD_FA`: specifies whether to build flash-attention. For `Radeon RX 7900 series (gfx1100) <https://rocm.docs.amd.com/projects/radeon/en/latest/index.html>`_, this should be set to 0 before flash-attention supports this target.
Their values can be passed in when running ``docker build`` with ``--build-arg`` options. Their values can be passed in when running ``docker build`` with ``--build-arg`` options.

View File

@ -24,7 +24,7 @@ MAIN_CUDA_VERSION = "12.1"
# Supported NVIDIA GPU architectures. # Supported NVIDIA GPU architectures.
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942"} ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942", "gfx1100"}
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) # SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)

View File

@ -1,6 +1,7 @@
"""Multi-head attention.""" """Multi-head attention."""
from typing import List, Optional from typing import List, Optional
import importlib
import torch import torch
import torch.nn as nn import torch.nn as nn
from xformers import ops as xops from xformers import ops as xops
@ -58,6 +59,40 @@ class PagedAttention(nn.Module):
raise ValueError(f"head_size ({self.head_size}) is not supported. " raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
self.use_ref_attention = self.check_use_ref_attention()
def check_use_ref_attention(self) -> bool:
if not is_hip():
return False
# For ROCm, check whether flash attention is installed or not.
# if not, use_ref_attention needs to be True
return importlib.util.find_spec("flash_attn") is None
def ref_masked_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
seq_len, _, _ = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min
attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query,
key).float()
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
@ -137,6 +172,16 @@ class PagedAttention(nn.Module):
self.alibi_slopes, self.num_kv_heads, batch_size, self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype) seq_len, query.dtype)
if self.use_ref_attention:
output = self.ref_masked_attention(
query,
key,
value,
)
# Using view got RuntimeError: view size is not compatible with input tensor's size and stride
# (at least one dimension spans across two contiguous subspaces). Use reshape instead
return output.reshape(batch_size, seq_len, hidden_size)
# TODO(woosuk): Too many view operations. Let's try to reduce # TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability. # them in the future for code readability.
if self.alibi_slopes is None: if self.alibi_slopes is None: