mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:15:01 +08:00
[core] move parallel sampling out from vllm core (#9302)
This commit is contained in:
parent
ef7faad1b8
commit
76a5e13270
@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
assert "".join(chunks) == single_output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||
)
|
||||
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
|
||||
"""Streaming for parallel sampling.
|
||||
The tokens from multiple samples, are flattened into a single stream,
|
||||
with an index to indicate which sample the token belongs to.
|
||||
"""
|
||||
|
||||
prompt = "What is an LLM?"
|
||||
n = 3
|
||||
max_tokens = 5
|
||||
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
stream=True)
|
||||
chunks: List[List[str]] = [[] for i in range(n)]
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
index = chunk.choices[0].index
|
||||
text = chunk.choices[0].text
|
||||
chunks[index].append(text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
assert finish_reason_count == n
|
||||
for chunk in chunks:
|
||||
assert len(chunk) == max_tokens
|
||||
print("".join(chunk))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
|
||||
@ -44,8 +44,10 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||
Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||
SequenceGroupOutput, SequenceStatus)
|
||||
ParallelSampleSequenceGroup, Sequence,
|
||||
SequenceGroup, SequenceGroupBase,
|
||||
SequenceGroupMetadata, SequenceGroupOutput,
|
||||
SequenceStatus)
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
init_tracer)
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
@ -474,6 +476,8 @@ class LLMEngine:
|
||||
),
|
||||
))
|
||||
|
||||
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
|
||||
|
||||
def _initialize_kv_caches(self) -> None:
|
||||
"""Initialize the KV cache in the worker(s).
|
||||
|
||||
@ -642,7 +646,10 @@ class LLMEngine:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
) -> SequenceGroup:
|
||||
"""Add a processed request to the engine's request pool.
|
||||
return the created sequence group.
|
||||
"""
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
@ -696,6 +703,8 @@ class LLMEngine:
|
||||
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
|
||||
min_cost_scheduler.add_seq_group(seq_group)
|
||||
|
||||
return seq_group
|
||||
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
self.model_executor.stop_remote_worker_execution_loop()
|
||||
|
||||
@ -711,7 +720,7 @@ class LLMEngine:
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
) -> Optional[SequenceGroup]:
|
||||
...
|
||||
|
||||
@overload
|
||||
@ -725,7 +734,7 @@ class LLMEngine:
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
) -> Optional[SequenceGroup]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
@ -744,7 +753,7 @@ class LLMEngine:
|
||||
priority: int = 0,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> None:
|
||||
) -> Optional[SequenceGroup]:
|
||||
"""Add a request to the engine's request pool.
|
||||
|
||||
The request is added to the request pool and will be processed by the
|
||||
@ -788,6 +797,22 @@ class LLMEngine:
|
||||
>>> # continue the request processing
|
||||
>>> ...
|
||||
"""
|
||||
|
||||
if isinstance(params, SamplingParams) and params.n > 1:
|
||||
ParallelSampleSequenceGroup.add_request(
|
||||
request_id,
|
||||
self,
|
||||
params,
|
||||
prompt=prompt,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
inputs=inputs,
|
||||
)
|
||||
return None
|
||||
|
||||
if inputs is not None:
|
||||
prompt = inputs
|
||||
assert prompt is not None and params is not None
|
||||
@ -818,7 +843,7 @@ class LLMEngine:
|
||||
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
|
||||
"mm_processor_kwargs")
|
||||
|
||||
self._add_processed_request(
|
||||
return self._add_processed_request(
|
||||
request_id=request_id,
|
||||
processed_inputs=processed_inputs,
|
||||
params=params,
|
||||
@ -1135,7 +1160,9 @@ class LLMEngine:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.maybe_set_first_token_time(now)
|
||||
request_output = RequestOutputFactory.create(
|
||||
seq_group, use_cache=self.use_cached_outputs)
|
||||
seq_group,
|
||||
self.seq_id_to_seq_group,
|
||||
use_cache=self.use_cached_outputs)
|
||||
if request_output:
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
@ -1175,7 +1202,9 @@ class LLMEngine:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.maybe_set_first_token_time(now)
|
||||
request_output = RequestOutputFactory.create(
|
||||
seq_group, use_cache=self.use_cached_outputs)
|
||||
seq_group,
|
||||
self.seq_id_to_seq_group,
|
||||
use_cache=self.use_cached_outputs)
|
||||
if request_output:
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
@ -1194,7 +1223,10 @@ class LLMEngine:
|
||||
continue
|
||||
|
||||
request_output = RequestOutputFactory.create(
|
||||
seq_group, use_cache=self.use_cached_outputs)
|
||||
seq_group,
|
||||
self.seq_id_to_seq_group,
|
||||
use_cache=self.use_cached_outputs,
|
||||
)
|
||||
if request_output:
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
||||
SequenceGroup, SequenceStatus)
|
||||
SequenceGroup, SequenceGroupBase, SequenceStatus)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -114,14 +114,28 @@ class RequestOutput:
|
||||
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
||||
|
||||
@classmethod
|
||||
def from_seq_group(cls, seq_group: SequenceGroup,
|
||||
use_cache: bool) -> Optional["RequestOutput"]:
|
||||
def from_seq_group(
|
||||
cls, seq_group: SequenceGroup, use_cache: bool,
|
||||
seq_id_to_seq_group: Dict[str, SequenceGroupBase]
|
||||
) -> Optional["RequestOutput"]:
|
||||
finished = seq_group.is_finished()
|
||||
|
||||
if seq_group.request_id in seq_id_to_seq_group:
|
||||
group: SequenceGroupBase = seq_id_to_seq_group[
|
||||
seq_group.request_id]
|
||||
if finished:
|
||||
group.finish_seq(seq_group)
|
||||
assembled_seq_group = group.maybe_assemble_group(seq_group)
|
||||
if assembled_seq_group is None:
|
||||
return None
|
||||
return cls.from_seq_group(assembled_seq_group, use_cache,
|
||||
seq_id_to_seq_group)
|
||||
|
||||
sampling_params = seq_group.sampling_params
|
||||
if sampling_params is None:
|
||||
raise ValueError(
|
||||
"Sampling parameters are missing for a CompletionRequest.")
|
||||
|
||||
finished = seq_group.is_finished()
|
||||
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
|
||||
not finished):
|
||||
return None
|
||||
@ -136,15 +150,7 @@ class RequestOutput:
|
||||
outputs=[],
|
||||
finished=False)
|
||||
|
||||
seqs = seq_group.get_seqs()
|
||||
if len(seqs) == 1:
|
||||
top_n_seqs = seqs
|
||||
else:
|
||||
# Get the top-n sequences.
|
||||
n = sampling_params._real_n or sampling_params.n
|
||||
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||
top_n_seqs = sorted_seqs[:n]
|
||||
top_n_seqs = seq_group.get_seqs()
|
||||
|
||||
# Create the outputs.
|
||||
# NOTE: We need omit logprobs here explicitly because the sequence
|
||||
@ -208,7 +214,7 @@ class RequestOutput:
|
||||
|
||||
else:
|
||||
output = CompletionOutput(
|
||||
seqs.index(seq), output_text, [output_token_ids]
|
||||
top_n_seqs.index(seq), output_text, [output_token_ids]
|
||||
if isinstance(output_token_ids, int) else output_token_ids,
|
||||
seq.get_cumulative_logprob() if include_logprobs else None,
|
||||
output_logprobs,
|
||||
@ -309,10 +315,13 @@ class EmbeddingRequestOutput:
|
||||
class RequestOutputFactory:
|
||||
|
||||
@staticmethod
|
||||
def create(seq_group: SequenceGroup, use_cache: bool = False):
|
||||
def create(seq_group: SequenceGroup,
|
||||
seq_id_to_seq_group: Dict[str, SequenceGroupBase],
|
||||
use_cache: bool = False):
|
||||
# Determine the type based on a condition, for example:
|
||||
if hasattr(seq_group,
|
||||
'embeddings') and seq_group.embeddings is not None:
|
||||
return EmbeddingRequestOutput.from_seq_group(seq_group)
|
||||
else:
|
||||
return RequestOutput.from_seq_group(seq_group, use_cache)
|
||||
return RequestOutput.from_seq_group(seq_group, use_cache,
|
||||
seq_id_to_seq_group)
|
||||
|
||||
122
vllm/sequence.py
122
vllm/sequence.py
@ -4,7 +4,7 @@ import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from array import array
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, reduce
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
@ -17,7 +17,7 @@ from vllm.inputs.parse import is_encoder_decoder_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -1401,3 +1401,121 @@ class ExecuteModelRequest(
|
||||
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
||||
if self.last_sampled_token_ids is not None else None,
|
||||
async_callback=self.async_callback)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceGroupBase:
|
||||
group_id: str # the original request id before splitting
|
||||
|
||||
assembled_seq_group: Optional[SequenceGroup] = None
|
||||
|
||||
# seq id to a unique index inside this group
|
||||
seq_id_to_index: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
# seq ids to be finished
|
||||
to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict)
|
||||
|
||||
# seq id to finished sequences
|
||||
finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict)
|
||||
|
||||
streaming: bool = False
|
||||
|
||||
output_produced: bool = False
|
||||
|
||||
@staticmethod
|
||||
def add_request(request_id: str, engine, params, *args, **kwargs):
|
||||
"""When we are ready to add a request with request_id and params
|
||||
into the engine, we can split the request into multiple requests.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def finish_seq(self, seq: SequenceGroup):
|
||||
"""The sequence `seq` finishes, we should record the information.
|
||||
"""
|
||||
del self.to_be_finished[seq.request_id]
|
||||
self.finished_reqs[seq.request_id] = seq
|
||||
|
||||
def maybe_assemble_group(
|
||||
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
|
||||
"""Assemble the sequence group, for producing the final
|
||||
output, or adding request in the engine again.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ParallelSampleSequenceGroup(SequenceGroupBase):
|
||||
|
||||
@staticmethod
|
||||
def add_request(request_id: str, engine, params, **kwargs):
|
||||
original_params = params
|
||||
params = copy.deepcopy(original_params)
|
||||
params.n = 1
|
||||
group = ParallelSampleSequenceGroup(request_id)
|
||||
seqs = []
|
||||
for i in range(original_params.n):
|
||||
request_id_i = f"{request_id}_parallel_sample_{i}"
|
||||
group.seq_id_to_index[request_id_i] = i
|
||||
seq_group = engine.add_request(
|
||||
request_id_i,
|
||||
params=params,
|
||||
**kwargs,
|
||||
) # type: ignore
|
||||
assert seq_group is not None
|
||||
engine.seq_id_to_seq_group[request_id_i] = group
|
||||
group.to_be_finished[request_id_i] = seq_group
|
||||
seqs.append(seq_group.seqs[0])
|
||||
|
||||
# for parallel sampling, the `assembled_seq_group` is always
|
||||
# available, since we have all the sequences ready, and they
|
||||
# will not change.
|
||||
group.assembled_seq_group = SequenceGroup(
|
||||
request_id=request_id,
|
||||
seqs=seqs,
|
||||
arrival_time=seq_group.arrival_time,
|
||||
sampling_params=original_params,
|
||||
lora_request=seq_group.lora_request,
|
||||
embeddings=seq_group.embeddings,
|
||||
pooling_params=seq_group.pooling_params,
|
||||
encoder_seq=seq_group.encoder_seq,
|
||||
trace_headers=seq_group.trace_headers,
|
||||
prompt_adapter_request=seq_group.prompt_adapter_request,
|
||||
priority=seq_group.priority,
|
||||
)
|
||||
|
||||
group.streaming = params.output_kind == RequestOutputKind.DELTA
|
||||
group.output_produced = False
|
||||
|
||||
def maybe_assemble_group(
|
||||
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
|
||||
|
||||
# in the streaming mode, we will return the assembled sequence
|
||||
# for the first sequence, and then return None for the rest of
|
||||
# sequences
|
||||
if self.streaming:
|
||||
if self.seq_id_to_index[seq_group.request_id] == 0:
|
||||
return self.assembled_seq_group
|
||||
return None
|
||||
|
||||
# in the non-streaming mode, we will return the assembled sequence
|
||||
# once after all sequences finish, and then return None for the
|
||||
# rest of the time
|
||||
|
||||
if len(self.to_be_finished) > 0:
|
||||
return None
|
||||
|
||||
assert self.assembled_seq_group is not None
|
||||
params = self.assembled_seq_group.sampling_params
|
||||
assert isinstance(params, SamplingParams)
|
||||
if not self.output_produced:
|
||||
self.output_produced = True
|
||||
if params._real_n is not None:
|
||||
# Get the top-n sequences.
|
||||
n = params._real_n or params.n
|
||||
seqs = self.assembled_seq_group.seqs
|
||||
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||
top_n_seqs = sorted_seqs[:n]
|
||||
self.assembled_seq_group.seqs = top_n_seqs
|
||||
return self.assembled_seq_group
|
||||
if self.output_produced:
|
||||
return None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user