From 248c5b632d8844d146b8346a11989d6f6da76004 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Tue, 28 Jan 2025 22:50:24 +0000 Subject: [PATCH] works Signed-off-by: Alexander Matveev --- tests/entrypoints/openai/test_accuracy.py | 2 +- vllm/platforms/tpu.py | 13 ++----- vllm/v1/worker/tpu_model_runner.py | 47 +++++++++++++++-------- 3 files changed, 36 insertions(+), 26 deletions(-) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index 976c8f3473d1a..d39da14c1177e 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -20,7 +20,7 @@ TASK = "gsm8k" FILTER = "exact_match,strict-match" RTOL = 0.03 EXPECTED_VALUE = 0.58 -DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"] +DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] MORE_ARGS_LIST = [ [], # Default ["--enable-chunked-prefill"], # Chunked diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 65bfff7311cd4..c0601e2afa12b 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -101,17 +101,10 @@ class TpuPlatform(Platform): # Adjust scheduler config for V1 # TODO: Add support for these - if envs.VLLM_USE_V1: - if vllm_config.cache_config.enable_prefix_caching: - logger.warning("[V1][TPU] Disable prefix caching") - vllm_config.cache_config.enable_prefix_caching = False + if envs.VLLM_USE_V1 and vllm_config.cache_config.enable_prefix_caching: + logger.warning("[V1][TPU] Disable prefix caching") + vllm_config.cache_config.enable_prefix_caching = False - if vllm_config.scheduler_config.chunked_prefill_enabled: - logger.warning("[V1][TPU] Disable chunked prefill") - vllm_config.scheduler_config.chunked_prefill_enabled = False - - assert not vllm_config.scheduler_config.chunked_prefill_enabled, ( - "Chunked prefill is not yet supported for TPU backend") assert not vllm_config.speculative_config, ( "Speculative decoding is not yet supported for TPU backend") diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index cfe4792a95d19..5baa086507b34 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -100,10 +100,6 @@ class TPUModelRunner(ModelRunnerBase): num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - # OPTIMIZATION: Start copying the block table first. - # This way, we can overlap the copy with the following CPU operations. - self.input_batch.block_table.commit(num_reqs) - req_ids = [] prompt_lens = [] input_tokens_list = [] @@ -135,20 +131,28 @@ class TPUModelRunner(ModelRunnerBase): # Seq len seq_len = num_computed_tokens + prompt_len - padded_seq_len = num_computed_tokens + padded_prompt_len # Input tokens - input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[ - req_index, num_computed_tokens:padded_seq_len].reshape(1, -1)) - input_tokens[:, prompt_len:] = 0 + input_tokens = torch.zeros((1, padded_prompt_len), + dtype=torch.int32, + device="cpu") + input_tokens[:, :prompt_len] = torch.from_numpy( + self.input_batch.token_ids_cpu[req_index, + num_computed_tokens:seq_len]) + # input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[ + # req_index, num_computed_tokens:padded_seq_len].reshape(1, -1)) + # input_tokens[:, prompt_len:] = 0 input_tokens_list.append(input_tokens.to(self.device)) # Input positions - input_positions = self.prefill_input_positions[:, - num_computed_tokens: - padded_seq_len].clone( - ) - input_positions[:, prompt_len:] = 0 + input_positions = torch.zeros((1, padded_prompt_len), + dtype=torch.int32, + device="cpu") + input_positions[:, : + prompt_len] = self.prefill_input_positions[:, + num_computed_tokens: + seq_len] + # input_positions[:, prompt_len:] = 0 input_positions_list.append(input_positions.to(self.device)) # Slot mapping @@ -167,8 +171,8 @@ class TPUModelRunner(ModelRunnerBase): # Block table block_table = None if num_computed_tokens > 0: - block_table = self.input_batch.block_table.get_device_tensor() - block_table = block_table[req_index].unsqueeze(0) + block_table = block_table_cpu_tensor[req_index].unsqueeze(0) + block_table = block_table.to(self.device) # Context len context_len = 0 @@ -197,6 +201,19 @@ class TPUModelRunner(ModelRunnerBase): effective_query_lens=effective_query_lens.to(self.device), )) + # TODO: Remove this + # if num_computed_tokens > 0: + # print("-------------------") + # print("input_tokens.shape = {}".format(input_tokens.shape)) + # print("input_positions.shape = {}".format( + # input_positions.shape)) + # print("slot_mapping.shape = {}".format(slot_mapping.shape)) + # print("block_table.shape = {}".format(block_table.shape)) + # print("context_lens.shape = {} data = {}".format( + # context_lens.shape, context_lens)) + # print("effective_query_lens.shape = {} data = {}".format( + # effective_query_lens.shape, effective_query_lens)) + return PromptInputData( req_ids=req_ids, prompt_lens=prompt_lens,