mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-11 10:17:03 +08:00
works
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
parent
950f349492
commit
248c5b632d
@ -20,7 +20,7 @@ TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
EXPECTED_VALUE = 0.58
|
||||
DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"]
|
||||
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
|
||||
MORE_ARGS_LIST = [
|
||||
[], # Default
|
||||
["--enable-chunked-prefill"], # Chunked
|
||||
|
||||
@ -101,17 +101,10 @@ class TpuPlatform(Platform):
|
||||
|
||||
# Adjust scheduler config for V1
|
||||
# TODO: Add support for these
|
||||
if envs.VLLM_USE_V1:
|
||||
if vllm_config.cache_config.enable_prefix_caching:
|
||||
logger.warning("[V1][TPU] Disable prefix caching")
|
||||
vllm_config.cache_config.enable_prefix_caching = False
|
||||
if envs.VLLM_USE_V1 and vllm_config.cache_config.enable_prefix_caching:
|
||||
logger.warning("[V1][TPU] Disable prefix caching")
|
||||
vllm_config.cache_config.enable_prefix_caching = False
|
||||
|
||||
if vllm_config.scheduler_config.chunked_prefill_enabled:
|
||||
logger.warning("[V1][TPU] Disable chunked prefill")
|
||||
vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
|
||||
assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
|
||||
"Chunked prefill is not yet supported for TPU backend")
|
||||
assert not vllm_config.speculative_config, (
|
||||
"Speculative decoding is not yet supported for TPU backend")
|
||||
|
||||
|
||||
@ -100,10 +100,6 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table.commit(num_reqs)
|
||||
|
||||
req_ids = []
|
||||
prompt_lens = []
|
||||
input_tokens_list = []
|
||||
@ -135,20 +131,28 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
|
||||
# Seq len
|
||||
seq_len = num_computed_tokens + prompt_len
|
||||
padded_seq_len = num_computed_tokens + padded_prompt_len
|
||||
|
||||
# Input tokens
|
||||
input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[
|
||||
req_index, num_computed_tokens:padded_seq_len].reshape(1, -1))
|
||||
input_tokens[:, prompt_len:] = 0
|
||||
input_tokens = torch.zeros((1, padded_prompt_len),
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
input_tokens[:, :prompt_len] = torch.from_numpy(
|
||||
self.input_batch.token_ids_cpu[req_index,
|
||||
num_computed_tokens:seq_len])
|
||||
# input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[
|
||||
# req_index, num_computed_tokens:padded_seq_len].reshape(1, -1))
|
||||
# input_tokens[:, prompt_len:] = 0
|
||||
input_tokens_list.append(input_tokens.to(self.device))
|
||||
|
||||
# Input positions
|
||||
input_positions = self.prefill_input_positions[:,
|
||||
num_computed_tokens:
|
||||
padded_seq_len].clone(
|
||||
)
|
||||
input_positions[:, prompt_len:] = 0
|
||||
input_positions = torch.zeros((1, padded_prompt_len),
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
input_positions[:, :
|
||||
prompt_len] = self.prefill_input_positions[:,
|
||||
num_computed_tokens:
|
||||
seq_len]
|
||||
# input_positions[:, prompt_len:] = 0
|
||||
input_positions_list.append(input_positions.to(self.device))
|
||||
|
||||
# Slot mapping
|
||||
@ -167,8 +171,8 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
# Block table
|
||||
block_table = None
|
||||
if num_computed_tokens > 0:
|
||||
block_table = self.input_batch.block_table.get_device_tensor()
|
||||
block_table = block_table[req_index].unsqueeze(0)
|
||||
block_table = block_table_cpu_tensor[req_index].unsqueeze(0)
|
||||
block_table = block_table.to(self.device)
|
||||
|
||||
# Context len
|
||||
context_len = 0
|
||||
@ -197,6 +201,19 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
effective_query_lens=effective_query_lens.to(self.device),
|
||||
))
|
||||
|
||||
# TODO: Remove this
|
||||
# if num_computed_tokens > 0:
|
||||
# print("-------------------")
|
||||
# print("input_tokens.shape = {}".format(input_tokens.shape))
|
||||
# print("input_positions.shape = {}".format(
|
||||
# input_positions.shape))
|
||||
# print("slot_mapping.shape = {}".format(slot_mapping.shape))
|
||||
# print("block_table.shape = {}".format(block_table.shape))
|
||||
# print("context_lens.shape = {} data = {}".format(
|
||||
# context_lens.shape, context_lens))
|
||||
# print("effective_query_lens.shape = {} data = {}".format(
|
||||
# effective_query_lens.shape, effective_query_lens))
|
||||
|
||||
return PromptInputData(
|
||||
req_ids=req_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user