mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 17:55:38 +08:00
[TPU] Reduce compilation time & Upgrade PyTorch XLA version (#6856)
This commit is contained in:
parent
f954d0715c
commit
fad5576c58
@ -1,4 +1,4 @@
|
|||||||
ARG NIGHTLY_DATE="20240713"
|
ARG NIGHTLY_DATE="20240726"
|
||||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||||
|
|
||||||
FROM $BASE_IMAGE
|
FROM $BASE_IMAGE
|
||||||
|
|||||||
@ -56,7 +56,7 @@ First, install the dependencies:
|
|||||||
$ pip uninstall torch torch-xla -y
|
$ pip uninstall torch torch-xla -y
|
||||||
|
|
||||||
$ # Install PyTorch and PyTorch XLA.
|
$ # Install PyTorch and PyTorch XLA.
|
||||||
$ export DATE="+20240713"
|
$ export DATE="+20240726"
|
||||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
@ -75,6 +75,13 @@ Next, build vLLM from source. This will only take a few seconds:
|
|||||||
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop
|
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop
|
||||||
|
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape.
|
||||||
|
The compilation time may take 20~30 minutes in the first run.
|
||||||
|
However, the compilation time reduces to ~5 minutes afterwards because the XLA graphs are cached in the disk (in :code:`VLLM_XLA_CACHE_PATH` or :code:`~/.cache/vllm/xla_cache` by default).
|
||||||
|
|
||||||
|
|
||||||
.. tip::
|
.. tip::
|
||||||
|
|
||||||
If you encounter the following error:
|
If you encounter the following error:
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
||||||
import torch_xla.experimental.dynamo_set_buffer_donor
|
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from vllm.platforms import current_platform
|
|||||||
|
|
||||||
if current_platform.is_tpu():
|
if current_platform.is_tpu():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
import torch_xla.runtime as xr
|
||||||
from torch_xla._internal import pjrt
|
from torch_xla._internal import pjrt
|
||||||
|
|
||||||
|
|
||||||
@ -20,7 +21,7 @@ class TpuCommunicator:
|
|||||||
local_rank = dist.get_rank(group)
|
local_rank = dist.get_rank(group)
|
||||||
world_size = dist.get_world_size(group)
|
world_size = dist.get_world_size(group)
|
||||||
pjrt.initialize_multiprocess(local_rank, world_size)
|
pjrt.initialize_multiprocess(local_rank, world_size)
|
||||||
xm._init_world_size_ordinal()
|
xr._init_world_size_ordinal()
|
||||||
|
|
||||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return xm.all_reduce(xm.REDUCE_SUM, x)
|
return xm.all_reduce(xm.REDUCE_SUM, x)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
import torch_xla.runtime as xr
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||||
@ -127,7 +128,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
# determine the order of concatenating the output tensors.
|
# determine the order of concatenating the output tensors.
|
||||||
# As a workaround, we use the xm's rank assignment only when loading
|
# As a workaround, we use the xm's rank assignment only when loading
|
||||||
# the embedding weights.
|
# the embedding weights.
|
||||||
xm_tp_rank = xm.get_ordinal()
|
xm_tp_rank = xr.global_ordinal()
|
||||||
with patch(
|
with patch(
|
||||||
"vllm.model_executor.layers.vocab_parallel_embedding."
|
"vllm.model_executor.layers.vocab_parallel_embedding."
|
||||||
"get_tensor_model_parallel_rank",
|
"get_tensor_model_parallel_rank",
|
||||||
@ -146,7 +147,17 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
|
|
||||||
model = ModelWrapper(model)
|
model = ModelWrapper(model)
|
||||||
self.model = torch.compile(model, backend="openxla", fullgraph=True)
|
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
||||||
|
# XLA compilation. Setting dynamic=True can reduce the torch.compile
|
||||||
|
# overhead by reusing the FX graph for different shapes.
|
||||||
|
# However, the XLA graph will still require static shapes and needs to
|
||||||
|
# be re-compiled for every different shapes. This overhead is inevitable
|
||||||
|
# in the first run, but can be skipped afterwards as we cache the XLA
|
||||||
|
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
||||||
|
self.model = torch.compile(model,
|
||||||
|
backend="openxla",
|
||||||
|
fullgraph=True,
|
||||||
|
dynamic=True)
|
||||||
|
|
||||||
def _dummy_run(
|
def _dummy_run(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401
|
|
||||||
import torch_xla.runtime as xr
|
import torch_xla.runtime as xr
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user