mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 00:55:40 +08:00
[TPU] Implement prefix caching for TPUs (#10307)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
c68f7ede6a
commit
2f77b6cfec
@ -16,8 +16,8 @@ ray[default]
|
|||||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
torch==2.6.0.dev20241028+cpu
|
torch==2.6.0.dev20241114+cpu
|
||||||
torchvision==0.20.0.dev20241028+cpu
|
torchvision==0.20.0.dev20241114+cpu
|
||||||
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl
|
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241114-cp310-cp310-linux_x86_64.whl
|
||||||
jaxlib==0.4.32.dev20240829
|
jaxlib==0.4.32.dev20240829
|
||||||
jax==0.4.32.dev20240829
|
jax==0.4.32.dev20240829
|
||||||
|
|||||||
@ -65,6 +65,7 @@ class PallasMetadata(AttentionMetadata):
|
|||||||
# or all decoding.
|
# or all decoding.
|
||||||
block_tables: Optional[torch.Tensor] = None
|
block_tables: Optional[torch.Tensor] = None
|
||||||
context_lens: Optional[torch.Tensor] = None
|
context_lens: Optional[torch.Tensor] = None
|
||||||
|
effective_query_lens: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prefill_metadata(self) -> Optional["PallasMetadata"]:
|
def prefill_metadata(self) -> Optional["PallasMetadata"]:
|
||||||
@ -72,8 +73,6 @@ class PallasMetadata(AttentionMetadata):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
assert self.num_decode_tokens == 0
|
assert self.num_decode_tokens == 0
|
||||||
assert self.block_tables is None
|
|
||||||
assert self.context_lens is None
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -186,20 +185,24 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
|
|
||||||
query = query * self.scale
|
query = query * self.scale
|
||||||
if attn_metadata.num_prefills > 0:
|
if attn_metadata.num_prefills > 0:
|
||||||
|
if attn_metadata.block_tables is None:
|
||||||
|
# Prefill without paged KV cache.
|
||||||
assert seq_len % 16 == 0, (
|
assert seq_len % 16 == 0, (
|
||||||
"Pallas FlashAttention kernel requires seq_len to be a "
|
"Pallas FlashAttention kernel requires seq_len to be a "
|
||||||
f"multiple of 16 but got {seq_len}")
|
f"multiple of 16 but got {seq_len}")
|
||||||
|
|
||||||
# Handle GQA/MQA.
|
# Handle GQA/MQA.
|
||||||
if self.num_kv_heads != self.num_heads:
|
if self.num_kv_heads != self.num_heads:
|
||||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=-2)
|
key = key.repeat_interleave(self.num_queries_per_kv,
|
||||||
|
dim=-2)
|
||||||
key = key.view(batch_size, seq_len, self.num_heads,
|
key = key.view(batch_size, seq_len, self.num_heads,
|
||||||
self.head_size)
|
self.head_size)
|
||||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||||
dim=-2)
|
dim=-2)
|
||||||
value = value.view(batch_size, seq_len, self.num_heads,
|
value = value.view(batch_size, seq_len, self.num_heads,
|
||||||
self.head_size)
|
self.head_size)
|
||||||
# FlashAttention requires [batch_size, num_heads, seq_len, d_model]
|
# FlashAttention kernel requires the input shape to be
|
||||||
|
# [batch_size, num_heads, seq_len, d_model]
|
||||||
# while the input is [batch_size, seq_len, num_heads, d_model].
|
# while the input is [batch_size, seq_len, num_heads, d_model].
|
||||||
# Permute the input to match the required format.
|
# Permute the input to match the required format.
|
||||||
output = torch.ops.xla.flash_attention(
|
output = torch.ops.xla.flash_attention(
|
||||||
@ -209,6 +212,23 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
output = output.permute(0, 2, 1, 3)
|
output = output.permute(0, 2, 1, 3)
|
||||||
|
else:
|
||||||
|
# Prefill with paged KV cache.
|
||||||
|
# TODO(woosuk): Tune the below knobs.
|
||||||
|
num_kv_pages_per_compute_block = 16
|
||||||
|
num_queries_per_compute_block = 16
|
||||||
|
assert seq_len % num_queries_per_compute_block == 0
|
||||||
|
output = torch.ops.xla.multi_queries_paged_attention(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.context_lens,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
attn_metadata.effective_query_lens,
|
||||||
|
num_kv_pages_per_compute_block,
|
||||||
|
num_queries_per_compute_block,
|
||||||
|
use_kernel=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Decoding run.
|
# Decoding run.
|
||||||
assert kv_cache[0].numel() > 0
|
assert kv_cache[0].numel() > 0
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import enum
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||||
@ -11,7 +12,6 @@ import torch_xla.core.xla_model as xm
|
|||||||
import torch_xla.runtime as xr
|
import torch_xla.runtime as xr
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
@ -39,6 +39,15 @@ _ENABLE_TOP_P = False
|
|||||||
_MAX_NUM_SAMPLES = 128
|
_MAX_NUM_SAMPLES = 128
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionMode(enum.Enum):
|
||||||
|
PREFILL = enum.auto()
|
||||||
|
DECODE = enum.auto()
|
||||||
|
PREFIX_PREFILL = enum.auto()
|
||||||
|
|
||||||
|
def is_prefill(self) -> bool:
|
||||||
|
return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ModelInputForTPU(ModelRunnerInputBase):
|
class ModelInputForTPU(ModelRunnerInputBase):
|
||||||
token_ids: torch.Tensor
|
token_ids: torch.Tensor
|
||||||
@ -140,16 +149,21 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
model = get_model(vllm_config=self.vllm_config)
|
model = get_model(vllm_config=self.vllm_config)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
self.model = ModelWrapper(model, self.vllm_config)
|
model = ModelWrapper(model)
|
||||||
|
self.model = torch.compile(model,
|
||||||
|
backend="openxla",
|
||||||
|
fullgraph=True,
|
||||||
|
dynamic=False)
|
||||||
|
|
||||||
def _dummy_run(
|
def _dummy_run(
|
||||||
self,
|
self,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
is_prompt: bool,
|
exec_mode: ExecutionMode,
|
||||||
) -> None:
|
) -> None:
|
||||||
if is_prompt:
|
exec_mode = ExecutionMode(exec_mode)
|
||||||
|
if exec_mode.is_prefill():
|
||||||
seq_len = (seq_len + 15) // 16 * 16
|
seq_len = (seq_len + 15) // 16 * 16
|
||||||
token_ids = torch.zeros((batch_size, seq_len),
|
token_ids = torch.zeros((batch_size, seq_len),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@ -160,6 +174,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
slot_mapping = torch.zeros((batch_size, seq_len),
|
slot_mapping = torch.zeros((batch_size, seq_len),
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
input_lens = torch.ones((batch_size, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
if exec_mode == ExecutionMode.PREFILL:
|
||||||
attn_metadata = self.attn_backend.make_metadata(
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
num_prefills=batch_size,
|
num_prefills=batch_size,
|
||||||
num_prefill_tokens=batch_size * seq_len,
|
num_prefill_tokens=batch_size * seq_len,
|
||||||
@ -168,10 +186,26 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
multi_modal_placeholder_index_maps=None,
|
multi_modal_placeholder_index_maps=None,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
context_lens=None,
|
context_lens=None,
|
||||||
|
effective_query_lens=None,
|
||||||
)
|
)
|
||||||
input_lens = torch.ones((batch_size, ),
|
else:
|
||||||
|
context_lens = torch.ones((batch_size, ),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
block_tables = torch.tensor(self.block_tables[:batch_size],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
effective_query_lens = torch.ones_like(context_lens)
|
||||||
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
|
num_prefills=batch_size,
|
||||||
|
num_prefill_tokens=batch_size * seq_len,
|
||||||
|
num_decode_tokens=0,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
multi_modal_placeholder_index_maps=None,
|
||||||
|
block_tables=block_tables,
|
||||||
|
context_lens=context_lens,
|
||||||
|
effective_query_lens=effective_query_lens,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert seq_len == 1
|
assert seq_len == 1
|
||||||
token_ids = torch.zeros((batch_size, seq_len),
|
token_ids = torch.zeros((batch_size, seq_len),
|
||||||
@ -204,7 +238,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
)
|
)
|
||||||
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||||
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||||
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1
|
||||||
|
|
||||||
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
||||||
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
|
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
|
||||||
@ -213,7 +247,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
# be re-compiled for every different shapes. This overhead is inevitable
|
# be re-compiled for every different shapes. This overhead is inevitable
|
||||||
# in the first run, but can be skipped afterwards as we cache the XLA
|
# in the first run, but can be skipped afterwards as we cache the XLA
|
||||||
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
||||||
if is_prompt:
|
if exec_mode.is_prefill():
|
||||||
# Prefll
|
# Prefll
|
||||||
torch._dynamo.mark_dynamic(token_ids, 1)
|
torch._dynamo.mark_dynamic(token_ids, 1)
|
||||||
torch._dynamo.mark_dynamic(position_ids, 1)
|
torch._dynamo.mark_dynamic(position_ids, 1)
|
||||||
@ -229,15 +263,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
torch._dynamo.mark_dynamic(t, 0)
|
torch._dynamo.mark_dynamic(t, 0)
|
||||||
torch._dynamo.mark_dynamic(p, 0)
|
torch._dynamo.mark_dynamic(p, 0)
|
||||||
# Dummy run.
|
# Dummy run.
|
||||||
self.model(token_ids,
|
self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
|
||||||
position_ids,
|
num_samples, kv_caches)
|
||||||
attn_metadata,
|
|
||||||
input_lens,
|
|
||||||
t,
|
|
||||||
p,
|
|
||||||
num_samples,
|
|
||||||
kv_caches,
|
|
||||||
is_prompt=is_prompt)
|
|
||||||
|
|
||||||
def warmup_model(
|
def warmup_model(
|
||||||
self,
|
self,
|
||||||
@ -248,13 +275,13 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
start = time.time()
|
start = time.time()
|
||||||
for batch_size in [1]:
|
for batch_size in [1]:
|
||||||
seq_len = 16
|
seq_len = 16
|
||||||
while True:
|
while seq_len <= self.model_config.max_model_len:
|
||||||
self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True)
|
self._dummy_run(batch_size,
|
||||||
|
seq_len,
|
||||||
|
kv_caches,
|
||||||
|
exec_mode=ExecutionMode.PREFILL)
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
|
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
|
||||||
|
|
||||||
if seq_len >= self.model_config.max_model_len:
|
|
||||||
break
|
|
||||||
num_tokens = batch_size * seq_len
|
num_tokens = batch_size * seq_len
|
||||||
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
|
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
|
||||||
break
|
break
|
||||||
@ -263,12 +290,39 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
end = time.time()
|
end = time.time()
|
||||||
logger.info("Compilation for prefill done in %.2f s.", end - start)
|
logger.info("Compilation for prefill done in %.2f s.", end - start)
|
||||||
|
|
||||||
|
# Prefix prefill
|
||||||
|
if self.cache_config.enable_prefix_caching:
|
||||||
|
logger.info("Compiling the model with different input shapes for "
|
||||||
|
"prefix prefill...")
|
||||||
|
start = time.time()
|
||||||
|
for batch_size in [1]:
|
||||||
|
seq_len = 16
|
||||||
|
while seq_len <= self.model_config.max_model_len:
|
||||||
|
self._dummy_run(batch_size,
|
||||||
|
seq_len,
|
||||||
|
kv_caches,
|
||||||
|
exec_mode=ExecutionMode.PREFIX_PREFILL)
|
||||||
|
xm.wait_device_ops()
|
||||||
|
logger.info("batch_size: %d, seq_len: %d", batch_size,
|
||||||
|
seq_len)
|
||||||
|
num_tokens = batch_size * seq_len
|
||||||
|
if (num_tokens >=
|
||||||
|
self.scheduler_config.max_num_batched_tokens):
|
||||||
|
break
|
||||||
|
seq_len = seq_len * 2
|
||||||
|
end = time.time()
|
||||||
|
logger.info("Compilation for prefix prefill done in %.2f s.",
|
||||||
|
end - start)
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
start = time.time()
|
start = time.time()
|
||||||
seq_len = 1
|
seq_len = 1
|
||||||
batch_size = 8 # Must be in sync with _get_padded_batch_size()
|
batch_size = 8 # Must be in sync with _get_padded_batch_size()
|
||||||
while True:
|
while True:
|
||||||
self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
|
self._dummy_run(batch_size,
|
||||||
|
seq_len,
|
||||||
|
kv_caches,
|
||||||
|
exec_mode=ExecutionMode.DECODE)
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
|
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
|
||||||
|
|
||||||
@ -287,9 +341,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
input_tokens: List[int] = []
|
input_tokens: List[int] = []
|
||||||
input_positions: List[int] = []
|
input_positions: List[int] = []
|
||||||
prompt_lens: List[int] = []
|
prompt_lens: List[int] = []
|
||||||
|
context_lens: List[int] = []
|
||||||
slot_mapping: List[int] = []
|
slot_mapping: List[int] = []
|
||||||
|
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for batch_idx, seq_group_metadata in enumerate(
|
||||||
|
seq_group_metadata_list):
|
||||||
assert seq_group_metadata.is_prompt
|
assert seq_group_metadata.is_prompt
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
assert len(seq_ids) == 1
|
assert len(seq_ids) == 1
|
||||||
@ -298,19 +354,31 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
# Could include output tokens when a request is preempted.
|
# Could include output tokens when a request is preempted.
|
||||||
prompt_tokens = seq_data.get_token_ids()
|
prompt_tokens = seq_data.get_token_ids()
|
||||||
|
seq_len = len(prompt_tokens)
|
||||||
|
|
||||||
|
num_computed_blocks = len(seq_group_metadata.computed_block_nums)
|
||||||
|
num_computed_tokens = num_computed_blocks * self.block_size
|
||||||
|
if num_computed_tokens > 0:
|
||||||
|
prompt_tokens = prompt_tokens[num_computed_tokens:]
|
||||||
|
context_lens.append(seq_len)
|
||||||
|
else:
|
||||||
|
context_lens.append(0)
|
||||||
|
|
||||||
prompt_len = len(prompt_tokens)
|
prompt_len = len(prompt_tokens)
|
||||||
prompt_lens.append(prompt_len)
|
prompt_lens.append(prompt_len)
|
||||||
|
|
||||||
input_tokens.extend(prompt_tokens)
|
input_tokens.extend(prompt_tokens)
|
||||||
input_positions.extend(list(range(prompt_len)))
|
input_positions.extend(range(num_computed_tokens, seq_len))
|
||||||
|
|
||||||
assert seq_group_metadata.block_tables is not None
|
assert seq_group_metadata.block_tables is not None
|
||||||
block_table = seq_group_metadata.block_tables[seq_id]
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
for i in range(prompt_len):
|
for i in range(num_computed_tokens, seq_len):
|
||||||
block_number = block_table[i // self.block_size]
|
block_number = block_table[i // self.block_size]
|
||||||
block_offset = i % self.block_size
|
block_offset = i % self.block_size
|
||||||
slot = block_number * self.block_size + block_offset
|
slot = block_number * self.block_size + block_offset
|
||||||
slot_mapping.append(slot)
|
slot_mapping.append(slot)
|
||||||
|
if num_computed_tokens > 0:
|
||||||
|
self.block_tables[batch_idx, :len(block_table)] = block_table
|
||||||
|
|
||||||
# Add paddings to EACH prompt to the smallest power of 2 that is
|
# Add paddings to EACH prompt to the smallest power of 2 that is
|
||||||
# greater than or equal to the prompt length.
|
# greater than or equal to the prompt length.
|
||||||
@ -338,14 +406,21 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
prompt_lens = torch.tensor(prompt_lens,
|
prompt_lens = torch.tensor(prompt_lens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu")
|
device="cpu")
|
||||||
|
context_lens = torch.tensor(context_lens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
block_tables = torch.tensor(self.block_tables[:num_prefills],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
attn_metadata = self.attn_backend.make_metadata(
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
num_prefill_tokens=0, # NOTE: This is not used.
|
num_prefill_tokens=0, # NOTE: This is not used.
|
||||||
num_decode_tokens=0,
|
num_decode_tokens=0,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
multi_modal_placeholder_index_maps=None,
|
multi_modal_placeholder_index_maps=None,
|
||||||
block_tables=None,
|
block_tables=block_tables,
|
||||||
context_lens=None,
|
context_lens=context_lens,
|
||||||
|
effective_query_lens=prompt_lens,
|
||||||
)
|
)
|
||||||
return input_tokens, input_positions, attn_metadata, prompt_lens
|
return input_tokens, input_positions, attn_metadata, prompt_lens
|
||||||
|
|
||||||
@ -550,6 +625,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
# process them separately. This is a temporary hack that should be
|
# process them separately. This is a temporary hack that should be
|
||||||
# optimized by using SplashAttention.
|
# optimized by using SplashAttention.
|
||||||
orig_slot_mapping = model_input.attn_metadata.slot_mapping
|
orig_slot_mapping = model_input.attn_metadata.slot_mapping
|
||||||
|
orig_block_tables = model_input.attn_metadata.block_tables
|
||||||
|
orig_context_lens = model_input.attn_metadata.context_lens
|
||||||
|
orig_effective_query_lens = \
|
||||||
|
model_input.attn_metadata.effective_query_lens
|
||||||
batch_size = model_input.input_lens.shape[0]
|
batch_size = model_input.input_lens.shape[0]
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
next_token_ids = []
|
next_token_ids = []
|
||||||
@ -568,18 +647,24 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
attn_metadata.num_prefills = 1
|
attn_metadata.num_prefills = 1
|
||||||
attn_metadata.slot_mapping = orig_slot_mapping[
|
attn_metadata.slot_mapping = orig_slot_mapping[
|
||||||
None, start_idx:end_idx].to(self.device)
|
None, start_idx:end_idx].to(self.device)
|
||||||
|
if orig_context_lens[i].item() > 0:
|
||||||
|
attn_metadata.context_lens = orig_context_lens[i:i + 1].to(
|
||||||
|
self.device)
|
||||||
|
attn_metadata.block_tables = orig_block_tables[
|
||||||
|
i].unsqueeze(0).to(self.device)
|
||||||
|
attn_metadata.effective_query_lens = \
|
||||||
|
orig_effective_query_lens[i:i + 1].to(self.device)
|
||||||
|
else:
|
||||||
|
attn_metadata.context_lens = None
|
||||||
|
attn_metadata.block_tables = None
|
||||||
|
attn_metadata.effective_query_lens = None
|
||||||
input_lens = model_input.input_lens[i:i + 1].to(self.device)
|
input_lens = model_input.input_lens[i:i + 1].to(self.device)
|
||||||
t = model_input.t[i:i + 1].to(self.device)
|
t = model_input.t[i:i + 1].to(self.device)
|
||||||
p = model_input.p[i:i + 1].to(self.device)
|
p = model_input.p[i:i + 1].to(self.device)
|
||||||
output_token_ids = self.model(token_ids,
|
output_token_ids = self.model(token_ids, position_ids,
|
||||||
position_ids,
|
attn_metadata, input_lens, t, p,
|
||||||
attn_metadata,
|
|
||||||
input_lens,
|
|
||||||
t,
|
|
||||||
p,
|
|
||||||
model_input.num_samples,
|
model_input.num_samples,
|
||||||
kv_caches,
|
kv_caches)
|
||||||
is_prompt=True)
|
|
||||||
next_token_ids.append(output_token_ids[0])
|
next_token_ids.append(output_token_ids[0])
|
||||||
start_idx = end_idx
|
start_idx = end_idx
|
||||||
|
|
||||||
@ -624,15 +709,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
input_lens = model_input.input_lens.to(self.device)
|
input_lens = model_input.input_lens.to(self.device)
|
||||||
for i in range(num_steps):
|
for i in range(num_steps):
|
||||||
slot_mapping = attn_metadata.slot_mapping
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
output_token_ids = self.model(token_ids,
|
output_token_ids = self.model(token_ids, position_ids,
|
||||||
position_ids,
|
attn_metadata, input_lens, t, p,
|
||||||
attn_metadata,
|
|
||||||
input_lens,
|
|
||||||
t,
|
|
||||||
p,
|
|
||||||
model_input.num_samples,
|
model_input.num_samples,
|
||||||
kv_caches,
|
kv_caches)
|
||||||
is_prompt=False)
|
|
||||||
self.cached_step_outputs.append(output_token_ids)
|
self.cached_step_outputs.append(output_token_ids)
|
||||||
|
|
||||||
if i < num_steps - 1:
|
if i < num_steps - 1:
|
||||||
@ -667,34 +747,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
return [sampler_output]
|
return [sampler_output]
|
||||||
|
|
||||||
|
|
||||||
class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
|
class ModelWrapper(nn.Module):
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, vllm_config: VllmConfig):
|
def __init__(self, model: nn.Module):
|
||||||
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
compiled_callable = torch.compile(self.forward,
|
|
||||||
backend="openxla",
|
|
||||||
fullgraph=True,
|
|
||||||
dynamic=False)
|
|
||||||
super().__init__(
|
|
||||||
compiled_callable,
|
|
||||||
compilation_level=vllm_config.compilation_config.level)
|
|
||||||
|
|
||||||
def __call__(self, *args, is_prompt: bool, **kwargs):
|
|
||||||
if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
|
|
||||||
# not fully compiled yet, or not using the custom dispatcher,
|
|
||||||
# let PyTorch handle it
|
|
||||||
return self.compiled_callable(*args, **kwargs)
|
|
||||||
# the 3 compiled codes are:
|
|
||||||
# 0: for profiling
|
|
||||||
# 1: for prompt
|
|
||||||
# 2: for decode
|
|
||||||
# dispatch to the compiled code directly, skip PyTorch
|
|
||||||
if is_prompt:
|
|
||||||
with self.dispatch_to_code(1):
|
|
||||||
return self.forward(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
with self.dispatch_to_code(2):
|
|
||||||
return self.forward(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
||||||
from vllm.worker.tpu_model_runner import TPUModelRunner
|
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
|
||||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||||
LoraNotSupportedWorkerBase, WorkerBase,
|
LoraNotSupportedWorkerBase, WorkerBase,
|
||||||
WorkerInput)
|
WorkerInput)
|
||||||
@ -112,7 +112,7 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
batch_size=1,
|
batch_size=1,
|
||||||
seq_len=self.scheduler_config.max_num_batched_tokens,
|
seq_len=self.scheduler_config.max_num_batched_tokens,
|
||||||
kv_caches=kv_caches,
|
kv_caches=kv_caches,
|
||||||
is_prompt=True,
|
exec_mode=ExecutionMode.PREFILL,
|
||||||
)
|
)
|
||||||
# Synchronize before measuring the memory usage.
|
# Synchronize before measuring the memory usage.
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user