Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev 2025-01-28 22:50:24 +00:00
parent 950f349492
commit 248c5b632d
3 changed files with 36 additions and 26 deletions

View File

@ -20,7 +20,7 @@ TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUE = 0.58
DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"]
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
MORE_ARGS_LIST = [
[], # Default
["--enable-chunked-prefill"], # Chunked

View File

@ -101,17 +101,10 @@ class TpuPlatform(Platform):
# Adjust scheduler config for V1
# TODO: Add support for these
if envs.VLLM_USE_V1:
if vllm_config.cache_config.enable_prefix_caching:
logger.warning("[V1][TPU] Disable prefix caching")
vllm_config.cache_config.enable_prefix_caching = False
if envs.VLLM_USE_V1 and vllm_config.cache_config.enable_prefix_caching:
logger.warning("[V1][TPU] Disable prefix caching")
vllm_config.cache_config.enable_prefix_caching = False
if vllm_config.scheduler_config.chunked_prefill_enabled:
logger.warning("[V1][TPU] Disable chunked prefill")
vllm_config.scheduler_config.chunked_prefill_enabled = False
assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
"Chunked prefill is not yet supported for TPU backend")
assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")

View File

@ -100,10 +100,6 @@ class TPUModelRunner(ModelRunnerBase):
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit(num_reqs)
req_ids = []
prompt_lens = []
input_tokens_list = []
@ -135,20 +131,28 @@ class TPUModelRunner(ModelRunnerBase):
# Seq len
seq_len = num_computed_tokens + prompt_len
padded_seq_len = num_computed_tokens + padded_prompt_len
# Input tokens
input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[
req_index, num_computed_tokens:padded_seq_len].reshape(1, -1))
input_tokens[:, prompt_len:] = 0
input_tokens = torch.zeros((1, padded_prompt_len),
dtype=torch.int32,
device="cpu")
input_tokens[:, :prompt_len] = torch.from_numpy(
self.input_batch.token_ids_cpu[req_index,
num_computed_tokens:seq_len])
# input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[
# req_index, num_computed_tokens:padded_seq_len].reshape(1, -1))
# input_tokens[:, prompt_len:] = 0
input_tokens_list.append(input_tokens.to(self.device))
# Input positions
input_positions = self.prefill_input_positions[:,
num_computed_tokens:
padded_seq_len].clone(
)
input_positions[:, prompt_len:] = 0
input_positions = torch.zeros((1, padded_prompt_len),
dtype=torch.int32,
device="cpu")
input_positions[:, :
prompt_len] = self.prefill_input_positions[:,
num_computed_tokens:
seq_len]
# input_positions[:, prompt_len:] = 0
input_positions_list.append(input_positions.to(self.device))
# Slot mapping
@ -167,8 +171,8 @@ class TPUModelRunner(ModelRunnerBase):
# Block table
block_table = None
if num_computed_tokens > 0:
block_table = self.input_batch.block_table.get_device_tensor()
block_table = block_table[req_index].unsqueeze(0)
block_table = block_table_cpu_tensor[req_index].unsqueeze(0)
block_table = block_table.to(self.device)
# Context len
context_len = 0
@ -197,6 +201,19 @@ class TPUModelRunner(ModelRunnerBase):
effective_query_lens=effective_query_lens.to(self.device),
))
# TODO: Remove this
# if num_computed_tokens > 0:
# print("-------------------")
# print("input_tokens.shape = {}".format(input_tokens.shape))
# print("input_positions.shape = {}".format(
# input_positions.shape))
# print("slot_mapping.shape = {}".format(slot_mapping.shape))
# print("block_table.shape = {}".format(block_table.shape))
# print("context_lens.shape = {} data = {}".format(
# context_lens.shape, context_lens))
# print("effective_query_lens.shape = {} data = {}".format(
# effective_query_lens.shape, effective_query_lens))
return PromptInputData(
req_ids=req_ids,
prompt_lens=prompt_lens,