mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 00:08:04 +08:00
swap works!
This commit is contained in:
parent
2b0526fa15
commit
996b92ccb4
@ -448,8 +448,8 @@ def swap_positions(b: InputBatch, id_1, id_2):
|
||||
assert id_2 == b.req_id_to_index[req_id_2]
|
||||
|
||||
b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1]
|
||||
b.req_id_to_index[id_1], b.req_id_to_index[id_2] = b.req_id_to_index[
|
||||
id_2], b.req_id_to_index[id_1]
|
||||
b.req_id_to_index[req_id_1], b.req_id_to_index[
|
||||
req_id_2] = b.req_id_to_index[req_id_2], b.req_id_to_index[req_id_1]
|
||||
|
||||
ids = [id_1, id_2]
|
||||
rev_ids = [id_2, id_1]
|
||||
@ -471,8 +471,13 @@ def swap_positions(b: InputBatch, id_1, id_2):
|
||||
id_1]
|
||||
b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[
|
||||
id_2], b.stop_token_ids[id_1]
|
||||
b.generators[id_1], b.generators[id_2] = b.generators[id_2], b.generators[
|
||||
id_1]
|
||||
|
||||
gen_1 = b.generators.pop(id_1, None)
|
||||
gen_2 = b.generators.pop(id_2, None)
|
||||
if gen_1 is not None:
|
||||
b.generators[id_2] = gen_1
|
||||
if gen_2 is not None:
|
||||
b.generators[id_1] = gen_2
|
||||
|
||||
|
||||
def ensure_decodes_first(b: InputBatch):
|
||||
@ -504,6 +509,4 @@ def ensure_decodes_first(b: InputBatch):
|
||||
break
|
||||
|
||||
# Swap
|
||||
print("Swapping first_prompt_index = {} with last_decode_index = {}".
|
||||
format(first_prompt_index, last_decode_index))
|
||||
swap_positions(b, first_prompt_index, last_decode_index)
|
||||
|
||||
@ -69,37 +69,57 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||
|
||||
# Cached torch/numpy tensors
|
||||
self.input_ids_cpu = torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
self.input_ids_np = self.input_ids_cpu.numpy()
|
||||
self.num_swaps = 2
|
||||
self.cur_swap_id = 0
|
||||
self.input_ids_cpu = []
|
||||
self.input_ids_np = []
|
||||
self.input_positions_cpu = []
|
||||
self.input_positions_np = []
|
||||
self.slot_mapping_cpu = []
|
||||
self.slot_mapping_np = []
|
||||
self.prompt_context_lens_cpu = []
|
||||
self.prompt_effective_query_lens_cpu = []
|
||||
self.decode_context_lens_cpu = []
|
||||
self.decode_context_lens_np = []
|
||||
for _ in range(self.num_swaps):
|
||||
self.input_ids_cpu.append(
|
||||
torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu"))
|
||||
self.input_ids_np.append(self.input_ids_cpu[-1].numpy())
|
||||
|
||||
self.input_positions_cpu = torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
self.input_positions_np = self.input_positions_cpu.numpy()
|
||||
self.input_positions_cpu.append(
|
||||
torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu"))
|
||||
self.input_positions_np.append(
|
||||
self.input_positions_cpu[-1].numpy())
|
||||
|
||||
self.slot_mapping_cpu = torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu")
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
self.slot_mapping_cpu.append(
|
||||
torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu"))
|
||||
self.slot_mapping_np.append(self.slot_mapping_cpu[-1].numpy())
|
||||
|
||||
self.prompt_context_lens_cpu = torch.empty((1),
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
self.prompt_effective_query_lens_cpu = torch.empty((1),
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
self.prompt_context_lens_cpu.append(
|
||||
torch.empty((1), dtype=torch.int32, device="cpu"))
|
||||
self.prompt_effective_query_lens_cpu.append(
|
||||
torch.empty((1), dtype=torch.int32, device="cpu"))
|
||||
|
||||
self.decode_context_lens_cpu = torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
self.decode_context_lens_np = self.decode_context_lens_cpu.numpy()
|
||||
self.decode_context_lens_cpu.append(
|
||||
torch.empty(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu"))
|
||||
self.decode_context_lens_np.append(
|
||||
self.decode_context_lens_cpu[-1].numpy())
|
||||
|
||||
# Range tensor with values [0 .. self.max_num_tokens - 1].
|
||||
# Used to initialize positions / context_lens / seq_lens
|
||||
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
|
||||
|
||||
def swap_step(self):
|
||||
self.cur_swap_id = (self.cur_swap_id + 1) % self.num_swaps
|
||||
|
||||
def _get_prompts_and_decodes(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@ -165,31 +185,35 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
seq_len = num_computed_tokens + prompt_len
|
||||
padded_seq_len = num_computed_tokens + padded_prompt_len
|
||||
|
||||
print("_prepare_prompt:")
|
||||
print(" prompt_len = {}".format(prompt_len))
|
||||
print(" padded_prompt_len = {}".format(padded_prompt_len))
|
||||
print(" num_computed_tokens = {}".format(num_computed_tokens))
|
||||
print(" num_prompt_tokens = {}".format(num_prompt_tokens))
|
||||
print(" seq_len = {}".format(seq_len))
|
||||
print(" padded_seq_len = {}".format(padded_seq_len))
|
||||
# DEBUG
|
||||
# print("_prepare_prompt:")
|
||||
# print(" prompt_len = {}".format(prompt_len))
|
||||
# print(" padded_prompt_len = {}".format(padded_prompt_len))
|
||||
# print(" num_computed_tokens = {}".format(num_computed_tokens))
|
||||
# print(" num_prompt_tokens = {}".format(num_prompt_tokens))
|
||||
# print(" seq_len = {}".format(seq_len))
|
||||
# print(" padded_seq_len = {}".format(padded_seq_len))
|
||||
|
||||
# Input tokens
|
||||
input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[
|
||||
req_index, num_computed_tokens:padded_seq_len]
|
||||
input_tokens_cpu[prompt_len:] = 0
|
||||
|
||||
print(" input_tokens_cpu.shape = {} val = {}".format(
|
||||
input_tokens_cpu.shape, input_tokens_cpu))
|
||||
# DEBUG
|
||||
# print(" input_tokens_cpu.shape = {} val = {}".format(
|
||||
# input_tokens_cpu.shape, input_tokens_cpu))
|
||||
|
||||
# Input positions
|
||||
input_positions_np = self.input_positions_np[:padded_prompt_len]
|
||||
input_positions_np = self.input_positions_np[
|
||||
self.cur_swap_id][:padded_prompt_len]
|
||||
np.add(num_computed_tokens,
|
||||
self.arange_np[:padded_prompt_len],
|
||||
out=input_positions_np)
|
||||
input_positions_np[prompt_len:] = 0
|
||||
|
||||
print(" input_positions_np.shape = {} val = {}".format(
|
||||
input_positions_np.shape, input_positions_np))
|
||||
# DEBUG
|
||||
# print(" input_positions_np.shape = {} val = {}".format(
|
||||
# input_positions_np.shape, input_positions_np))
|
||||
|
||||
# Slot mapping
|
||||
block_table_np = \
|
||||
@ -198,14 +222,16 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
self.block_size]
|
||||
block_offsets_np = input_positions_np % self.block_size
|
||||
|
||||
slot_mapping_np = self.slot_mapping_np[:padded_prompt_len]
|
||||
slot_mapping_np = self.slot_mapping_np[
|
||||
self.cur_swap_id][:padded_prompt_len]
|
||||
np.add(block_numbers_np * self.block_size,
|
||||
block_offsets_np,
|
||||
out=slot_mapping_np)
|
||||
slot_mapping_np[prompt_len:] = _PAD_SLOT_ID
|
||||
|
||||
print(" slot_mapping_np.shape = {} val = {}".format(
|
||||
slot_mapping_np.shape, slot_mapping_np))
|
||||
# DEBUG
|
||||
# print(" slot_mapping_np.shape = {} val = {}".format(
|
||||
# slot_mapping_np.shape, slot_mapping_np))
|
||||
|
||||
# Block table
|
||||
block_table_cpu = None
|
||||
@ -213,40 +239,47 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
||||
block_table_cpu = block_table_cpu[req_index]
|
||||
|
||||
print(" block_table_cpu = {}".format(block_table_cpu))
|
||||
# DEBUG
|
||||
# print(" block_table_cpu = {}".format(block_table_cpu))
|
||||
|
||||
# Context len
|
||||
self.prompt_context_lens_cpu[0] = 0
|
||||
self.prompt_context_lens_cpu[self.cur_swap_id][0] = 0
|
||||
if num_computed_tokens > 0:
|
||||
self.prompt_context_lens_cpu[0] = seq_len
|
||||
self.prompt_context_lens_cpu[self.cur_swap_id][0] = seq_len
|
||||
|
||||
# Effective query len
|
||||
self.prompt_effective_query_lens_cpu[0] = prompt_len
|
||||
self.prompt_effective_query_lens_cpu[self.cur_swap_id][0] = prompt_len
|
||||
|
||||
# Get final tensors
|
||||
input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device)
|
||||
input_positions = self.input_positions_cpu[:padded_prompt_len].reshape(
|
||||
1, -1).to(self.device)
|
||||
slot_mapping = self.slot_mapping_cpu[:padded_prompt_len].reshape(
|
||||
1, -1).to(self.device)
|
||||
input_positions = self.input_positions_cpu[
|
||||
self.cur_swap_id][:padded_prompt_len].reshape(1,
|
||||
-1).to(self.device)
|
||||
slot_mapping = self.slot_mapping_cpu[
|
||||
self.cur_swap_id][:padded_prompt_len].reshape(1,
|
||||
-1).to(self.device)
|
||||
block_table = block_table_cpu.reshape(1, -1).to(
|
||||
self.device) if block_table_cpu is not None else None
|
||||
|
||||
context_lens = self.prompt_context_lens_cpu.to(self.device)
|
||||
effective_query_lens = self.prompt_effective_query_lens_cpu.to(
|
||||
context_lens = self.prompt_context_lens_cpu[self.cur_swap_id].to(
|
||||
self.device)
|
||||
effective_query_lens = self.prompt_effective_query_lens_cpu[
|
||||
self.cur_swap_id].to(self.device)
|
||||
|
||||
print(" input_tokens.shape = {} val = {}".format(
|
||||
input_tokens.shape, input_tokens))
|
||||
print(" input_positions.shape = {} val = {}".format(
|
||||
input_positions.shape, input_positions))
|
||||
print(" slot_mapping.shape = {} val = {}".format(
|
||||
slot_mapping.shape, slot_mapping))
|
||||
print(" block_table = {}".format(block_table))
|
||||
print(" context_lens.shape = {} val = {}".format(
|
||||
context_lens.shape, context_lens))
|
||||
print(" effective_query_lens.shape = {} val = {}".format(
|
||||
effective_query_lens.shape, effective_query_lens))
|
||||
self.swap_step()
|
||||
|
||||
# DEBUG
|
||||
# print(" input_tokens.shape = {} val = {}".format(
|
||||
# input_tokens.shape, input_tokens))
|
||||
# print(" input_positions.shape = {} val = {}".format(
|
||||
# input_positions.shape, input_positions))
|
||||
# print(" slot_mapping.shape = {} val = {}".format(
|
||||
# slot_mapping.shape, slot_mapping))
|
||||
# print(" block_table = {}".format(block_table))
|
||||
# print(" context_lens.shape = {} val = {}".format(
|
||||
# context_lens.shape, context_lens))
|
||||
# print(" effective_query_lens.shape = {} val = {}".format(
|
||||
# effective_query_lens.shape, effective_query_lens))
|
||||
|
||||
# Attn metadata
|
||||
attn_metadata = PallasMetadata(
|
||||
@ -275,78 +308,91 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
# Init [0 .. batch_size - 1]
|
||||
req_indices_np = self.arange_np[:padded_batch_size]
|
||||
|
||||
print("_prepare_decode:")
|
||||
print(" batch_size = {}".format(batch_size))
|
||||
print(" padded_batch_size = {}".format(padded_batch_size))
|
||||
print(" req_indices_np.shape = {} val = {}".format(
|
||||
req_indices_np.shape, req_indices_np))
|
||||
# DEBUG
|
||||
# print("_prepare_decode:")
|
||||
# print(" batch_size = {}".format(batch_size))
|
||||
# print(" padded_batch_size = {}".format(padded_batch_size))
|
||||
# print(" req_indices_np.shape = {} val = {}".format(
|
||||
# req_indices_np.shape, req_indices_np))
|
||||
|
||||
# Input positions
|
||||
input_positions_np = self.input_positions_np[:padded_batch_size]
|
||||
input_positions_np = self.input_positions_np[
|
||||
self.cur_swap_id][:padded_batch_size]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
|
||||
0,
|
||||
out=input_positions_np)
|
||||
input_positions_np[batch_size:] = 0
|
||||
input_positions_cpu = self.input_positions_cpu[:padded_batch_size]
|
||||
input_positions_cpu = self.input_positions_cpu[
|
||||
self.cur_swap_id][:padded_batch_size]
|
||||
|
||||
print(" input_positions_cpu.shape = {} data = {}".format(
|
||||
input_positions_cpu.shape, input_positions_cpu))
|
||||
# DEBUG
|
||||
# print(" input_positions_cpu.shape = {} data = {}".format(
|
||||
# input_positions_cpu.shape, input_positions_cpu))
|
||||
|
||||
# Input tokens
|
||||
token_indices_np = (
|
||||
input_positions_np +
|
||||
req_indices_np * self.input_batch.token_ids_cpu.shape[1])
|
||||
input_tokens_cpu = self.input_ids_cpu[:padded_batch_size]
|
||||
input_tokens_cpu = self.input_ids_cpu[
|
||||
self.cur_swap_id][:padded_batch_size]
|
||||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||
0,
|
||||
torch.from_numpy(token_indices_np),
|
||||
out=input_tokens_cpu)
|
||||
input_tokens_cpu[batch_size:] = 0
|
||||
|
||||
print(" token_indices_np.shape = {} val = {}".format(
|
||||
token_indices_np.shape, token_indices_np))
|
||||
|
||||
print(" input_tokens_cpu.shape = {} data = {}".format(
|
||||
input_tokens_cpu.shape, input_tokens_cpu))
|
||||
# DEBUG
|
||||
# print(" token_indices_np.shape = {} val = {}".format(
|
||||
# token_indices_np.shape, token_indices_np))
|
||||
# print(" input_tokens_cpu.shape = {} data = {}".format(
|
||||
# input_tokens_cpu.shape, input_tokens_cpu))
|
||||
|
||||
# Slot mapping
|
||||
block_table_indices_np = (
|
||||
req_indices_np * self.max_num_blocks_per_req +
|
||||
input_positions_np // self.block_size)
|
||||
|
||||
print(
|
||||
" block_table_indices_np.shape = {} data = {} max_num_blocks_per_req = {}"
|
||||
.format(block_table_indices_np.shape, block_table_indices_np,
|
||||
self.max_num_blocks_per_req))
|
||||
# DEBUG
|
||||
# print(
|
||||
# " block_table_indices_np.shape = {} data = {} max_num_blocks_per_req = {}"
|
||||
# .format(block_table_indices_np.shape, block_table_indices_np,
|
||||
# self.max_num_blocks_per_req))
|
||||
|
||||
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
||||
|
||||
print(" block_table_cpu.shape = {} data = {}".format(
|
||||
block_table_cpu.shape, block_table_cpu[:padded_batch_size, :10]))
|
||||
# DEBUG
|
||||
# print(" block_table_cpu.shape = {} data = {}".format(
|
||||
# block_table_cpu.shape, block_table_cpu[:padded_batch_size, :10]))
|
||||
|
||||
block_numbers_np = block_table_cpu.flatten(
|
||||
)[block_table_indices_np].numpy()
|
||||
|
||||
print(" block_numbers_np.shape = {} data = {}".format(
|
||||
block_numbers_np.shape, block_numbers_np))
|
||||
# DEBUG
|
||||
# print(" block_numbers_np.shape = {} data = {}".format(
|
||||
# block_numbers_np.shape, block_numbers_np))
|
||||
|
||||
block_offsets_np = input_positions_np % self.block_size
|
||||
|
||||
print(" block_offsets_np.shape = {} data = {}".format(
|
||||
block_offsets_np.shape, block_offsets_np))
|
||||
# DEBUG
|
||||
# print(" block_offsets_np.shape = {} data = {}".format(
|
||||
# block_offsets_np.shape, block_offsets_np))
|
||||
|
||||
slot_mapping_np = self.slot_mapping_np[:padded_batch_size]
|
||||
slot_mapping_np = self.slot_mapping_np[
|
||||
self.cur_swap_id][:padded_batch_size]
|
||||
np.add(block_numbers_np * self.block_size,
|
||||
block_offsets_np,
|
||||
out=slot_mapping_np)
|
||||
slot_mapping_np[batch_size:] = _PAD_SLOT_ID
|
||||
|
||||
print(" slot_mapping_np.shape = {} data = {}".format(
|
||||
slot_mapping_np.shape, slot_mapping_np))
|
||||
# DEBUG
|
||||
# print(" slot_mapping_np.shape = {} data = {}".format(
|
||||
# slot_mapping_np.shape, slot_mapping_np))
|
||||
|
||||
block_table_cpu = block_table_cpu[:padded_batch_size]
|
||||
|
||||
# Context lens
|
||||
context_lens_np = self.decode_context_lens_np[:padded_batch_size]
|
||||
context_lens_np = self.decode_context_lens_np[
|
||||
self.cur_swap_id][:padded_batch_size]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
|
||||
1,
|
||||
out=context_lens_np)
|
||||
@ -355,14 +401,18 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
# Get final tensors
|
||||
input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device)
|
||||
input_positions = input_positions_cpu.reshape(-1, 1).to(self.device)
|
||||
slot_mapping = self.slot_mapping_cpu[:padded_batch_size].reshape(
|
||||
-1, 1).to(self.device)
|
||||
slot_mapping = self.slot_mapping_cpu[
|
||||
self.cur_swap_id][:padded_batch_size].reshape(-1,
|
||||
1).to(self.device)
|
||||
block_table = block_table_cpu.to(self.device)
|
||||
context_lens = self.decode_context_lens_cpu[:padded_batch_size].to(
|
||||
self.device)
|
||||
context_lens = self.decode_context_lens_cpu[
|
||||
self.cur_swap_id][:padded_batch_size].to(self.device)
|
||||
|
||||
print(" context_lens.shape = {} val = {}".format(
|
||||
context_lens.shape, context_lens))
|
||||
self.swap_step()
|
||||
|
||||
# DEBUG
|
||||
# print(" context_lens.shape = {} val = {}".format(
|
||||
# context_lens.shape, context_lens))
|
||||
|
||||
# Attn metadata
|
||||
attn_metadata = PallasMetadata(
|
||||
@ -399,9 +449,7 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
num_prompts = len(pd_info.prompt_req_ids)
|
||||
num_decodes = len(pd_info.decode_req_ids)
|
||||
decode_data = None
|
||||
prompt_sampled_token_ids = []
|
||||
decode_sampled_token_ids = []
|
||||
sampled_token_ids = []
|
||||
sampled_token_ids = [0] * self.input_batch.num_reqs
|
||||
|
||||
# Run each prompt individually
|
||||
is_first = True
|
||||
@ -446,12 +494,14 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
|
||||
# Get output token
|
||||
token_id = selected_token_ids_cpu[prompt_len - 1].item()
|
||||
prompt_sampled_token_ids.append(token_id)
|
||||
sampled_token_ids[req_index] = token_id
|
||||
|
||||
# DEBUG
|
||||
# print(
|
||||
# " -- Got token_id = {} for prompt_len = {} req_id = {} req_index = {} selected_token_ids_cpu = {}"
|
||||
# .format(token_id, prompt_len, req_id, req_index,
|
||||
# selected_token_ids_cpu))
|
||||
|
||||
print(
|
||||
" -- Got token_id = {} for prompt_len = {} req_id = {} req_index = {} selected_token_ids_cpu = {}"
|
||||
.format(token_id, prompt_len, req_id, req_index,
|
||||
selected_token_ids_cpu))
|
||||
# Add output token to the request
|
||||
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||
self.input_batch.num_tokens[req_index] += 1
|
||||
@ -488,17 +538,12 @@ class TPUModelRunner(ModelRunnerBase):
|
||||
seq_len = req_state.num_computed_tokens + 1
|
||||
|
||||
token_id = decode_token_ids_list[i]
|
||||
decode_sampled_token_ids.append(token_id)
|
||||
sampled_token_ids[req_index] = token_id
|
||||
|
||||
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||
self.input_batch.num_tokens[req_index] += 1
|
||||
req_state.output_token_ids.append(token_id)
|
||||
|
||||
# Create the final sampled token id list. This must match the actual
|
||||
# batch index positions, so we put decodes first and then prompts.
|
||||
sampled_token_ids.extend(decode_sampled_token_ids)
|
||||
sampled_token_ids.extend(prompt_sampled_token_ids)
|
||||
|
||||
# Create output
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user