[Hardware][Intel-Gaudi] Add Intel Gaudi (HPU) inference backend (#6143)

Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: Bob Zhu <bob.zhu@intel.com>
Signed-off-by: zehao-intel <zehao.huang@intel.com>
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Sanju C Sudhakaran <scsudhakaran@habana.ai>
Co-authored-by: Michal Adamczyk <madamczyk@habana.ai>
Co-authored-by: Marceli Fylcek <mfylcek@habana.ai>
Co-authored-by: Himangshu Lahkar <49579433+hlahkar@users.noreply.github.com>
Co-authored-by: Vivek Goel <vgoel@habana.ai>
Co-authored-by: yuwenzho <yuwen.zhou@intel.com>
Co-authored-by: Dominika Olszewska <dolszewska@habana.ai>
Co-authored-by: barak goldberg <149692267+bgoldberg-habana@users.noreply.github.com>
Co-authored-by: Michal Szutenberg <37601244+szutenberg@users.noreply.github.com>
Co-authored-by: Jan Kaniecki <jkaniecki@habana.ai>
Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyniewicz-habana@users.noreply.github.com>
Co-authored-by: Krzysztof Wisniewski <kwisniewski@habana.ai>
Co-authored-by: Dudi Lester <160421192+dudilester@users.noreply.github.com>
Co-authored-by: Ilia Taraban <tarabanil@gmail.com>
Co-authored-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Michał Kuligowski <mkuligowski@habana.ai>
Co-authored-by: Jakub Maksymczuk <jmaksymczuk@habana.ai>
Co-authored-by: Tomasz Zielinski <85164140+tzielinski-habana@users.noreply.github.com>
Co-authored-by: Sun Choi <schoi@habana.ai>
Co-authored-by: Iryna Boiko <iboiko@habana.ai>
Co-authored-by: Bob Zhu <41610754+czhu15@users.noreply.github.com>
Co-authored-by: hlin99 <73271530+hlin99@users.noreply.github.com>
Co-authored-by: Zehao Huang <zehao.huang@intel.com>
Co-authored-by: Andrzej Kotłowski <Andrzej.Kotlowski@intel.com>
Co-authored-by: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com>
Co-authored-by: Nir David <ndavid@habana.ai>
Co-authored-by: Yu-Zhou <yu.zhou@intel.com>
Co-authored-by: Ruheena Suhani Shaik <rsshaik@habana.ai>
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
Co-authored-by: Marcin Swiniarski <mswiniarski@habana.ai>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Jacek Czaja <jacek.czaja@intel.com>
Co-authored-by: Jacek Czaja <jczaja@habana.ai>
Co-authored-by: Yuan <yuan.zhou@outlook.com>
This commit is contained in:
Konrad Zawora 2024-11-06 10:09:10 +01:00 committed by GitHub
parent a5fda50a10
commit a02a50e6e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 4279 additions and 20 deletions

16
Dockerfile.hpu Normal file
View File

@ -0,0 +1,16 @@
FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
COPY ./ /workspace/vllm
WORKDIR /workspace/vllm
RUN pip install -v -r requirements-hpu.txt
ENV no_proxy=localhost,127.0.0.1
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true
RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install
WORKDIR /workspace/
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]

View File

@ -0,0 +1,402 @@
Installation with Intel® Gaudi® AI Accelerators
===============================================
This README provides instructions on running vLLM with Intel Gaudi devices.
Requirements and Installation
=============================
Please follow the instructions provided in the `Gaudi Installation
Guide <https://docs.habana.ai/en/latest/Installation_Guide/index.html>`__
to set up the execution environment. To achieve the best performance,
please follow the methods outlined in the `Optimizing Training Platform
Guide <https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_Training_Platform.html>`__.
Requirements
------------
- OS: Ubuntu 22.04 LTS
- Python: 3.10
- Intel Gaudi accelerator
- Intel Gaudi software version 1.18.0
Quick start using Dockerfile
----------------------------
.. code:: console
$ docker build -f Dockerfile.hpu -t vllm-hpu-env .
$ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --rm vllm-hpu-env
.. tip::
If you're observing the following error: ``docker: Error response from daemon: Unknown runtime specified habana.``, please refer to "Install Using Containers" section of `Intel Gaudi Software Stack and Driver Installation <https://docs.habana.ai/en/v1.18.0/Installation_Guide/Bare_Metal_Fresh_OS.html>`__. Make sure you have ``habana-container-runtime`` package installed and that ``habana`` container runtime is registered.
Build from source
-----------------
Environment verification
~~~~~~~~~~~~~~~~~~~~~~~~
To verify that the Intel Gaudi software was correctly installed, run:
.. code:: console
$ hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible
$ apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core, habanalabs-thunk and habanalabs-container-runtime are installed
$ pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed
$ pip list | grep neural # verify that neural_compressor is installed
Refer to `Intel Gaudi Software Stack
Verification <https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade>`__
for more details.
Run Docker Image
~~~~~~~~~~~~~~~~
It is highly recommended to use the latest Docker image from Intel Gaudi
vault. Refer to the `Intel Gaudi
documentation <https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#pull-prebuilt-containers>`__
for more details.
Use the following commands to run a Docker image:
.. code:: console
$ docker pull vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
$ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
Build and Install vLLM
~~~~~~~~~~~~~~~~~~~~~~
To build and install vLLM from source, run:
.. code:: console
$ git clone https://github.com/vllm-project/vllm.git
$ cd vllm
$ python setup.py develop
Currently, the latest features and performance optimizations are developed in Gaudi's `vLLM-fork <https://github.com/HabanaAI/vllm-fork>`__ and we periodically upstream them to vLLM main repo. To install latest `HabanaAI/vLLM-fork <https://github.com/HabanaAI/vllm-fork>`__, run the following:
.. code:: console
$ git clone https://github.com/HabanaAI/vllm-fork.git
$ cd vllm-fork
$ git checkout habana_main
$ python setup.py develop
Supported Features
==================
- `Offline batched
inference <https://docs.vllm.ai/en/latest/getting_started/quickstart.html#offline-batched-inference>`__
- Online inference via `OpenAI-Compatible
Server <https://docs.vllm.ai/en/latest/getting_started/quickstart.html#openai-compatible-server>`__
- HPU autodetection - no need to manually select device within vLLM
- Paged KV cache with algorithms enabled for Intel Gaudi accelerators
- Custom Intel Gaudi implementations of Paged Attention, KV cache ops,
prefill attention, Root Mean Square Layer Normalization, Rotary
Positional Encoding
- Tensor parallelism support for multi-card inference
- Inference with `HPU Graphs <https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html>`__
for accelerating low-batch latency and throughput
- Attention with Linear Biases (ALiBi)
Unsupported Features
====================
- Beam search
- LoRA adapters
- Quantization
- Prefill chunking (mixed-batch inferencing)
Supported Configurations
========================
The following configurations have been validated to be function with
Gaudi2 devices. Configurations that are not listed may or may not work.
- `meta-llama/Llama-2-7b <https://huggingface.co/meta-llama/Llama-2-7b>`__
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
datatype with random or greedy sampling
- `meta-llama/Llama-2-7b-chat-hf <https://huggingface.co/meta-llama/Llama-2-7b-chat-hf>`__
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3-8B <https://huggingface.co/meta-llama/Meta-Llama-3-8B>`__
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3-8B-Instruct <https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct>`__
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3.1-8B <https://huggingface.co/meta-llama/Meta-Llama-3.1-8B>`__
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3.1-8B-Instruct <https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct>`__
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
datatype with random or greedy sampling
- `meta-llama/Llama-2-70b <https://huggingface.co/meta-llama/Llama-2-70b>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- `meta-llama/Llama-2-70b-chat-hf <https://huggingface.co/meta-llama/Llama-2-70b-chat-hf>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3-70B <https://huggingface.co/meta-llama/Meta-Llama-3-70B>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3-70B-Instruct <https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3.1-70B <https://huggingface.co/meta-llama/Meta-Llama-3.1-70B>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
- `meta-llama/Meta-Llama-3.1-70B-Instruct <https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct>`__
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
Performance Tuning
==================
Execution modes
---------------
Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via ``PT_HPU_LAZY_MODE`` environment variable), and ``--enforce-eager`` flag.
.. list-table:: vLLM execution modes
:widths: 25 25 50
:header-rows: 1
* - ``PT_HPU_LAZY_MODE``
- ``enforce_eager``
- execution mode
* - 0
- 0
- torch.compile
* - 0
- 1
- PyTorch eager mode
* - 1
- 0
- HPU Graphs
* - 1
- 1
- PyTorch lazy mode
.. warning::
In 1.18.0, all modes utilizing ``PT_HPU_LAZY_MODE=0`` are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.18.0, please use HPU Graphs, or PyTorch lazy mode.
Bucketing mechanism
-------------------
Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. `Intel Gaudi Graph Compiler <https://docs.habana.ai/en/latest/Gaudi_Overview/Intel_Gaudi_Software_Suite.html#graph-compiler-and-runtime>`__ is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution.
In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occurring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - ``batch_size`` and ``sequence_length``.
.. note::
Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase.
Bucketing ranges are determined with 3 parameters - ``min``, ``step`` and ``max``. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup:
.. code-block::
INFO 08-01 21:37:59 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024]
INFO 08-01 21:37:59 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)]
INFO 08-01 21:37:59 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048]
INFO 08-01 21:37:59 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)]
``min`` determines the lowest value of the bucket. ``step`` determines the interval between buckets, and ``max`` determines the upper bound of the bucket. Furthermore, interval between ``min`` and ``step`` has special handling - ``min`` gets multiplied by consecutive powers of two, until ``step`` gets reached. We call this the ramp-up phase and it is used for handling lower batch sizes with minimum wastage, while allowing larger padding on larger batch sizes.
Example (with ramp-up)
.. code-block::
min = 2, step = 32, max = 64
=> ramp_up = (2, 4, 8, 16)
=> stable = (32, 64)
=> buckets = ramp_up + stable => (2, 4, 8, 16, 32, 64)
Example (without ramp-up)
.. code-block::
min = 128, step = 128, max = 512
=> ramp_up = ()
=> stable = (128, 256, 384, 512)
=> buckets = ramp_up + stable => (128, 256, 384, 512)
In the logged scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket.
.. warning::
If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario.
As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as ``(4, 512)`` prefill bucket, as ``batch_size`` (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as ``(4, 512)`` decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a ``(2, 512)`` bucket, or context length increases above 512 tokens, in which case it will become ``(4, 640)`` bucket.
.. note::
Bucketing is transparent to a client - padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests.
Warmup
------
Warmup is an optional, but highly recommended step occurring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundaries during server runtime. Each warmup step is logged during vLLM startup:
.. code-block::
INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB
INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB
INFO 08-01 22:26:48 hpu_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB
...
INFO 08-01 22:26:59 hpu_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB
INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB
INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB
INFO 08-01 22:27:01 hpu_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB
...
INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB
INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB
This example uses the same buckets as in *Bucketing mechanism* section. Each output line corresponds to execution of a single bucket. When bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations.
.. tip::
Compiling all the buckets might take some time and can be turned off with ``VLLM_SKIP_WARMUP=true`` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment.
HPU Graph capture
-----------------
`HPU Graphs <https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html>`__ are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management.
When HPU Graphs are being used, they share the common memory pool ("usable memory") as KV cache, determined by ``gpu_memory_utilization`` flag (``0.9`` by default).
Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage.
Only after that, ``gpu_memory_utilization`` flag is utilized - at its default value, will mark 90% of free device memory at that point as usable.
Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured.
Environment variable ``VLLM_GRAPH_RESERVED_MEM`` defines the ratio of memory reserved for HPU Graphs capture.
With its default value (``VLLM_GRAPH_RESERVED_MEM=0.1``), 10% of usable memory will be reserved for graph capture (later referred to as "usable graph memory"), and the remaining 90% will be utilized for KV cache.
Environment variable ``VLLM_GRAPH_PROMPT_RATIO`` determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (``VLLM_GRAPH_PROMPT_RATIO=0.3``), both stages have equal memory constraints.
Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. ``VLLM_GRAPH_PROMPT_RATIO=0.2`` will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs.
.. note::
``gpu_memory_utilization`` does not correspond to the absolute memory usage across HPU. It specifies the memory margin after loading the model and performing a profile run. If device has 100 GiB of total memory, and 50 GiB of free memory after loading model weights and executing profiling run, ``gpu_memory_utilization`` at its default value will mark 90% of 50 GiB as usable, leaving 5 GiB of margin, regardless of total device memory.
User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented:
- ``max_bs`` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. ``(64, 128)``, ``(64, 256)``, ``(32, 128)``, ``(32, 256)``, ``(1, 128)``, ``(1,256)``), default strategy for decode
- ``min_tokens`` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (``batch_size*sequence_length``), default strategy for prompt
When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by ``max_bs`` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in ``min_tokens`` strategy.
.. note::
``VLLM_GRAPH_PROMPT_RATIO`` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * ``VLLM_GRAPH_PROMPT_RATIO``) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below.
Each described step is logged by vLLM server, as follows (negative values correspond to memory being released):
.. code-block::
INFO 08-02 17:37:44 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024]
INFO 08-02 17:37:44 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)]
INFO 08-02 17:37:44 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048]
INFO 08-02 17:37:44 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)]
INFO 08-02 17:37:52 hpu_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used)
INFO 08-02 17:37:52 hpu_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used)
INFO 08-02 17:37:52 hpu_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used)
INFO 08-02 17:37:54 hpu_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used)
INFO 08-02 17:37:54 hpu_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache
INFO 08-02 17:37:54 hpu_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0
INFO 08-02 17:37:54 hpu_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used)
INFO 08-02 17:37:54 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB
...
INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB
INFO 08-02 17:38:22 hpu_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.3)
INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB
...
INFO 08-02 17:38:26 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB
INFO 08-02 17:38:27 hpu_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB
...
INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB
INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB
INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB
INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB
INFO 08-02 17:38:43 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB
INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)]
INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)]
INFO 08-02 17:38:43 hpu_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory
INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used)
Recommended vLLM Parameters
---------------------------
- We recommend running inference on Gaudi 2 with ``block_size`` of 128
for BF16 data type. Using default values (16, 32) might lead to
sub-optimal performance due to Matrix Multiplication Engine
under-utilization (see `Gaudi
Architecture <https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html>`__).
- For max throughput on Llama 7B, we recommend running with batch size
of 128 or 256 and max context length of 2048 with HPU Graphs enabled.
If you encounter out-of-memory issues, see troubleshooting section.
Environment variables
---------------------
**Diagnostic and profiling knobs:**
- ``VLLM_PROFILER_ENABLED``: if ``true``, high level profiler will be enabled. Resulting JSON traces can be viewed in `perfetto.habana.ai <https://perfetto.habana.ai/#!/viewer>`__. Disabled by default.
- ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION``: if ``true``, will log graph compilations per each vLLM engine step, only when there was any - highly recommended to use alongside ``PT_HPU_METRICS_GC_DETAILS=1``. Disabled by default.
- ``VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL``: if ``true``, will log graph compilations per each vLLM engine step, always, even if there were none. Disabled by default.
- ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS``: if ``true``, will log cpu fallbacks per each vLLM engine step, only when there was any. Disabled by default.
- ``VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL``: if ``true``, will log cpu fallbacks per each vLLM engine step, always, even if there were none. Disabled by default.
**Performance tuning knobs:**
- ``VLLM_SKIP_WARMUP``: if ``true``, warmup will be skipped, ``false`` by default
- ``VLLM_GRAPH_RESERVED_MEM``: percentage of memory dedicated for HPUGraph capture, ``0.1`` by default
- ``VLLM_GRAPH_PROMPT_RATIO``: percentage of reserved graph memory dedicated for prompt graphs, ``0.3`` by default
- ``VLLM_GRAPH_PROMPT_STRATEGY``: strategy determining order of prompt graph capture, ``min_tokens`` or ``max_bs``, ``min_tokens`` by default
- ``VLLM_GRAPH_DECODE_STRATEGY``: strategy determining order of decode graph capture, ``min_tokens`` or ``max_bs``, ``max_bs`` by default
- ``VLLM_{phase}_{dim}_BUCKET_{param}`` - collection of 12 environment variables configuring ranges of bucketing mechanism
- ``{phase}`` is either ``PROMPT`` or ``DECODE``
- ``{dim}`` is either ``BS``, ``SEQ`` or ``BLOCK``
- ``{param}`` is either ``MIN``, ``STEP`` or ``MAX``
- Default values:
- Prompt:
- batch size min (``VLLM_PROMPT_BS_BUCKET_MIN``): ``1``
- batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)``
- batch size max (``VLLM_PROMPT_BS_BUCKET_MAX``): ``min(max_num_seqs, 64)``
- sequence length min (``VLLM_PROMPT_SEQ_BUCKET_MIN``): ``block_size``
- sequence length step (``VLLM_PROMPT_SEQ_BUCKET_STEP``): ``block_size``
- sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``max_model_len``
- Decode:
- batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1``
- batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)``
- batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs``
- sequence length min (``VLLM_DECODE_BLOCK_BUCKET_MIN``): ``block_size``
- sequence length step (``VLLM_DECODE_BLOCK_BUCKET_STEP``): ``block_size``
- sequence length max (``VLLM_DECODE_BLOCK_BUCKET_MAX``): ``max(128, (max_num_seqs*max_model_len)/block_size)``
Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution:
- ``PT_HPU_LAZY_MODE``: if ``0``, PyTorch Eager backend for Gaudi will be used, if ``1`` PyTorch Lazy backend for Gaudi will be used, ``1`` is default
- ``PT_HPU_ENABLE_LAZY_COLLECTIVES``: required to be ``true`` for tensor parallel inference with HPU Graphs
Troubleshooting: Tweaking HPU Graphs
====================================
If you experience device out-of-memory issues or want to attempt
inference at higher batch sizes, try tweaking HPU Graphs by following
the below:
- Tweak ``gpu_memory_utilization`` knob. It will decrease the
allocation of KV cache, leaving some headroom for capturing graphs
with larger batch size. By default ``gpu_memory_utilization`` is set
to 0.9. It attempts to allocate ~90% of HBM left for KV cache after
short profiling run. Note that decreasing reduces the number of KV
cache blocks you have available, and therefore reduces the effective
maximum number of tokens you can handle at a given time.
- If this method is not efficient, you can disable ``HPUGraph``
completely. With HPU Graphs disabled, you are trading latency and
throughput at lower batches for potentially higher throughput on
higher batches. You can do that by adding ``--enforce-eager`` flag to
server (for online inference), or by passing ``enforce_eager=True``
argument to LLM constructor (for offline inference).

View File

@ -43,7 +43,7 @@ vLLM is flexible and easy to use with:
* Tensor parallelism and pipeline parallelism support for distributed inference
* Streaming outputs
* OpenAI-compatible API server
* Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
* Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, PowerPC CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
* Prefix caching support
* Multi-lora support
@ -66,6 +66,7 @@ Documentation
getting_started/amd-installation
getting_started/openvino-installation
getting_started/cpu-installation
getting_started/gaudi-installation
getting_started/neuron-installation
getting_started/tpu-installation
getting_started/xpu-installation

11
requirements-hpu.txt Normal file
View File

@ -0,0 +1,11 @@
# Common dependencies
-r requirements-common.txt
# Dependencies for HPU code
ray
triton
pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@fd7f2e6

View File

@ -253,6 +253,24 @@ class cmake_build_ext(build_ext):
self.copy_file(file, dst_file)
def _is_hpu() -> bool:
is_hpu_available = True
try:
subprocess.run(["hl-smi"], capture_output=True, check=True)
except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
if not os.path.exists('/dev/accel/accel0') and not os.path.exists(
'/dev/accel/accel_controlD0'):
# last resort...
try:
output = subprocess.check_output(
'lsmod | grep habanalabs | wc -l', shell=True)
is_hpu_available = int(output) > 0
except (ValueError, FileNotFoundError, PermissionError,
subprocess.CalledProcessError):
is_hpu_available = False
return is_hpu_available or VLLM_TARGET_DEVICE == "hpu"
def _no_device() -> bool:
return VLLM_TARGET_DEVICE == "empty"
@ -260,7 +278,7 @@ def _no_device() -> bool:
def _is_cuda() -> bool:
has_cuda = torch.version.cuda is not None
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda
and not (_is_neuron() or _is_tpu()))
and not (_is_neuron() or _is_tpu() or _is_hpu()))
def _is_hip() -> bool:
@ -356,6 +374,23 @@ def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)
def get_gaudi_sw_version():
"""
Returns the driver version.
"""
# Enable console printing for `hl-smi` check
output = subprocess.run("hl-smi",
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env={"ENABLE_CONSOLE": "true"})
if output.returncode == 0 and output.stdout:
return output.stdout.split("\n")[2].replace(
" ", "").split(":")[1][:-1].split("-")[0]
return "0.0.0" # when hl-smi is not available
def get_vllm_version() -> str:
version = get_version(
write_to="vllm/_version.py", # TODO: move this to pyproject.toml
@ -385,6 +420,12 @@ def get_vllm_version() -> str:
if neuron_version != MAIN_CUDA_VERSION:
neuron_version_str = neuron_version.replace(".", "")[:3]
version += f"{sep}neuron{neuron_version_str}"
elif _is_hpu():
# Get the Intel Gaudi Software Suite version
gaudi_sw_version = str(get_gaudi_sw_version())
if gaudi_sw_version != MAIN_CUDA_VERSION:
gaudi_sw_version = gaudi_sw_version.replace(".", "")[:3]
version += f"{sep}gaudi{gaudi_sw_version}"
elif _is_openvino():
version += f"{sep}openvino"
elif _is_tpu():
@ -443,6 +484,8 @@ def get_requirements() -> List[str]:
requirements = _read_requirements("requirements-rocm.txt")
elif _is_neuron():
requirements = _read_requirements("requirements-neuron.txt")
elif _is_hpu():
requirements = _read_requirements("requirements-hpu.txt")
elif _is_openvino():
requirements = _read_requirements("requirements-openvino.txt")
elif _is_tpu():
@ -453,7 +496,7 @@ def get_requirements() -> List[str]:
requirements = _read_requirements("requirements-xpu.txt")
else:
raise ValueError(
"Unsupported platform, please use CUDA, ROCm, Neuron, "
"Unsupported platform, please use CUDA, ROCm, Neuron, HPU, "
"OpenVINO, or CPU.")
return requirements

View File

@ -12,7 +12,7 @@ from vllm.scalar_type import ScalarType
logger = init_logger(__name__)
if not current_platform.is_tpu():
if not current_platform.is_tpu() and not current_platform.is_hpu():
try:
import vllm._C
except ImportError as e:

View File

@ -0,0 +1,264 @@
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
HPUPagedAttentionMetadata)
from vllm.logger import init_logger
logger = init_logger(__name__)
class HPUAttentionBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["HPUAttentionImpl"]:
return HPUAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return HPUAttentionMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
HPUPagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
"""Metadata for HPUAttentionbackend."""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]
class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
max_seq_len: int = 4096,
) -> None:
super(AttentionImpl, self).__init__()
self.kv_cache_dtype = kv_cache_dtype
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.matmul_qk = Matmul()
self.softmax = Softmax()
self.matmul_av = Matmul()
self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache()
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
self.alibi_slopes = alibi_slopes
if alibi_slopes is not None:
alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=torch.bfloat16)
self.alibi_slopes = alibi_slopes_tensor
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true']
if self.prefill_usefusedsdpa:
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'
suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"HPUAttentionImpl")
batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
if attn_metadata.is_prompt:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key, key_cache, block_indices,
block_offsets)
value_cache = self.v_cache(value, value_cache, block_indices,
block_offsets)
if attn_metadata.is_prompt:
# Prompt run.
if not self.prefill_usefusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward!'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads,
attn_bias.dtype,
attn_bias.shape[-1])
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
else:
attn_bias = None
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)
out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_scales=attn_metadata.block_scales,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
seq_len: int,
) -> torch.Tensor:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
# Calculate a matrix where each element represents ith element- jth
# element.
bias = bias[None, :] - bias[:, None]
padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return bias

View File

@ -0,0 +1,103 @@
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from vllm_hpu_extension import cache_ops, ops
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
@dataclass
class HPUPagedAttentionMetadata:
"""Metadata for PagedAttention."""
block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor]
block_indices: Optional[torch.Tensor]
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
class HPUPagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
key_cache = kv_cache[0]
value_cache = kv_cache[1]
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor, kv_cache_dtype: str,
is_prompt: bool) -> None:
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, is_prompt)
@staticmethod
def forward_decode(**kwargs) -> torch.Tensor:
return ops.flat_pa(**kwargs)
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
) -> torch.Tensor:
raise NotImplementedError(
"forward_prefix is not implemented for HPUPagedAttention")
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)

View File

@ -23,6 +23,7 @@ class _Backend(enum.Enum):
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
@ -145,6 +146,10 @@ def get_attn_backend(
logger.info("Using Flashinfer backend.")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
elif backend == _Backend.HPU_ATTN:
logger.info("Using HPUAttention backend.")
from vllm.attention.backends.hpu_attn import HPUAttentionBackend
return HPUAttentionBackend
elif backend == _Backend.PALLAS:
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
@ -220,6 +225,9 @@ def which_attn_to_use(
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH
if current_platform.is_hpu():
return _Backend.HPU_ATTN
if envs.VLLM_USE_V1:
return _Backend.FLASH_ATTN_VLLM_V1

View File

@ -466,9 +466,10 @@ class ModelConfig:
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if device_config.device_type not in ("cuda", "tpu", "xpu"):
if device_config.device_type not in ("cuda", "tpu", "xpu", "hpu"):
logger.warning(
"Async output processing is only supported for CUDA, TPU, XPU. "
"Async output processing is only supported for CUDA, TPU, XPU "
"and HPU."
"Disabling it for other platforms.")
self.use_async_output_proc = False
return
@ -860,7 +861,6 @@ class LoadConfig:
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
@ -964,6 +964,13 @@ class ParallelConfig:
raise ValueError(
"TPU backend only supports Ray for distributed inference.")
if current_platform.is_hpu() and self.world_size > 1:
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "ray"
if self.distributed_executor_backend != "ray":
raise ValueError(
"HPU backend only supports Ray for distributed inference.")
if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
@ -1166,6 +1173,8 @@ class DeviceConfig:
self.device_type = "cuda"
elif current_platform.is_neuron():
self.device_type = "neuron"
elif current_platform.is_hpu():
self.device_type = "hpu"
elif current_platform.is_openvino():
self.device_type = "openvino"
elif current_platform.is_tpu():
@ -1745,6 +1754,13 @@ def _get_and_verify_dtype(
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
if current_platform.is_hpu() and config_dtype == torch.float16:
logger.info(
"For HPU, we cast models to bfloat16 instead of"
"using float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")

View File

@ -4,6 +4,7 @@ from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId,
DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
from vllm.platforms import current_platform
from vllm.utils import Device
@ -52,7 +53,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
- The block IDs are assigned contiguously, with GPU block IDs coming
before CPU block IDs.
"""
block_ids = list(range(num_gpu_blocks + num_cpu_blocks))
# For HPU, block id 0 is used only for padding
reserved_blocks = 1 if current_platform.is_hpu() else 0
block_ids = list(
range(reserved_blocks, num_gpu_blocks + num_cpu_blocks))
num_gpu_blocks -= reserved_blocks
gpu_block_ids = block_ids[:num_gpu_blocks]
cpu_block_ids = block_ids[num_gpu_blocks:]

View File

@ -0,0 +1,48 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.platforms import current_platform
if current_platform.is_hpu():
import habana_frameworks.torch as htorch # noqa: F401
class HpuCommunicator:
def __init__(self, group: ProcessGroup):
if not current_platform.is_hpu():
self.disabled = True
return
self.disabled = False
self.group = group
self.world_size = dist.get_world_size(self.group)
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
dist.all_reduce(x, group=self.group)
return x
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
if dim < 0:
# Convert negative dim to positive.
dim += x.dim()
input_size = x.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=x.dtype,
device=x.device)
# All-gather.
htorch.core.mark_step()
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor

View File

@ -177,6 +177,7 @@ class GroupCoordinator:
use_pynccl: bool,
use_custom_allreduce: bool,
use_tpu_communicator: bool,
use_hpu_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
):
@ -213,6 +214,7 @@ class GroupCoordinator:
self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
self.use_tpu_communicator = use_tpu_communicator
self.use_hpu_communicator = use_hpu_communicator
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
@ -241,6 +243,12 @@ class GroupCoordinator:
if use_tpu_communicator and self.world_size > 1:
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
from vllm.distributed.device_communicators.hpu_communicator import (
HpuCommunicator)
self.hpu_communicator: Optional[HpuCommunicator]
if use_hpu_communicator and self.world_size > 1:
self.hpu_communicator = HpuCommunicator(group=self.device_group)
from vllm.distributed.device_communicators.shm_broadcast import (
MessageQueue)
self.mq_broadcaster: Optional[MessageQueue] = None
@ -362,6 +370,10 @@ class GroupCoordinator:
# TPU handles Dynamo with its own logic.
return self.tpu_communicator.all_reduce(input_)
if self.hpu_communicator is not None and \
not self.hpu_communicator.disabled:
return self.hpu_communicator.all_reduce(input_)
if self.ca_comm is not None and \
not self.ca_comm.disabled and \
self.ca_comm.should_custom_ar(input_):
@ -400,6 +412,11 @@ class GroupCoordinator:
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_gather(input_, dim)
# For HPUs, use HPU communicator.
hpu_comm = self.hpu_communicator
if hpu_comm is not None and not hpu_comm.disabled:
return hpu_comm.all_gather(input_, dim)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
@ -879,6 +896,7 @@ def init_world_group(ranks: List[int], local_rank: int,
use_pynccl=False,
use_custom_allreduce=False,
use_tpu_communicator=False,
use_hpu_communicator=False,
group_name="world",
)
@ -900,6 +918,7 @@ def init_model_parallel_group(
use_pynccl=True,
use_custom_allreduce=use_custom_allreduce,
use_tpu_communicator=True,
use_hpu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
)

View File

@ -17,6 +17,7 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.utils import check_gguf_file
@ -37,6 +38,7 @@ DEVICE_OPTIONS = [
"openvino",
"tpu",
"xpu",
"hpu",
]
@ -110,7 +112,9 @@ class EngineArgs:
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
# NOTE(kzawora): default block size for Gaudi should be 128
# smaller sizes still work, but very inefficiently
block_size: int = 16 if not current_platform.is_hpu() else 128
enable_prefix_caching: bool = False
disable_sliding_window: bool = False
use_v2_block_manager: bool = True
@ -397,7 +401,7 @@ class EngineArgs:
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32],
choices=[8, 16, 32, 64, 128],
help='Token block size for contiguous chunks of '
'tokens. This is ignored on neuron devices and '
'set to max-model-len')
@ -1132,8 +1136,7 @@ class EngineArgs:
multi_step_stream_outputs=self.multi_step_stream_outputs,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray),
policy=self.scheduling_policy,
)
policy=self.scheduling_policy)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,

View File

@ -627,6 +627,14 @@ class AsyncLLMEngine(EngineClient):
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.device_config.device_type == "hpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync
executor_class = RayHPUExecutorAsync
else:
from vllm.executor.hpu_executor import HPUExecutorAsync
executor_class = HPUExecutorAsync
elif engine_config.device_config.device_type == "openvino":
assert distributed_executor_backend is None, (
"Distributed execution is not supported with "

View File

@ -528,6 +528,14 @@ class LLMEngine:
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.device_config.device_type == "hpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_hpu_executor import RayHPUExecutor
executor_class = RayHPUExecutor
else:
from vllm.executor.hpu_executor import HPUExecutor
executor_class = HPUExecutor
elif engine_config.device_config.device_type == "openvino":
from vllm.executor.openvino_executor import OpenVINOExecutor
executor_class = OpenVINOExecutor

View File

@ -0,0 +1,205 @@
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
import contextlib
import os
from typing import Any, Dict, List, Optional, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
class HPUExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
"""Initialize the worker and load the model."""
self._init_worker()
def _get_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
"""Return worker init args for a given rank."""
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=rank == 0,
)
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.hpu_worker",
worker_class_name="HPUWorker",
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
def _init_worker(self):
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# HPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
from vllm_hpu_extension.profiler import HabanaMemoryProfiler
with HabanaMemoryProfiler() as cache_init_m:
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
msg = f"init_cache_engine took {cache_init_m.get_summary_string()}"
logger.info(msg)
def finish_measurements(self):
self.driver_worker.finish_measurements()
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL - will log cpu fallbacks per engine step, always, even if there were none # noqa:E501
log_graph_compilation_all = os.environ.get(
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL', '0') != '0'
log_graph_compilation = os.environ.get(
'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION',
'0') != '0' or log_graph_compilation_all
log_cpu_fallbacks_all = os.environ.get(
'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0'
log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS',
'0') != '0' or log_cpu_fallbacks_all
if log_graph_compilation or log_cpu_fallbacks:
from habana_frameworks.torch.hpu.metrics import metric_localcontext
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
is_prompt = any([
seq_group_metadata.is_prompt
for seq_group_metadata in seq_group_metadata_list
])
max_context_len = max([
max([
len(v.prompt_token_ids) + len(v.output_token_ids)
for v in seq_group_metadata.seq_data.values()
]) for seq_group_metadata in seq_group_metadata_list
]) # whoa, that's some spicy stuff right here
max_num_blocks = (
(max_context_len - 1) // self.cache_config.block_size) + 1
input_stats = (f'is_prompt: {is_prompt}, '
f'num_seqs: {len(seq_group_metadata_list)}, '
f'max_context_len: {max_context_len}, '
f'max_num_blocks {max_num_blocks}')
gc_ctx = metric_localcontext(
"graph_compilation"
) if log_graph_compilation else contextlib.nullcontext()
cpu_fallback_ctx = metric_localcontext(
"cpu_fallback"
) if log_cpu_fallbacks else contextlib.nullcontext()
with gc_ctx as gc_local_metric, \
cpu_fallback_ctx as cpu_fallback_local_metric:
output = self.driver_worker.execute_model(execute_model_req)
if (log_graph_compilation and gc_local_metric.stats()[0][1] > 0
) or log_graph_compilation_all:
msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: "
f"{gc_local_metric.stats()}, {input_stats}")
logger.warning(msg)
if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1] >
0) or log_cpu_fallbacks_all:
msg = ("VLLM_HPU_STEP_CPU_FALLBACK: "
f"{cpu_fallback_local_metric.stats()}, {input_stats}")
logger.warning(msg)
return output
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def list_prompt_adapters(self) -> Set[int]:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return
def start_profile(self) -> None:
self.driver_worker.start_profile()
def stop_profile(self) -> None:
self.driver_worker.stop_profile()
def shutdown(self) -> None:
self.driver_worker.shutdown_inc()
class HPUExecutorAsync(HPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req, )
return output

View File

@ -0,0 +1,554 @@
import asyncio
import os
from collections import defaultdict
from itertools import islice, repeat
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Type)
import msgspec
import vllm.envs as envs
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.msgspec_utils import encode_hook
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, get_vllm_instance_id,
make_async)
from vllm.worker.worker_base import WorkerBase
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
class RayHPUExecutor(DistributedGPUExecutor):
uses_ray: bool = True
def _init_executor(self) -> None:
self.forward_dag: Optional["ray.dag.CompiledDAG"] = None
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Currently, this requires USE_RAY_SPMD_WORKER=True.
self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
# If the env var is set, then we do not distinguish between the
# "driver worker" vs other workers. Also, the rank 0 worker will
# be executed in a remote Ray worker. Currently this requires
# USE_RAY_COMPILED_DAG=True.
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if self.use_ray_compiled_dag:
assert self.use_ray_spmd_worker, (
"VLLM_USE_RAY_COMPILED_DAG=1 requires "
"VLLM_USE_RAY_SPMD_WORKER=1")
if self.use_ray_spmd_worker:
# TODO: Support SPMD worker for non-DAG Ray executor.
assert self.use_ray_compiled_dag, (
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
"VLLM_USE_RAY_COMPILED_DAG=1")
assert self.uses_ray
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
self.output_decoder = msgspec.msgpack.Decoder(
Optional[List[SamplerOutput]])
def shutdown(self) -> None:
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.forward_dag = None
def finish_measurements(self):
self._run_workers("finish_measurements")
def _get_worker_module_and_class(
self
) -> Tuple[str, str, Optional[Callable[[],
Type[WorkerBase]]]]: # noqa: F821
worker_class_fn = None
if self.scheduler_config.is_multi_step:
raise NotImplementedError(
"Multi-step execution is not implemented for HPU")
elif self.speculative_config:
raise NotImplementedError(
"Speculative decoding is not implemented for HPU")
else:
worker_module_name = "vllm.worker.hpu_worker"
worker_class_name = "HPUWorker"
return (worker_module_name, worker_class_name, worker_class_fn)
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
(worker_module_name, worker_class_name,
worker_class_fn) = self._get_worker_module_and_class()
return dict(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
worker_class_fn=worker_class_fn,
trust_remote_code=self.model_config.trust_remote_code,
)
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
# Otherwise, the ray workers are allocated with a full GPU.
num_gpus = 1
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []
# Used in ray compiled DAG: indexed first by PP rank,
# and then TP rank. In other words, the inner list is
# the TP group of workers for a PP rank.
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers.
driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("HPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=0,
resources={'HPU': num_gpus},
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
if self.use_ray_spmd_worker:
self.workers.append(worker)
else:
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
**worker_wrapper_kwargs)
else:
# Else, added to the list of workers.
self.workers.append(worker)
logger.debug("workers: %s", self.workers)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")
worker_ips = [
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
for worker in self.workers
]
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
def sort_by_driver_then_worker_ip(worker):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = ray.get(worker.get_node_ip.remote())
return (ip != driver_ip, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True)
node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
# string sorting is not sufficient.
# see https://github.com/vllm-project/vllm/issues/5590
gpu_ids = [int(x) for x in gpu_ids]
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
all_ips = set(worker_ips + [driver_ip])
n_ips = len(all_ips)
n_nodes = len(node_workers)
if n_nodes != n_ips:
raise RuntimeError(
f"Every node should have a unique IP address. Got {n_nodes}"
f" nodes with node ids {list(node_workers.keys())} and "
f"{n_ips} unique IP addresses {all_ips}. Please check your"
" network configuration. If you set `VLLM_HOST_IP` or "
"`HOST_IP` environment variable, make sure it is unique for"
" each node.")
VLLM_INSTANCE_ID = get_vllm_instance_id()
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [({
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for (node_id, _) in worker_node_and_gpu_ids]
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank),
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
if self.use_ray_spmd_worker:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(
self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size
) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: List[RayWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []
# Enforce rank order for correct rank to return final output.
for index, worker in enumerate(self.workers):
# The driver worker is rank 0 and not in self.workers.
rank = index + 1
if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
else:
self.non_driver_workers.append(worker)
def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
return self.driver_worker.execute_method("execute_model",
execute_model_req)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if not self.use_ray_spmd_worker:
return super().execute_model(execute_model_req)
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
serialized_data = self.input_encoder.encode(execute_model_req)
outputs = ray.get(self.forward_dag.execute(serialized_data))
output = self.output_decoder.decode(outputs[0])
return output
def _run_workers(
self,
method: str,
*args,
async_run_tensor_parallel_workers_only: bool = False,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers. Can be used in the following
ways:
Args:
- async_run_tensor_parallel_workers_only: If True the method will be
run only in the remote TP workers, not the driver worker.
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, (
"async_run_tensor_parallel_workers_only is not supported for "
"spmd mode.")
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
count = len(self.workers) if not \
async_run_tensor_parallel_workers_only \
else len(self.non_driver_workers)
# If using SPMD worker, all workers are the same, so we should execute
# the args on all workers. Otherwise, we skip the first worker's args
# because those args will go to the driver worker.
first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, first_worker_args_index, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, first_worker_args_index, None)
# Start the ray workers first.
ray_workers = self.workers
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
]
if async_run_tensor_parallel_workers_only:
# Just return futures
return ray_worker_outputs
driver_worker_output = []
# In SPMD mode, the driver worker is the same as any other worker,
# so we only explicitly execute on the driver worker if using a
# non-SPMD worker class.
if not self.use_ray_spmd_worker:
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = [
self.driver_worker.execute_method(method, *driver_args,
**driver_kwargs)
]
else:
assert self.driver_dummy_worker is not None
driver_worker_output = [
ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
]
# Get the results of the ray workers.
if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs)
return driver_worker_output + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks)
def _check_ray_adag_installation(self):
import pkg_resources
from packaging import version
required_version = version.parse("2.35")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
# TODO: update the constraint once we adapt to the backward
# incompatible API change from ray 2.36
if current_version != required_version:
raise ValueError(f"Ray version {required_version} is "
f"required, but found {current_version}")
import importlib.util
adag_spec = importlib.util.find_spec(
"ray.experimental.compiled_dag_ref")
if adag_spec is None:
raise ValueError("Ray accelerated DAG is not installed. "
"Run `pip install ray[adag]` to install it.")
def _compiled_ray_dag(self, enable_asyncio: bool):
assert self.parallel_config.use_ray
self._check_ray_adag_installation()
from ray.dag import InputNode, MultiOutputNode
from ray.experimental.channel.torch_tensor_type import TorchTensorType
with InputNode() as input_data:
# Example DAG: PP=2, TP=4
# (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501
# -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501
# -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501
# -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501
# All workers in the first TP group will take in the
# ExecuteModelRequest as input.
outputs = [input_data for _ in self.pp_tp_workers[0]]
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
# Each PP worker takes in the output of the previous PP worker,
# and the TP group executes in SPMD fashion.
outputs = [
worker.execute_model_spmd.
bind( # type: ignore[attr-defined]
outputs[i]) for i, worker in enumerate(tp_group)
]
last_pp_rank = len(self.pp_tp_workers) - 1
if pp_rank < last_pp_rank:
# Specify how intermediate tensors should be passed
# between pp stages, no need to specify for the last
# pp stage.
transport = "auto"
outputs = [
output.with_type_hint(
TorchTensorType(transport=transport))
for output in outputs
]
forward_dag = MultiOutputNode(outputs)
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def __del__(self):
self.shutdown()
class RayHPUExecutorAsync(RayHPUExecutor, DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pp_locks: Optional[List[asyncio.Lock]] = None
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if not self.use_ray_compiled_dag:
self.driver_exec_method = make_async(
self.driver_worker.execute_method)
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if not self.use_ray_spmd_worker:
return await super().execute_model_async(execute_model_req)
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
serialized_data = self.input_encoder.encode(execute_model_req)
dag_future = await self.forward_dag.execute_async(serialized_data)
outputs = await dag_future
return self.output_decoder.decode(outputs[0])
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
if not self.tp_driver_workers:
return await self.driver_exec_method("execute_model",
execute_model_req)
if self.pp_locks is None:
# This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time
# We create the locks here to avoid creating them in the constructor
# which uses a different asyncio loop.
self.pp_locks = [
asyncio.Lock()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
tasks = [
asyncio.create_task(
_run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
"execute_model", execute_model_req))
]
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
start=1):
tasks.append(
asyncio.create_task(
_run_task_with_lock(driver_worker.execute_method.remote,
self.pp_locks[pp_rank],
"execute_model", execute_model_req)))
results = await asyncio.gather(*tasks)
# Only the last PP stage has the final results.
return results[-1]
async def _start_worker_execution_loop(self):
assert not self.use_ray_spmd_worker, (
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
coros = [
worker.execute_method.remote("start_worker_execution_loop")
for worker in self.non_driver_workers
]
return await asyncio.gather(*coros)
def __del__(self):
self.shutdown()

View File

@ -249,7 +249,11 @@ def initialize_ray_cluster(
# Placement group is already set.
return
device_str = "GPU" if not current_platform.is_tpu() else "TPU"
device_str = "GPU"
if current_platform.is_tpu():
device_str = "TPU"
elif current_platform.is_hpu():
device_str = 'HPU'
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:

View File

@ -55,10 +55,9 @@ class CustomOp(nn.Module):
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs)
def forward_gaudi(self, *args, **kwargs):
def forward_hpu(self, *args, **kwargs):
# By default, we assume that Gaudi ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs)
def dispatch_forward(self):
@ -76,6 +75,8 @@ class CustomOp(nn.Module):
return self.forward_hip
elif current_platform.is_cpu():
return self.forward_cpu
elif current_platform.is_hpu():
return self.forward_hpu
elif current_platform.is_tpu():
return self.forward_tpu
elif current_platform.is_xpu():

View File

@ -92,6 +92,25 @@ class RMSNorm(CustomOp):
)
return out
def forward_hpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm_hpu_extension.ops import HPUFusedRMSNorm
if HPUFusedRMSNorm is None:
return self.forward_native(x, residual)
if residual is not None:
orig_shape = x.shape
residual += x.view(residual.shape)
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
x = HPUFusedRMSNorm.apply(residual, self.weight,
self.variance_epsilon)
return x.view(orig_shape), residual
x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
return x
def forward_xpu(
self,
x: torch.Tensor,

View File

@ -111,8 +111,14 @@ def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)
# NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios
# (warmup, profile_run) we might not have selected_token_indices,
# so we skip pruning.
if sampling_metadata.selected_token_indices is not None:
return hidden_states.index_select(
0, sampling_metadata.selected_token_indices)
else:
return hidden_states
def _apply_logits_processors(

View File

@ -194,6 +194,61 @@ class RotaryEmbedding(CustomOp):
self.cos_sin_cache, self.is_neox_style)
return query, key
def forward_hpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
positions = positions.flatten()
if offsets is not None:
positions = positions + offsets
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions).view(
num_tokens, 1, -1)
cos, sin = cos_sin.chunk(2, dim=-1)
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# to query hidden dimension, so the original tensors need to be
# expanded
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
# and expansion of cos/sin tensors via concatenation
# GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
# and expansion of cos/sin tensors via repeat_interleave
rope_mode: RotaryPosEmbeddingMode
if self.is_neox_style:
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
cos = torch.cat((cos, cos), dim=-1)
sin = torch.cat((sin, sin), dim=-1)
else:
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
sin = torch.repeat_interleave(sin,
2,
dim=-1,
output_size=cos_sin.shape[-1])
cos = torch.repeat_interleave(cos,
2,
dim=-1,
output_size=cos_sin.shape[-1])
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0,
rope_mode)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"

View File

@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
DEFAULT_VOCAB_PADDING_SIZE = 64
@ -382,8 +383,20 @@ class VocabParallelEmbedding(torch.nn.Module):
# Copy the data.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0]:].data.fill_(0)
if current_platform.is_hpu():
# FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here,
# so we're using a workaround. Remove this when fixed in
# HPU PT bridge.
padded_weight = torch.cat([
loaded_weight,
torch.zeros(param.shape[0] - loaded_weight.shape[0],
*loaded_weight.shape[1:])
])
param.data.copy_(padded_weight)
else:
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0]:].data.fill_(0)
def forward(self, input_):
if self.tp_size > 1:

View File

@ -284,7 +284,8 @@ def _prepare_seq_groups(
else:
# Decode
prompt_logprob_len = 0
query_len = query_lens[i] if query_lens is not None else 1
query_len = query_lens[i] if query_lens is not None and len(
query_lens) > 0 else 1
sample_len = len(seq_ids) * query_len if do_sample else 0
if sampling_params.seed is not None and generators is not None:

View File

@ -42,6 +42,13 @@ try:
except Exception:
pass
is_hpu = False
try:
from importlib import util
is_hpu = util.find_spec('habana_frameworks') is not None
except Exception:
pass
is_xpu = False
try:
@ -86,6 +93,9 @@ elif is_cuda:
elif is_rocm:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_hpu:
from .hpu import HpuPlatform
current_platform = HpuPlatform()
elif is_xpu:
from .xpu import XPUPlatform
current_platform = XPUPlatform()

11
vllm/platforms/hpu.py Normal file
View File

@ -0,0 +1,11 @@
import torch
from .interface import Platform, PlatformEnum
class HpuPlatform(Platform):
_enum = PlatformEnum.HPU
@staticmethod
def inference_mode():
return torch.no_grad()

View File

@ -10,6 +10,7 @@ class PlatformEnum(enum.Enum):
CUDA = enum.auto()
ROCM = enum.auto()
TPU = enum.auto()
HPU = enum.auto()
XPU = enum.auto()
CPU = enum.auto()
NEURON = enum.auto()
@ -46,6 +47,9 @@ class Platform:
def is_tpu(self) -> bool:
return self._enum == PlatformEnum.TPU
def is_hpu(self) -> bool:
return self._enum == PlatformEnum.HPU
def is_xpu(self) -> bool:
return self._enum == PlatformEnum.XPU

View File

@ -728,6 +728,9 @@ def is_pin_memory_available() -> bool:
elif current_platform.is_neuron():
print_warning_once("Pin memory is not supported on Neuron.")
return False
elif current_platform.is_hpu():
print_warning_once("Pin memory is not supported on HPU.")
return False
elif current_platform.is_cpu() or current_platform.is_openvino():
return False
return True

File diff suppressed because it is too large Load Diff

410
vllm/worker/hpu_worker.py Normal file
View File

@ -0,0 +1,410 @@
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
import gc
import os
from typing import List, Optional, Set, Tuple, Type
import habana_frameworks.torch as htorch # noqa:F401
import torch
import torch.distributed
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.hpu_model_runner import HPUModelRunner
from vllm.worker.model_runner_base import ModelRunnerBase
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class HPUWorker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a HPU.
Each worker is associated with a single HPU. The worker is responsible for
maintaining the KV cache and executing the model on the HPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[ModelRunnerBase]] = None,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner: HPUModelRunner = HPUModelRunner(
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[HPUCacheEngine]
# Initialize gpu_cache as embedding models don't initialize kv_caches
self.hpu_cache: Optional[List[List[torch.tensor]]] = None
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.HPU,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
def start_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.start()
def stop_profile(self):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
def _set_env_vars(self):
local_rank = self.local_rank
if self.parallel_config.world_size == 1:
local_rank = -1
import os
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["ID"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(self.parallel_config.world_size)
os.environ["RANK"] = str(self.rank)
def init_device(self) -> None:
if self.device_config.device.type == "hpu":
self.device = torch.device("hpu")
torch.hpu.set_device(self.device)
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
if self.model_config.quantization == 'inc':
self._set_env_vars()
init_worker_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with HabanaMemoryProfiler() as m:
self.model_runner.profile_run()
torch.hpu.synchronize()
msg = ("Model profiling run "
f"took {m.get_summary_string()}")
logger.info(msg)
# At this point we should've allocated the maximum workspace for all
# recipes we will use the extra memory for graphs/blocks
free_hpu_memory = torch.hpu.mem_get_info()[0]
cache_block_size = self.get_cache_block_size_bytes()
graph_reserved_mem = (float(
os.environ.get('VLLM_GRAPH_RESERVED_MEM', '0.1'))
if not self.model_config.enforce_eager else 0)
graph_headroom = 1 - graph_reserved_mem
available_hpu_memory = free_hpu_memory * \
self.cache_config.gpu_memory_utilization
hpu_memory_margin = free_hpu_memory * (
1 - self.cache_config.gpu_memory_utilization)
self.model_runner.mem_margin = hpu_memory_margin
cache_size_bytes = available_hpu_memory * graph_headroom
graph_headroom_bytes = available_hpu_memory * (1 - graph_headroom)
msg = (
f"Free device memory: {format_bytes(free_hpu_memory)}, "
f"{format_bytes(available_hpu_memory)} usable "
f"(gpu_memory_utilization={self.cache_config.gpu_memory_utilization}),"
f" {format_bytes(graph_headroom_bytes)} reserved for HPUGraphs "
f"(VLLM_GRAPH_RESERVED_MEM={graph_reserved_mem}), "
f"{format_bytes(cache_size_bytes)} reserved for KV cache")
logger.info(msg)
num_hpu_blocks = int(cache_size_bytes // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_hpu_blocks = max(num_hpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()
return num_hpu_blocks, num_cpu_blocks
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks.
This also warms up the model, which may record CUDA graphs.
"""
raise_if_cache_size_invalid(num_gpu_blocks,
self.cache_config.block_size,
self.model_config.max_model_len)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
with HabanaMemoryProfiler() as m:
self._init_cache_engine()
torch.hpu.synchronize()
msg = ("Initializing cache engine "
f"took {m.get_summary_string()}")
logger.info(msg)
self._warm_up_model()
def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [
HPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.hpu_cache = [
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
def _warm_up_model(self) -> None:
# NOTE(kzawora): We should use virtual engine index here
# for pipeline parallelism. Using 0 for now.
assert self.hpu_cache is not None
self.model_runner.warmup_model(self.hpu_cache[0])
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def finish_measurements(self):
self.model_runner.finish_measurements()
@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.hpu_cache
@torch.inference_mode()
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync.
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
device="cpu",
dtype=torch.int64).view(-1, 2)
blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
device="cpu",
dtype=torch.int64).view(-1, 2)
# `blocks_to_copy` is a gpu tensor. The src and tgt of
# blocks to copy are in the same device, and `blocks_to_copy`
# can be used directly within cuda kernels.
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device,
dtype=torch.int64).view(-1, 2)
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
)
@torch.inference_mode()
def execute_worker(self, worker_input: WorkerInput) -> None:
virtual_engine = worker_input.virtual_engine
# Issue cache operations.
if (worker_input.blocks_to_swap_in is not None
and worker_input.blocks_to_swap_in.numel() > 0):
self.cache_engine[virtual_engine].swap_in(
worker_input.blocks_to_swap_in)
if (worker_input.blocks_to_swap_out is not None
and worker_input.blocks_to_swap_out.numel() > 0):
self.cache_engine[virtual_engine].swap_out(
worker_input.blocks_to_swap_out)
if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def list_prompt_adapters(self) -> Set[int]:
raise NotImplementedError(
"Prompt Adapter is not implemented for HPU backend.")
def shutdown_inc(self):
self.model_runner.shutdown_inc()
@property
def max_model_len(self) -> int:
return self.model_config.max_model_len
@property
def vocab_size(self) -> int:
return self.model_runner.vocab_size
def get_cache_block_size_bytes(self) -> int:
"""Get the size of the KV cache block size in bytes.
"""
return HPUCacheEngine.get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
init_distributed_environment(parallel_config.world_size,
rank,
distributed_init_method,
local_rank,
backend='hccl')
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
torch.distributed.init_process_group(
backend="hccl",
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
# A small all_reduce for warmup & checking conformance.
dummy_tensor_hpu = torch.ones(1).to('hpu')
torch.distributed.all_reduce(dummy_tensor_hpu)
assert dummy_tensor_hpu.item() == parallel_config.world_size
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
max_model_len) -> None:
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = block_size * num_gpu_blocks
if max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")
class HPUCacheEngine(CacheEngine):
def _allocate_kv_cache(
self,
num_blocks: int,
device: str,
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Allocates KV cache on the specified device."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
for _ in range(self.num_attention_layers):
key_cache = torch.zeros(kv_cache_shape,
dtype=self.dtype,
device=device)
value_cache = torch.zeros(kv_cache_shape,
dtype=self.dtype,
device=device)
kv_layer = (key_cache, value_cache)
kv_cache.append(kv_layer)
return kv_cache