mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-09 03:04:34 +08:00
[Misc] Misc code simplifications (#26450)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
a83ff278d6
commit
ddcbc2f334
@ -1474,7 +1474,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
affected_req_ids.add(request.request_id)
|
affected_req_ids.add(request.request_id)
|
||||||
|
|
||||||
return (affected_req_ids, total_affected_tokens)
|
return affected_req_ids, total_affected_tokens
|
||||||
|
|
||||||
def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
|
def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
|
||||||
total_requests_to_reschedule = 0
|
total_requests_to_reschedule = 0
|
||||||
|
|||||||
@ -59,8 +59,7 @@ def check_stop(
|
|||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
assert sampling_params is not None
|
assert sampling_params is not None
|
||||||
|
|
||||||
min_tokens = sampling_params.min_tokens
|
if request.num_output_tokens < sampling_params.min_tokens:
|
||||||
if request.num_output_tokens < min_tokens:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
last_token_id = request.output_token_ids[-1]
|
last_token_id = request.output_token_ids[-1]
|
||||||
|
|||||||
@ -147,22 +147,20 @@ class RejectionSampler(nn.Module):
|
|||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
metadata: SpecDecodeMetadata,
|
metadata: SpecDecodeMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
has_penalties = not sampling_metadata.no_penalties
|
||||||
any_penalties_or_bad_words = (
|
any_penalties_or_bad_words = (
|
||||||
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
|
sampling_metadata.bad_words_token_ids or has_penalties
|
||||||
)
|
)
|
||||||
|
|
||||||
output_token_ids = sampling_metadata.output_token_ids
|
output_token_ids = sampling_metadata.output_token_ids
|
||||||
if any_penalties_or_bad_words:
|
if any_penalties_or_bad_words:
|
||||||
output_token_ids = self._combine_outputs_with_spec_tokens(
|
output_token_ids = self._combine_outputs_with_spec_tokens(
|
||||||
sampling_metadata.output_token_ids,
|
output_token_ids,
|
||||||
sampling_metadata.spec_token_ids,
|
sampling_metadata.spec_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate indices of target logits.
|
# Calculate indices of target logits.
|
||||||
if (
|
if sampling_metadata.allowed_token_ids_mask is not None or has_penalties:
|
||||||
sampling_metadata.allowed_token_ids_mask is not None
|
|
||||||
or not sampling_metadata.no_penalties
|
|
||||||
):
|
|
||||||
num_requests = len(sampling_metadata.output_token_ids)
|
num_requests = len(sampling_metadata.output_token_ids)
|
||||||
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
|
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
|
||||||
original_indices = torch.arange(num_requests, device="cpu")
|
original_indices = torch.arange(num_requests, device="cpu")
|
||||||
@ -180,18 +178,15 @@ class RejectionSampler(nn.Module):
|
|||||||
logits.masked_fill_(token_mask, float("-inf"))
|
logits.masked_fill_(token_mask, float("-inf"))
|
||||||
|
|
||||||
# Apply bad words exclusion.
|
# Apply bad words exclusion.
|
||||||
if sampling_metadata.bad_words_token_ids:
|
if bad_words_token_ids := sampling_metadata.bad_words_token_ids:
|
||||||
apply_bad_words_with_drafts(
|
apply_bad_words_with_drafts(
|
||||||
logits,
|
logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens
|
||||||
sampling_metadata.bad_words_token_ids,
|
|
||||||
output_token_ids,
|
|
||||||
metadata.num_draft_tokens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def apply_penalties(
|
def apply_penalties(
|
||||||
self,
|
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
metadata: SpecDecodeMetadata,
|
metadata: SpecDecodeMetadata,
|
||||||
@ -218,8 +213,8 @@ class RejectionSampler(nn.Module):
|
|||||||
)
|
)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _combine_outputs_with_spec_tokens(
|
def _combine_outputs_with_spec_tokens(
|
||||||
self,
|
|
||||||
output_token_ids: list[list[int]],
|
output_token_ids: list[list[int]],
|
||||||
spec_token_ids: Optional[list[list[int]]] = None,
|
spec_token_ids: Optional[list[list[int]]] = None,
|
||||||
) -> list[list[int]]:
|
) -> list[list[int]]:
|
||||||
|
|||||||
@ -120,8 +120,8 @@ class Sampler(nn.Module):
|
|||||||
)
|
)
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def apply_temperature(
|
def apply_temperature(
|
||||||
self,
|
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
temp: torch.Tensor,
|
temp: torch.Tensor,
|
||||||
all_random: bool,
|
all_random: bool,
|
||||||
@ -132,7 +132,8 @@ class Sampler(nn.Module):
|
|||||||
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
||||||
return logits.div_(temp.unsqueeze(dim=1))
|
return logits.div_(temp.unsqueeze(dim=1))
|
||||||
|
|
||||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
@staticmethod
|
||||||
|
def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
|
||||||
return logits.argmax(dim=-1).view(-1)
|
return logits.argmax(dim=-1).view(-1)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
@ -191,11 +192,12 @@ class Sampler(nn.Module):
|
|||||||
)
|
)
|
||||||
return sampled, processed_logprobs
|
return sampled, processed_logprobs
|
||||||
|
|
||||||
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
@staticmethod
|
||||||
|
def compute_logprobs(logits: torch.Tensor) -> torch.Tensor:
|
||||||
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def gather_logprobs(
|
def gather_logprobs(
|
||||||
self,
|
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
token_ids: torch.Tensor,
|
token_ids: torch.Tensor,
|
||||||
@ -238,8 +240,8 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
return LogprobsTensors(indices, logprobs, token_ranks)
|
return LogprobsTensors(indices, logprobs, token_ranks)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _combine_outputs_with_spec_tokens(
|
def _combine_outputs_with_spec_tokens(
|
||||||
self,
|
|
||||||
output_token_ids: list[list[int]],
|
output_token_ids: list[list[int]],
|
||||||
spec_token_ids: Optional[list[list[int]]] = None,
|
spec_token_ids: Optional[list[list[int]]] = None,
|
||||||
) -> list[list[int]]:
|
) -> list[list[int]]:
|
||||||
@ -257,8 +259,9 @@ class Sampler(nn.Module):
|
|||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
predict_bonus_token: bool,
|
predict_bonus_token: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
bad_words_token_ids = sampling_metadata.bad_words_token_ids
|
||||||
any_penalties_or_bad_words = (
|
any_penalties_or_bad_words = (
|
||||||
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
|
bool(bad_words_token_ids) or not sampling_metadata.no_penalties
|
||||||
)
|
)
|
||||||
|
|
||||||
output_token_ids = sampling_metadata.output_token_ids
|
output_token_ids = sampling_metadata.output_token_ids
|
||||||
@ -266,7 +269,7 @@ class Sampler(nn.Module):
|
|||||||
# Combine base outputs with spec tokens when speculative decoding
|
# Combine base outputs with spec tokens when speculative decoding
|
||||||
# is enabled.
|
# is enabled.
|
||||||
output_token_ids = self._combine_outputs_with_spec_tokens(
|
output_token_ids = self._combine_outputs_with_spec_tokens(
|
||||||
sampling_metadata.output_token_ids,
|
output_token_ids,
|
||||||
sampling_metadata.spec_token_ids,
|
sampling_metadata.spec_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -275,14 +278,8 @@ class Sampler(nn.Module):
|
|||||||
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
|
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
|
||||||
|
|
||||||
# Apply bad words exclusion.
|
# Apply bad words exclusion.
|
||||||
if sampling_metadata.bad_words_token_ids:
|
if bad_words_token_ids:
|
||||||
apply_bad_words(
|
apply_bad_words(logits, bad_words_token_ids, output_token_ids)
|
||||||
logits,
|
|
||||||
sampling_metadata.bad_words_token_ids,
|
|
||||||
output_token_ids
|
|
||||||
if output_token_ids is not None
|
|
||||||
else sampling_metadata.output_token_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply logits processors which can impact greedy sampling.
|
# Apply logits processors which can impact greedy sampling.
|
||||||
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
||||||
@ -292,22 +289,21 @@ class Sampler(nn.Module):
|
|||||||
logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
|
logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def apply_penalties(
|
def apply_penalties(
|
||||||
self,
|
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
output_token_ids: Optional[list[list[int]]] = None,
|
output_token_ids: list[list[int]],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if not sampling_metadata.no_penalties:
|
if sampling_metadata.no_penalties:
|
||||||
assert sampling_metadata.prompt_token_ids is not None
|
return logits
|
||||||
logits = apply_all_penalties(
|
|
||||||
logits,
|
assert sampling_metadata.prompt_token_ids is not None
|
||||||
sampling_metadata.prompt_token_ids,
|
return apply_all_penalties(
|
||||||
sampling_metadata.presence_penalties,
|
logits,
|
||||||
sampling_metadata.frequency_penalties,
|
sampling_metadata.prompt_token_ids,
|
||||||
sampling_metadata.repetition_penalties,
|
sampling_metadata.presence_penalties,
|
||||||
output_token_ids
|
sampling_metadata.frequency_penalties,
|
||||||
if output_token_ids is not None
|
sampling_metadata.repetition_penalties,
|
||||||
else sampling_metadata.output_token_ids,
|
output_token_ids,
|
||||||
)
|
)
|
||||||
return logits
|
|
||||||
|
|||||||
@ -62,10 +62,9 @@ class CachedRequestState:
|
|||||||
"provided via prompt_embeds, and its ID is unknown."
|
"provided via prompt_embeds, and its ID is unknown."
|
||||||
)
|
)
|
||||||
return self.prompt_token_ids[idx]
|
return self.prompt_token_ids[idx]
|
||||||
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
|
if idx - self.num_prompt_tokens < len(self.output_token_ids):
|
||||||
return self.output_token_ids[idx - self.num_prompt_tokens]
|
return self.output_token_ids[idx - self.num_prompt_tokens]
|
||||||
else:
|
return -1
|
||||||
return -1
|
|
||||||
|
|
||||||
|
|
||||||
class InputBatch:
|
class InputBatch:
|
||||||
@ -770,14 +769,13 @@ class InputBatch:
|
|||||||
not self.no_penalties
|
not self.no_penalties
|
||||||
or self.logits_processing_needs_token_ids[:num_reqs].any()
|
or self.logits_processing_needs_token_ids[:num_reqs].any()
|
||||||
)
|
)
|
||||||
if needs_prompt_token_ids:
|
# The prompt tokens are used only for applying penalties or
|
||||||
# The prompt tokens are used only for applying penalties or
|
# step pooling during the sampling/pooling process.
|
||||||
# step pooling during the sampling/pooling process.
|
# Hence copy these tensors only when there are requests which
|
||||||
# Hence copy these tensors only when there are requests which
|
# need penalties/step_pooler to be applied.
|
||||||
# need penalties/step_pooler to be applied.
|
prompt_token_ids = (
|
||||||
prompt_token_ids = self._make_prompt_token_ids_tensor()
|
self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
|
||||||
else:
|
)
|
||||||
prompt_token_ids = None
|
|
||||||
|
|
||||||
allowed_token_ids_mask: Optional[torch.Tensor] = None
|
allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||||
if not self.no_allowed_token_ids:
|
if not self.no_allowed_token_ids:
|
||||||
|
|||||||
@ -1996,7 +1996,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Should be called after attention metadata creation. This just pads
|
# Should be called after attention metadata creation. This just pads
|
||||||
# the second ubatch slice out to the total number of tokens
|
# the second ubatch slice out to the total number of tokens
|
||||||
# (num_tokens + padding)
|
# (num_tokens + padding)
|
||||||
def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, num_total_tokens: int):
|
@staticmethod
|
||||||
|
def pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int):
|
||||||
padded_second_ubatch_slice = slice(
|
padded_second_ubatch_slice = slice(
|
||||||
ubatch_slices[1].token_slice.start, num_total_tokens
|
ubatch_slices[1].token_slice.start, num_total_tokens
|
||||||
)
|
)
|
||||||
@ -2085,12 +2086,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
dict[str, Any],
|
dict[str, Any],
|
||||||
]:
|
]:
|
||||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
|
is_first_rank = get_pp_group().is_first_rank
|
||||||
|
|
||||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||||
# modal outputs after that to ensure the correct order
|
# modal outputs after that to ensure the correct order
|
||||||
if (
|
if (
|
||||||
self.supports_mm_inputs
|
self.supports_mm_inputs
|
||||||
and get_pp_group().is_first_rank
|
and is_first_rank
|
||||||
and not self.model_config.is_encoder_decoder
|
and not self.model_config.is_encoder_decoder
|
||||||
):
|
):
|
||||||
# Run the multimodal encoder if any.
|
# Run the multimodal encoder if any.
|
||||||
@ -2115,7 +2117,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
**self._init_model_kwargs(num_scheduled_tokens),
|
**self._init_model_kwargs(num_scheduled_tokens),
|
||||||
**self._extract_mm_kwargs(scheduler_output),
|
**self._extract_mm_kwargs(scheduler_output),
|
||||||
}
|
}
|
||||||
elif self.enable_prompt_embeds and get_pp_group().is_first_rank:
|
elif self.enable_prompt_embeds and is_first_rank:
|
||||||
# Get the input embeddings for the tokens that are not input embeds,
|
# Get the input embeddings for the tokens that are not input embeds,
|
||||||
# then put them into the appropriate positions.
|
# then put them into the appropriate positions.
|
||||||
# TODO(qthequartermasterman): Since even when prompt embeds are
|
# TODO(qthequartermasterman): Since even when prompt embeds are
|
||||||
@ -2155,7 +2157,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
positions = self.positions.gpu[:num_input_tokens]
|
positions = self.positions.gpu[:num_input_tokens]
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if is_first_rank:
|
||||||
intermediate_tensors = None
|
intermediate_tensors = None
|
||||||
else:
|
else:
|
||||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||||
@ -2186,38 +2188,37 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Sample the next token and get logprobs if needed.
|
# Sample the next token and get logprobs if needed.
|
||||||
sampling_metadata = self.input_batch.sampling_metadata
|
sampling_metadata = self.input_batch.sampling_metadata
|
||||||
if spec_decode_metadata is None:
|
if spec_decode_metadata is None:
|
||||||
sampler_output = self.sampler(
|
return self.sampler(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
|
||||||
# creates a new tensor with separate storage from the original
|
|
||||||
# logits tensor. This means any in-place operations on bonus_logits
|
|
||||||
# won't affect the original logits tensor.
|
|
||||||
assert logits is not None
|
|
||||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
|
||||||
sampler_output = self.sampler(
|
|
||||||
logits=bonus_logits,
|
|
||||||
sampling_metadata=sampling_metadata,
|
|
||||||
predict_bonus_token=True,
|
|
||||||
)
|
|
||||||
bonus_token_ids = sampler_output.sampled_token_ids
|
|
||||||
|
|
||||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||||
# separate storage from the original `logits` tensor. Therefore,
|
# creates a new tensor with separate storage from the original
|
||||||
# it is safe to update `target_logits` in place.
|
# logits tensor. This means any in-place operations on bonus_logits
|
||||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
# won't affect the original logits tensor.
|
||||||
output_token_ids = self.rejection_sampler(
|
assert logits is not None
|
||||||
spec_decode_metadata,
|
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
||||||
None, # draft_probs
|
sampler_output = self.sampler(
|
||||||
target_logits,
|
logits=bonus_logits,
|
||||||
bonus_token_ids,
|
sampling_metadata=sampling_metadata,
|
||||||
sampling_metadata,
|
predict_bonus_token=True,
|
||||||
)
|
)
|
||||||
sampler_output.sampled_token_ids = output_token_ids
|
bonus_token_ids = sampler_output.sampled_token_ids
|
||||||
self._update_states_after_model_execute(output_token_ids)
|
|
||||||
|
|
||||||
|
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||||
|
# separate storage from the original `logits` tensor. Therefore,
|
||||||
|
# it is safe to update `target_logits` in place.
|
||||||
|
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||||
|
output_token_ids = self.rejection_sampler(
|
||||||
|
spec_decode_metadata,
|
||||||
|
None, # draft_probs
|
||||||
|
target_logits,
|
||||||
|
bonus_token_ids,
|
||||||
|
sampling_metadata,
|
||||||
|
)
|
||||||
|
sampler_output.sampled_token_ids = output_token_ids
|
||||||
|
self._update_states_after_model_execute(output_token_ids)
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|
||||||
def _bookkeeping_sync(
|
def _bookkeeping_sync(
|
||||||
@ -3741,7 +3742,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
decode_cudagraph_batch_sizes = [
|
decode_cudagraph_batch_sizes = [
|
||||||
x
|
x
|
||||||
for x in self.cudagraph_batch_sizes
|
for x in self.cudagraph_batch_sizes
|
||||||
if x <= max_num_tokens and x >= self.uniform_decode_query_len
|
if max_num_tokens >= x >= self.uniform_decode_query_len
|
||||||
]
|
]
|
||||||
compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes))
|
compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes))
|
||||||
self._capture_cudagraphs(
|
self._capture_cudagraphs(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user