mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 23:35:01 +08:00
[Bugfix] Multi-sequence broken (#11898)
Signed-off-by: Andy Lo <andy@mistral.ai>
This commit is contained in:
parent
132a132100
commit
18fd4a8331
@ -31,7 +31,7 @@ def test_random_sample_with_seed(
|
|||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
# Parameters to ensure sufficient randomness
|
# Parameters to ensure sufficient randomness
|
||||||
temperature=2.0,
|
temperature=3.0,
|
||||||
top_p=min(random.random() + 0.3, 1),
|
top_p=min(random.random() + 0.3, 1),
|
||||||
top_k=random.randint(5, 20),
|
top_k=random.randint(5, 20),
|
||||||
n=random.randint(1, 10),
|
n=random.randint(1, 10),
|
||||||
@ -75,3 +75,8 @@ def test_random_sample_with_seed(
|
|||||||
# verify requests with the same seed match
|
# verify requests with the same seed match
|
||||||
assert outputs[1] == outputs[4]
|
assert outputs[1] == outputs[4]
|
||||||
assert outputs[2] == outputs[5]
|
assert outputs[2] == outputs[5]
|
||||||
|
|
||||||
|
# verify generations within the same parallel sampling group differ
|
||||||
|
for output in outputs:
|
||||||
|
for sub_output_a, sub_output_b in combinations(output, 2):
|
||||||
|
assert sub_output_a != sub_output_b
|
||||||
|
|||||||
@ -172,9 +172,9 @@ class RequestOutput:
|
|||||||
if seq_group.request_id in seq_id_to_seq_group:
|
if seq_group.request_id in seq_id_to_seq_group:
|
||||||
group: SequenceGroupBase = seq_id_to_seq_group[
|
group: SequenceGroupBase = seq_id_to_seq_group[
|
||||||
seq_group.request_id]
|
seq_group.request_id]
|
||||||
|
assembled_seq_group = group.maybe_assemble_group(seq_group)
|
||||||
if finished:
|
if finished:
|
||||||
group.finish_seq(seq_group)
|
group.finish_seq(seq_group)
|
||||||
assembled_seq_group = group.maybe_assemble_group(seq_group)
|
|
||||||
if assembled_seq_group is None:
|
if assembled_seq_group is None:
|
||||||
return None
|
return None
|
||||||
return cls.from_seq_group(assembled_seq_group, use_cache,
|
return cls.from_seq_group(assembled_seq_group, use_cache,
|
||||||
|
|||||||
@ -815,7 +815,9 @@ class SequenceGroup:
|
|||||||
def get_max_num_running_seqs(self) -> int:
|
def get_max_num_running_seqs(self) -> int:
|
||||||
"""The maximum number of sequences running in parallel in the remaining
|
"""The maximum number of sequences running in parallel in the remaining
|
||||||
lifetime of the request."""
|
lifetime of the request."""
|
||||||
return 0 if self.first_seq.is_finished() else 1
|
if self.is_single_seq:
|
||||||
|
return 0 if self.first_seq.is_finished() else 1
|
||||||
|
return self.num_seqs() - self.num_finished_seqs()
|
||||||
|
|
||||||
def get_seqs(
|
def get_seqs(
|
||||||
self,
|
self,
|
||||||
@ -824,7 +826,10 @@ class SequenceGroup:
|
|||||||
if status is None:
|
if status is None:
|
||||||
return self.seqs
|
return self.seqs
|
||||||
|
|
||||||
return self.seqs if self.first_seq.status == status else []
|
if self.is_single_seq:
|
||||||
|
return self.seqs if self.first_seq.status == status else []
|
||||||
|
|
||||||
|
return [seq for seq in self.seqs if seq.status == status]
|
||||||
|
|
||||||
def is_encoder_decoder(self) -> bool:
|
def is_encoder_decoder(self) -> bool:
|
||||||
return self.encoder_seq is not None
|
return self.encoder_seq is not None
|
||||||
@ -833,19 +838,22 @@ class SequenceGroup:
|
|||||||
return self.encoder_seq
|
return self.encoder_seq
|
||||||
|
|
||||||
def get_finished_seqs(self) -> List[Sequence]:
|
def get_finished_seqs(self) -> List[Sequence]:
|
||||||
return self.seqs if self.first_seq.is_finished() else []
|
if self.is_single_seq:
|
||||||
|
return self.seqs if self.first_seq.is_finished() else []
|
||||||
|
|
||||||
|
return [seq for seq in self.seqs if seq.is_finished()]
|
||||||
|
|
||||||
def update_num_computed_tokens(self, num_new_computed_tokens: int):
|
def update_num_computed_tokens(self, num_new_computed_tokens: int):
|
||||||
"""Update number of tokens computed so far."""
|
"""Update number of tokens computed so far."""
|
||||||
seq = self.first_seq
|
for seq in self.seqs:
|
||||||
if not seq.is_finished():
|
if not seq.is_finished():
|
||||||
seq.data.update_num_computed_tokens(num_new_computed_tokens)
|
seq.data.update_num_computed_tokens(num_new_computed_tokens)
|
||||||
|
|
||||||
def get_num_uncomputed_tokens(self) -> int:
|
def get_num_uncomputed_tokens(self) -> int:
|
||||||
num_uncomputed_tokens = 0
|
num_uncomputed_tokens = 0
|
||||||
seq = self.first_seq
|
for seq in self.seqs:
|
||||||
if not seq.is_finished():
|
if not seq.is_finished():
|
||||||
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
||||||
return num_uncomputed_tokens
|
return num_uncomputed_tokens
|
||||||
|
|
||||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||||
@ -860,10 +868,14 @@ class SequenceGroup:
|
|||||||
return len(self.get_seqs(status))
|
return len(self.get_seqs(status))
|
||||||
|
|
||||||
def num_finished_seqs(self) -> int:
|
def num_finished_seqs(self) -> int:
|
||||||
return 1 if self.first_seq.is_finished() else 0
|
if self.is_single_seq:
|
||||||
|
return 1 if self.seqs[0].is_finished() else 0
|
||||||
|
return len(self.get_finished_seqs())
|
||||||
|
|
||||||
def is_finished(self) -> bool:
|
def is_finished(self) -> bool:
|
||||||
return self.first_seq.is_finished()
|
if self.is_single_seq:
|
||||||
|
return self.first_seq.is_finished()
|
||||||
|
return all(seq.is_finished() for seq in self.seqs)
|
||||||
|
|
||||||
def is_prefill(self) -> bool:
|
def is_prefill(self) -> bool:
|
||||||
return self.first_seq.is_prefill()
|
return self.first_seq.is_prefill()
|
||||||
@ -1391,13 +1403,15 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def add_request(request_id: str, engine, params, **kwargs):
|
def add_request(request_id: str, engine, params, **kwargs):
|
||||||
original_params = params
|
original_params = params
|
||||||
params = original_params.clone()
|
|
||||||
params.n = 1
|
|
||||||
group = ParallelSampleSequenceGroup(request_id)
|
group = ParallelSampleSequenceGroup(request_id)
|
||||||
seqs = []
|
seqs = []
|
||||||
for i in range(original_params.n):
|
for i in range(original_params.n):
|
||||||
request_id_i = f"{request_id}_parallel_sample_{i}"
|
request_id_i = f"{request_id}_parallel_sample_{i}"
|
||||||
group.seq_id_to_index[request_id_i] = i
|
group.seq_id_to_index[request_id_i] = i
|
||||||
|
params = copy.deepcopy(original_params)
|
||||||
|
params.n = 1
|
||||||
|
if params.seed is not None:
|
||||||
|
params.seed += i
|
||||||
seq_group = engine._add_processed_request(
|
seq_group = engine._add_processed_request(
|
||||||
request_id_i,
|
request_id_i,
|
||||||
params=params,
|
params=params,
|
||||||
@ -1432,33 +1446,34 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
|
|||||||
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
|
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
|
||||||
|
|
||||||
# in the streaming mode, we will return the assembled sequence
|
# in the streaming mode, we will return the assembled sequence
|
||||||
# for the first sequence, and then return None for the rest of
|
# for the first remaining sequence, and then return None for the
|
||||||
# sequences
|
# rest of sequences
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
if self.seq_id_to_index[seq_group.request_id] == 0:
|
first_remaining_id = next(iter(self.to_be_finished))
|
||||||
|
if seq_group.request_id == first_remaining_id:
|
||||||
return self.assembled_seq_group
|
return self.assembled_seq_group
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# in the non-streaming mode, we will return the assembled sequence
|
# in the non-streaming mode, we will return the assembled sequence
|
||||||
# once after all sequences finish, and then return None for the
|
# when the last sequences finishes, and then return None for the
|
||||||
# rest of the time
|
# rest of the time
|
||||||
|
if (len(self.to_be_finished) == 1
|
||||||
if len(self.to_be_finished) > 0:
|
and seq_group.request_id in self.to_be_finished
|
||||||
return None
|
and seq_group.is_finished()):
|
||||||
|
assert self.assembled_seq_group is not None
|
||||||
assert self.assembled_seq_group is not None
|
params = self.assembled_seq_group.sampling_params
|
||||||
params = self.assembled_seq_group.sampling_params
|
assert isinstance(params, SamplingParams)
|
||||||
assert isinstance(params, SamplingParams)
|
if not self.output_produced:
|
||||||
if not self.output_produced:
|
self.output_produced = True
|
||||||
self.output_produced = True
|
if params._real_n is not None:
|
||||||
if params._real_n is not None:
|
# Get the top-n sequences.
|
||||||
# Get the top-n sequences.
|
n = params._real_n or params.n
|
||||||
n = params._real_n or params.n
|
seqs = self.assembled_seq_group.seqs
|
||||||
seqs = self.assembled_seq_group.seqs
|
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||||
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||||
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
top_n_seqs = sorted_seqs[:n]
|
||||||
top_n_seqs = sorted_seqs[:n]
|
self.assembled_seq_group.seqs = top_n_seqs
|
||||||
self.assembled_seq_group.seqs = top_n_seqs
|
return self.assembled_seq_group
|
||||||
return self.assembled_seq_group
|
if self.output_produced:
|
||||||
if self.output_produced:
|
return None
|
||||||
return None
|
return None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user