This commit is contained in:
Alexander Matveev 2025-02-05 16:54:57 +00:00
parent 7be649256f
commit 2b0526fa15

View File

@ -100,11 +100,6 @@ class TPUModelRunner(ModelRunnerBase):
# Used to initialize positions / context_lens / seq_lens
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
# Cached lists
self.req_ids = []
self.prompt_token_ids = []
self.sampled_token_ids = []
def _get_prompts_and_decodes(
self,
scheduler_output: "SchedulerOutput",
@ -170,11 +165,22 @@ class TPUModelRunner(ModelRunnerBase):
seq_len = num_computed_tokens + prompt_len
padded_seq_len = num_computed_tokens + padded_prompt_len
print("_prepare_prompt:")
print(" prompt_len = {}".format(prompt_len))
print(" padded_prompt_len = {}".format(padded_prompt_len))
print(" num_computed_tokens = {}".format(num_computed_tokens))
print(" num_prompt_tokens = {}".format(num_prompt_tokens))
print(" seq_len = {}".format(seq_len))
print(" padded_seq_len = {}".format(padded_seq_len))
# Input tokens
input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[
req_index, num_computed_tokens:padded_seq_len]
input_tokens_cpu[prompt_len:] = 0
print(" input_tokens_cpu.shape = {} val = {}".format(
input_tokens_cpu.shape, input_tokens_cpu))
# Input positions
input_positions_np = self.input_positions_np[:padded_prompt_len]
np.add(num_computed_tokens,
@ -182,6 +188,9 @@ class TPUModelRunner(ModelRunnerBase):
out=input_positions_np)
input_positions_np[prompt_len:] = 0
print(" input_positions_np.shape = {} val = {}".format(
input_positions_np.shape, input_positions_np))
# Slot mapping
block_table_np = \
self.input_batch.block_table.get_numpy_array()
@ -195,12 +204,17 @@ class TPUModelRunner(ModelRunnerBase):
out=slot_mapping_np)
slot_mapping_np[prompt_len:] = _PAD_SLOT_ID
print(" slot_mapping_np.shape = {} val = {}".format(
slot_mapping_np.shape, slot_mapping_np))
# Block table
block_table_cpu = None
if num_computed_tokens > 0:
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_table_cpu = block_table_cpu[req_index]
print(" block_table_cpu = {}".format(block_table_cpu))
# Context len
self.prompt_context_lens_cpu[0] = 0
if num_computed_tokens > 0:
@ -222,6 +236,18 @@ class TPUModelRunner(ModelRunnerBase):
effective_query_lens = self.prompt_effective_query_lens_cpu.to(
self.device)
print(" input_tokens.shape = {} val = {}".format(
input_tokens.shape, input_tokens))
print(" input_positions.shape = {} val = {}".format(
input_positions.shape, input_positions))
print(" slot_mapping.shape = {} val = {}".format(
slot_mapping.shape, slot_mapping))
print(" block_table = {}".format(block_table))
print(" context_lens.shape = {} val = {}".format(
context_lens.shape, context_lens))
print(" effective_query_lens.shape = {} val = {}".format(
effective_query_lens.shape, effective_query_lens))
# Attn metadata
attn_metadata = PallasMetadata(
num_prefills=1,
@ -372,17 +398,18 @@ class TPUModelRunner(ModelRunnerBase):
# Init
num_prompts = len(pd_info.prompt_req_ids)
num_decodes = len(pd_info.decode_req_ids)
decode_token_ids_list = None
decode_data = None
self.req_ids.clear()
self.prompt_token_ids.clear()
self.sampled_token_ids.clear()
prompt_sampled_token_ids = []
decode_sampled_token_ids = []
sampled_token_ids = []
# Run each prompt individually
is_first = True
for i in range(num_prompts):
req_id = pd_info.prompt_req_ids[i]
req_index = num_decodes + i
assert req_index == self.input_batch.req_id_to_index[
req_id] # TODO: Remove
req_state = self.requests[req_id]
num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i]
prompt_len = num_scheduled_tokens
@ -419,8 +446,12 @@ class TPUModelRunner(ModelRunnerBase):
# Get output token
token_id = selected_token_ids_cpu[prompt_len - 1].item()
self.prompt_token_ids.append(token_id)
prompt_sampled_token_ids.append(token_id)
print(
" -- Got token_id = {} for prompt_len = {} req_id = {} req_index = {} selected_token_ids_cpu = {}"
.format(token_id, prompt_len, req_id, req_index,
selected_token_ids_cpu))
# Add output token to the request
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
self.input_batch.num_tokens[req_index] += 1
@ -451,29 +482,28 @@ class TPUModelRunner(ModelRunnerBase):
for i in range(num_decodes):
req_id = pd_info.decode_req_ids[i]
req_index = i
assert req_index == self.input_batch.req_id_to_index[
req_id] # TODO: Remove
req_state = self.requests[req_id]
seq_len = req_state.num_computed_tokens + 1
token_id = decode_token_ids_list[i]
decode_sampled_token_ids.append(token_id)
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
self.input_batch.num_tokens[req_index] += 1
req_state.output_token_ids.append(token_id)
# Create final req_id => token lists.
# This must match the actual batch index positions,
# so we put decodes first and then prompts.
self.req_ids.extend(pd_info.decode_req_ids)
self.req_ids.extend(pd_info.prompt_req_ids)
if decode_token_ids_list is not None:
self.sampled_token_ids.extend(decode_token_ids_list)
self.sampled_token_ids.extend(self.prompt_token_ids)
# Create the final sampled token id list. This must match the actual
# batch index positions, so we put decodes first and then prompts.
sampled_token_ids.extend(decode_sampled_token_ids)
sampled_token_ids.extend(prompt_sampled_token_ids)
# Create output
model_runner_output = ModelRunnerOutput(
req_ids=self.req_ids,
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=self.sampled_token_ids,
sampled_token_ids=sampled_token_ids,
logprob_token_ids_cpu=None,
logprobs_cpu=None,
)