mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 18:35:41 +08:00
fixes
This commit is contained in:
parent
c2867d5bc1
commit
627efde813
@ -33,11 +33,13 @@ logger = init_logger(__name__)
|
|||||||
# Here we utilize the behavior that out-of-bound index is ignored.
|
# Here we utilize the behavior that out-of-bound index is ignored.
|
||||||
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
||||||
_PAD_SLOT_ID = 1_000_000_000
|
_PAD_SLOT_ID = 1_000_000_000
|
||||||
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
|
|
||||||
_ENABLE_TOP_P = False
|
|
||||||
# FIXME(woosuk): A temporary hack to support `n > 1`.
|
@dataclass
|
||||||
# This can significantly affect the performance if too large.
|
class PromptDecodeInfo:
|
||||||
_MAX_NUM_SAMPLES = 128
|
prompt_req_ids: List[str]
|
||||||
|
decode_req_ids: List[str]
|
||||||
|
prompt_scheduled_tokens: List[int]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -63,55 +65,42 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
):
|
):
|
||||||
super().__init__(vllm_config, device)
|
super().__init__(vllm_config, device)
|
||||||
|
|
||||||
# Persistent batch.
|
|
||||||
self.input_batch = InputBatch(
|
|
||||||
max_num_reqs=self.max_num_reqs,
|
|
||||||
max_model_len=self.max_model_len,
|
|
||||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
|
||||||
device=self.device,
|
|
||||||
pin_memory=self.pin_memory,
|
|
||||||
vocab_size=self.model_config.get_vocab_size(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Request states.
|
|
||||||
self.requests: Dict[str, CachedRequestState] = {}
|
|
||||||
|
|
||||||
# KV caches for forward pass
|
# KV caches for forward pass
|
||||||
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||||
|
|
||||||
# Cache torch/numpy tensors
|
# Cached torch/numpy tensors
|
||||||
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
|
self.input_ids_cpu = torch.empty(self.max_num_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device="cpu")
|
||||||
pin_memory=self.pin_memory)
|
|
||||||
self.input_ids_np = self.input_ids_cpu.numpy()
|
self.input_ids_np = self.input_ids_cpu.numpy()
|
||||||
|
|
||||||
self.input_positions_cpu = torch.empty(self.max_model_len,
|
self.input_positions_cpu = torch.empty(self.max_num_tokens,
|
||||||
dtype=torch.int64,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device="cpu")
|
||||||
pin_memory=self.pin_memory)
|
|
||||||
self.input_positions_np = self.input_positions_cpu.numpy()
|
self.input_positions_np = self.input_positions_cpu.numpy()
|
||||||
|
|
||||||
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
self.slot_mapping_cpu = torch.empty(self.max_num_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int64,
|
||||||
device="cpu",
|
device="cpu")
|
||||||
pin_memory=self.pin_memory)
|
|
||||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||||
|
|
||||||
self.prompt_context_lens_cpu = torch.zeros((1),
|
self.prompt_context_lens_cpu = torch.empty((1),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu")
|
device="cpu")
|
||||||
self.prompt_effective_query_lens = torch.zeros((1),
|
self.prompt_effective_query_lens_cpu = torch.empty((1),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu")
|
device="cpu")
|
||||||
|
|
||||||
self.decode_context_lens_cpu = torch.zeros(self.max_model_len,
|
self.decode_context_lens_cpu = torch.empty(self.max_num_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu")
|
device="cpu")
|
||||||
self.decode_context_lens_np = self.decode_context_lens_cpu.numpy()
|
self.decode_context_lens_np = self.decode_context_lens_cpu.numpy()
|
||||||
|
|
||||||
self.arange_np = np.arange(self.max_model_len, dtype=np.int32)
|
# Range tensor with values [0 .. self.max_num_tokens - 1].
|
||||||
|
# Used to initialize positions / context_lens / seq_lens
|
||||||
|
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
|
||||||
|
|
||||||
|
# Cached lists
|
||||||
self.req_ids = []
|
self.req_ids = []
|
||||||
self.prompt_token_ids = []
|
self.prompt_token_ids = []
|
||||||
self.sampled_token_ids = []
|
self.sampled_token_ids = []
|
||||||
@ -119,7 +108,7 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
def _get_prompts_and_decodes(
|
def _get_prompts_and_decodes(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
):
|
) -> PromptDecodeInfo:
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
assert total_num_scheduled_tokens > 0
|
assert total_num_scheduled_tokens > 0
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
@ -157,10 +146,11 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
# Must be prompt
|
# Must be prompt
|
||||||
assert num_computed_tokens < num_prompt_tokens
|
assert num_computed_tokens < num_prompt_tokens
|
||||||
|
|
||||||
prompt_scheduled_tokens.append(num_scheduled_tokens)
|
|
||||||
prompt_req_ids.append(req_id)
|
prompt_req_ids.append(req_id)
|
||||||
|
prompt_scheduled_tokens.append(num_scheduled_tokens)
|
||||||
|
|
||||||
return prompt_req_ids, decode_req_ids, prompt_scheduled_tokens
|
return PromptDecodeInfo(prompt_req_ids, decode_req_ids,
|
||||||
|
prompt_scheduled_tokens)
|
||||||
|
|
||||||
def _prepare_prompt(self, req_index: int,
|
def _prepare_prompt(self, req_index: int,
|
||||||
num_scheduled_tokens: int) -> PromptData:
|
num_scheduled_tokens: int) -> PromptData:
|
||||||
@ -203,7 +193,7 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
np.add(block_numbers_np * self.block_size,
|
np.add(block_numbers_np * self.block_size,
|
||||||
block_offsets_np,
|
block_offsets_np,
|
||||||
out=slot_mapping_np)
|
out=slot_mapping_np)
|
||||||
slot_mapping_np[:, prompt_len:] = _PAD_SLOT_ID
|
slot_mapping_np[prompt_len:] = _PAD_SLOT_ID
|
||||||
|
|
||||||
# Block table
|
# Block table
|
||||||
block_table_cpu = None
|
block_table_cpu = None
|
||||||
@ -217,7 +207,7 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
self.prompt_context_lens_cpu[0] = seq_len
|
self.prompt_context_lens_cpu[0] = seq_len
|
||||||
|
|
||||||
# Effective query len
|
# Effective query len
|
||||||
self.prompt_effective_query_lens[0] = prompt_len
|
self.prompt_effective_query_lens_cpu[0] = prompt_len
|
||||||
|
|
||||||
# Get final tensors
|
# Get final tensors
|
||||||
input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device)
|
input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device)
|
||||||
@ -230,7 +220,7 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
context_lens = self.prompt_context_lens_cpu.reshape(1,
|
context_lens = self.prompt_context_lens_cpu.reshape(1,
|
||||||
-1).to(self.device)
|
-1).to(self.device)
|
||||||
effective_query_lens = self.prompt_effective_query_lens.reshape(
|
effective_query_lens = self.prompt_effective_query_lens_cpu.reshape(
|
||||||
1, -1).to(self.device)
|
1, -1).to(self.device)
|
||||||
|
|
||||||
# Attn metadata
|
# Attn metadata
|
||||||
@ -263,7 +253,7 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
0,
|
0,
|
||||||
out=input_positions_np)
|
out=input_positions_np)
|
||||||
input_positions_np[batch_size:] = 0
|
input_positions_np[batch_size:] = 0
|
||||||
input_positions_cpu = torch.from_numpy(input_positions_np)
|
input_positions_cpu = self.input_positions_cpu[:padded_batch_size]
|
||||||
|
|
||||||
# Input tokens
|
# Input tokens
|
||||||
input_tokens_cpu = self.input_ids_cpu[:padded_batch_size]
|
input_tokens_cpu = self.input_ids_cpu[:padded_batch_size]
|
||||||
@ -334,29 +324,31 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
ensure_decodes_first(self.input_batch)
|
ensure_decodes_first(self.input_batch)
|
||||||
|
|
||||||
# Prepare prompts/decodes info
|
# Prepare prompts/decodes info
|
||||||
prompt_req_ids, decode_req_ids, prompt_scheduled_tokens = self._get_prompts_and_decodes(
|
pd_info = self._get_prompts_and_decodes(scheduler_output)
|
||||||
scheduler_output)
|
|
||||||
|
|
||||||
# Init
|
# Init
|
||||||
decode_token_ids = None
|
num_prompts = len(pd_info.prompt_req_ids)
|
||||||
|
num_decodes = len(pd_info.decode_req_ids)
|
||||||
|
decode_token_ids_list = None
|
||||||
decode_data = None
|
decode_data = None
|
||||||
self.req_ids.clear()
|
self.req_ids.clear()
|
||||||
self.prompt_token_ids.clear()
|
self.prompt_token_ids.clear()
|
||||||
self.sampled_token_ids.clear()
|
self.sampled_token_ids.clear()
|
||||||
|
|
||||||
# Run each prompt
|
# Run each prompt individually
|
||||||
is_first = True
|
is_first = True
|
||||||
for i, req_id in enumerate(prompt_req_ids):
|
for i in range(num_prompts):
|
||||||
req_index = len(decode_req_ids) + i
|
req_id = pd_info.prompt_req_ids[i]
|
||||||
|
req_index = num_decodes + i
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
num_scheduled_tokens = prompt_scheduled_tokens[i]
|
num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i]
|
||||||
seq_len = req_state.num_computed_tokens + num_scheduled_tokens
|
|
||||||
prompt_len = num_scheduled_tokens
|
prompt_len = num_scheduled_tokens
|
||||||
|
seq_len = req_state.num_computed_tokens + num_scheduled_tokens
|
||||||
|
|
||||||
# Prepare first prompt
|
# Prepare first prompt
|
||||||
if is_first:
|
if is_first:
|
||||||
prompt_data = self._prepare_prompt(req_index,
|
prompt_data = self._prepare_prompt(req_index,
|
||||||
prompt_scheduled_tokens[i])
|
num_scheduled_tokens)
|
||||||
is_first = False
|
is_first = False
|
||||||
|
|
||||||
# Run forward pass
|
# Run forward pass
|
||||||
@ -369,29 +361,36 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
self.kv_caches)
|
self.kv_caches)
|
||||||
|
|
||||||
# In parallel to TPU execution, prepare the next iteration
|
# In parallel to TPU execution, prepare the next iteration
|
||||||
if i < len(prompt_req_ids) - 1:
|
if i < num_prompts - 1:
|
||||||
|
# There is next prompt => prepare it
|
||||||
prompt_data = self._prepare_prompt(
|
prompt_data = self._prepare_prompt(
|
||||||
req_index + 1, prompt_scheduled_tokens[i + 1])
|
req_index + 1, pd_info.prompt_scheduled_tokens[i + 1])
|
||||||
elif i == len(prompt_req_ids) - 1 and len(decode_req_ids) > 0:
|
elif i == num_prompts - 1 and num_decodes > 0:
|
||||||
decode_data = self._prepare_decode(decode_req_ids)
|
# There is next decode => prepare it
|
||||||
|
decode_data = self._prepare_decode(pd_info.decode_req_ids)
|
||||||
|
|
||||||
# Update cached state
|
# Update cached state (if prompt is fully done)
|
||||||
if seq_len >= len(req_state.prompt_token_ids):
|
if seq_len >= len(req_state.prompt_token_ids):
|
||||||
# Transfer sampled tokens from TPU to CPU
|
# Transfer sampled tokens from TPU to CPU
|
||||||
token_id = selected_token_ids.cpu()[prompt_len - 1].item()
|
selected_token_ids_cpu = selected_token_ids.cpu()
|
||||||
|
|
||||||
|
# Get output token
|
||||||
|
token_id = selected_token_ids_cpu[prompt_len - 1].item()
|
||||||
self.prompt_token_ids.append(token_id)
|
self.prompt_token_ids.append(token_id)
|
||||||
|
|
||||||
# Update cached state
|
# Add output token to the request
|
||||||
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||||
self.input_batch.num_tokens[req_index] += 1
|
self.input_batch.num_tokens[req_index] += 1
|
||||||
req_state.output_token_ids.append(token_id)
|
req_state.output_token_ids.append(token_id)
|
||||||
|
|
||||||
# Run decodes (a single batch)
|
# Run decodes (a single batch)
|
||||||
if len(decode_req_ids) > 0:
|
if num_decodes > 0:
|
||||||
if decode_data is None:
|
|
||||||
decode_data = self._prepare_decode(decode_req_ids)
|
|
||||||
|
|
||||||
# Forward
|
# Prepare decode (if was not yet prepared)
|
||||||
|
if decode_data is None:
|
||||||
|
decode_data = self._prepare_decode(pd_info.decode_req_ids)
|
||||||
|
|
||||||
|
# Run forward pass
|
||||||
with set_forward_context(decode_data.attn_metadata,
|
with set_forward_context(decode_data.attn_metadata,
|
||||||
self.vllm_config):
|
self.vllm_config):
|
||||||
assert self.model is not None
|
assert self.model is not None
|
||||||
@ -401,26 +400,30 @@ class TPUModelRunner(ModelRunnerBase):
|
|||||||
self.kv_caches)
|
self.kv_caches)
|
||||||
|
|
||||||
# Transfer sampled tokens from TPU to CPU
|
# Transfer sampled tokens from TPU to CPU
|
||||||
decode_token_ids = selected_token_ids.cpu().tolist()
|
decode_token_ids_cpu = selected_token_ids.cpu()
|
||||||
|
# Convert to list
|
||||||
|
decode_token_ids_list = decode_token_ids_cpu.tolist()
|
||||||
|
|
||||||
# Update cached state
|
# Update cached state for each decode request
|
||||||
for i, req_id in enumerate(decode_req_ids):
|
for i in range(num_decodes):
|
||||||
|
req_id = pd_info.decode_req_ids[i]
|
||||||
req_index = i
|
req_index = i
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
seq_len = req_state.num_computed_tokens + 1
|
seq_len = req_state.num_computed_tokens + 1
|
||||||
|
|
||||||
token_id = decode_token_ids[i]
|
token_id = decode_token_ids_list[i]
|
||||||
|
|
||||||
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
|
||||||
self.input_batch.num_tokens[req_index] += 1
|
self.input_batch.num_tokens[req_index] += 1
|
||||||
req_state.output_token_ids.append(token_id)
|
req_state.output_token_ids.append(token_id)
|
||||||
|
|
||||||
# Create final req_id => token lists.
|
# Create final req_id => token lists.
|
||||||
# This must match the actual batch index positions
|
# This must match the actual batch index positions,
|
||||||
self.req_ids.extend(decode_req_ids)
|
# so we put decodes first and then prompts.
|
||||||
self.req_ids.extend(prompt_req_ids)
|
self.req_ids.extend(pd_info.decode_req_ids)
|
||||||
if decode_token_ids is not None:
|
self.req_ids.extend(pd_info.prompt_req_ids)
|
||||||
self.sampled_token_ids.extend(decode_token_ids)
|
if decode_token_ids_list is not None:
|
||||||
|
self.sampled_token_ids.extend(decode_token_ids_list)
|
||||||
self.sampled_token_ids.extend(self.prompt_token_ids)
|
self.sampled_token_ids.extend(self.prompt_token_ids)
|
||||||
|
|
||||||
# Create output
|
# Create output
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user