[Bugfix] Multi-sequence broken (#11898)

Signed-off-by: Andy Lo <andy@mistral.ai>
This commit is contained in:
Andy Lo 2025-01-21 19:51:35 +00:00 committed by GitHub
parent 132a132100
commit 18fd4a8331
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 59 additions and 39 deletions

View File

@ -31,7 +31,7 @@ def test_random_sample_with_seed(
sampling_params = SamplingParams(
# Parameters to ensure sufficient randomness
temperature=2.0,
temperature=3.0,
top_p=min(random.random() + 0.3, 1),
top_k=random.randint(5, 20),
n=random.randint(1, 10),
@ -75,3 +75,8 @@ def test_random_sample_with_seed(
# verify requests with the same seed match
assert outputs[1] == outputs[4]
assert outputs[2] == outputs[5]
# verify generations within the same parallel sampling group differ
for output in outputs:
for sub_output_a, sub_output_b in combinations(output, 2):
assert sub_output_a != sub_output_b

View File

@ -172,9 +172,9 @@ class RequestOutput:
if seq_group.request_id in seq_id_to_seq_group:
group: SequenceGroupBase = seq_id_to_seq_group[
seq_group.request_id]
assembled_seq_group = group.maybe_assemble_group(seq_group)
if finished:
group.finish_seq(seq_group)
assembled_seq_group = group.maybe_assemble_group(seq_group)
if assembled_seq_group is None:
return None
return cls.from_seq_group(assembled_seq_group, use_cache,

View File

@ -815,7 +815,9 @@ class SequenceGroup:
def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
return 0 if self.first_seq.is_finished() else 1
if self.is_single_seq:
return 0 if self.first_seq.is_finished() else 1
return self.num_seqs() - self.num_finished_seqs()
def get_seqs(
self,
@ -824,7 +826,10 @@ class SequenceGroup:
if status is None:
return self.seqs
return self.seqs if self.first_seq.status == status else []
if self.is_single_seq:
return self.seqs if self.first_seq.status == status else []
return [seq for seq in self.seqs if seq.status == status]
def is_encoder_decoder(self) -> bool:
return self.encoder_seq is not None
@ -833,19 +838,22 @@ class SequenceGroup:
return self.encoder_seq
def get_finished_seqs(self) -> List[Sequence]:
return self.seqs if self.first_seq.is_finished() else []
if self.is_single_seq:
return self.seqs if self.first_seq.is_finished() else []
return [seq for seq in self.seqs if seq.is_finished()]
def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
seq = self.first_seq
if not seq.is_finished():
seq.data.update_num_computed_tokens(num_new_computed_tokens)
for seq in self.seqs:
if not seq.is_finished():
seq.data.update_num_computed_tokens(num_new_computed_tokens)
def get_num_uncomputed_tokens(self) -> int:
num_uncomputed_tokens = 0
seq = self.first_seq
if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
for seq in self.seqs:
if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return num_uncomputed_tokens
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
@ -860,10 +868,14 @@ class SequenceGroup:
return len(self.get_seqs(status))
def num_finished_seqs(self) -> int:
return 1 if self.first_seq.is_finished() else 0
if self.is_single_seq:
return 1 if self.seqs[0].is_finished() else 0
return len(self.get_finished_seqs())
def is_finished(self) -> bool:
return self.first_seq.is_finished()
if self.is_single_seq:
return self.first_seq.is_finished()
return all(seq.is_finished() for seq in self.seqs)
def is_prefill(self) -> bool:
return self.first_seq.is_prefill()
@ -1391,13 +1403,15 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
@staticmethod
def add_request(request_id: str, engine, params, **kwargs):
original_params = params
params = original_params.clone()
params.n = 1
group = ParallelSampleSequenceGroup(request_id)
seqs = []
for i in range(original_params.n):
request_id_i = f"{request_id}_parallel_sample_{i}"
group.seq_id_to_index[request_id_i] = i
params = copy.deepcopy(original_params)
params.n = 1
if params.seed is not None:
params.seed += i
seq_group = engine._add_processed_request(
request_id_i,
params=params,
@ -1432,33 +1446,34 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
# in the streaming mode, we will return the assembled sequence
# for the first sequence, and then return None for the rest of
# sequences
# for the first remaining sequence, and then return None for the
# rest of sequences
if self.streaming:
if self.seq_id_to_index[seq_group.request_id] == 0:
first_remaining_id = next(iter(self.to_be_finished))
if seq_group.request_id == first_remaining_id:
return self.assembled_seq_group
return None
# in the non-streaming mode, we will return the assembled sequence
# once after all sequences finish, and then return None for the
# when the last sequences finishes, and then return None for the
# rest of the time
if len(self.to_be_finished) > 0:
return None
assert self.assembled_seq_group is not None
params = self.assembled_seq_group.sampling_params
assert isinstance(params, SamplingParams)
if not self.output_produced:
self.output_produced = True
if params._real_n is not None:
# Get the top-n sequences.
n = params._real_n or params.n
seqs = self.assembled_seq_group.seqs
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
self.assembled_seq_group.seqs = top_n_seqs
return self.assembled_seq_group
if self.output_produced:
return None
if (len(self.to_be_finished) == 1
and seq_group.request_id in self.to_be_finished
and seq_group.is_finished()):
assert self.assembled_seq_group is not None
params = self.assembled_seq_group.sampling_params
assert isinstance(params, SamplingParams)
if not self.output_produced:
self.output_produced = True
if params._real_n is not None:
# Get the top-n sequences.
n = params._real_n or params.n
seqs = self.assembled_seq_group.seqs
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
self.assembled_seq_group.seqs = top_n_seqs
return self.assembled_seq_group
if self.output_produced:
return None
return None