mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:54:56 +08:00
[TPU] Add TPU specific var VLLM_TPU_MOST_MODEL_LEN (#19919)
Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
parent
55c65ab495
commit
2d7620c3eb
@ -587,3 +587,17 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||||
|
||||
|
||||
def test_most_model_len(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048")
|
||||
vllm_config = get_vllm_config()
|
||||
vllm_config.model_config.max_model_len = 32000
|
||||
vllm_config.scheduler_config.max_num_seqs = 1200
|
||||
model_runner = get_model_runner(vllm_config)
|
||||
|
||||
# verify model runner will adjust num_reqs to avoid SMEM OOM.
|
||||
assert model_runner.num_reqs_most_model_len == 1200
|
||||
# num_page_per_req = 32k // 128
|
||||
# num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524
|
||||
assert model_runner.num_reqs_max_model_len == 524
|
||||
|
||||
@ -119,6 +119,7 @@ if TYPE_CHECKING:
|
||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
|
||||
VLLM_USE_DEEP_GEMM: bool = False
|
||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||
@ -833,6 +834,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_TPU_BUCKET_PADDING_GAP":
|
||||
lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
|
||||
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0,
|
||||
"VLLM_TPU_MOST_MODEL_LEN":
|
||||
lambda: maybe_convert_int(os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None)),
|
||||
|
||||
# Allow use of DeepGemm kernels for fused moe ops.
|
||||
"VLLM_USE_DEEP_GEMM":
|
||||
|
||||
@ -122,16 +122,6 @@ class TpuPlatform(Platform):
|
||||
PallasAttentionBackend)
|
||||
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
||||
vllm_config) # type: ignore[assignment]
|
||||
min_page_size = PallasAttentionBackend.get_min_page_size(
|
||||
vllm_config)
|
||||
if min_page_size > cache_config.block_size:
|
||||
logger.warning(
|
||||
"Increase the page size from %s to %s to make sure there's"
|
||||
"no SMEM OOM",
|
||||
cache_config.block_size,
|
||||
min_page_size,
|
||||
)
|
||||
cache_config.block_size = min_page_size # type: ignore[assignment]
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
|
||||
@ -71,6 +71,11 @@ class PallasAttentionBackend(AttentionBackend):
|
||||
min_page_size = 1 << (min_page_size - 1).bit_length()
|
||||
return min_page_size
|
||||
|
||||
@staticmethod
|
||||
def get_max_num_seqs(model_len: int, page_size: int) -> int:
|
||||
num_page_per_req = cdiv(model_len, page_size)
|
||||
return 1024 * 1024 // 2 // num_page_per_req // 4
|
||||
|
||||
# TPU has limited SREGs (scalar registers), if page_size is too small, we
|
||||
# can spill SREGs easily which leads to bad performance. The strategy we
|
||||
# apply here is trying to split max-model-len to 16 pages which make the
|
||||
|
||||
@ -37,8 +37,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
||||
KVCacheConfig, KVCacheSpec,
|
||||
SlidingWindowSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
|
||||
LogprobsTensors, ModelRunnerOutput)
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
@ -150,7 +150,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.most_model_len = envs.VLLM_TPU_MOST_MODEL_LEN
|
||||
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
||||
self.num_blocks_per_most_len_req = cdiv(
|
||||
self.most_model_len,
|
||||
self.block_size) if self.most_model_len is not None else None
|
||||
# InputBatch needs to work with sampling tensors greater than padding
|
||||
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
|
||||
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
|
||||
@ -220,12 +224,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
self.positions_np = self.positions_cpu.numpy()
|
||||
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(self.max_num_reqs, self.max_num_blocks_per_req),
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
|
||||
# adjust num_reqs to avoid SMEM OOM.
|
||||
self.num_reqs_most_model_len = min(
|
||||
PallasAttentionBackend.get_max_num_seqs(self.most_model_len,
|
||||
self.block_size),
|
||||
self.max_num_reqs) if self.most_model_len is not None else None
|
||||
self.num_reqs_max_model_len = min(
|
||||
PallasAttentionBackend.get_max_num_seqs(self.max_model_len,
|
||||
self.block_size),
|
||||
self.max_num_reqs)
|
||||
self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
@ -515,25 +526,50 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput",
|
||||
start_index: int):
|
||||
assert scheduler_output.total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
assert start_index < num_reqs
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
use_max_model_len = self.most_model_len is None
|
||||
num_scheduled_tokens_per_req = []
|
||||
max_num_scheduled_tokens_all_reqs = 0
|
||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||
end_index = start_index
|
||||
|
||||
# Use either most_model_len or max_model_len depending on request size.
|
||||
for i in range(start_index, num_reqs):
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
assert req_id is not None
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
if not use_max_model_len and num_tokens > self.most_model_len:
|
||||
use_max_model_len = True
|
||||
num_scheduled_tokens_per_req.append(num_tokens)
|
||||
max_num_scheduled_tokens_all_reqs = max(
|
||||
max_num_scheduled_tokens_all_reqs, num_tokens)
|
||||
if use_max_model_len:
|
||||
if len(num_scheduled_tokens_per_req) > self.num_reqs_max_model_len:
|
||||
num_scheduled_tokens_per_req = \
|
||||
num_scheduled_tokens_per_req[:self.num_reqs_max_model_len]
|
||||
end_index = start_index + self.num_reqs_max_model_len
|
||||
else:
|
||||
end_index = num_reqs
|
||||
else:
|
||||
if len(num_scheduled_tokens_per_req
|
||||
) > self.num_reqs_most_model_len:
|
||||
num_scheduled_tokens_per_req = \
|
||||
num_scheduled_tokens_per_req[:self.num_reqs_most_model_len]
|
||||
end_index = start_index + self.num_reqs_most_model_len
|
||||
else:
|
||||
end_index = num_reqs
|
||||
max_num_scheduled_tokens_all_reqs = max(num_scheduled_tokens_per_req)
|
||||
num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req,
|
||||
dtype=np.int32)
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens_per_req)
|
||||
assert max_num_scheduled_tokens_all_reqs > 0
|
||||
|
||||
num_reqs = len(num_scheduled_tokens_per_req)
|
||||
|
||||
# Get request indices.
|
||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
# For each scheduled token, what are the corresponding req index.
|
||||
@ -615,13 +651,29 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.block_table[0].
|
||||
slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
|
||||
self.device))
|
||||
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
||||
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
||||
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
|
||||
if use_max_model_len:
|
||||
block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, :
|
||||
self.max_num_blocks_per_req]
|
||||
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
||||
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
|
||||
query_start_loc = self.query_start_loc_cpu[:self.
|
||||
num_reqs_max_model_len +
|
||||
1].to(self.device)
|
||||
seq_lens = self.seq_lens_cpu[:self.num_reqs_max_model_len].to(
|
||||
self.device)
|
||||
else:
|
||||
block_tables = self.block_table_cpu[:self.
|
||||
num_reqs_most_model_len, :self.
|
||||
num_blocks_per_most_len_req]
|
||||
block_tables[:num_reqs, :self.num_blocks_per_most_len_req] = (
|
||||
self.input_batch.block_table[0].get_cpu_tensor()
|
||||
[:num_reqs, :self.num_blocks_per_most_len_req])
|
||||
query_start_loc = self.query_start_loc_cpu[:self.
|
||||
num_reqs_most_model_len +
|
||||
1].to(self.device)
|
||||
seq_lens = self.seq_lens_cpu[:self.num_reqs_most_model_len].to(
|
||||
self.device)
|
||||
block_tables = block_tables.to(self.device)
|
||||
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
|
||||
self.device)
|
||||
seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)
|
||||
|
||||
if self.lora_config is not None:
|
||||
# We need to respect padding when activating LoRA adapters
|
||||
@ -672,7 +724,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
layer_name: attn_metadata
|
||||
for layer_name in layer_names
|
||||
}
|
||||
return per_layer_attn_metadata, logits_indices, padded_num_reqs
|
||||
return per_layer_attn_metadata, logits_indices, padded_num_reqs,\
|
||||
num_reqs, end_index
|
||||
|
||||
def _scatter_placeholders(
|
||||
self,
|
||||
@ -847,52 +900,84 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
mm_embeds = []
|
||||
xm.mark_step()
|
||||
# Prepare inputs
|
||||
attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs(
|
||||
scheduler_output)
|
||||
input_ids, inputs_embeds = self._get_model_inputs(
|
||||
self.input_ids, mm_embeds)
|
||||
xm.mark_step()
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
# Run the decoder
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=scheduler_output.total_num_scheduled_tokens):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self.position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
hidden_states = self.select_hidden_states(hidden_states,
|
||||
logits_indices)
|
||||
logits = self.compute_logits(hidden_states)
|
||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
||||
from_input_batch(self.input_batch, padded_num_reqs, self.device)
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
require_struct_decoding, grammar_bitmask_padded, arange = \
|
||||
self.prepare_structured_decoding_input(logits, scheduler_output)
|
||||
logits = self.structured_decode(require_struct_decoding,
|
||||
grammar_bitmask_padded, logits,
|
||||
arange)
|
||||
selected_token_ids = self.sample_from_logits_func(
|
||||
logits, tpu_sampling_metadata)
|
||||
# NOTE (NickLucche) Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs. We can't enforce it due
|
||||
# to recompilations outside torch.compiled code, so just make sure
|
||||
# `sample_from_logits` does not modify the logits in-place.
|
||||
logprobs = self.gather_logprobs(logits, selected_token_ids) \
|
||||
if tpu_sampling_metadata.logprobs else None
|
||||
# Prepare inputs, the requests might be splitted into multiple
|
||||
# executions, combine the result of each execution.
|
||||
start_index = 0
|
||||
combined_selected_tokens: list[torch.Tensor] = []
|
||||
combined_logprobs: list[LogprobsLists] = []
|
||||
while start_index < self.input_batch.num_reqs:
|
||||
attn_metadata, logits_indices, padded_num_reqs, num_reqs,\
|
||||
end_index = self._prepare_inputs(scheduler_output, start_index)
|
||||
input_ids, inputs_embeds = self._get_model_inputs(
|
||||
self.input_ids, mm_embeds)
|
||||
xm.mark_step()
|
||||
# Run the decoder
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=scheduler_output.total_num_scheduled_tokens):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self.position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
hidden_states = self.select_hidden_states(hidden_states,
|
||||
logits_indices)
|
||||
logits = self.compute_logits(hidden_states)
|
||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
||||
from_input_batch(self.input_batch, padded_num_reqs, self.device)
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
require_struct_decoding, grammar_bitmask_padded, arange = \
|
||||
self.prepare_structured_decoding_input(logits,
|
||||
scheduler_output)
|
||||
logits = self.structured_decode(require_struct_decoding,
|
||||
grammar_bitmask_padded, logits,
|
||||
arange)
|
||||
selected_token_ids = self.sample_from_logits_func(
|
||||
logits, tpu_sampling_metadata)
|
||||
# NOTE (NickLucche) Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs. We can't enforce it
|
||||
# due to recompilations outside torch.compiled code, so just make
|
||||
# sure `sample_from_logits` does not modify the logits in-place.
|
||||
logprobs = self.gather_logprobs(logits, selected_token_ids) \
|
||||
if tpu_sampling_metadata.logprobs else None
|
||||
|
||||
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||
logprobs_lists = logprobs.tolists() \
|
||||
if tpu_sampling_metadata.logprobs else None
|
||||
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||
|
||||
combined_selected_tokens.append(selected_token_ids)
|
||||
if tpu_sampling_metadata.logprobs:
|
||||
combined_logprobs.append(logprobs.tolists())
|
||||
|
||||
start_index = end_index
|
||||
|
||||
selected_token_ids = torch.cat(combined_selected_tokens, dim=0)
|
||||
if tpu_sampling_metadata.logprobs:
|
||||
|
||||
def concat_lists(input_lists):
|
||||
result = []
|
||||
for input_list in input_lists:
|
||||
result.extend(input_list)
|
||||
return result
|
||||
|
||||
logprobs_lists = LogprobsLists(logprob_token_ids=concat_lists(
|
||||
[lp.logprob_token_ids for lp in combined_logprobs]),
|
||||
logprobs=concat_lists([
|
||||
lp.logprobs
|
||||
for lp in combined_logprobs
|
||||
]),
|
||||
sampled_token_ranks=concat_lists([
|
||||
lp.sampled_token_ranks
|
||||
for lp in combined_logprobs
|
||||
]))
|
||||
else:
|
||||
logprobs_lists = None
|
||||
|
||||
# Update the cache state concurrently. Code above will not block until
|
||||
# we use `selected_token_ids`. Add mark_step if post-processing changes
|
||||
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
|
||||
discard_sampled_tokens_req_indices = []
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||
assert req_id is not None
|
||||
req_state = self.requests[req_id]
|
||||
@ -1020,7 +1105,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.sampler = TPUSampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def _dummy_run(self, num_tokens: int) -> None:
|
||||
def _dummy_run(self, num_tokens: int, num_reqs: int,
|
||||
num_blocks: int) -> None:
|
||||
if self.is_multimodal_model:
|
||||
input_ids = None
|
||||
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
|
||||
@ -1030,20 +1116,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_ids = torch.zeros((num_tokens),
|
||||
dtype=torch.int32).to(self.device)
|
||||
inputs_embeds = None
|
||||
actual_num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
actual_num_reqs = min(num_tokens, num_reqs)
|
||||
position_ids = torch.zeros(num_tokens,
|
||||
dtype=torch.int32).to(self.device)
|
||||
slot_mapping = torch.zeros(num_tokens,
|
||||
dtype=torch.int64).to(self.device)
|
||||
block_tables = torch.zeros(
|
||||
(self.max_num_reqs, self.block_table_cpu.shape[1]),
|
||||
dtype=torch.int32).to(self.device)
|
||||
query_lens = [1] * self.max_num_reqs
|
||||
block_tables = torch.zeros((num_reqs, num_blocks),
|
||||
dtype=torch.int32).to(self.device)
|
||||
query_lens = [1] * num_reqs
|
||||
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
|
||||
dtype=torch.int32),
|
||||
dim=0,
|
||||
dtype=torch.int32).to(self.device)
|
||||
context_lens = torch.ones((self.max_num_reqs, ),
|
||||
context_lens = torch.ones((num_reqs, ),
|
||||
dtype=torch.int32).to(self.device)
|
||||
num_seqs = torch.tensor([actual_num_reqs],
|
||||
dtype=torch.int32).to(self.device)
|
||||
@ -1061,6 +1146,9 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
torch._dynamo.mark_dynamic(input_ids, 0)
|
||||
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.block_tables, (0, 1))
|
||||
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
|
||||
|
||||
layer_names = get_layers_from_vllm_config(self.vllm_config,
|
||||
Attention).keys()
|
||||
@ -1152,7 +1240,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
start = time.perf_counter()
|
||||
for num_tokens in self.num_tokens_paddings:
|
||||
logger.info(" -- num_tokens: %d", num_tokens)
|
||||
self._dummy_run(num_tokens)
|
||||
self._dummy_run(num_tokens, self.num_reqs_max_model_len,
|
||||
self.max_num_blocks_per_req)
|
||||
if self.most_model_len is not None:
|
||||
self._dummy_run(num_tokens, self.num_reqs_most_model_len,
|
||||
self.num_blocks_per_most_len_req)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in %.2f [secs].", end - start)
|
||||
@ -1341,7 +1433,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||
|
||||
# Trigger compilation for general shape.
|
||||
self._dummy_run(num_tokens)
|
||||
self._dummy_run(num_tokens, self.num_reqs_max_model_len,
|
||||
self.max_num_blocks_per_req)
|
||||
if self.most_model_len is not None:
|
||||
self._dummy_run(num_tokens, self.num_reqs_most_model_len,
|
||||
self.num_blocks_per_most_len_req)
|
||||
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user