mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-06 09:25:45 +08:00
works!
This commit is contained in:
parent
7be649256f
commit
2b0526fa15
@ -100,11 +100,6 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
# Used to initialize positions / context_lens / seq_lens
|
# Used to initialize positions / context_lens / seq_lens
|
||||||
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
|
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
|
||||||
|
|
||||||
# Cached lists
|
|
||||||
self.req_ids = []
|
|
||||||
self.prompt_token_ids = []
|
|
||||||
self.sampled_token_ids = []
|
|
||||||
|
|
||||||
def _get_prompts_and_decodes(
|
def _get_prompts_and_decodes(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
@ -170,11 +165,22 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
seq_len = num_computed_tokens + prompt_len
|
seq_len = num_computed_tokens + prompt_len
|
||||||
padded_seq_len = num_computed_tokens + padded_prompt_len
|
padded_seq_len = num_computed_tokens + padded_prompt_len
|
||||||
|
|
||||||
|
print("_prepare_prompt:")
|
||||||
|
print(" prompt_len = {}".format(prompt_len))
|
||||||
|
print(" padded_prompt_len = {}".format(padded_prompt_len))
|
||||||
|
print(" num_computed_tokens = {}".format(num_computed_tokens))
|
||||||
|
print(" num_prompt_tokens = {}".format(num_prompt_tokens))
|
||||||
|
print(" seq_len = {}".format(seq_len))
|
||||||
|
print(" padded_seq_len = {}".format(padded_seq_len))
|
||||||
|
|
||||||
# Input tokens
|
# Input tokens
|
||||||
input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[
|
input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[
|
||||||
req_index, num_computed_tokens:padded_seq_len]
|
req_index, num_computed_tokens:padded_seq_len]
|
||||||
input_tokens_cpu[prompt_len:] = 0
|
input_tokens_cpu[prompt_len:] = 0
|
||||||
|
|
||||||
|
print(" input_tokens_cpu.shape = {} val = {}".format(
|
||||||
|
input_tokens_cpu.shape, input_tokens_cpu))
|
||||||
|
|
||||||
# Input positions
|
# Input positions
|
||||||
input_positions_np = self.input_positions_np[:padded_prompt_len]
|
input_positions_np = self.input_positions_np[:padded_prompt_len]
|
||||||
np.add(num_computed_tokens,
|
np.add(num_computed_tokens,
|
||||||
@ -182,6 +188,9 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
out=input_positions_np)
|
out=input_positions_np)
|
||||||
input_positions_np[prompt_len:] = 0
|
input_positions_np[prompt_len:] = 0
|
||||||
|
|
||||||
|
print(" input_positions_np.shape = {} val = {}".format(
|
||||||
|
input_positions_np.shape, input_positions_np))
|
||||||
|
|
||||||
# Slot mapping
|
# Slot mapping
|
||||||
block_table_np = \
|
block_table_np = \
|
||||||
self.input_batch.block_table.get_numpy_array()
|
self.input_batch.block_table.get_numpy_array()
|
||||||
@ -195,12 +204,17 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
out=slot_mapping_np)
|
out=slot_mapping_np)
|
||||||
slot_mapping_np[prompt_len:] = _PAD_SLOT_ID
|
slot_mapping_np[prompt_len:] = _PAD_SLOT_ID
|
||||||
|
|
||||||
|
print(" slot_mapping_np.shape = {} val = {}".format(
|
||||||
|
slot_mapping_np.shape, slot_mapping_np))
|
||||||
|
|
||||||
# Block table
|
# Block table
|
||||||
block_table_cpu = None
|
block_table_cpu = None
|
||||||
if num_computed_tokens > 0:
|
if num_computed_tokens > 0:
|
||||||
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
||||||
block_table_cpu = block_table_cpu[req_index]
|
block_table_cpu = block_table_cpu[req_index]
|
||||||
|
|
||||||
|
print(" block_table_cpu = {}".format(block_table_cpu))
|
||||||
|
|
||||||
# Context len
|
# Context len
|
||||||
self.prompt_context_lens_cpu[0] = 0
|
self.prompt_context_lens_cpu[0] = 0
|
||||||
if num_computed_tokens > 0:
|
if num_computed_tokens > 0:
|
||||||
@ -222,6 +236,18 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
effective_query_lens = self.prompt_effective_query_lens_cpu.to(
|
effective_query_lens = self.prompt_effective_query_lens_cpu.to(
|
||||||
self.device)
|
self.device)
|
||||||
|
|
||||||
|
print(" input_tokens.shape = {} val = {}".format(
|
||||||
|
input_tokens.shape, input_tokens))
|
||||||
|
print(" input_positions.shape = {} val = {}".format(
|
||||||
|
input_positions.shape, input_positions))
|
||||||
|
print(" slot_mapping.shape = {} val = {}".format(
|
||||||
|
slot_mapping.shape, slot_mapping))
|
||||||
|
print(" block_table = {}".format(block_table))
|
||||||
|
print(" context_lens.shape = {} val = {}".format(
|
||||||
|
context_lens.shape, context_lens))
|
||||||
|
print(" effective_query_lens.shape = {} val = {}".format(
|
||||||
|
effective_query_lens.shape, effective_query_lens))
|
||||||
|
|
||||||
# Attn metadata
|
# Attn metadata
|
||||||
attn_metadata = PallasMetadata(
|
attn_metadata = PallasMetadata(
|
||||||
num_prefills=1,
|
num_prefills=1,
|
||||||
@ -372,17 +398,18 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
# Init
|
# Init
|
||||||
num_prompts = len(pd_info.prompt_req_ids)
|
num_prompts = len(pd_info.prompt_req_ids)
|
||||||
num_decodes = len(pd_info.decode_req_ids)
|
num_decodes = len(pd_info.decode_req_ids)
|
||||||
decode_token_ids_list = None
|
|
||||||
decode_data = None
|
decode_data = None
|
||||||
self.req_ids.clear()
|
prompt_sampled_token_ids = []
|
||||||
self.prompt_token_ids.clear()
|
decode_sampled_token_ids = []
|
||||||
self.sampled_token_ids.clear()
|
sampled_token_ids = []
|
||||||
|
|
||||||
# Run each prompt individually
|
# Run each prompt individually
|
||||||
is_first = True
|
is_first = True
|
||||||
for i in range(num_prompts):
|
for i in range(num_prompts):
|
||||||
req_id = pd_info.prompt_req_ids[i]
|
req_id = pd_info.prompt_req_ids[i]
|
||||||
req_index = num_decodes + i
|
req_index = num_decodes + i
|
||||||
|
assert req_index == self.input_batch.req_id_to_index[
|
||||||
|
req_id] # TODO: Remove
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i]
|
num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i]
|
||||||
prompt_len = num_scheduled_tokens
|
prompt_len = num_scheduled_tokens
|
||||||
@ -419,8 +446,12 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
# Get output token
|
# Get output token
|
||||||
token_id = selected_token_ids_cpu[prompt_len - 1].item()
|
token_id = selected_token_ids_cpu[prompt_len - 1].item()
|
||||||
self.prompt_token_ids.append(token_id)
|
prompt_sampled_token_ids.append(token_id)
|
||||||
|
|
||||||
|
print(
|
||||||
|
" -- Got token_id = {} for prompt_len = {} req_id = {} req_index = {} selected_token_ids_cpu = {}"
|
||||||
|
.format(token_id, prompt_len, req_id, req_index,
|
||||||
|
selected_token_ids_cpu))
|
||||||
# Add output token to the request
|
# Add output token to the request
|
||||||
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||||
self.input_batch.num_tokens[req_index] += 1
|
self.input_batch.num_tokens[req_index] += 1
|
||||||
@ -451,29 +482,28 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
for i in range(num_decodes):
|
for i in range(num_decodes):
|
||||||
req_id = pd_info.decode_req_ids[i]
|
req_id = pd_info.decode_req_ids[i]
|
||||||
req_index = i
|
req_index = i
|
||||||
|
assert req_index == self.input_batch.req_id_to_index[
|
||||||
|
req_id] # TODO: Remove
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
seq_len = req_state.num_computed_tokens + 1
|
seq_len = req_state.num_computed_tokens + 1
|
||||||
|
|
||||||
token_id = decode_token_ids_list[i]
|
token_id = decode_token_ids_list[i]
|
||||||
|
decode_sampled_token_ids.append(token_id)
|
||||||
|
|
||||||
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||||
self.input_batch.num_tokens[req_index] += 1
|
self.input_batch.num_tokens[req_index] += 1
|
||||||
req_state.output_token_ids.append(token_id)
|
req_state.output_token_ids.append(token_id)
|
||||||
|
|
||||||
# Create final req_id => token lists.
|
# Create the final sampled token id list. This must match the actual
|
||||||
# This must match the actual batch index positions,
|
# batch index positions, so we put decodes first and then prompts.
|
||||||
# so we put decodes first and then prompts.
|
sampled_token_ids.extend(decode_sampled_token_ids)
|
||||||
self.req_ids.extend(pd_info.decode_req_ids)
|
sampled_token_ids.extend(prompt_sampled_token_ids)
|
||||||
self.req_ids.extend(pd_info.prompt_req_ids)
|
|
||||||
if decode_token_ids_list is not None:
|
|
||||||
self.sampled_token_ids.extend(decode_token_ids_list)
|
|
||||||
self.sampled_token_ids.extend(self.prompt_token_ids)
|
|
||||||
|
|
||||||
# Create output
|
# Create output
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=self.req_ids,
|
req_ids=self.input_batch.req_ids,
|
||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
sampled_token_ids=self.sampled_token_ids,
|
sampled_token_ids=sampled_token_ids,
|
||||||
logprob_token_ids_cpu=None,
|
logprob_token_ids_cpu=None,
|
||||||
logprobs_cpu=None,
|
logprobs_cpu=None,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user