mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
[TPU] Use mark_dynamic to reduce compilation time (#7340)
This commit is contained in:
parent
4c5d8e8ea9
commit
90bab18f24
@ -1,4 +1,4 @@
|
|||||||
ARG NIGHTLY_DATE="20240726"
|
ARG NIGHTLY_DATE="20240808"
|
||||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||||
|
|
||||||
FROM $BASE_IMAGE
|
FROM $BASE_IMAGE
|
||||||
|
|||||||
@ -56,7 +56,7 @@ First, install the dependencies:
|
|||||||
$ pip uninstall torch torch-xla -y
|
$ pip uninstall torch torch-xla -y
|
||||||
|
|
||||||
$ # Install PyTorch and PyTorch XLA.
|
$ # Install PyTorch and PyTorch XLA.
|
||||||
$ export DATE="+20240726"
|
$ export DATE="+20240808"
|
||||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
@ -65,7 +65,7 @@ First, install the dependencies:
|
|||||||
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
|
|
||||||
$ # Install other build dependencies.
|
$ # Install other build dependencies.
|
||||||
$ pip install packaging aiohttp
|
$ pip install -r requirements-tpu.txt
|
||||||
|
|
||||||
|
|
||||||
Next, build vLLM from source. This will only take a few seconds:
|
Next, build vLLM from source. This will only take a few seconds:
|
||||||
|
|||||||
@ -147,19 +147,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
)
|
)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
|
self.model = CompiledModelWrapper(model)
|
||||||
model = ModelWrapper(model)
|
|
||||||
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
|
||||||
# XLA compilation. Setting dynamic=True can reduce the torch.compile
|
|
||||||
# overhead by reusing the FX graph for different shapes.
|
|
||||||
# However, the XLA graph will still require static shapes and needs to
|
|
||||||
# be re-compiled for every different shapes. This overhead is inevitable
|
|
||||||
# in the first run, but can be skipped afterwards as we cache the XLA
|
|
||||||
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
|
||||||
self.model = torch.compile(model,
|
|
||||||
backend="openxla",
|
|
||||||
fullgraph=True,
|
|
||||||
dynamic=True)
|
|
||||||
|
|
||||||
def _dummy_run(
|
def _dummy_run(
|
||||||
self,
|
self,
|
||||||
@ -697,6 +685,52 @@ class ModelWrapper(nn.Module):
|
|||||||
return next_token_ids
|
return next_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
class CompiledModelWrapper:
|
||||||
|
|
||||||
|
def __init__(self, model: nn.Module):
|
||||||
|
model = ModelWrapper(model)
|
||||||
|
self.model = torch.compile(model,
|
||||||
|
backend="openxla",
|
||||||
|
fullgraph=True,
|
||||||
|
dynamic=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
token_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
input_lens: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
p: torch.Tensor,
|
||||||
|
num_samples: int,
|
||||||
|
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
||||||
|
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
|
||||||
|
# overhead by reusing the FX graph for different shapes.
|
||||||
|
# However, the XLA graph will still require static shapes and needs to
|
||||||
|
# be re-compiled for every different shapes. This overhead is inevitable
|
||||||
|
# in the first run, but can be skipped afterwards as we cache the XLA
|
||||||
|
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
||||||
|
if attn_metadata.num_prefills > 0:
|
||||||
|
# Prefll
|
||||||
|
torch._dynamo.mark_dynamic(token_ids, 1)
|
||||||
|
torch._dynamo.mark_dynamic(position_ids, 1)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
|
||||||
|
else:
|
||||||
|
# Decode
|
||||||
|
torch._dynamo.mark_dynamic(token_ids, 0)
|
||||||
|
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||||
|
torch._dynamo.mark_dynamic(input_lens, 0)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||||
|
torch._dynamo.mark_dynamic(t, 0)
|
||||||
|
torch._dynamo.mark_dynamic(p, 0)
|
||||||
|
return self.model(token_ids, position_ids, attn_metadata, input_lens,
|
||||||
|
t, p, num_samples, kv_caches)
|
||||||
|
|
||||||
|
|
||||||
def _get_padded_prefill_len(x: int) -> int:
|
def _get_padded_prefill_len(x: int) -> int:
|
||||||
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||||
# length to be a multiple of 16. We pad the prompt length to the nearest
|
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user