[core] move parallel sampling out from vllm core (#9302)

This commit is contained in:
youkaichao 2024-10-21 17:31:44 -07:00 committed by GitHub
parent ef7faad1b8
commit 76a5e13270
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 222 additions and 29 deletions

View File

@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
assert "".join(chunks) == single_output
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling.
The tokens from multiple samples, are flattened into a single stream,
with an index to indicate which sample the token belongs to.
"""
prompt = "What is an LLM?"
n = 3
max_tokens = 5
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
stream=True)
chunks: List[List[str]] = [[] for i in range(n)]
finish_reason_count = 0
async for chunk in stream:
index = chunk.choices[0].index
text = chunk.choices[0].text
chunks[index].append(text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == n
for chunk in chunks:
assert len(chunk) == max_tokens
print("".join(chunk))
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",

View File

@ -44,8 +44,10 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceGroupOutput, SequenceStatus)
ParallelSampleSequenceGroup, Sequence,
SequenceGroup, SequenceGroupBase,
SequenceGroupMetadata, SequenceGroupOutput,
SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
@ -474,6 +476,8 @@ class LLMEngine:
),
))
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
@ -642,7 +646,10 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> None:
) -> SequenceGroup:
"""Add a processed request to the engine's request pool.
return the created sequence group.
"""
self._validate_model_inputs(processed_inputs)
# Create the sequences.
block_size = self.cache_config.block_size
@ -696,6 +703,8 @@ class LLMEngine:
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group)
return seq_group
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
@ -711,7 +720,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
) -> Optional[SequenceGroup]:
...
@overload
@ -725,7 +734,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
) -> Optional[SequenceGroup]:
...
@deprecate_kwargs(
@ -744,7 +753,7 @@ class LLMEngine:
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
) -> Optional[SequenceGroup]:
"""Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the
@ -788,6 +797,22 @@ class LLMEngine:
>>> # continue the request processing
>>> ...
"""
if isinstance(params, SamplingParams) and params.n > 1:
ParallelSampleSequenceGroup.add_request(
request_id,
self,
params,
prompt=prompt,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
inputs=inputs,
)
return None
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
@ -818,7 +843,7 @@ class LLMEngine:
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
"mm_processor_kwargs")
self._add_processed_request(
return self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
params=params,
@ -1135,7 +1160,9 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)
@ -1175,7 +1202,9 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)
@ -1194,7 +1223,10 @@ class LLMEngine:
continue
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs,
)
if request_output:
ctx.request_outputs.append(request_output)

View File

@ -1,13 +1,13 @@
import time
from dataclasses import dataclass
from typing import List, Optional
from typing import Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Union
from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceStatus)
SequenceGroup, SequenceGroupBase, SequenceStatus)
@dataclass
@ -114,14 +114,28 @@ class RequestOutput:
self.encoder_prompt_token_ids = encoder_prompt_token_ids
@classmethod
def from_seq_group(cls, seq_group: SequenceGroup,
use_cache: bool) -> Optional["RequestOutput"]:
def from_seq_group(
cls, seq_group: SequenceGroup, use_cache: bool,
seq_id_to_seq_group: Dict[str, SequenceGroupBase]
) -> Optional["RequestOutput"]:
finished = seq_group.is_finished()
if seq_group.request_id in seq_id_to_seq_group:
group: SequenceGroupBase = seq_id_to_seq_group[
seq_group.request_id]
if finished:
group.finish_seq(seq_group)
assembled_seq_group = group.maybe_assemble_group(seq_group)
if assembled_seq_group is None:
return None
return cls.from_seq_group(assembled_seq_group, use_cache,
seq_id_to_seq_group)
sampling_params = seq_group.sampling_params
if sampling_params is None:
raise ValueError(
"Sampling parameters are missing for a CompletionRequest.")
finished = seq_group.is_finished()
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
not finished):
return None
@ -136,15 +150,7 @@ class RequestOutput:
outputs=[],
finished=False)
seqs = seq_group.get_seqs()
if len(seqs) == 1:
top_n_seqs = seqs
else:
# Get the top-n sequences.
n = sampling_params._real_n or sampling_params.n
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
top_n_seqs = seq_group.get_seqs()
# Create the outputs.
# NOTE: We need omit logprobs here explicitly because the sequence
@ -208,7 +214,7 @@ class RequestOutput:
else:
output = CompletionOutput(
seqs.index(seq), output_text, [output_token_ids]
top_n_seqs.index(seq), output_text, [output_token_ids]
if isinstance(output_token_ids, int) else output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
output_logprobs,
@ -309,10 +315,13 @@ class EmbeddingRequestOutput:
class RequestOutputFactory:
@staticmethod
def create(seq_group: SequenceGroup, use_cache: bool = False):
def create(seq_group: SequenceGroup,
seq_id_to_seq_group: Dict[str, SequenceGroupBase],
use_cache: bool = False):
# Determine the type based on a condition, for example:
if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group)
else:
return RequestOutput.from_seq_group(seq_group, use_cache)
return RequestOutput.from_seq_group(seq_group, use_cache,
seq_id_to_seq_group)

View File

@ -4,7 +4,7 @@ import enum
from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import cached_property, reduce
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
@ -17,7 +17,7 @@ from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if TYPE_CHECKING:
@ -1401,3 +1401,121 @@ class ExecuteModelRequest(
last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None,
async_callback=self.async_callback)
@dataclass
class SequenceGroupBase:
group_id: str # the original request id before splitting
assembled_seq_group: Optional[SequenceGroup] = None
# seq id to a unique index inside this group
seq_id_to_index: Dict[str, int] = field(default_factory=dict)
# seq ids to be finished
to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict)
# seq id to finished sequences
finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict)
streaming: bool = False
output_produced: bool = False
@staticmethod
def add_request(request_id: str, engine, params, *args, **kwargs):
"""When we are ready to add a request with request_id and params
into the engine, we can split the request into multiple requests.
"""
raise NotImplementedError
def finish_seq(self, seq: SequenceGroup):
"""The sequence `seq` finishes, we should record the information.
"""
del self.to_be_finished[seq.request_id]
self.finished_reqs[seq.request_id] = seq
def maybe_assemble_group(
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
"""Assemble the sequence group, for producing the final
output, or adding request in the engine again.
"""
raise NotImplementedError
class ParallelSampleSequenceGroup(SequenceGroupBase):
@staticmethod
def add_request(request_id: str, engine, params, **kwargs):
original_params = params
params = copy.deepcopy(original_params)
params.n = 1
group = ParallelSampleSequenceGroup(request_id)
seqs = []
for i in range(original_params.n):
request_id_i = f"{request_id}_parallel_sample_{i}"
group.seq_id_to_index[request_id_i] = i
seq_group = engine.add_request(
request_id_i,
params=params,
**kwargs,
) # type: ignore
assert seq_group is not None
engine.seq_id_to_seq_group[request_id_i] = group
group.to_be_finished[request_id_i] = seq_group
seqs.append(seq_group.seqs[0])
# for parallel sampling, the `assembled_seq_group` is always
# available, since we have all the sequences ready, and they
# will not change.
group.assembled_seq_group = SequenceGroup(
request_id=request_id,
seqs=seqs,
arrival_time=seq_group.arrival_time,
sampling_params=original_params,
lora_request=seq_group.lora_request,
embeddings=seq_group.embeddings,
pooling_params=seq_group.pooling_params,
encoder_seq=seq_group.encoder_seq,
trace_headers=seq_group.trace_headers,
prompt_adapter_request=seq_group.prompt_adapter_request,
priority=seq_group.priority,
)
group.streaming = params.output_kind == RequestOutputKind.DELTA
group.output_produced = False
def maybe_assemble_group(
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
# in the streaming mode, we will return the assembled sequence
# for the first sequence, and then return None for the rest of
# sequences
if self.streaming:
if self.seq_id_to_index[seq_group.request_id] == 0:
return self.assembled_seq_group
return None
# in the non-streaming mode, we will return the assembled sequence
# once after all sequences finish, and then return None for the
# rest of the time
if len(self.to_be_finished) > 0:
return None
assert self.assembled_seq_group is not None
params = self.assembled_seq_group.sampling_params
assert isinstance(params, SamplingParams)
if not self.output_produced:
self.output_produced = True
if params._real_n is not None:
# Get the top-n sequences.
n = params._real_n or params.n
seqs = self.assembled_seq_group.seqs
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
self.assembled_seq_group.seqs = top_n_seqs
return self.assembled_seq_group
if self.output_produced:
return None