mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 16:55:01 +08:00
[BugFix] Fix batch updates for pooling models (#23398)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
24d0c9e6ed
commit
c80c53a30f
@ -50,6 +50,10 @@ class BatchUpdateBuilder:
|
|||||||
self.added = added or []
|
self.added = added or []
|
||||||
self._is_removed_sorted = False
|
self._is_removed_sorted = False
|
||||||
|
|
||||||
|
# Used to track changes in the pooling case
|
||||||
|
# where we don't populate the added list.
|
||||||
|
self.batch_changed = False
|
||||||
|
|
||||||
def _ensure_removed_sorted(self) -> None:
|
def _ensure_removed_sorted(self) -> None:
|
||||||
"""Sort removed request indices in
|
"""Sort removed request indices in
|
||||||
descending order.
|
descending order.
|
||||||
@ -80,6 +84,7 @@ class BatchUpdateBuilder:
|
|||||||
raise RuntimeError("Cannot register new removed request after"
|
raise RuntimeError("Cannot register new removed request after"
|
||||||
" self.removed has been read.")
|
" self.removed has been read.")
|
||||||
self._removed.append(index)
|
self._removed.append(index)
|
||||||
|
self.batch_changed = True
|
||||||
|
|
||||||
def has_removed(self) -> bool:
|
def has_removed(self) -> bool:
|
||||||
return bool(self._removed)
|
return bool(self._removed)
|
||||||
@ -98,9 +103,15 @@ class BatchUpdateBuilder:
|
|||||||
return self._removed.pop()
|
return self._removed.pop()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _is_update(self) -> bool:
|
def reset(self) -> bool:
|
||||||
"""True if there is a batch state change"""
|
"""Returns True if there were any changes to the batch."""
|
||||||
return any((self._removed, self.moved, self.added))
|
self._is_removed_sorted = False
|
||||||
|
self._removed.clear()
|
||||||
|
self.moved.clear()
|
||||||
|
self.added.clear()
|
||||||
|
batch_changed = self.batch_changed
|
||||||
|
self.batch_changed = False
|
||||||
|
return batch_changed
|
||||||
|
|
||||||
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
|
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
|
||||||
"""Generate a logitsprocs batch update data structure and reset
|
"""Generate a logitsprocs batch update data structure and reset
|
||||||
@ -114,7 +125,8 @@ class BatchUpdateBuilder:
|
|||||||
"""
|
"""
|
||||||
# Reset removal-sorting logic
|
# Reset removal-sorting logic
|
||||||
self._is_removed_sorted = False
|
self._is_removed_sorted = False
|
||||||
if not self._is_update():
|
self.batch_changed = False
|
||||||
|
if not any((self._removed, self.moved, self.added)):
|
||||||
# No update; short-circuit
|
# No update; short-circuit
|
||||||
return None
|
return None
|
||||||
# Build batch state update
|
# Build batch state update
|
||||||
|
|||||||
@ -65,8 +65,7 @@ class CachedRequestState:
|
|||||||
def get_token_id(self, idx: int) -> int:
|
def get_token_id(self, idx: int) -> int:
|
||||||
if idx < self.num_prompt_tokens:
|
if idx < self.num_prompt_tokens:
|
||||||
return self.prompt_token_ids[idx]
|
return self.prompt_token_ids[idx]
|
||||||
else:
|
return self.output_token_ids[idx - self.num_prompt_tokens]
|
||||||
return self.output_token_ids[idx - self.num_prompt_tokens]
|
|
||||||
|
|
||||||
|
|
||||||
class InputBatch:
|
class InputBatch:
|
||||||
@ -261,30 +260,27 @@ class InputBatch:
|
|||||||
Not applicable to pooling models.
|
Not applicable to pooling models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Detailed added request metadata is only required for non-pooling
|
|
||||||
# models, to support logitsprocs
|
|
||||||
assert request.sampling_params
|
|
||||||
|
|
||||||
# Fill the next empty index if there is one.
|
# Fill the next empty index if there is one.
|
||||||
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
|
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
|
||||||
# Append to end otherwise.
|
# Append to end otherwise.
|
||||||
new_req_index = self.num_reqs
|
new_req_index = self.num_reqs
|
||||||
|
|
||||||
assert new_req_index < self.max_num_reqs
|
assert new_req_index < self.max_num_reqs
|
||||||
self.batch_update_builder.added.append(
|
self.batch_update_builder.batch_changed = True
|
||||||
(new_req_index, request.sampling_params, request.prompt_token_ids,
|
if request.sampling_params:
|
||||||
request.output_token_ids))
|
# Detailed added request metadata is only required for non-pooling
|
||||||
|
# models, to support logitsprocs.
|
||||||
|
self.batch_update_builder.added.append(
|
||||||
|
(new_req_index, request.sampling_params,
|
||||||
|
request.prompt_token_ids, request.output_token_ids))
|
||||||
|
|
||||||
return new_req_index
|
return new_req_index
|
||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
request: "CachedRequestState",
|
request: "CachedRequestState",
|
||||||
) -> int:
|
) -> int:
|
||||||
if not self.is_pooling_model:
|
req_index = self._register_add_request(request)
|
||||||
# New request index bookkeeping for autoregressive models.
|
|
||||||
req_index = self._register_add_request(request)
|
|
||||||
else:
|
|
||||||
req_index = self.num_reqs
|
|
||||||
|
|
||||||
req_id = request.req_id
|
req_id = request.req_id
|
||||||
if req_index == len(self._req_ids):
|
if req_index == len(self._req_ids):
|
||||||
@ -389,7 +385,7 @@ class InputBatch:
|
|||||||
self.logits_processing_needs_token_ids[req_index] = (
|
self.logits_processing_needs_token_ids[req_index] = (
|
||||||
pooling_params.requires_token_ids)
|
pooling_params.requires_token_ids)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(request)
|
raise NotImplementedError("Unrecognized request type")
|
||||||
|
|
||||||
# Add request lora ID
|
# Add request lora ID
|
||||||
if request.lora_request:
|
if request.lora_request:
|
||||||
@ -419,13 +415,25 @@ class InputBatch:
|
|||||||
req_index = self.req_id_to_index.pop(req_id, None)
|
req_index = self.req_id_to_index.pop(req_id, None)
|
||||||
if req_index is None:
|
if req_index is None:
|
||||||
return None
|
return None
|
||||||
if not self.is_pooling_model:
|
|
||||||
# Autoregressive models require bookkeeping of removed requests to
|
self.batch_update_builder.removed_append(req_index)
|
||||||
# support logitsprocs.
|
|
||||||
self.batch_update_builder.removed_append(req_index)
|
|
||||||
self._req_ids[req_index] = None
|
self._req_ids[req_index] = None
|
||||||
self.req_output_token_ids[req_index] = None
|
self.req_output_token_ids[req_index] = None
|
||||||
|
|
||||||
|
# LoRA
|
||||||
|
lora_id = self.request_lora_mapping[req_index]
|
||||||
|
if lora_id != 0:
|
||||||
|
lora_req_ids = self.lora_id_to_request_ids[lora_id]
|
||||||
|
lora_req_ids.discard(req_id)
|
||||||
|
if not lora_req_ids:
|
||||||
|
del self.lora_id_to_request_ids[lora_id]
|
||||||
|
del self.lora_id_to_lora_request[lora_id]
|
||||||
|
self.request_lora_mapping[req_index] = 0
|
||||||
|
|
||||||
|
if self.is_pooling_model:
|
||||||
|
self.pooling_params.pop(req_id, None)
|
||||||
|
return req_index
|
||||||
|
|
||||||
self.greedy_reqs.discard(req_id)
|
self.greedy_reqs.discard(req_id)
|
||||||
self.random_reqs.discard(req_id)
|
self.random_reqs.discard(req_id)
|
||||||
self.top_p_reqs.discard(req_id)
|
self.top_p_reqs.discard(req_id)
|
||||||
@ -439,29 +447,14 @@ class InputBatch:
|
|||||||
self.num_prompt_logprobs.pop(req_id, None)
|
self.num_prompt_logprobs.pop(req_id, None)
|
||||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||||
|
|
||||||
# LoRA
|
|
||||||
lora_id = self.request_lora_mapping[req_index]
|
|
||||||
if lora_id != 0:
|
|
||||||
lora_req_ids = self.lora_id_to_request_ids[lora_id]
|
|
||||||
lora_req_ids.discard(req_id)
|
|
||||||
if not lora_req_ids:
|
|
||||||
del self.lora_id_to_request_ids[lora_id]
|
|
||||||
del self.lora_id_to_lora_request[lora_id]
|
|
||||||
self.request_lora_mapping[req_index] = 0
|
|
||||||
|
|
||||||
self.has_allowed_token_ids.discard(req_id)
|
self.has_allowed_token_ids.discard(req_id)
|
||||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||||
# False means we don't fill with -inf.
|
# False means we don't fill with -inf.
|
||||||
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||||
self.bad_words_token_ids.pop(req_index, None)
|
self.bad_words_token_ids.pop(req_index, None)
|
||||||
self.pooling_params.pop(req_id, None)
|
|
||||||
return req_index
|
return req_index
|
||||||
|
|
||||||
def swap_states(self, i1: int, i2: int) -> None:
|
def swap_states(self, i1: int, i2: int) -> None:
|
||||||
# For autoregressive models, track detailed request reordering info
|
|
||||||
# to support logitsprocs
|
|
||||||
self.batch_update_builder.moved.append(
|
|
||||||
(i1, i2, MoveDirectionality.SWAP))
|
|
||||||
old_id_i1 = self._req_ids[i1]
|
old_id_i1 = self._req_ids[i1]
|
||||||
old_id_i2 = self._req_ids[i2]
|
old_id_i2 = self._req_ids[i2]
|
||||||
self._req_ids[i1], self._req_ids[i2] =\
|
self._req_ids[i1], self._req_ids[i2] =\
|
||||||
@ -479,18 +472,6 @@ class InputBatch:
|
|||||||
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
||||||
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
||||||
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
||||||
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
|
|
||||||
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
|
||||||
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
|
|
||||||
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
|
||||||
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
|
|
||||||
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
|
||||||
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
|
|
||||||
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
|
||||||
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
|
|
||||||
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
|
||||||
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
|
|
||||||
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
|
||||||
|
|
||||||
# NOTE: the following is unsafe
|
# NOTE: the following is unsafe
|
||||||
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||||
@ -501,18 +482,41 @@ class InputBatch:
|
|||||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||||
self.token_ids_cpu[i2, ...] = tmp
|
self.token_ids_cpu[i2, ...] = tmp
|
||||||
|
|
||||||
|
self.block_table.swap_row(i1, i2)
|
||||||
|
|
||||||
|
self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \
|
||||||
|
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||||
|
|
||||||
|
if self.is_pooling_model:
|
||||||
|
# Sampling and logits parameters don't apply to pooling models.
|
||||||
|
return
|
||||||
|
|
||||||
|
# For autoregressive models, track detailed request reordering info
|
||||||
|
# to support logitsprocs.
|
||||||
|
self.batch_update_builder.moved.append(
|
||||||
|
(i1, i2, MoveDirectionality.SWAP))
|
||||||
|
|
||||||
|
self.temperature_cpu[i1], self.temperature_cpu[i2] = \
|
||||||
|
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
||||||
|
self.top_p_cpu[i1], self.top_p_cpu[i2] = \
|
||||||
|
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
||||||
|
self.top_k_cpu[i1], self.top_k_cpu[i2] = \
|
||||||
|
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
||||||
|
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \
|
||||||
|
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
||||||
|
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \
|
||||||
|
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
||||||
|
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
|
||||||
|
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
||||||
|
|
||||||
swap_dict_values(self.generators, i1, i2)
|
swap_dict_values(self.generators, i1, i2)
|
||||||
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||||
|
|
||||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
|
||||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
|
||||||
|
|
||||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||||
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
||||||
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
||||||
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
||||||
self.allowed_token_ids_mask_cpu_tensor[i1]
|
self.allowed_token_ids_mask_cpu_tensor[i1]
|
||||||
self.block_table.swap_row(i1, i2)
|
|
||||||
|
|
||||||
def condense(self) -> None:
|
def condense(self) -> None:
|
||||||
"""Slide non-empty requests down into lower, empty indices.
|
"""Slide non-empty requests down into lower, empty indices.
|
||||||
@ -529,12 +533,6 @@ class InputBatch:
|
|||||||
"""
|
"""
|
||||||
num_reqs = self.num_reqs
|
num_reqs = self.num_reqs
|
||||||
|
|
||||||
if self.is_pooling_model:
|
|
||||||
# Will be contiguous in pooling case, just trim the lists.
|
|
||||||
del self._req_ids[num_reqs:]
|
|
||||||
del self.req_output_token_ids[num_reqs:]
|
|
||||||
return
|
|
||||||
|
|
||||||
if not (empty_req_indices := self.batch_update_builder.removed):
|
if not (empty_req_indices := self.batch_update_builder.removed):
|
||||||
# All removed requests were replaced by added requests, or else no
|
# All removed requests were replaced by added requests, or else no
|
||||||
# requests were removed at all. No condense() needed
|
# requests were removed at all. No condense() needed
|
||||||
@ -562,11 +560,6 @@ class InputBatch:
|
|||||||
# Move active request down into empty request
|
# Move active request down into empty request
|
||||||
# index.
|
# index.
|
||||||
self.batch_update_builder.pop_removed()
|
self.batch_update_builder.pop_removed()
|
||||||
# Autoregressive models require detailed tracking of condense
|
|
||||||
# operations to support logitsprocs
|
|
||||||
self.batch_update_builder.moved.append(
|
|
||||||
(last_req_index, empty_index,
|
|
||||||
MoveDirectionality.UNIDIRECTIONAL))
|
|
||||||
req_id = self._req_ids[last_req_index]
|
req_id = self._req_ids[last_req_index]
|
||||||
output_token_ids = self.req_output_token_ids[last_req_index]
|
output_token_ids = self.req_output_token_ids[last_req_index]
|
||||||
assert req_id is not None
|
assert req_id is not None
|
||||||
@ -587,6 +580,21 @@ class InputBatch:
|
|||||||
self.num_computed_tokens_cpu[
|
self.num_computed_tokens_cpu[
|
||||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||||
self.block_table.move_row(last_req_index, empty_index)
|
self.block_table.move_row(last_req_index, empty_index)
|
||||||
|
|
||||||
|
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||||
|
last_req_index]
|
||||||
|
|
||||||
|
if self.is_pooling_model:
|
||||||
|
last_req_index -= 1
|
||||||
|
# Samping state not used by pooling models.
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Autoregressive models require detailed tracking of condense
|
||||||
|
# operations to support logitsprocs
|
||||||
|
self.batch_update_builder.moved.append(
|
||||||
|
(last_req_index, empty_index,
|
||||||
|
MoveDirectionality.UNIDIRECTIONAL))
|
||||||
|
|
||||||
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
||||||
last_req_index]
|
last_req_index]
|
||||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||||
@ -601,9 +609,6 @@ class InputBatch:
|
|||||||
if generator is not None:
|
if generator is not None:
|
||||||
self.generators[empty_index] = generator
|
self.generators[empty_index] = generator
|
||||||
|
|
||||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
|
||||||
last_req_index]
|
|
||||||
|
|
||||||
# TODO convert these to LogitsProcessors
|
# TODO convert these to LogitsProcessors
|
||||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||||
self.allowed_token_ids_mask_cpu_tensor[
|
self.allowed_token_ids_mask_cpu_tensor[
|
||||||
@ -626,8 +631,9 @@ class InputBatch:
|
|||||||
"""Apply any batch updates to sampling metadata."""
|
"""Apply any batch updates to sampling metadata."""
|
||||||
|
|
||||||
if self.is_pooling_model:
|
if self.is_pooling_model:
|
||||||
# Batch changes every step for pooling models.
|
batch_changed = self.batch_update_builder.reset()
|
||||||
self.sampling_metadata = self._make_sampling_metadata()
|
if batch_changed:
|
||||||
|
self.sampling_metadata = self._make_sampling_metadata()
|
||||||
return
|
return
|
||||||
|
|
||||||
# For non-pooling models - generate and apply logitsprocs update;
|
# For non-pooling models - generate and apply logitsprocs update;
|
||||||
@ -720,7 +726,8 @@ class InputBatch:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
num_reqs = self.num_reqs
|
||||||
|
max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
|
||||||
prompt_token_ids_cpu_tensor = torch.empty(
|
prompt_token_ids_cpu_tensor = torch.empty(
|
||||||
(self.num_reqs, max_prompt_len),
|
(self.num_reqs, max_prompt_len),
|
||||||
device="cpu",
|
device="cpu",
|
||||||
@ -728,11 +735,10 @@ class InputBatch:
|
|||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
)
|
)
|
||||||
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
||||||
prompt_token_ids[:] = self.token_ids_cpu[:self.
|
prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
|
||||||
num_reqs, :max_prompt_len]
|
|
||||||
# Use the value of vocab_size as a pad since we don't have a
|
# Use the value of vocab_size as a pad since we don't have a
|
||||||
# token_id of this value.
|
# token_id of this value.
|
||||||
for i in range(self.num_reqs):
|
for i in range(num_reqs):
|
||||||
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
|
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
|
||||||
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
|||||||
@ -1489,10 +1489,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
for raw_output, seq_len, prompt_len in zip(
|
for raw_output, seq_len, prompt_len in zip(
|
||||||
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
||||||
|
|
||||||
if seq_len == prompt_len:
|
output = raw_output.data if seq_len == prompt_len else None
|
||||||
pooler_output.append(raw_output.data)
|
pooler_output.append(output)
|
||||||
else:
|
|
||||||
pooler_output.append(None)
|
|
||||||
|
|
||||||
return ModelRunnerOutput(
|
return ModelRunnerOutput(
|
||||||
req_ids=self.input_batch.req_ids,
|
req_ids=self.input_batch.req_ids,
|
||||||
@ -1522,7 +1520,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Prepare the decoder inputs.
|
# Prepare the decoder inputs.
|
||||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||||
max_query_len) = (self._prepare_inputs(scheduler_output))
|
max_query_len) = self._prepare_inputs(scheduler_output)
|
||||||
|
|
||||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user