mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
[Bugfix] Multi-sequence broken (#11898)
Signed-off-by: Andy Lo <andy@mistral.ai>
This commit is contained in:
parent
132a132100
commit
18fd4a8331
@ -31,7 +31,7 @@ def test_random_sample_with_seed(
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
# Parameters to ensure sufficient randomness
|
||||
temperature=2.0,
|
||||
temperature=3.0,
|
||||
top_p=min(random.random() + 0.3, 1),
|
||||
top_k=random.randint(5, 20),
|
||||
n=random.randint(1, 10),
|
||||
@ -75,3 +75,8 @@ def test_random_sample_with_seed(
|
||||
# verify requests with the same seed match
|
||||
assert outputs[1] == outputs[4]
|
||||
assert outputs[2] == outputs[5]
|
||||
|
||||
# verify generations within the same parallel sampling group differ
|
||||
for output in outputs:
|
||||
for sub_output_a, sub_output_b in combinations(output, 2):
|
||||
assert sub_output_a != sub_output_b
|
||||
|
||||
@ -172,9 +172,9 @@ class RequestOutput:
|
||||
if seq_group.request_id in seq_id_to_seq_group:
|
||||
group: SequenceGroupBase = seq_id_to_seq_group[
|
||||
seq_group.request_id]
|
||||
assembled_seq_group = group.maybe_assemble_group(seq_group)
|
||||
if finished:
|
||||
group.finish_seq(seq_group)
|
||||
assembled_seq_group = group.maybe_assemble_group(seq_group)
|
||||
if assembled_seq_group is None:
|
||||
return None
|
||||
return cls.from_seq_group(assembled_seq_group, use_cache,
|
||||
|
||||
@ -815,7 +815,9 @@ class SequenceGroup:
|
||||
def get_max_num_running_seqs(self) -> int:
|
||||
"""The maximum number of sequences running in parallel in the remaining
|
||||
lifetime of the request."""
|
||||
return 0 if self.first_seq.is_finished() else 1
|
||||
if self.is_single_seq:
|
||||
return 0 if self.first_seq.is_finished() else 1
|
||||
return self.num_seqs() - self.num_finished_seqs()
|
||||
|
||||
def get_seqs(
|
||||
self,
|
||||
@ -824,7 +826,10 @@ class SequenceGroup:
|
||||
if status is None:
|
||||
return self.seqs
|
||||
|
||||
return self.seqs if self.first_seq.status == status else []
|
||||
if self.is_single_seq:
|
||||
return self.seqs if self.first_seq.status == status else []
|
||||
|
||||
return [seq for seq in self.seqs if seq.status == status]
|
||||
|
||||
def is_encoder_decoder(self) -> bool:
|
||||
return self.encoder_seq is not None
|
||||
@ -833,19 +838,22 @@ class SequenceGroup:
|
||||
return self.encoder_seq
|
||||
|
||||
def get_finished_seqs(self) -> List[Sequence]:
|
||||
return self.seqs if self.first_seq.is_finished() else []
|
||||
if self.is_single_seq:
|
||||
return self.seqs if self.first_seq.is_finished() else []
|
||||
|
||||
return [seq for seq in self.seqs if seq.is_finished()]
|
||||
|
||||
def update_num_computed_tokens(self, num_new_computed_tokens: int):
|
||||
"""Update number of tokens computed so far."""
|
||||
seq = self.first_seq
|
||||
if not seq.is_finished():
|
||||
seq.data.update_num_computed_tokens(num_new_computed_tokens)
|
||||
for seq in self.seqs:
|
||||
if not seq.is_finished():
|
||||
seq.data.update_num_computed_tokens(num_new_computed_tokens)
|
||||
|
||||
def get_num_uncomputed_tokens(self) -> int:
|
||||
num_uncomputed_tokens = 0
|
||||
seq = self.first_seq
|
||||
if not seq.is_finished():
|
||||
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
||||
for seq in self.seqs:
|
||||
if not seq.is_finished():
|
||||
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
||||
return num_uncomputed_tokens
|
||||
|
||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||
@ -860,10 +868,14 @@ class SequenceGroup:
|
||||
return len(self.get_seqs(status))
|
||||
|
||||
def num_finished_seqs(self) -> int:
|
||||
return 1 if self.first_seq.is_finished() else 0
|
||||
if self.is_single_seq:
|
||||
return 1 if self.seqs[0].is_finished() else 0
|
||||
return len(self.get_finished_seqs())
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return self.first_seq.is_finished()
|
||||
if self.is_single_seq:
|
||||
return self.first_seq.is_finished()
|
||||
return all(seq.is_finished() for seq in self.seqs)
|
||||
|
||||
def is_prefill(self) -> bool:
|
||||
return self.first_seq.is_prefill()
|
||||
@ -1391,13 +1403,15 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
|
||||
@staticmethod
|
||||
def add_request(request_id: str, engine, params, **kwargs):
|
||||
original_params = params
|
||||
params = original_params.clone()
|
||||
params.n = 1
|
||||
group = ParallelSampleSequenceGroup(request_id)
|
||||
seqs = []
|
||||
for i in range(original_params.n):
|
||||
request_id_i = f"{request_id}_parallel_sample_{i}"
|
||||
group.seq_id_to_index[request_id_i] = i
|
||||
params = copy.deepcopy(original_params)
|
||||
params.n = 1
|
||||
if params.seed is not None:
|
||||
params.seed += i
|
||||
seq_group = engine._add_processed_request(
|
||||
request_id_i,
|
||||
params=params,
|
||||
@ -1432,33 +1446,34 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
|
||||
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
|
||||
|
||||
# in the streaming mode, we will return the assembled sequence
|
||||
# for the first sequence, and then return None for the rest of
|
||||
# sequences
|
||||
# for the first remaining sequence, and then return None for the
|
||||
# rest of sequences
|
||||
if self.streaming:
|
||||
if self.seq_id_to_index[seq_group.request_id] == 0:
|
||||
first_remaining_id = next(iter(self.to_be_finished))
|
||||
if seq_group.request_id == first_remaining_id:
|
||||
return self.assembled_seq_group
|
||||
return None
|
||||
|
||||
# in the non-streaming mode, we will return the assembled sequence
|
||||
# once after all sequences finish, and then return None for the
|
||||
# when the last sequences finishes, and then return None for the
|
||||
# rest of the time
|
||||
|
||||
if len(self.to_be_finished) > 0:
|
||||
return None
|
||||
|
||||
assert self.assembled_seq_group is not None
|
||||
params = self.assembled_seq_group.sampling_params
|
||||
assert isinstance(params, SamplingParams)
|
||||
if not self.output_produced:
|
||||
self.output_produced = True
|
||||
if params._real_n is not None:
|
||||
# Get the top-n sequences.
|
||||
n = params._real_n or params.n
|
||||
seqs = self.assembled_seq_group.seqs
|
||||
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||
top_n_seqs = sorted_seqs[:n]
|
||||
self.assembled_seq_group.seqs = top_n_seqs
|
||||
return self.assembled_seq_group
|
||||
if self.output_produced:
|
||||
return None
|
||||
if (len(self.to_be_finished) == 1
|
||||
and seq_group.request_id in self.to_be_finished
|
||||
and seq_group.is_finished()):
|
||||
assert self.assembled_seq_group is not None
|
||||
params = self.assembled_seq_group.sampling_params
|
||||
assert isinstance(params, SamplingParams)
|
||||
if not self.output_produced:
|
||||
self.output_produced = True
|
||||
if params._real_n is not None:
|
||||
# Get the top-n sequences.
|
||||
n = params._real_n or params.n
|
||||
seqs = self.assembled_seq_group.seqs
|
||||
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||
top_n_seqs = sorted_seqs[:n]
|
||||
self.assembled_seq_group.seqs = top_n_seqs
|
||||
return self.assembled_seq_group
|
||||
if self.output_produced:
|
||||
return None
|
||||
return None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user