From 90bab18f24dce1967282fbb1ebcd2c9aecc67d30 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 10 Aug 2024 18:12:22 -0700 Subject: [PATCH] [TPU] Use mark_dynamic to reduce compilation time (#7340) --- Dockerfile.tpu | 2 +- .../getting_started/tpu-installation.rst | 4 +- vllm/worker/tpu_model_runner.py | 60 +++++++++++++++---- 3 files changed, 50 insertions(+), 16 deletions(-) diff --git a/Dockerfile.tpu b/Dockerfile.tpu index ef0a53a7afe1..1cf43247e978 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -1,4 +1,4 @@ -ARG NIGHTLY_DATE="20240726" +ARG NIGHTLY_DATE="20240808" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" FROM $BASE_IMAGE diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index 2e6c522422c2..57b7917b5853 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -56,7 +56,7 @@ First, install the dependencies: $ pip uninstall torch torch-xla -y $ # Install PyTorch and PyTorch XLA. - $ export DATE="+20240726" + $ export DATE="+20240808" $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl @@ -65,7 +65,7 @@ First, install the dependencies: $ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html $ # Install other build dependencies. - $ pip install packaging aiohttp + $ pip install -r requirements-tpu.txt Next, build vLLM from source. This will only take a few seconds: diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index cf4cc5535ba5..685ae0fd7cc8 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -147,19 +147,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ) model = model.eval() xm.wait_device_ops() - - model = ModelWrapper(model) - # NOTE(woosuk): There are two stages of compilation: torch.compile and - # XLA compilation. Setting dynamic=True can reduce the torch.compile - # overhead by reusing the FX graph for different shapes. - # However, the XLA graph will still require static shapes and needs to - # be re-compiled for every different shapes. This overhead is inevitable - # in the first run, but can be skipped afterwards as we cache the XLA - # graphs in the disk (VLLM_XLA_CACHE_PATH). - self.model = torch.compile(model, - backend="openxla", - fullgraph=True, - dynamic=True) + self.model = CompiledModelWrapper(model) def _dummy_run( self, @@ -697,6 +685,52 @@ class ModelWrapper(nn.Module): return next_token_ids +class CompiledModelWrapper: + + def __init__(self, model: nn.Module): + model = ModelWrapper(model) + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=False) + + def __call__( + self, + token_ids: torch.Tensor, + position_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + input_lens: torch.Tensor, + t: torch.Tensor, + p: torch.Tensor, + num_samples: int, + kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], + ) -> torch.Tensor: + # NOTE(woosuk): There are two stages of compilation: torch.compile and + # XLA compilation. Using `mark_dynamic` can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + if attn_metadata.num_prefills > 0: + # Prefll + torch._dynamo.mark_dynamic(token_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + else: + # Decode + torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(input_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + torch._dynamo.mark_dynamic(t, 0) + torch._dynamo.mark_dynamic(p, 0) + return self.model(token_ids, position_ids, attn_metadata, input_lens, + t, p, num_samples, kv_caches) + + def _get_padded_prefill_len(x: int) -> int: # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence # length to be a multiple of 16. We pad the prompt length to the nearest