[TPU] Add TPU specific var VLLM_TPU_MOST_MODEL_LEN (#19919)

Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
Chenyaaang 2025-06-25 15:51:02 -07:00 committed by GitHub
parent 55c65ab495
commit 2d7620c3eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 184 additions and 76 deletions

View File

@ -587,3 +587,17 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_most_model_len(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048")
vllm_config = get_vllm_config()
vllm_config.model_config.max_model_len = 32000
vllm_config.scheduler_config.max_num_seqs = 1200
model_runner = get_model_runner(vllm_config)
# verify model runner will adjust num_reqs to avoid SMEM OOM.
assert model_runner.num_reqs_most_model_len == 1200
# num_page_per_req = 32k // 128
# num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524
assert model_runner.num_reqs_max_model_len == 524

View File

@ -119,6 +119,7 @@ if TYPE_CHECKING:
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_USE_DEEP_GEMM: bool = False VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
@ -833,6 +834,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TPU_BUCKET_PADDING_GAP": "VLLM_TPU_BUCKET_PADDING_GAP":
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"]) lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0, if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
"VLLM_TPU_MOST_MODEL_LEN":
lambda: maybe_convert_int(os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None)),
# Allow use of DeepGemm kernels for fused moe ops. # Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM": "VLLM_USE_DEEP_GEMM":

View File

@ -122,16 +122,6 @@ class TpuPlatform(Platform):
PallasAttentionBackend) PallasAttentionBackend)
cache_config.block_size = PallasAttentionBackend.get_page_size( cache_config.block_size = PallasAttentionBackend.get_page_size(
vllm_config) # type: ignore[assignment] vllm_config) # type: ignore[assignment]
min_page_size = PallasAttentionBackend.get_min_page_size(
vllm_config)
if min_page_size > cache_config.block_size:
logger.warning(
"Increase the page size from %s to %s to make sure there's"
"no SMEM OOM",
cache_config.block_size,
min_page_size,
)
cache_config.block_size = min_page_size # type: ignore[assignment]
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config

View File

@ -71,6 +71,11 @@ class PallasAttentionBackend(AttentionBackend):
min_page_size = 1 << (min_page_size - 1).bit_length() min_page_size = 1 << (min_page_size - 1).bit_length()
return min_page_size return min_page_size
@staticmethod
def get_max_num_seqs(model_len: int, page_size: int) -> int:
num_page_per_req = cdiv(model_len, page_size)
return 1024 * 1024 // 2 // num_page_per_req // 4
# TPU has limited SREGs (scalar registers), if page_size is too small, we # TPU has limited SREGs (scalar registers), if page_size is too small, we
# can spill SREGs easily which leads to bad performance. The strategy we # can spill SREGs easily which leads to bad performance. The strategy we
# apply here is trying to split max-model-len to 16 pages which make the # apply here is trying to split max-model-len to 16 pages which make the

View File

@ -37,8 +37,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec, KVCacheConfig, KVCacheSpec,
SlidingWindowSpec) SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
ModelRunnerOutput) LogprobsTensors, ModelRunnerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
@ -150,7 +150,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.sliding_window = model_config.get_sliding_window() self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
self.most_model_len = envs.VLLM_TPU_MOST_MODEL_LEN
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.num_blocks_per_most_len_req = cdiv(
self.most_model_len,
self.block_size) if self.most_model_len is not None else None
# InputBatch needs to work with sampling tensors greater than padding # InputBatch needs to work with sampling tensors greater than padding
# to avoid dynamic shapes. Also, avoid suboptimal alignment. # to avoid dynamic shapes. Also, avoid suboptimal alignment.
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS) self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
@ -220,12 +224,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int32, dtype=torch.int32,
device="cpu") device="cpu")
self.positions_np = self.positions_cpu.numpy() self.positions_np = self.positions_cpu.numpy()
self.block_table_cpu = torch.zeros( self.block_table_cpu = torch.zeros(
(self.max_num_reqs, self.max_num_blocks_per_req), (self.max_num_reqs, self.max_num_blocks_per_req),
dtype=torch.int32, dtype=torch.int32,
device="cpu") device="cpu")
# adjust num_reqs to avoid SMEM OOM.
self.num_reqs_most_model_len = min(
PallasAttentionBackend.get_max_num_seqs(self.most_model_len,
self.block_size),
self.max_num_reqs) if self.most_model_len is not None else None
self.num_reqs_max_model_len = min(
PallasAttentionBackend.get_max_num_seqs(self.max_model_len,
self.block_size),
self.max_num_reqs)
self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
dtype=torch.int32, dtype=torch.int32,
device="cpu", device="cpu",
@ -515,25 +526,50 @@ class TPUModelRunner(LoRAModelRunnerMixin):
return kv_cache_spec return kv_cache_spec
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens start_index: int):
assert total_num_scheduled_tokens > 0 assert scheduler_output.total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
assert num_reqs > 0 assert num_reqs > 0
assert start_index < num_reqs
# Get the number of scheduled tokens for each request. # Get the number of scheduled tokens for each request.
use_max_model_len = self.most_model_len is None
num_scheduled_tokens_per_req = [] num_scheduled_tokens_per_req = []
max_num_scheduled_tokens_all_reqs = 0 max_num_scheduled_tokens_all_reqs = 0
for req_id in self.input_batch.req_ids[:num_reqs]: end_index = start_index
# Use either most_model_len or max_model_len depending on request size.
for i in range(start_index, num_reqs):
req_id = self.input_batch.req_ids[i]
assert req_id is not None assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_tokens = scheduler_output.num_scheduled_tokens[req_id]
if not use_max_model_len and num_tokens > self.most_model_len:
use_max_model_len = True
num_scheduled_tokens_per_req.append(num_tokens) num_scheduled_tokens_per_req.append(num_tokens)
max_num_scheduled_tokens_all_reqs = max( if use_max_model_len:
max_num_scheduled_tokens_all_reqs, num_tokens) if len(num_scheduled_tokens_per_req) > self.num_reqs_max_model_len:
num_scheduled_tokens_per_req = \
num_scheduled_tokens_per_req[:self.num_reqs_max_model_len]
end_index = start_index + self.num_reqs_max_model_len
else:
end_index = num_reqs
else:
if len(num_scheduled_tokens_per_req
) > self.num_reqs_most_model_len:
num_scheduled_tokens_per_req = \
num_scheduled_tokens_per_req[:self.num_reqs_most_model_len]
end_index = start_index + self.num_reqs_most_model_len
else:
end_index = num_reqs
max_num_scheduled_tokens_all_reqs = max(num_scheduled_tokens_per_req)
num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req, num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req,
dtype=np.int32) dtype=np.int32)
total_num_scheduled_tokens = sum(num_scheduled_tokens_per_req)
assert max_num_scheduled_tokens_all_reqs > 0 assert max_num_scheduled_tokens_all_reqs > 0
num_reqs = len(num_scheduled_tokens_per_req)
# Get request indices. # Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
# For each scheduled token, what are the corresponding req index. # For each scheduled token, what are the corresponding req index.
@ -615,13 +651,29 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.block_table[0]. self.input_batch.block_table[0].
slot_mapping_cpu[:padded_total_num_scheduled_tokens].to( slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
self.device)) self.device))
block_tables = self.block_table_cpu[:self.max_num_reqs] if use_max_model_len:
block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, :
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs]) self.max_num_blocks_per_req]
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
query_start_loc = self.query_start_loc_cpu[:self.
num_reqs_max_model_len +
1].to(self.device)
seq_lens = self.seq_lens_cpu[:self.num_reqs_max_model_len].to(
self.device)
else:
block_tables = self.block_table_cpu[:self.
num_reqs_most_model_len, :self.
num_blocks_per_most_len_req]
block_tables[:num_reqs, :self.num_blocks_per_most_len_req] = (
self.input_batch.block_table[0].get_cpu_tensor()
[:num_reqs, :self.num_blocks_per_most_len_req])
query_start_loc = self.query_start_loc_cpu[:self.
num_reqs_most_model_len +
1].to(self.device)
seq_lens = self.seq_lens_cpu[:self.num_reqs_most_model_len].to(
self.device)
block_tables = block_tables.to(self.device) block_tables = block_tables.to(self.device)
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
self.device)
seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)
if self.lora_config is not None: if self.lora_config is not None:
# We need to respect padding when activating LoRA adapters # We need to respect padding when activating LoRA adapters
@ -672,7 +724,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
layer_name: attn_metadata layer_name: attn_metadata
for layer_name in layer_names for layer_name in layer_names
} }
return per_layer_attn_metadata, logits_indices, padded_num_reqs return per_layer_attn_metadata, logits_indices, padded_num_reqs,\
num_reqs, end_index
def _scatter_placeholders( def _scatter_placeholders(
self, self,
@ -847,52 +900,84 @@ class TPUModelRunner(LoRAModelRunnerMixin):
else: else:
mm_embeds = [] mm_embeds = []
xm.mark_step() xm.mark_step()
# Prepare inputs # Prepare inputs, the requests might be splitted into multiple
attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs( # executions, combine the result of each execution.
scheduler_output) start_index = 0
input_ids, inputs_embeds = self._get_model_inputs( combined_selected_tokens: list[torch.Tensor] = []
self.input_ids, mm_embeds) combined_logprobs: list[LogprobsLists] = []
xm.mark_step() while start_index < self.input_batch.num_reqs:
num_reqs = self.input_batch.num_reqs attn_metadata, logits_indices, padded_num_reqs, num_reqs,\
# Run the decoder end_index = self._prepare_inputs(scheduler_output, start_index)
with set_forward_context( input_ids, inputs_embeds = self._get_model_inputs(
attn_metadata, self.input_ids, mm_embeds)
self.vllm_config, xm.mark_step()
num_tokens=scheduler_output.total_num_scheduled_tokens): # Run the decoder
hidden_states = self.model( with set_forward_context(
input_ids=input_ids, attn_metadata,
positions=self.position_ids, self.vllm_config,
inputs_embeds=inputs_embeds, num_tokens=scheduler_output.total_num_scheduled_tokens):
) hidden_states = self.model(
hidden_states = self.select_hidden_states(hidden_states, input_ids=input_ids,
logits_indices) positions=self.position_ids,
logits = self.compute_logits(hidden_states) inputs_embeds=inputs_embeds,
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ )
from_input_batch(self.input_batch, padded_num_reqs, self.device) hidden_states = self.select_hidden_states(hidden_states,
if scheduler_output.grammar_bitmask is not None: logits_indices)
require_struct_decoding, grammar_bitmask_padded, arange = \ logits = self.compute_logits(hidden_states)
self.prepare_structured_decoding_input(logits, scheduler_output) tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
logits = self.structured_decode(require_struct_decoding, from_input_batch(self.input_batch, padded_num_reqs, self.device)
grammar_bitmask_padded, logits, if scheduler_output.grammar_bitmask is not None:
arange) require_struct_decoding, grammar_bitmask_padded, arange = \
selected_token_ids = self.sample_from_logits_func( self.prepare_structured_decoding_input(logits,
logits, tpu_sampling_metadata) scheduler_output)
# NOTE (NickLucche) Use the original logits (before any penalties or logits = self.structured_decode(require_struct_decoding,
# temperature scaling) for the top-k logprobs. We can't enforce it due grammar_bitmask_padded, logits,
# to recompilations outside torch.compiled code, so just make sure arange)
# `sample_from_logits` does not modify the logits in-place. selected_token_ids = self.sample_from_logits_func(
logprobs = self.gather_logprobs(logits, selected_token_ids) \ logits, tpu_sampling_metadata)
if tpu_sampling_metadata.logprobs else None # NOTE (NickLucche) Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs. We can't enforce it
# due to recompilations outside torch.compiled code, so just make
# sure `sample_from_logits` does not modify the logits in-place.
logprobs = self.gather_logprobs(logits, selected_token_ids) \
if tpu_sampling_metadata.logprobs else None
# Remove padding on cpu and keep dynamic op outside of xla graph. # Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs] selected_token_ids = selected_token_ids.cpu()[:num_reqs]
logprobs_lists = logprobs.tolists() \
if tpu_sampling_metadata.logprobs else None combined_selected_tokens.append(selected_token_ids)
if tpu_sampling_metadata.logprobs:
combined_logprobs.append(logprobs.tolists())
start_index = end_index
selected_token_ids = torch.cat(combined_selected_tokens, dim=0)
if tpu_sampling_metadata.logprobs:
def concat_lists(input_lists):
result = []
for input_list in input_lists:
result.extend(input_list)
return result
logprobs_lists = LogprobsLists(logprob_token_ids=concat_lists(
[lp.logprob_token_ids for lp in combined_logprobs]),
logprobs=concat_lists([
lp.logprobs
for lp in combined_logprobs
]),
sampled_token_ranks=concat_lists([
lp.sampled_token_ranks
for lp in combined_logprobs
]))
else:
logprobs_lists = None
# Update the cache state concurrently. Code above will not block until # Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes # we use `selected_token_ids`. Add mark_step if post-processing changes
request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
discard_sampled_tokens_req_indices = [] discard_sampled_tokens_req_indices = []
num_reqs = self.input_batch.num_reqs
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None assert req_id is not None
req_state = self.requests[req_id] req_state = self.requests[req_id]
@ -1020,7 +1105,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.sampler = TPUSampler() self.sampler = TPUSampler()
@torch.no_grad() @torch.no_grad()
def _dummy_run(self, num_tokens: int) -> None: def _dummy_run(self, num_tokens: int, num_reqs: int,
num_blocks: int) -> None:
if self.is_multimodal_model: if self.is_multimodal_model:
input_ids = None input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size), inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
@ -1030,20 +1116,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
input_ids = torch.zeros((num_tokens), input_ids = torch.zeros((num_tokens),
dtype=torch.int32).to(self.device) dtype=torch.int32).to(self.device)
inputs_embeds = None inputs_embeds = None
actual_num_reqs = min(num_tokens, self.max_num_reqs) actual_num_reqs = min(num_tokens, num_reqs)
position_ids = torch.zeros(num_tokens, position_ids = torch.zeros(num_tokens,
dtype=torch.int32).to(self.device) dtype=torch.int32).to(self.device)
slot_mapping = torch.zeros(num_tokens, slot_mapping = torch.zeros(num_tokens,
dtype=torch.int64).to(self.device) dtype=torch.int64).to(self.device)
block_tables = torch.zeros( block_tables = torch.zeros((num_reqs, num_blocks),
(self.max_num_reqs, self.block_table_cpu.shape[1]), dtype=torch.int32).to(self.device)
dtype=torch.int32).to(self.device) query_lens = [1] * num_reqs
query_lens = [1] * self.max_num_reqs
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.int32), dtype=torch.int32),
dim=0, dim=0,
dtype=torch.int32).to(self.device) dtype=torch.int32).to(self.device)
context_lens = torch.ones((self.max_num_reqs, ), context_lens = torch.ones((num_reqs, ),
dtype=torch.int32).to(self.device) dtype=torch.int32).to(self.device)
num_seqs = torch.tensor([actual_num_reqs], num_seqs = torch.tensor([actual_num_reqs],
dtype=torch.int32).to(self.device) dtype=torch.int32).to(self.device)
@ -1061,6 +1146,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(input_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, (0, 1))
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
layer_names = get_layers_from_vllm_config(self.vllm_config, layer_names = get_layers_from_vllm_config(self.vllm_config,
Attention).keys() Attention).keys()
@ -1152,7 +1240,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
start = time.perf_counter() start = time.perf_counter()
for num_tokens in self.num_tokens_paddings: for num_tokens in self.num_tokens_paddings:
logger.info(" -- num_tokens: %d", num_tokens) logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(num_tokens) self._dummy_run(num_tokens, self.num_reqs_max_model_len,
self.max_num_blocks_per_req)
if self.most_model_len is not None:
self._dummy_run(num_tokens, self.num_reqs_most_model_len,
self.num_blocks_per_most_len_req)
xm.wait_device_ops() xm.wait_device_ops()
end = time.perf_counter() end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start) logger.info("Compilation finished in %.2f [secs].", end - start)
@ -1341,7 +1433,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
# Trigger compilation for general shape. # Trigger compilation for general shape.
self._dummy_run(num_tokens) self._dummy_run(num_tokens, self.num_reqs_max_model_len,
self.max_num_blocks_per_req)
if self.most_model_len is not None:
self._dummy_run(num_tokens, self.num_reqs_most_model_len,
self.num_blocks_per_most_len_req)
xm.mark_step() xm.mark_step()
xm.wait_device_ops() xm.wait_device_ops()