[core] simplify seq group code (#9569)

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
youkaichao 2024-10-24 00:16:44 -07:00 committed by GitHub
parent 3770071eb4
commit 4fdc581f9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 61 additions and 565 deletions

View File

@ -4,7 +4,6 @@ from unittest.mock import MagicMock
import pytest # noqa
from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler
from vllm.sequence import Logprob, SequenceGroup
@ -347,158 +346,6 @@ def test_prompt_limit_exceed():
assert out.ignored_seq_groups[0] == seq_group
def test_swap():
"""Verify swapping works with chunked prefill requests"""
block_size = 4
max_seqs = 30
max_model_len = 200
max_num_batched_tokens = 30
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 16
cache_config.num_gpu_blocks = 16
scheduler = Scheduler(scheduler_config, cache_config, None)
_, seq_group = create_dummy_prompt("1",
prompt_length=60,
best_of=2,
block_size=block_size)
scheduler.add_seq_group(seq_group)
_, out = schedule_and_update_computed_tokens(scheduler)
# The request is chunked.
# prefill scheduled now.
assert len(out.scheduled_seq_groups) == 1
assert out.num_prefill_groups == 1
assert seq_group.is_prefill()
assert out.num_batched_tokens == max_num_batched_tokens
# The last request should be swapped out.
scheduler.block_manager.can_append_slots = MagicMock()
def cannot_append_second_group(seq_group, num_lookahead_slots):
return seq_group.request_id != "1"
scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group)
# The running prefill is now swapped.
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 0
assert out.num_batched_tokens == 0
assert out.blocks_to_swap_out != []
assert out.blocks_to_swap_in == []
# Add 1 more task. Swap should be prioritized over new prefill.
_, seq_group = create_dummy_prompt("2", prompt_length=60)
scheduler.add_seq_group(seq_group)
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == []
def test_running_prefill_prioritized_over_swap():
block_size = 4
max_seqs = 30
max_model_len = 200
max_num_batched_tokens = 30
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 32
cache_config.num_gpu_blocks = 32
scheduler = Scheduler(scheduler_config, cache_config, None)
_, seq_group = create_dummy_prompt("1",
prompt_length=60,
best_of=2,
block_size=block_size)
scheduler.add_seq_group(seq_group)
_, out = schedule_and_update_computed_tokens(scheduler)
# The request is chunked.
# prefill scheduled now.
assert len(out.scheduled_seq_groups) == 1
assert out.num_prefill_groups == 1
assert seq_group.is_prefill()
assert out.num_batched_tokens == max_num_batched_tokens
# The request should be swapped out.
scheduler.block_manager.can_append_slots = MagicMock()
def cannot_append_second_group(seq_group, num_lookahead_slots):
return seq_group.request_id != "1"
scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group)
# The running prefill is now swapped.
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 0
assert out.num_batched_tokens == 0
assert out.blocks_to_swap_out != []
assert out.blocks_to_swap_in == []
# Add 1 more task. Swap is not possible, so prefill is running.
scheduler.block_manager.can_swap_in = MagicMock()
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
_, seq_group2 = create_dummy_prompt("2",
prompt_length=60,
block_size=block_size)
scheduler.add_seq_group(seq_group2)
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == []
assert out.scheduled_seq_groups[0].seq_group == seq_group2
# Now although swap is possible, running prefill is prioritized.
scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == []
assert not seq_group2.is_prefill()
assert out.scheduled_seq_groups[0].seq_group == seq_group2
append_new_token(seq_group2, 1)
# Decoding is prioritized.
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 1
assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == []
assert not seq_group2.is_prefill()
assert out.scheduled_seq_groups[0].seq_group == seq_group2
append_new_token(seq_group2, 1)
# Since we abort the sequence group, we can finally swap.
scheduler.abort_seq_group(seq_group2.request_id)
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == []
def test_chunked_prefill_preempt():
"""Verify preempt works with chunked prefill requests"""
block_size = 4

View File

@ -10,7 +10,7 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
from vllm.sequence import SequenceGroup, SequenceStatus
from vllm.sequence import SequenceGroup
from .utils import (append_new_token, append_new_token_seq_group,
create_dummy_prompt, get_sequence_groups,
@ -296,55 +296,6 @@ def test_scheduler_delay_factor():
append_new_token(out, 1)
def test_swapped_out_prioritized():
block_size = 4
scheduler = initialize_scheduler(max_num_seqs=6,
block_size=block_size,
num_cpu_blocks=64,
num_gpu_blocks=64)
# best_of=2 * 3 == 6 sequences.
for i in range(3):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
best_of=2,
block_size=block_size)
scheduler.add_seq_group(seq_group)
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
# prefill scheduled now.
assert len(out.scheduled_seq_groups) == 3
append_new_token(out, 1)
# The last request should be swapped out.
scheduler.block_manager.can_append_slots = MagicMock()
def cannot_append_second_group(seq_group, num_lookahead_slots):
return seq_group.request_id != "2"
scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group)
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 2
assert out.num_batched_tokens == 2
assert out.blocks_to_swap_out != []
assert out.blocks_to_swap_in == []
append_new_token(out, 1)
# Add 1 more task. Swap should be prioritized over prefill.
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
best_of=2,
block_size=block_size)
scheduler.add_seq_group(seq_group)
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
append_new_token(out, 1)
assert len(out.scheduled_seq_groups) == 3
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 3
assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == []
def initialize_scheduler(
*,
max_num_seqs=1000,
@ -646,60 +597,6 @@ def test_decode_schedule_preempted():
assert output.blocks_to_copy == []
def test_decode_swap_beam_search():
"""
Test best_of > 1 swap out blocks
"""
block_size = 4
scheduler = initialize_scheduler(block_size=block_size,
num_gpu_blocks=64,
num_cpu_blocks=64)
curr_loras = None
budget = create_token_budget()
for i in range(3):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
best_of=2,
block_size=block_size)
scheduler._allocate_and_set_running(seq_group)
scheduler._add_seq_group_to_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
budget.add_num_seqs(seq_group.request_id,
seq_group.get_max_num_running_seqs())
budget.add_num_batched_tokens(
seq_group.request_id, seq_group.num_seqs(SequenceStatus.RUNNING))
# The last request should be swapped out.
scheduler.block_manager.can_append_slots = MagicMock()
def cannot_append_second_group(seq_group, num_lookahead_slots):
return seq_group.request_id != "2"
scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group)
scheduler.block_manager.swap_out = MagicMock()
expected_swap_mapping = [("5", "7")]
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
output = scheduler._schedule_running(budget, curr_loras)
remainig_running = scheduler.running
assert len(remainig_running) == 0
assert len(output.decode_seq_groups) == 2
assert len(output.prefill_seq_groups) == 0
assert output.decode_seq_groups[0].seq_group.request_id == "0"
assert output.decode_seq_groups[1].seq_group.request_id == "1"
assert len(output.preempted) == 0
assert len(output.swapped_out) == 1
# Budget should refledct preempted requests.
assert budget.num_batched_tokens == 2
# since there are 2 sequences, 2 should be subtracted.
assert budget.num_curr_seqs == 4
# Both should be preempted, not swapped.
assert output.blocks_to_swap_out == expected_swap_mapping
# Nothing is copied.
assert output.blocks_to_copy == []
def test_schedule_decode_blocks_to_copy_update():
"""
Verify blocks_to_copy is updated.
@ -736,105 +633,6 @@ def test_schedule_decode_blocks_to_copy_update():
assert output.blocks_to_copy == [(2, 3)]
def test_schedule_swapped_simple():
block_size = 4
scheduler = initialize_scheduler(block_size=block_size)
curr_loras = None
blocks_to_swap_out: List[Tuple[int, int]] = []
_, seq_group = create_dummy_prompt("1",
prompt_length=4,
best_of=2,
block_size=block_size)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(4, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
scheduler._add_seq_group_to_swapped(seq_group)
budget = create_token_budget()
output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 0
assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 2
assert len(output.decode_seq_groups) == 1
assert len(output.prefill_seq_groups) == 0
# swap in is the reverse of swap out
blocks_to_swap_in_reverse = []
for swapin, swapout in output.blocks_to_swap_in:
blocks_to_swap_in_reverse.append((swapout, swapin))
assert blocks_to_swap_out == blocks_to_swap_in_reverse
def test_schedule_swapped_max_token_budget():
block_size = 4
scheduler = initialize_scheduler(block_size=block_size,
num_cpu_blocks=32,
num_gpu_blocks=32)
curr_loras = None
blocks_to_swap_out: List[Tuple[int, int]] = []
for i in range(2):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
scheduler._add_seq_group_to_swapped(seq_group)
budget = create_token_budget(token_budget=1)
output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 1
assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 2
assert len(output.decode_seq_groups) == 1
assert len(output.prefill_seq_groups) == 0
# Verify num_batched_tokens are respected.
budget = create_token_budget(token_budget=1)
add_token_budget(budget, 1, 0)
output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 1
assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 0
assert len(output.decode_seq_groups) == 0
assert len(output.prefill_seq_groups) == 0
def test_schedule_swapped_max_seqs():
block_size = 4
scheduler = initialize_scheduler(block_size=block_size,
num_cpu_blocks=64,
num_gpu_blocks=64)
curr_loras = None
blocks_to_swap_out: List[Tuple[int, int]] = []
for i in range(4):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
block_size=4)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
scheduler._add_seq_group_to_swapped(seq_group)
budget = create_token_budget(max_num_seqs=2)
output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 2
assert budget.num_batched_tokens == 2
assert budget.num_curr_seqs == 2
assert len(output.decode_seq_groups) == 2
assert len(output.prefill_seq_groups) == 0
# Verify num_curr_seqs are respected.
output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 2
assert budget.num_batched_tokens == 2
assert budget.num_curr_seqs == 2
assert len(output.decode_seq_groups) == 0
assert len(output.prefill_seq_groups) == 0
def test_schedule_swapped_max_loras():
block_size = 4
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)

View File

@ -290,7 +290,7 @@ def scheduler_running_outputs_builder():
def scheduled_seq_group_builder():
return ScheduledSequenceGroup(SequenceGroup("", [], -1),
return ScheduledSequenceGroup(SequenceGroup.__new__(SequenceGroup),
token_chunk_size=0)
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)

View File

@ -647,10 +647,24 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> SequenceGroup:
) -> Optional[SequenceGroup]:
"""Add a processed request to the engine's request pool.
return the created sequence group.
"""
if isinstance(params, SamplingParams) and params.n > 1:
ParallelSampleSequenceGroup.add_request(
request_id,
self,
params,
processed_inputs=processed_inputs,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
return None
self._validate_model_inputs(processed_inputs)
# Create the sequences.
block_size = self.cache_config.block_size
@ -721,7 +735,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Optional[SequenceGroup]:
) -> None:
...
@overload
@ -735,7 +749,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> Optional[SequenceGroup]:
) -> None:
...
@deprecate_kwargs(
@ -754,7 +768,7 @@ class LLMEngine:
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> Optional[SequenceGroup]:
) -> None:
"""Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the
@ -798,22 +812,6 @@ class LLMEngine:
>>> # continue the request processing
>>> ...
"""
if isinstance(params, SamplingParams) and params.n > 1:
ParallelSampleSequenceGroup.add_request(
request_id,
self,
params,
prompt=prompt,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
inputs=inputs,
)
return None
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
@ -844,7 +842,7 @@ class LLMEngine:
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
"mm_processor_kwargs")
return self._add_processed_request(
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
params=params,

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import List
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
@ -6,9 +6,8 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sequence import (CompletionSequenceGroupOutput, Sequence,
SequenceGroup, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup,
SequenceGroupOutput)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter
@ -114,104 +113,22 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
outputs: SequenceGroupOutput,
is_async: bool) -> None:
sampling_params = seq_group.sampling_params
if sampling_params.n == 1:
# only have one output sample
sample = outputs.samples[0]
# only have one sequence
seq = seq_group.seqs[0]
if not is_async:
seq.append_token_id(sample.output_token, sample.logprobs)
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
else:
new_char_count = 0
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count,
sampling_params,
lora_req=seq_group.lora_request,
)
if seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)
return
# TODO: Add support for async for beam search
assert not is_async
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
parent_child_dict: Dict[int, List[SequenceOutput]] = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
}
for sample in samples:
# Guard against a KeyError which can occur if the request was
# aborted while the output was generated
if (child_list :=
parent_child_dict.get(sample.parent_seq_id)) is not None:
child_list.append(sample)
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []
# Process the child samples for each parent sequence
for parent in parent_seqs:
child_samples: List[SequenceOutput] = parent_child_dict[
parent.seq_id]
if len(child_samples) == 0:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id)
for scheduler in self.scheduler:
scheduler.free_seq(parent)
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
new_child_seq_id: int = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
else:
new_char_count = 0
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count,
sampling_params,
lora_req=seq_group.lora_request,
)
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
for scheduler in self.scheduler:
scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)
return
sample = outputs.samples[0]
seq = seq_group.first_seq
if not is_async:
seq.append_token_id(sample.output_token, sample.logprobs)
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
else:
new_char_count = 0
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count,
sampling_params,
lora_req=seq_group.lora_request,
)
if seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)

View File

@ -681,6 +681,7 @@ class SequenceGroup:
) -> None:
self.request_id = request_id
self.seqs = seqs
self.first_seq = seqs[0]
self.arrival_time = arrival_time
self.is_single_seq = len(seqs) == 1
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
@ -705,15 +706,11 @@ class SequenceGroup:
@property
def prompt(self) -> Optional[str]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return self.seqs[0].prompt
return self.first_seq.prompt
@property
def prompt_token_ids(self) -> List[int]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return self.seqs[0].prompt_token_ids
return self.first_seq.prompt_token_ids
@property
def encoder_prompt(self) -> Optional[str]:
@ -733,17 +730,11 @@ class SequenceGroup:
@property
def multi_modal_data(self) -> "MultiModalDataDict":
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
return self.seqs[0].multi_modal_data
return self.first_seq.multi_modal_data
@property
def mm_processor_kwargs(self) -> Dict[str, Any]:
# As with multi-modal data, all sequences in the group should have the
# same processor kwargs (i.e., mm_processor_kwargs are optionally
# provided per request; note that are independent of whether the model
# decoder-only or an encoder-decoder).
return self.seqs[0].mm_processor_kwargs
return self.first_seq.mm_processor_kwargs
@property
def lora_int_id(self) -> int:
@ -808,7 +799,7 @@ class SequenceGroup:
# in TPOT, rather than recalculating TTFT (since from the )
# POV of the user, there is simply a long generation delay.
if (self.metrics.first_token_time is None
and self.seqs[0].get_output_len() == 1):
and self.first_seq.get_output_len() == 1):
self.metrics.first_token_time = time
def maybe_set_first_scheduled_time(self, time: float) -> None:
@ -825,18 +816,7 @@ class SequenceGroup:
def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if self.sampling_params:
n = self.sampling_params.n
assert isinstance(n, int)
if n > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `n` sequences
# running.
return n
# At sampling stages, return the number of actual sequences
# that are not finished yet.
return self.num_unfinished_seqs()
return 0 if self.first_seq.is_finished() else 1
def get_seqs(
self,
@ -845,10 +825,7 @@ class SequenceGroup:
if status is None:
return self.seqs
if self.is_single_seq:
return self.seqs if self.seqs[0].status == status else []
return [seq for seq in self.seqs if seq.status == status]
return self.seqs if self.first_seq.status == status else []
def is_encoder_decoder(self) -> bool:
return self.encoder_seq is not None
@ -856,29 +833,20 @@ class SequenceGroup:
def get_encoder_seq(self) -> Optional[Sequence]:
return self.encoder_seq
def get_unfinished_seqs(self) -> List[Sequence]:
if self.is_single_seq:
return self.seqs if not self.seqs[0].is_finished() else []
return [seq for seq in self.seqs if not seq.is_finished()]
def get_finished_seqs(self) -> List[Sequence]:
if self.is_single_seq:
return self.seqs if self.seqs[0].is_finished() else []
return [seq for seq in self.seqs if seq.is_finished()]
return self.seqs if self.first_seq.is_finished() else []
def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
for seq in self.seqs:
if not seq.is_finished():
seq.data.update_num_computed_tokens(num_new_computed_tokens)
seq = self.first_seq
if not seq.is_finished():
seq.data.update_num_computed_tokens(num_new_computed_tokens)
def get_num_uncomputed_tokens(self) -> int:
num_uncomputed_tokens = 0
for seq in self.seqs:
if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
seq = self.first_seq
if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return num_uncomputed_tokens
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
@ -892,46 +860,14 @@ class SequenceGroup:
return len(self.get_seqs(status))
def num_unfinished_seqs(self) -> int:
if self.is_single_seq:
return 1 if not self.seqs[0].is_finished() else 0
return len(self.get_unfinished_seqs())
def num_finished_seqs(self) -> int:
if self.is_single_seq:
return 1 if self.seqs[0].is_finished() else 0
return len(self.get_finished_seqs())
def find(self, seq_id: int) -> Sequence:
if seq_id not in self.seqs_dict:
raise ValueError(f"Sequence {seq_id} not found.")
return self.seqs_dict[seq_id]
def add(self, seq: Sequence) -> None:
if seq.seq_id in self.seqs_dict:
raise ValueError(f"Sequence {seq.seq_id} already exists.")
self.seqs_dict[seq.seq_id] = seq
self.seqs.append(seq)
self.is_single_seq = len(self.seqs) == 1
def remove(self, seq_id: int) -> None:
seq = self.seqs_dict.pop(seq_id, None)
if seq is None:
raise ValueError(f"Sequence {seq_id} not found.")
self.seqs.remove(seq)
self.is_single_seq = len(self.seqs) == 1
return 1 if self.first_seq.is_finished() else 0
def is_finished(self) -> bool:
if self.is_single_seq:
return self.seqs[0].is_finished()
return all(seq.is_finished() for seq in self.seqs)
return self.first_seq.is_finished()
def is_prefill(self) -> bool:
# Every sequence should be in the same stage.
return self.seqs[0].is_prefill()
return self.first_seq.is_prefill()
def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, "
@ -1455,7 +1391,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
for i in range(original_params.n):
request_id_i = f"{request_id}_parallel_sample_{i}"
group.seq_id_to_index[request_id_i] = i
seq_group = engine.add_request(
seq_group = engine._add_processed_request(
request_id_i,
params=params,
**kwargs,