mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-17 08:17:08 +08:00
works!
This commit is contained in:
parent
7be649256f
commit
2b0526fa15
@ -100,11 +100,6 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
# Used to initialize positions / context_lens / seq_lens
|
||||
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
|
||||
|
||||
# Cached lists
|
||||
self.req_ids = []
|
||||
self.prompt_token_ids = []
|
||||
self.sampled_token_ids = []
|
||||
|
||||
def _get_prompts_and_decodes(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@ -170,11 +165,22 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
seq_len = num_computed_tokens + prompt_len
|
||||
padded_seq_len = num_computed_tokens + padded_prompt_len
|
||||
|
||||
print("_prepare_prompt:")
|
||||
print(" prompt_len = {}".format(prompt_len))
|
||||
print(" padded_prompt_len = {}".format(padded_prompt_len))
|
||||
print(" num_computed_tokens = {}".format(num_computed_tokens))
|
||||
print(" num_prompt_tokens = {}".format(num_prompt_tokens))
|
||||
print(" seq_len = {}".format(seq_len))
|
||||
print(" padded_seq_len = {}".format(padded_seq_len))
|
||||
|
||||
# Input tokens
|
||||
input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[
|
||||
req_index, num_computed_tokens:padded_seq_len]
|
||||
input_tokens_cpu[prompt_len:] = 0
|
||||
|
||||
print(" input_tokens_cpu.shape = {} val = {}".format(
|
||||
input_tokens_cpu.shape, input_tokens_cpu))
|
||||
|
||||
# Input positions
|
||||
input_positions_np = self.input_positions_np[:padded_prompt_len]
|
||||
np.add(num_computed_tokens,
|
||||
@ -182,6 +188,9 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
out=input_positions_np)
|
||||
input_positions_np[prompt_len:] = 0
|
||||
|
||||
print(" input_positions_np.shape = {} val = {}".format(
|
||||
input_positions_np.shape, input_positions_np))
|
||||
|
||||
# Slot mapping
|
||||
block_table_np = \
|
||||
self.input_batch.block_table.get_numpy_array()
|
||||
@ -195,12 +204,17 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
out=slot_mapping_np)
|
||||
slot_mapping_np[prompt_len:] = _PAD_SLOT_ID
|
||||
|
||||
print(" slot_mapping_np.shape = {} val = {}".format(
|
||||
slot_mapping_np.shape, slot_mapping_np))
|
||||
|
||||
# Block table
|
||||
block_table_cpu = None
|
||||
if num_computed_tokens > 0:
|
||||
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
||||
block_table_cpu = block_table_cpu[req_index]
|
||||
|
||||
print(" block_table_cpu = {}".format(block_table_cpu))
|
||||
|
||||
# Context len
|
||||
self.prompt_context_lens_cpu[0] = 0
|
||||
if num_computed_tokens > 0:
|
||||
@ -222,6 +236,18 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
effective_query_lens = self.prompt_effective_query_lens_cpu.to(
|
||||
self.device)
|
||||
|
||||
print(" input_tokens.shape = {} val = {}".format(
|
||||
input_tokens.shape, input_tokens))
|
||||
print(" input_positions.shape = {} val = {}".format(
|
||||
input_positions.shape, input_positions))
|
||||
print(" slot_mapping.shape = {} val = {}".format(
|
||||
slot_mapping.shape, slot_mapping))
|
||||
print(" block_table = {}".format(block_table))
|
||||
print(" context_lens.shape = {} val = {}".format(
|
||||
context_lens.shape, context_lens))
|
||||
print(" effective_query_lens.shape = {} val = {}".format(
|
||||
effective_query_lens.shape, effective_query_lens))
|
||||
|
||||
# Attn metadata
|
||||
attn_metadata = PallasMetadata(
|
||||
num_prefills=1,
|
||||
@ -372,17 +398,18 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
# Init
|
||||
num_prompts = len(pd_info.prompt_req_ids)
|
||||
num_decodes = len(pd_info.decode_req_ids)
|
||||
decode_token_ids_list = None
|
||||
decode_data = None
|
||||
self.req_ids.clear()
|
||||
self.prompt_token_ids.clear()
|
||||
self.sampled_token_ids.clear()
|
||||
prompt_sampled_token_ids = []
|
||||
decode_sampled_token_ids = []
|
||||
sampled_token_ids = []
|
||||
|
||||
# Run each prompt individually
|
||||
is_first = True
|
||||
for i in range(num_prompts):
|
||||
req_id = pd_info.prompt_req_ids[i]
|
||||
req_index = num_decodes + i
|
||||
assert req_index == self.input_batch.req_id_to_index[
|
||||
req_id] # TODO: Remove
|
||||
req_state = self.requests[req_id]
|
||||
num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i]
|
||||
prompt_len = num_scheduled_tokens
|
||||
@ -419,8 +446,12 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
|
||||
# Get output token
|
||||
token_id = selected_token_ids_cpu[prompt_len - 1].item()
|
||||
self.prompt_token_ids.append(token_id)
|
||||
prompt_sampled_token_ids.append(token_id)
|
||||
|
||||
print(
|
||||
" -- Got token_id = {} for prompt_len = {} req_id = {} req_index = {} selected_token_ids_cpu = {}"
|
||||
.format(token_id, prompt_len, req_id, req_index,
|
||||
selected_token_ids_cpu))
|
||||
# Add output token to the request
|
||||
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||
self.input_batch.num_tokens[req_index] += 1
|
||||
@ -451,29 +482,28 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
for i in range(num_decodes):
|
||||
req_id = pd_info.decode_req_ids[i]
|
||||
req_index = i
|
||||
assert req_index == self.input_batch.req_id_to_index[
|
||||
req_id] # TODO: Remove
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = req_state.num_computed_tokens + 1
|
||||
|
||||
token_id = decode_token_ids_list[i]
|
||||
decode_sampled_token_ids.append(token_id)
|
||||
|
||||
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||
self.input_batch.num_tokens[req_index] += 1
|
||||
req_state.output_token_ids.append(token_id)
|
||||
|
||||
# Create final req_id => token lists.
|
||||
# This must match the actual batch index positions,
|
||||
# so we put decodes first and then prompts.
|
||||
self.req_ids.extend(pd_info.decode_req_ids)
|
||||
self.req_ids.extend(pd_info.prompt_req_ids)
|
||||
if decode_token_ids_list is not None:
|
||||
self.sampled_token_ids.extend(decode_token_ids_list)
|
||||
self.sampled_token_ids.extend(self.prompt_token_ids)
|
||||
# Create the final sampled token id list. This must match the actual
|
||||
# batch index positions, so we put decodes first and then prompts.
|
||||
sampled_token_ids.extend(decode_sampled_token_ids)
|
||||
sampled_token_ids.extend(prompt_sampled_token_ids)
|
||||
|
||||
# Create output
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.req_ids,
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=self.sampled_token_ids,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprob_token_ids_cpu=None,
|
||||
logprobs_cpu=None,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user