mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:15:31 +08:00
Store eos_token_id in Sequence for easy access (#3166)
This commit is contained in:
parent
05af6da8d9
commit
8999ec3c16
@ -54,7 +54,8 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):
|
||||
for prompt in prompts:
|
||||
hashes[-1].append([])
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
||||
tokenizer.tokenizer.eos_token_id)
|
||||
|
||||
num_blocks = len(prompt_token_ids) // block_size
|
||||
for idx in range(num_blocks):
|
||||
|
||||
@ -59,10 +59,9 @@ class SchedulerOutputs:
|
||||
and not self.blocks_to_swap_out and not self.blocks_to_copy)
|
||||
|
||||
def _sort_by_lora_ids(self) -> bool:
|
||||
self.scheduled_seq_groups = sorted(
|
||||
self.scheduled_seq_groups,
|
||||
key=lambda g: (g.lora_request.lora_int_id
|
||||
if g.lora_request else 0, g.request_id))
|
||||
self.scheduled_seq_groups = sorted(self.scheduled_seq_groups,
|
||||
key=lambda g:
|
||||
(g.lora_int_id, g.request_id))
|
||||
|
||||
@property
|
||||
def lora_requests(self) -> Set[LoRARequest]:
|
||||
|
||||
@ -491,8 +491,10 @@ class LLMEngine:
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seq_id = next(self.seq_counter)
|
||||
eos_token_id = self.tokenizer.get_lora_tokenizer(
|
||||
lora_request).eos_token_id
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
||||
lora_request)
|
||||
eos_token_id, lora_request)
|
||||
|
||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||
# this doesn't deep-copy LogitsProcessor objects
|
||||
@ -548,15 +550,13 @@ class LLMEngine:
|
||||
if early_stopping is True:
|
||||
return True
|
||||
|
||||
current_worst_score = (current_worst_seq.get_beam_search_score(
|
||||
current_worst_score = current_worst_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.get_tokenizer_for_seq(
|
||||
current_worst_seq).eos_token_id))
|
||||
eos_token_id=current_worst_seq.eos_token_id)
|
||||
if early_stopping is False:
|
||||
highest_attainable_score = (best_running_seq.get_beam_search_score(
|
||||
highest_attainable_score = best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.get_tokenizer_for_seq(
|
||||
best_running_seq).eos_token_id))
|
||||
eos_token_id=best_running_seq.eos_token_id)
|
||||
else:
|
||||
assert early_stopping == "never"
|
||||
if length_penalty > 0.0:
|
||||
@ -570,8 +570,7 @@ class LLMEngine:
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.get_tokenizer_for_seq(
|
||||
best_running_seq).eos_token_id,
|
||||
eos_token_id=best_running_seq.eos_token_id,
|
||||
seq_len=max_possible_length))
|
||||
else:
|
||||
# Otherwise, beam search will prefer shorter sequences. The
|
||||
@ -580,8 +579,7 @@ class LLMEngine:
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.get_tokenizer_for_seq(
|
||||
best_running_seq).eos_token_id))
|
||||
eos_token_id=best_running_seq.eos_token_id))
|
||||
return current_worst_score >= highest_attainable_score
|
||||
|
||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||
@ -679,8 +677,7 @@ class LLMEngine:
|
||||
all_finished_seqs = existing_finished_seqs + new_finished_seqs
|
||||
# Sort the finished sequences by their scores.
|
||||
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
|
||||
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
|
||||
reverse=True)
|
||||
for seq, parent, is_new in all_finished_seqs[:beam_width]:
|
||||
if is_new:
|
||||
@ -707,8 +704,7 @@ class LLMEngine:
|
||||
if not seq.is_finished()]
|
||||
# Sort the running sequences by their scores.
|
||||
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
|
||||
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
|
||||
reverse=True)
|
||||
|
||||
# Check if we can stop the beam search.
|
||||
@ -1014,8 +1010,8 @@ class LLMEngine:
|
||||
return
|
||||
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if ((not sampling_params.ignore_eos) and seq.get_last_token_id()
|
||||
== self.get_tokenizer_for_seq(seq).eos_token_id):
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.get_last_token_id() == seq.eos_token_id):
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
|
||||
@ -516,7 +516,6 @@ def _get_logprobs(
|
||||
if (i < sampling_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
num_logprobs = sampling_params.prompt_logprobs
|
||||
prompt_len = sampling_metadata.prompt_lens[i]
|
||||
prompt_tokens = sampling_metadata.seq_data[
|
||||
seq_ids[0]].prompt_token_ids
|
||||
group_prompt_logprobs: PromptLogprobs = [None]
|
||||
|
||||
@ -90,29 +90,30 @@ class RequestOutput:
|
||||
# Get the top-n sequences.
|
||||
n = seq_group.sampling_params.n
|
||||
seqs = seq_group.get_seqs()
|
||||
if seq_group.sampling_params.use_beam_search:
|
||||
sorting_key = lambda seq: seq.get_beam_search_score(
|
||||
seq_group.sampling_params.length_penalty)
|
||||
if n == 1:
|
||||
top_n_seqs = seqs
|
||||
else:
|
||||
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||
top_n_seqs = sorted_seqs[:n]
|
||||
if seq_group.sampling_params.use_beam_search:
|
||||
sorting_key = lambda seq: seq.get_beam_search_score(
|
||||
seq_group.sampling_params.length_penalty)
|
||||
else:
|
||||
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||
top_n_seqs = sorted_seqs[:n]
|
||||
|
||||
# Create the outputs.
|
||||
outputs: List[CompletionOutput] = []
|
||||
for seq in top_n_seqs:
|
||||
logprobs = seq.output_logprobs
|
||||
if seq_group.sampling_params.logprobs is None:
|
||||
# NOTE: We need to take care of this case because the sequence
|
||||
# always has the logprobs of the sampled tokens even if the
|
||||
# logprobs are not requested.
|
||||
logprobs = None
|
||||
finshed_reason = SequenceStatus.get_finished_reason(seq.status)
|
||||
output = CompletionOutput(seqs.index(seq), seq.output_text,
|
||||
seq.get_output_token_ids(),
|
||||
seq.get_cumulative_logprob(), logprobs,
|
||||
finshed_reason)
|
||||
outputs.append(output)
|
||||
# NOTE: We need omit logprobs here explicitly because the sequence
|
||||
# always has the logprobs of the sampled tokens even if the
|
||||
# logprobs are not requested.
|
||||
include_logprobs = seq_group.sampling_params.logprobs
|
||||
outputs = [
|
||||
CompletionOutput(seqs.index(seq), seq.output_text,
|
||||
seq.get_output_token_ids(),
|
||||
seq.get_cumulative_logprob(),
|
||||
seq.output_logprobs if include_logprobs else None,
|
||||
SequenceStatus.get_finished_reason(seq.status))
|
||||
for seq in top_n_seqs
|
||||
]
|
||||
|
||||
# Every sequence in the sequence group should have the same prompt.
|
||||
prompt = seq_group.prompt
|
||||
|
||||
@ -142,11 +142,13 @@ class Sequence:
|
||||
prompt: str,
|
||||
prompt_token_ids: List[int],
|
||||
block_size: int,
|
||||
eos_token_id: int,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.prompt = prompt
|
||||
self.block_size = block_size
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
|
||||
self.data = SequenceData(prompt_token_ids)
|
||||
@ -362,12 +364,9 @@ class SequenceGroup:
|
||||
self,
|
||||
status: Optional[SequenceStatus] = None,
|
||||
) -> List[Sequence]:
|
||||
if status is None:
|
||||
return list(self.seqs_dict.values())
|
||||
else:
|
||||
return [
|
||||
seq for seq in self.seqs_dict.values() if seq.status == status
|
||||
]
|
||||
return list(self.seqs_dict.values()) if status is None else [
|
||||
seq for seq in self.seqs_dict.values() if seq.status == status
|
||||
]
|
||||
|
||||
def get_unfinished_seqs(self) -> List[Sequence]:
|
||||
return [
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user