[Bugfix] Multi-sequence broken (#11898)

Signed-off-by: Andy Lo <andy@mistral.ai>
This commit is contained in:
Andy Lo 2025-01-21 19:51:35 +00:00 committed by GitHub
parent 132a132100
commit 18fd4a8331
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 59 additions and 39 deletions

View File

@ -31,7 +31,7 @@ def test_random_sample_with_seed(
sampling_params = SamplingParams( sampling_params = SamplingParams(
# Parameters to ensure sufficient randomness # Parameters to ensure sufficient randomness
temperature=2.0, temperature=3.0,
top_p=min(random.random() + 0.3, 1), top_p=min(random.random() + 0.3, 1),
top_k=random.randint(5, 20), top_k=random.randint(5, 20),
n=random.randint(1, 10), n=random.randint(1, 10),
@ -75,3 +75,8 @@ def test_random_sample_with_seed(
# verify requests with the same seed match # verify requests with the same seed match
assert outputs[1] == outputs[4] assert outputs[1] == outputs[4]
assert outputs[2] == outputs[5] assert outputs[2] == outputs[5]
# verify generations within the same parallel sampling group differ
for output in outputs:
for sub_output_a, sub_output_b in combinations(output, 2):
assert sub_output_a != sub_output_b

View File

@ -172,9 +172,9 @@ class RequestOutput:
if seq_group.request_id in seq_id_to_seq_group: if seq_group.request_id in seq_id_to_seq_group:
group: SequenceGroupBase = seq_id_to_seq_group[ group: SequenceGroupBase = seq_id_to_seq_group[
seq_group.request_id] seq_group.request_id]
assembled_seq_group = group.maybe_assemble_group(seq_group)
if finished: if finished:
group.finish_seq(seq_group) group.finish_seq(seq_group)
assembled_seq_group = group.maybe_assemble_group(seq_group)
if assembled_seq_group is None: if assembled_seq_group is None:
return None return None
return cls.from_seq_group(assembled_seq_group, use_cache, return cls.from_seq_group(assembled_seq_group, use_cache,

View File

@ -815,7 +815,9 @@ class SequenceGroup:
def get_max_num_running_seqs(self) -> int: def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining """The maximum number of sequences running in parallel in the remaining
lifetime of the request.""" lifetime of the request."""
if self.is_single_seq:
return 0 if self.first_seq.is_finished() else 1 return 0 if self.first_seq.is_finished() else 1
return self.num_seqs() - self.num_finished_seqs()
def get_seqs( def get_seqs(
self, self,
@ -824,8 +826,11 @@ class SequenceGroup:
if status is None: if status is None:
return self.seqs return self.seqs
if self.is_single_seq:
return self.seqs if self.first_seq.status == status else [] return self.seqs if self.first_seq.status == status else []
return [seq for seq in self.seqs if seq.status == status]
def is_encoder_decoder(self) -> bool: def is_encoder_decoder(self) -> bool:
return self.encoder_seq is not None return self.encoder_seq is not None
@ -833,17 +838,20 @@ class SequenceGroup:
return self.encoder_seq return self.encoder_seq
def get_finished_seqs(self) -> List[Sequence]: def get_finished_seqs(self) -> List[Sequence]:
if self.is_single_seq:
return self.seqs if self.first_seq.is_finished() else [] return self.seqs if self.first_seq.is_finished() else []
return [seq for seq in self.seqs if seq.is_finished()]
def update_num_computed_tokens(self, num_new_computed_tokens: int): def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far.""" """Update number of tokens computed so far."""
seq = self.first_seq for seq in self.seqs:
if not seq.is_finished(): if not seq.is_finished():
seq.data.update_num_computed_tokens(num_new_computed_tokens) seq.data.update_num_computed_tokens(num_new_computed_tokens)
def get_num_uncomputed_tokens(self) -> int: def get_num_uncomputed_tokens(self) -> int:
num_uncomputed_tokens = 0 num_uncomputed_tokens = 0
seq = self.first_seq for seq in self.seqs:
if not seq.is_finished(): if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return num_uncomputed_tokens return num_uncomputed_tokens
@ -860,10 +868,14 @@ class SequenceGroup:
return len(self.get_seqs(status)) return len(self.get_seqs(status))
def num_finished_seqs(self) -> int: def num_finished_seqs(self) -> int:
return 1 if self.first_seq.is_finished() else 0 if self.is_single_seq:
return 1 if self.seqs[0].is_finished() else 0
return len(self.get_finished_seqs())
def is_finished(self) -> bool: def is_finished(self) -> bool:
if self.is_single_seq:
return self.first_seq.is_finished() return self.first_seq.is_finished()
return all(seq.is_finished() for seq in self.seqs)
def is_prefill(self) -> bool: def is_prefill(self) -> bool:
return self.first_seq.is_prefill() return self.first_seq.is_prefill()
@ -1391,13 +1403,15 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
@staticmethod @staticmethod
def add_request(request_id: str, engine, params, **kwargs): def add_request(request_id: str, engine, params, **kwargs):
original_params = params original_params = params
params = original_params.clone()
params.n = 1
group = ParallelSampleSequenceGroup(request_id) group = ParallelSampleSequenceGroup(request_id)
seqs = [] seqs = []
for i in range(original_params.n): for i in range(original_params.n):
request_id_i = f"{request_id}_parallel_sample_{i}" request_id_i = f"{request_id}_parallel_sample_{i}"
group.seq_id_to_index[request_id_i] = i group.seq_id_to_index[request_id_i] = i
params = copy.deepcopy(original_params)
params.n = 1
if params.seed is not None:
params.seed += i
seq_group = engine._add_processed_request( seq_group = engine._add_processed_request(
request_id_i, request_id_i,
params=params, params=params,
@ -1432,20 +1446,20 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
# in the streaming mode, we will return the assembled sequence # in the streaming mode, we will return the assembled sequence
# for the first sequence, and then return None for the rest of # for the first remaining sequence, and then return None for the
# sequences # rest of sequences
if self.streaming: if self.streaming:
if self.seq_id_to_index[seq_group.request_id] == 0: first_remaining_id = next(iter(self.to_be_finished))
if seq_group.request_id == first_remaining_id:
return self.assembled_seq_group return self.assembled_seq_group
return None return None
# in the non-streaming mode, we will return the assembled sequence # in the non-streaming mode, we will return the assembled sequence
# once after all sequences finish, and then return None for the # when the last sequences finishes, and then return None for the
# rest of the time # rest of the time
if (len(self.to_be_finished) == 1
if len(self.to_be_finished) > 0: and seq_group.request_id in self.to_be_finished
return None and seq_group.is_finished()):
assert self.assembled_seq_group is not None assert self.assembled_seq_group is not None
params = self.assembled_seq_group.sampling_params params = self.assembled_seq_group.sampling_params
assert isinstance(params, SamplingParams) assert isinstance(params, SamplingParams)
@ -1462,3 +1476,4 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
return self.assembled_seq_group return self.assembled_seq_group
if self.output_produced: if self.output_produced:
return None return None
return None