[TPU] Add TPU specific var VLLM_TPU_MOST_MODEL_LEN (#19919)

Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
Chenyaaang 2025-06-25 15:51:02 -07:00 committed by GitHub
parent 55c65ab495
commit 2d7620c3eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 184 additions and 76 deletions

View File

@ -587,3 +587,17 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_most_model_len(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048")
vllm_config = get_vllm_config()
vllm_config.model_config.max_model_len = 32000
vllm_config.scheduler_config.max_num_seqs = 1200
model_runner = get_model_runner(vllm_config)
# verify model runner will adjust num_reqs to avoid SMEM OOM.
assert model_runner.num_reqs_most_model_len == 1200
# num_page_per_req = 32k // 128
# num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524
assert model_runner.num_reqs_max_model_len == 524

View File

@ -119,6 +119,7 @@ if TYPE_CHECKING:
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
@ -833,6 +834,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TPU_BUCKET_PADDING_GAP":
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
"VLLM_TPU_MOST_MODEL_LEN":
lambda: maybe_convert_int(os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None)),
# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM":

View File

@ -122,16 +122,6 @@ class TpuPlatform(Platform):
PallasAttentionBackend)
cache_config.block_size = PallasAttentionBackend.get_page_size(
vllm_config) # type: ignore[assignment]
min_page_size = PallasAttentionBackend.get_min_page_size(
vllm_config)
if min_page_size > cache_config.block_size:
logger.warning(
"Increase the page size from %s to %s to make sure there's"
"no SMEM OOM",
cache_config.block_size,
min_page_size,
)
cache_config.block_size = min_page_size # type: ignore[assignment]
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config

View File

@ -71,6 +71,11 @@ class PallasAttentionBackend(AttentionBackend):
min_page_size = 1 << (min_page_size - 1).bit_length()
return min_page_size
@staticmethod
def get_max_num_seqs(model_len: int, page_size: int) -> int:
num_page_per_req = cdiv(model_len, page_size)
return 1024 * 1024 // 2 // num_page_per_req // 4
# TPU has limited SREGs (scalar registers), if page_size is too small, we
# can spill SREGs easily which leads to bad performance. The strategy we
# apply here is trying to split max-model-len to 16 pages which make the

View File

@ -37,8 +37,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec,
SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
LogprobsTensors, ModelRunnerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
@ -150,7 +150,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.most_model_len = envs.VLLM_TPU_MOST_MODEL_LEN
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.num_blocks_per_most_len_req = cdiv(
self.most_model_len,
self.block_size) if self.most_model_len is not None else None
# InputBatch needs to work with sampling tensors greater than padding
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
@ -220,12 +224,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int32,
device="cpu")
self.positions_np = self.positions_cpu.numpy()
self.block_table_cpu = torch.zeros(
(self.max_num_reqs, self.max_num_blocks_per_req),
dtype=torch.int32,
device="cpu")
# adjust num_reqs to avoid SMEM OOM.
self.num_reqs_most_model_len = min(
PallasAttentionBackend.get_max_num_seqs(self.most_model_len,
self.block_size),
self.max_num_reqs) if self.most_model_len is not None else None
self.num_reqs_max_model_len = min(
PallasAttentionBackend.get_max_num_seqs(self.max_model_len,
self.block_size),
self.max_num_reqs)
self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
dtype=torch.int32,
device="cpu",
@ -515,25 +526,50 @@ class TPUModelRunner(LoRAModelRunnerMixin):
return kv_cache_spec
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
start_index: int):
assert scheduler_output.total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
assert start_index < num_reqs
# Get the number of scheduled tokens for each request.
use_max_model_len = self.most_model_len is None
num_scheduled_tokens_per_req = []
max_num_scheduled_tokens_all_reqs = 0
for req_id in self.input_batch.req_ids[:num_reqs]:
end_index = start_index
# Use either most_model_len or max_model_len depending on request size.
for i in range(start_index, num_reqs):
req_id = self.input_batch.req_ids[i]
assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
if not use_max_model_len and num_tokens > self.most_model_len:
use_max_model_len = True
num_scheduled_tokens_per_req.append(num_tokens)
max_num_scheduled_tokens_all_reqs = max(
max_num_scheduled_tokens_all_reqs, num_tokens)
if use_max_model_len:
if len(num_scheduled_tokens_per_req) > self.num_reqs_max_model_len:
num_scheduled_tokens_per_req = \
num_scheduled_tokens_per_req[:self.num_reqs_max_model_len]
end_index = start_index + self.num_reqs_max_model_len
else:
end_index = num_reqs
else:
if len(num_scheduled_tokens_per_req
) > self.num_reqs_most_model_len:
num_scheduled_tokens_per_req = \
num_scheduled_tokens_per_req[:self.num_reqs_most_model_len]
end_index = start_index + self.num_reqs_most_model_len
else:
end_index = num_reqs
max_num_scheduled_tokens_all_reqs = max(num_scheduled_tokens_per_req)
num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req,
dtype=np.int32)
total_num_scheduled_tokens = sum(num_scheduled_tokens_per_req)
assert max_num_scheduled_tokens_all_reqs > 0
num_reqs = len(num_scheduled_tokens_per_req)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
# For each scheduled token, what are the corresponding req index.
@ -615,13 +651,29 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.block_table[0].
slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
self.device))
block_tables = self.block_table_cpu[:self.max_num_reqs]
if use_max_model_len:
block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, :
self.max_num_blocks_per_req]
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
block_tables = block_tables.to(self.device)
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
query_start_loc = self.query_start_loc_cpu[:self.
num_reqs_max_model_len +
1].to(self.device)
seq_lens = self.seq_lens_cpu[:self.num_reqs_max_model_len].to(
self.device)
seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)
else:
block_tables = self.block_table_cpu[:self.
num_reqs_most_model_len, :self.
num_blocks_per_most_len_req]
block_tables[:num_reqs, :self.num_blocks_per_most_len_req] = (
self.input_batch.block_table[0].get_cpu_tensor()
[:num_reqs, :self.num_blocks_per_most_len_req])
query_start_loc = self.query_start_loc_cpu[:self.
num_reqs_most_model_len +
1].to(self.device)
seq_lens = self.seq_lens_cpu[:self.num_reqs_most_model_len].to(
self.device)
block_tables = block_tables.to(self.device)
if self.lora_config is not None:
# We need to respect padding when activating LoRA adapters
@ -672,7 +724,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
layer_name: attn_metadata
for layer_name in layer_names
}
return per_layer_attn_metadata, logits_indices, padded_num_reqs
return per_layer_attn_metadata, logits_indices, padded_num_reqs,\
num_reqs, end_index
def _scatter_placeholders(
self,
@ -847,13 +900,17 @@ class TPUModelRunner(LoRAModelRunnerMixin):
else:
mm_embeds = []
xm.mark_step()
# Prepare inputs
attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs(
scheduler_output)
# Prepare inputs, the requests might be splitted into multiple
# executions, combine the result of each execution.
start_index = 0
combined_selected_tokens: list[torch.Tensor] = []
combined_logprobs: list[LogprobsLists] = []
while start_index < self.input_batch.num_reqs:
attn_metadata, logits_indices, padded_num_reqs, num_reqs,\
end_index = self._prepare_inputs(scheduler_output, start_index)
input_ids, inputs_embeds = self._get_model_inputs(
self.input_ids, mm_embeds)
xm.mark_step()
num_reqs = self.input_batch.num_reqs
# Run the decoder
with set_forward_context(
attn_metadata,
@ -871,28 +928,56 @@ class TPUModelRunner(LoRAModelRunnerMixin):
from_input_batch(self.input_batch, padded_num_reqs, self.device)
if scheduler_output.grammar_bitmask is not None:
require_struct_decoding, grammar_bitmask_padded, arange = \
self.prepare_structured_decoding_input(logits, scheduler_output)
self.prepare_structured_decoding_input(logits,
scheduler_output)
logits = self.structured_decode(require_struct_decoding,
grammar_bitmask_padded, logits,
arange)
selected_token_ids = self.sample_from_logits_func(
logits, tpu_sampling_metadata)
# NOTE (NickLucche) Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs. We can't enforce it due
# to recompilations outside torch.compiled code, so just make sure
# `sample_from_logits` does not modify the logits in-place.
# temperature scaling) for the top-k logprobs. We can't enforce it
# due to recompilations outside torch.compiled code, so just make
# sure `sample_from_logits` does not modify the logits in-place.
logprobs = self.gather_logprobs(logits, selected_token_ids) \
if tpu_sampling_metadata.logprobs else None
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
logprobs_lists = logprobs.tolists() \
if tpu_sampling_metadata.logprobs else None
combined_selected_tokens.append(selected_token_ids)
if tpu_sampling_metadata.logprobs:
combined_logprobs.append(logprobs.tolists())
start_index = end_index
selected_token_ids = torch.cat(combined_selected_tokens, dim=0)
if tpu_sampling_metadata.logprobs:
def concat_lists(input_lists):
result = []
for input_list in input_lists:
result.extend(input_list)
return result
logprobs_lists = LogprobsLists(logprob_token_ids=concat_lists(
[lp.logprob_token_ids for lp in combined_logprobs]),
logprobs=concat_lists([
lp.logprobs
for lp in combined_logprobs
]),
sampled_token_ranks=concat_lists([
lp.sampled_token_ranks
for lp in combined_logprobs
]))
else:
logprobs_lists = None
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
discard_sampled_tokens_req_indices = []
num_reqs = self.input_batch.num_reqs
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
req_state = self.requests[req_id]
@ -1020,7 +1105,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.sampler = TPUSampler()
@torch.no_grad()
def _dummy_run(self, num_tokens: int) -> None:
def _dummy_run(self, num_tokens: int, num_reqs: int,
num_blocks: int) -> None:
if self.is_multimodal_model:
input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
@ -1030,20 +1116,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
input_ids = torch.zeros((num_tokens),
dtype=torch.int32).to(self.device)
inputs_embeds = None
actual_num_reqs = min(num_tokens, self.max_num_reqs)
actual_num_reqs = min(num_tokens, num_reqs)
position_ids = torch.zeros(num_tokens,
dtype=torch.int32).to(self.device)
slot_mapping = torch.zeros(num_tokens,
dtype=torch.int64).to(self.device)
block_tables = torch.zeros(
(self.max_num_reqs, self.block_table_cpu.shape[1]),
block_tables = torch.zeros((num_reqs, num_blocks),
dtype=torch.int32).to(self.device)
query_lens = [1] * self.max_num_reqs
query_lens = [1] * num_reqs
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.int32),
dim=0,
dtype=torch.int32).to(self.device)
context_lens = torch.ones((self.max_num_reqs, ),
context_lens = torch.ones((num_reqs, ),
dtype=torch.int32).to(self.device)
num_seqs = torch.tensor([actual_num_reqs],
dtype=torch.int32).to(self.device)
@ -1061,6 +1146,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
torch._dynamo.mark_dynamic(input_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, (0, 1))
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
layer_names = get_layers_from_vllm_config(self.vllm_config,
Attention).keys()
@ -1152,7 +1240,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
start = time.perf_counter()
for num_tokens in self.num_tokens_paddings:
logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(num_tokens)
self._dummy_run(num_tokens, self.num_reqs_max_model_len,
self.max_num_blocks_per_req)
if self.most_model_len is not None:
self._dummy_run(num_tokens, self.num_reqs_most_model_len,
self.num_blocks_per_most_len_req)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start)
@ -1341,7 +1433,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
# Trigger compilation for general shape.
self._dummy_run(num_tokens)
self._dummy_run(num_tokens, self.num_reqs_max_model_len,
self.max_num_blocks_per_req)
if self.most_model_len is not None:
self._dummy_run(num_tokens, self.num_reqs_most_model_len,
self.num_blocks_per_most_len_req)
xm.mark_step()
xm.wait_device_ops()