mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:04:27 +08:00
[core] simplify seq group code (#9569)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
3770071eb4
commit
4fdc581f9e
@ -4,7 +4,6 @@ from unittest.mock import MagicMock
|
||||
import pytest # noqa
|
||||
|
||||
from vllm.config import CacheConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.sequence import Logprob, SequenceGroup
|
||||
|
||||
@ -347,158 +346,6 @@ def test_prompt_limit_exceed():
|
||||
assert out.ignored_seq_groups[0] == seq_group
|
||||
|
||||
|
||||
def test_swap():
|
||||
"""Verify swapping works with chunked prefill requests"""
|
||||
block_size = 4
|
||||
max_seqs = 30
|
||||
max_model_len = 200
|
||||
max_num_batched_tokens = 30
|
||||
scheduler_config = SchedulerConfig(
|
||||
"generate",
|
||||
max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 16
|
||||
cache_config.num_gpu_blocks = 16
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
_, seq_group = create_dummy_prompt("1",
|
||||
prompt_length=60,
|
||||
best_of=2,
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# The request is chunked.
|
||||
# prefill scheduled now.
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert seq_group.is_prefill()
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "1"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
|
||||
# The running prefill is now swapped.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 0
|
||||
assert out.num_batched_tokens == 0
|
||||
assert out.blocks_to_swap_out != []
|
||||
assert out.blocks_to_swap_in == []
|
||||
|
||||
# Add 1 more task. Swap should be prioritized over new prefill.
|
||||
_, seq_group = create_dummy_prompt("2", prompt_length=60)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in != []
|
||||
assert out.blocks_to_swap_out == []
|
||||
|
||||
|
||||
def test_running_prefill_prioritized_over_swap():
|
||||
block_size = 4
|
||||
max_seqs = 30
|
||||
max_model_len = 200
|
||||
max_num_batched_tokens = 30
|
||||
scheduler_config = SchedulerConfig(
|
||||
"generate",
|
||||
max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 32
|
||||
cache_config.num_gpu_blocks = 32
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
_, seq_group = create_dummy_prompt("1",
|
||||
prompt_length=60,
|
||||
best_of=2,
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# The request is chunked.
|
||||
# prefill scheduled now.
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert seq_group.is_prefill()
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
|
||||
# The request should be swapped out.
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "1"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
|
||||
# The running prefill is now swapped.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 0
|
||||
assert out.num_batched_tokens == 0
|
||||
assert out.blocks_to_swap_out != []
|
||||
assert out.blocks_to_swap_in == []
|
||||
|
||||
# Add 1 more task. Swap is not possible, so prefill is running.
|
||||
scheduler.block_manager.can_swap_in = MagicMock()
|
||||
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
|
||||
|
||||
_, seq_group2 = create_dummy_prompt("2",
|
||||
prompt_length=60,
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seq_group2)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in == []
|
||||
assert out.blocks_to_swap_out == []
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
|
||||
# Now although swap is possible, running prefill is prioritized.
|
||||
scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in == []
|
||||
assert out.blocks_to_swap_out == []
|
||||
assert not seq_group2.is_prefill()
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
append_new_token(seq_group2, 1)
|
||||
|
||||
# Decoding is prioritized.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 1
|
||||
assert out.blocks_to_swap_in == []
|
||||
assert out.blocks_to_swap_out == []
|
||||
assert not seq_group2.is_prefill()
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
append_new_token(seq_group2, 1)
|
||||
|
||||
# Since we abort the sequence group, we can finally swap.
|
||||
scheduler.abort_seq_group(seq_group2.request_id)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in != []
|
||||
assert out.blocks_to_swap_out == []
|
||||
|
||||
|
||||
def test_chunked_prefill_preempt():
|
||||
"""Verify preempt works with chunked prefill requests"""
|
||||
block_size = 4
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.core.scheduler import Scheduler, SchedulingBudget
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SequenceGroup, SequenceStatus
|
||||
from vllm.sequence import SequenceGroup
|
||||
|
||||
from .utils import (append_new_token, append_new_token_seq_group,
|
||||
create_dummy_prompt, get_sequence_groups,
|
||||
@ -296,55 +296,6 @@ def test_scheduler_delay_factor():
|
||||
append_new_token(out, 1)
|
||||
|
||||
|
||||
def test_swapped_out_prioritized():
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(max_num_seqs=6,
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=64,
|
||||
num_gpu_blocks=64)
|
||||
# best_of=2 * 3 == 6 sequences.
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
best_of=2,
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# prefill scheduled now.
|
||||
assert len(out.scheduled_seq_groups) == 3
|
||||
append_new_token(out, 1)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "2"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 2
|
||||
assert out.num_batched_tokens == 2
|
||||
assert out.blocks_to_swap_out != []
|
||||
assert out.blocks_to_swap_in == []
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Add 1 more task. Swap should be prioritized over prefill.
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
best_of=2,
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
append_new_token(out, 1)
|
||||
assert len(out.scheduled_seq_groups) == 3
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 3
|
||||
assert out.blocks_to_swap_in != []
|
||||
assert out.blocks_to_swap_out == []
|
||||
|
||||
|
||||
def initialize_scheduler(
|
||||
*,
|
||||
max_num_seqs=1000,
|
||||
@ -646,60 +597,6 @@ def test_decode_schedule_preempted():
|
||||
assert output.blocks_to_copy == []
|
||||
|
||||
|
||||
def test_decode_swap_beam_search():
|
||||
"""
|
||||
Test best_of > 1 swap out blocks
|
||||
"""
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(block_size=block_size,
|
||||
num_gpu_blocks=64,
|
||||
num_cpu_blocks=64)
|
||||
curr_loras = None
|
||||
budget = create_token_budget()
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
best_of=2,
|
||||
block_size=block_size)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
scheduler._add_seq_group_to_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
budget.add_num_seqs(seq_group.request_id,
|
||||
seq_group.get_max_num_running_seqs())
|
||||
budget.add_num_batched_tokens(
|
||||
seq_group.request_id, seq_group.num_seqs(SequenceStatus.RUNNING))
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "2"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
scheduler.block_manager.swap_out = MagicMock()
|
||||
expected_swap_mapping = [("5", "7")]
|
||||
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
|
||||
|
||||
output = scheduler._schedule_running(budget, curr_loras)
|
||||
remainig_running = scheduler.running
|
||||
assert len(remainig_running) == 0
|
||||
assert len(output.decode_seq_groups) == 2
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
assert output.decode_seq_groups[0].seq_group.request_id == "0"
|
||||
assert output.decode_seq_groups[1].seq_group.request_id == "1"
|
||||
assert len(output.preempted) == 0
|
||||
assert len(output.swapped_out) == 1
|
||||
# Budget should refledct preempted requests.
|
||||
assert budget.num_batched_tokens == 2
|
||||
# since there are 2 sequences, 2 should be subtracted.
|
||||
assert budget.num_curr_seqs == 4
|
||||
# Both should be preempted, not swapped.
|
||||
assert output.blocks_to_swap_out == expected_swap_mapping
|
||||
# Nothing is copied.
|
||||
assert output.blocks_to_copy == []
|
||||
|
||||
|
||||
def test_schedule_decode_blocks_to_copy_update():
|
||||
"""
|
||||
Verify blocks_to_copy is updated.
|
||||
@ -736,105 +633,6 @@ def test_schedule_decode_blocks_to_copy_update():
|
||||
assert output.blocks_to_copy == [(2, 3)]
|
||||
|
||||
|
||||
def test_schedule_swapped_simple():
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(block_size=block_size)
|
||||
curr_loras = None
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
_, seq_group = create_dummy_prompt("1",
|
||||
prompt_length=4,
|
||||
best_of=2,
|
||||
block_size=block_size)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(4, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
budget = create_token_budget()
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 0
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
# swap in is the reverse of swap out
|
||||
blocks_to_swap_in_reverse = []
|
||||
for swapin, swapout in output.blocks_to_swap_in:
|
||||
blocks_to_swap_in_reverse.append((swapout, swapin))
|
||||
assert blocks_to_swap_out == blocks_to_swap_in_reverse
|
||||
|
||||
|
||||
def test_schedule_swapped_max_token_budget():
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(block_size=block_size,
|
||||
num_cpu_blocks=32,
|
||||
num_gpu_blocks=32)
|
||||
curr_loras = None
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
budget = create_token_budget(token_budget=1)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 1
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
# Verify num_batched_tokens are respected.
|
||||
budget = create_token_budget(token_budget=1)
|
||||
add_token_budget(budget, 1, 0)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 1
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(output.decode_seq_groups) == 0
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
|
||||
def test_schedule_swapped_max_seqs():
|
||||
block_size = 4
|
||||
scheduler = initialize_scheduler(block_size=block_size,
|
||||
num_cpu_blocks=64,
|
||||
num_gpu_blocks=64)
|
||||
curr_loras = None
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
for i in range(4):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
block_size=4)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
scheduler._add_seq_group_to_swapped(seq_group)
|
||||
|
||||
budget = create_token_budget(max_num_seqs=2)
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 2
|
||||
assert budget.num_batched_tokens == 2
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(output.decode_seq_groups) == 2
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
# Verify num_curr_seqs are respected.
|
||||
output = scheduler._schedule_swapped(budget, curr_loras)
|
||||
remaining_swapped = scheduler.swapped
|
||||
assert len(remaining_swapped) == 2
|
||||
assert budget.num_batched_tokens == 2
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(output.decode_seq_groups) == 0
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
|
||||
def test_schedule_swapped_max_loras():
|
||||
block_size = 4
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
|
||||
|
||||
@ -290,7 +290,7 @@ def scheduler_running_outputs_builder():
|
||||
|
||||
|
||||
def scheduled_seq_group_builder():
|
||||
return ScheduledSequenceGroup(SequenceGroup("", [], -1),
|
||||
return ScheduledSequenceGroup(SequenceGroup.__new__(SequenceGroup),
|
||||
token_chunk_size=0)
|
||||
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
|
||||
|
||||
|
||||
@ -647,10 +647,24 @@ class LLMEngine:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> SequenceGroup:
|
||||
) -> Optional[SequenceGroup]:
|
||||
"""Add a processed request to the engine's request pool.
|
||||
return the created sequence group.
|
||||
"""
|
||||
if isinstance(params, SamplingParams) and params.n > 1:
|
||||
ParallelSampleSequenceGroup.add_request(
|
||||
request_id,
|
||||
self,
|
||||
params,
|
||||
processed_inputs=processed_inputs,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
)
|
||||
return None
|
||||
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
@ -721,7 +735,7 @@ class LLMEngine:
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> Optional[SequenceGroup]:
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
@ -735,7 +749,7 @@ class LLMEngine:
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> Optional[SequenceGroup]:
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
@ -754,7 +768,7 @@ class LLMEngine:
|
||||
priority: int = 0,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> Optional[SequenceGroup]:
|
||||
) -> None:
|
||||
"""Add a request to the engine's request pool.
|
||||
|
||||
The request is added to the request pool and will be processed by the
|
||||
@ -798,22 +812,6 @@ class LLMEngine:
|
||||
>>> # continue the request processing
|
||||
>>> ...
|
||||
"""
|
||||
|
||||
if isinstance(params, SamplingParams) and params.n > 1:
|
||||
ParallelSampleSequenceGroup.add_request(
|
||||
request_id,
|
||||
self,
|
||||
params,
|
||||
prompt=prompt,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
inputs=inputs,
|
||||
)
|
||||
return None
|
||||
|
||||
if inputs is not None:
|
||||
prompt = inputs
|
||||
assert prompt is not None and params is not None
|
||||
@ -844,7 +842,7 @@ class LLMEngine:
|
||||
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
|
||||
"mm_processor_kwargs")
|
||||
|
||||
return self._add_processed_request(
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
processed_inputs=processed_inputs,
|
||||
params=params,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import List
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
from vllm.core.scheduler import Scheduler
|
||||
@ -6,9 +6,8 @@ from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Sequence,
|
||||
SequenceGroup, SequenceGroupOutput, SequenceOutput,
|
||||
SequenceStatus)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup,
|
||||
SequenceGroupOutput)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
@ -114,104 +113,22 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
outputs: SequenceGroupOutput,
|
||||
is_async: bool) -> None:
|
||||
sampling_params = seq_group.sampling_params
|
||||
if sampling_params.n == 1:
|
||||
# only have one output sample
|
||||
sample = outputs.samples[0]
|
||||
# only have one sequence
|
||||
seq = seq_group.seqs[0]
|
||||
if not is_async:
|
||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||
if sampling_params.detokenize and self.detokenizer:
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, sampling_params)
|
||||
else:
|
||||
new_char_count = 0
|
||||
self.stop_checker.maybe_stop_sequence(
|
||||
seq,
|
||||
new_char_count,
|
||||
sampling_params,
|
||||
lora_req=seq_group.lora_request,
|
||||
)
|
||||
if seq.is_finished():
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_seq(seq)
|
||||
return
|
||||
|
||||
# TODO: Add support for async for beam search
|
||||
assert not is_async
|
||||
|
||||
# Process samples
|
||||
samples = outputs.samples
|
||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
parent_child_dict: Dict[int, List[SequenceOutput]] = {
|
||||
parent_seq.seq_id: []
|
||||
for parent_seq in parent_seqs
|
||||
}
|
||||
for sample in samples:
|
||||
# Guard against a KeyError which can occur if the request was
|
||||
# aborted while the output was generated
|
||||
if (child_list :=
|
||||
parent_child_dict.get(sample.parent_seq_id)) is not None:
|
||||
child_list.append(sample)
|
||||
# List of (child, parent)
|
||||
child_seqs: List[Tuple[Sequence, Sequence]] = []
|
||||
|
||||
# Process the child samples for each parent sequence
|
||||
for parent in parent_seqs:
|
||||
child_samples: List[SequenceOutput] = parent_child_dict[
|
||||
parent.seq_id]
|
||||
if len(child_samples) == 0:
|
||||
# This parent sequence has no children samples. Remove
|
||||
# the parent sequence from the sequence group since it will
|
||||
# not be used in the future iterations.
|
||||
parent.status = SequenceStatus.FINISHED_ABORTED
|
||||
seq_group.remove(parent.seq_id)
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_seq(parent)
|
||||
continue
|
||||
# Fork the parent sequence if there are multiple child samples.
|
||||
for child_sample in child_samples[:-1]:
|
||||
new_child_seq_id: int = next(self.seq_counter)
|
||||
child = parent.fork(new_child_seq_id)
|
||||
child.append_token_id(child_sample.output_token,
|
||||
child_sample.logprobs)
|
||||
child_seqs.append((child, parent))
|
||||
# Continue the parent sequence for the last child sample.
|
||||
# We reuse the parent sequence here to reduce redundant memory
|
||||
# copies, especially when using non-beam search sampling methods.
|
||||
last_child_sample = child_samples[-1]
|
||||
parent.append_token_id(last_child_sample.output_token,
|
||||
last_child_sample.logprobs)
|
||||
child_seqs.append((parent, parent))
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
if sampling_params.detokenize and self.detokenizer:
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, sampling_params)
|
||||
else:
|
||||
new_char_count = 0
|
||||
self.stop_checker.maybe_stop_sequence(
|
||||
seq,
|
||||
new_char_count,
|
||||
sampling_params,
|
||||
lora_req=seq_group.lora_request,
|
||||
)
|
||||
|
||||
# For newly created child sequences, add them to the sequence group
|
||||
# and fork them in block manager if they are not finished.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
# NOTE: we need to fork the new sequences before freeing the
|
||||
# old sequences.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_seq(seq)
|
||||
return
|
||||
sample = outputs.samples[0]
|
||||
seq = seq_group.first_seq
|
||||
if not is_async:
|
||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||
if sampling_params.detokenize and self.detokenizer:
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, sampling_params)
|
||||
else:
|
||||
new_char_count = 0
|
||||
self.stop_checker.maybe_stop_sequence(
|
||||
seq,
|
||||
new_char_count,
|
||||
sampling_params,
|
||||
lora_req=seq_group.lora_request,
|
||||
)
|
||||
if seq.is_finished():
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_seq(seq)
|
||||
|
||||
102
vllm/sequence.py
102
vllm/sequence.py
@ -681,6 +681,7 @@ class SequenceGroup:
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.seqs = seqs
|
||||
self.first_seq = seqs[0]
|
||||
self.arrival_time = arrival_time
|
||||
self.is_single_seq = len(seqs) == 1
|
||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||
@ -705,15 +706,11 @@ class SequenceGroup:
|
||||
|
||||
@property
|
||||
def prompt(self) -> Optional[str]:
|
||||
# All sequences in the group should have the same prompt.
|
||||
# We use the prompt of an arbitrary sequence.
|
||||
return self.seqs[0].prompt
|
||||
return self.first_seq.prompt
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> List[int]:
|
||||
# All sequences in the group should have the same prompt.
|
||||
# We use the prompt of an arbitrary sequence.
|
||||
return self.seqs[0].prompt_token_ids
|
||||
return self.first_seq.prompt_token_ids
|
||||
|
||||
@property
|
||||
def encoder_prompt(self) -> Optional[str]:
|
||||
@ -733,17 +730,11 @@ class SequenceGroup:
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
# All sequences in the group should have the same multi-modal data.
|
||||
# We use the multi-modal data of an arbitrary sequence.
|
||||
return self.seqs[0].multi_modal_data
|
||||
return self.first_seq.multi_modal_data
|
||||
|
||||
@property
|
||||
def mm_processor_kwargs(self) -> Dict[str, Any]:
|
||||
# As with multi-modal data, all sequences in the group should have the
|
||||
# same processor kwargs (i.e., mm_processor_kwargs are optionally
|
||||
# provided per request; note that are independent of whether the model
|
||||
# decoder-only or an encoder-decoder).
|
||||
return self.seqs[0].mm_processor_kwargs
|
||||
return self.first_seq.mm_processor_kwargs
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
@ -808,7 +799,7 @@ class SequenceGroup:
|
||||
# in TPOT, rather than recalculating TTFT (since from the )
|
||||
# POV of the user, there is simply a long generation delay.
|
||||
if (self.metrics.first_token_time is None
|
||||
and self.seqs[0].get_output_len() == 1):
|
||||
and self.first_seq.get_output_len() == 1):
|
||||
self.metrics.first_token_time = time
|
||||
|
||||
def maybe_set_first_scheduled_time(self, time: float) -> None:
|
||||
@ -825,18 +816,7 @@ class SequenceGroup:
|
||||
def get_max_num_running_seqs(self) -> int:
|
||||
"""The maximum number of sequences running in parallel in the remaining
|
||||
lifetime of the request."""
|
||||
if self.sampling_params:
|
||||
n = self.sampling_params.n
|
||||
assert isinstance(n, int)
|
||||
if n > self.num_seqs():
|
||||
# At prompt stage, the sequence group is not yet filled up
|
||||
# and only have one sequence running. However, in the
|
||||
# generation stage, we will have `n` sequences
|
||||
# running.
|
||||
return n
|
||||
# At sampling stages, return the number of actual sequences
|
||||
# that are not finished yet.
|
||||
return self.num_unfinished_seqs()
|
||||
return 0 if self.first_seq.is_finished() else 1
|
||||
|
||||
def get_seqs(
|
||||
self,
|
||||
@ -845,10 +825,7 @@ class SequenceGroup:
|
||||
if status is None:
|
||||
return self.seqs
|
||||
|
||||
if self.is_single_seq:
|
||||
return self.seqs if self.seqs[0].status == status else []
|
||||
|
||||
return [seq for seq in self.seqs if seq.status == status]
|
||||
return self.seqs if self.first_seq.status == status else []
|
||||
|
||||
def is_encoder_decoder(self) -> bool:
|
||||
return self.encoder_seq is not None
|
||||
@ -856,29 +833,20 @@ class SequenceGroup:
|
||||
def get_encoder_seq(self) -> Optional[Sequence]:
|
||||
return self.encoder_seq
|
||||
|
||||
def get_unfinished_seqs(self) -> List[Sequence]:
|
||||
if self.is_single_seq:
|
||||
return self.seqs if not self.seqs[0].is_finished() else []
|
||||
|
||||
return [seq for seq in self.seqs if not seq.is_finished()]
|
||||
|
||||
def get_finished_seqs(self) -> List[Sequence]:
|
||||
if self.is_single_seq:
|
||||
return self.seqs if self.seqs[0].is_finished() else []
|
||||
|
||||
return [seq for seq in self.seqs if seq.is_finished()]
|
||||
return self.seqs if self.first_seq.is_finished() else []
|
||||
|
||||
def update_num_computed_tokens(self, num_new_computed_tokens: int):
|
||||
"""Update number of tokens computed so far."""
|
||||
for seq in self.seqs:
|
||||
if not seq.is_finished():
|
||||
seq.data.update_num_computed_tokens(num_new_computed_tokens)
|
||||
seq = self.first_seq
|
||||
if not seq.is_finished():
|
||||
seq.data.update_num_computed_tokens(num_new_computed_tokens)
|
||||
|
||||
def get_num_uncomputed_tokens(self) -> int:
|
||||
num_uncomputed_tokens = 0
|
||||
for seq in self.seqs:
|
||||
if not seq.is_finished():
|
||||
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
||||
seq = self.first_seq
|
||||
if not seq.is_finished():
|
||||
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
||||
return num_uncomputed_tokens
|
||||
|
||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||
@ -892,46 +860,14 @@ class SequenceGroup:
|
||||
|
||||
return len(self.get_seqs(status))
|
||||
|
||||
def num_unfinished_seqs(self) -> int:
|
||||
if self.is_single_seq:
|
||||
return 1 if not self.seqs[0].is_finished() else 0
|
||||
|
||||
return len(self.get_unfinished_seqs())
|
||||
|
||||
def num_finished_seqs(self) -> int:
|
||||
if self.is_single_seq:
|
||||
return 1 if self.seqs[0].is_finished() else 0
|
||||
|
||||
return len(self.get_finished_seqs())
|
||||
|
||||
def find(self, seq_id: int) -> Sequence:
|
||||
if seq_id not in self.seqs_dict:
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
return self.seqs_dict[seq_id]
|
||||
|
||||
def add(self, seq: Sequence) -> None:
|
||||
if seq.seq_id in self.seqs_dict:
|
||||
raise ValueError(f"Sequence {seq.seq_id} already exists.")
|
||||
self.seqs_dict[seq.seq_id] = seq
|
||||
self.seqs.append(seq)
|
||||
self.is_single_seq = len(self.seqs) == 1
|
||||
|
||||
def remove(self, seq_id: int) -> None:
|
||||
seq = self.seqs_dict.pop(seq_id, None)
|
||||
if seq is None:
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
self.seqs.remove(seq)
|
||||
self.is_single_seq = len(self.seqs) == 1
|
||||
return 1 if self.first_seq.is_finished() else 0
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
if self.is_single_seq:
|
||||
return self.seqs[0].is_finished()
|
||||
|
||||
return all(seq.is_finished() for seq in self.seqs)
|
||||
return self.first_seq.is_finished()
|
||||
|
||||
def is_prefill(self) -> bool:
|
||||
# Every sequence should be in the same stage.
|
||||
return self.seqs[0].is_prefill()
|
||||
return self.first_seq.is_prefill()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceGroup(request_id={self.request_id}, "
|
||||
@ -1455,7 +1391,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
|
||||
for i in range(original_params.n):
|
||||
request_id_i = f"{request_id}_parallel_sample_{i}"
|
||||
group.seq_id_to_index[request_id_i] = i
|
||||
seq_group = engine.add_request(
|
||||
seq_group = engine._add_processed_request(
|
||||
request_id_i,
|
||||
params=params,
|
||||
**kwargs,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user