mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:45:44 +08:00
[Perf] Support stream interval for reducing host overhead (#27869)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
f9f3b596f3
commit
5d6ce2b960
@ -49,10 +49,15 @@ def _ref_convert_id_to_token(
|
||||
@pytest.mark.parametrize(
|
||||
"request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
|
||||
)
|
||||
@pytest.mark.parametrize("stream_interval", [1, 5, 10])
|
||||
def test_incremental_detokenization(
|
||||
request_output_kind: RequestOutputKind, dummy_test_vectors
|
||||
request_output_kind: RequestOutputKind,
|
||||
stream_interval: int,
|
||||
dummy_test_vectors,
|
||||
):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
||||
output_processor = OutputProcessor(
|
||||
dummy_test_vectors.tokenizer, log_stats=False, stream_interval=stream_interval
|
||||
)
|
||||
engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens)
|
||||
|
||||
# Make N requests.
|
||||
@ -104,9 +109,18 @@ def test_incremental_detokenization(
|
||||
if request_id not in gen_strings:
|
||||
gen_strings[request_id] = new_text
|
||||
gen_tokens[request_id] = new_tokens
|
||||
if request_output_kind == RequestOutputKind.DELTA:
|
||||
assert len(new_tokens) == 1, f"{len(new_tokens)=}"
|
||||
else:
|
||||
gen_strings[request_id] += new_text
|
||||
gen_tokens[request_id].extend(new_tokens)
|
||||
if (
|
||||
request_output_kind == RequestOutputKind.DELTA
|
||||
and not request_output.finished
|
||||
):
|
||||
assert len(new_tokens) >= stream_interval, (
|
||||
f"{len(new_tokens)=}, {stream_interval=}"
|
||||
)
|
||||
|
||||
# Confirmed tracked values matches what we expected.
|
||||
for idx, (ref_gen_str, ref_gen_toks) in enumerate(
|
||||
|
||||
@ -142,6 +142,12 @@ class SchedulerConfig:
|
||||
speculative decoding and pipeline parallelism.
|
||||
"""
|
||||
|
||||
stream_interval: int = Field(default=1, ge=1)
|
||||
"""The interval (or buffer size) for streaming in terms of token length.
|
||||
A smaller value (1) makes streaming smoother by sending each token immediately,
|
||||
while a larger value (e.g., 10) reduces host overhead and may increase throughput
|
||||
by batching multiple tokens before sending."""
|
||||
|
||||
def get_scheduler_cls(self) -> type["SchedulerInterface"]:
|
||||
if self.scheduler_cls is None:
|
||||
if self.async_scheduling:
|
||||
|
||||
@ -558,6 +558,8 @@ class EngineArgs:
|
||||
|
||||
async_scheduling: bool | None = SchedulerConfig.async_scheduling
|
||||
|
||||
stream_interval: int = SchedulerConfig.stream_interval
|
||||
|
||||
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
|
||||
|
||||
kv_offloading_size: float | None = CacheConfig.kv_offloading_size
|
||||
@ -1067,6 +1069,9 @@ class EngineArgs:
|
||||
scheduler_group.add_argument(
|
||||
"--async-scheduling", **scheduler_kwargs["async_scheduling"]
|
||||
)
|
||||
scheduler_group.add_argument(
|
||||
"--stream-interval", **scheduler_kwargs["stream_interval"]
|
||||
)
|
||||
|
||||
# Compilation arguments
|
||||
compilation_kwargs = get_kwargs(CompilationConfig)
|
||||
@ -1562,6 +1567,7 @@ class EngineArgs:
|
||||
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
||||
disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
|
||||
async_scheduling=self.async_scheduling,
|
||||
stream_interval=self.stream_interval,
|
||||
)
|
||||
|
||||
if not model_config.is_multimodal_model and self.default_mm_loras:
|
||||
|
||||
@ -120,8 +120,9 @@ class AsyncLLM(EngineClient):
|
||||
)
|
||||
|
||||
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
|
||||
stream_interval = self.vllm_config.scheduler_config.stream_interval
|
||||
self.output_processor = OutputProcessor(
|
||||
self.tokenizer, log_stats=self.log_stats
|
||||
self.tokenizer, log_stats=self.log_stats, stream_interval=stream_interval
|
||||
)
|
||||
endpoint = self.observability_config.otlp_traces_endpoint
|
||||
if endpoint is not None:
|
||||
|
||||
@ -96,8 +96,9 @@ class LLMEngine:
|
||||
)
|
||||
|
||||
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
|
||||
stream_interval = self.vllm_config.scheduler_config.stream_interval
|
||||
self.output_processor = OutputProcessor(
|
||||
self.tokenizer, log_stats=self.log_stats
|
||||
self.tokenizer, log_stats=self.log_stats, stream_interval=stream_interval
|
||||
)
|
||||
endpoint = self.observability_config.otlp_traces_endpoint
|
||||
if endpoint is not None:
|
||||
|
||||
@ -104,6 +104,7 @@ class RequestState:
|
||||
arrival_time: float,
|
||||
queue: RequestOutputCollector | None,
|
||||
log_stats: bool,
|
||||
stream_interval: int,
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
temperature: float | None = None,
|
||||
@ -131,6 +132,10 @@ class RequestState:
|
||||
|
||||
self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
|
||||
|
||||
# Stream Interval
|
||||
self.stream_interval = stream_interval
|
||||
self.sent_tokens_offset = 0 # Offset of sent tokens
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
@ -141,6 +146,7 @@ class RequestState:
|
||||
request_index: int,
|
||||
queue: RequestOutputCollector | None,
|
||||
log_stats: bool,
|
||||
stream_interval: int,
|
||||
) -> "RequestState":
|
||||
if sampling_params := request.sampling_params:
|
||||
if not sampling_params.detokenize:
|
||||
@ -188,6 +194,7 @@ class RequestState:
|
||||
arrival_time=request.arrival_time,
|
||||
queue=queue,
|
||||
log_stats=log_stats,
|
||||
stream_interval=stream_interval,
|
||||
)
|
||||
|
||||
def make_request_output(
|
||||
@ -205,6 +212,29 @@ class RequestState:
|
||||
# Only the final output is required in FINAL_ONLY mode.
|
||||
return None
|
||||
|
||||
if self.stream_interval > 1:
|
||||
assert self.detokenizer is not None
|
||||
|
||||
# Send output request only when
|
||||
# 1. It has finished, or
|
||||
# 2. It is the first token, or
|
||||
# 3. It has reached the stream interval number of tokens
|
||||
if not (
|
||||
finished
|
||||
or self.sent_tokens_offset == 0
|
||||
or len(self.detokenizer.output_token_ids) - self.sent_tokens_offset
|
||||
>= self.stream_interval
|
||||
):
|
||||
return None
|
||||
|
||||
if self.output_kind == RequestOutputKind.DELTA:
|
||||
# Send tokens from the offset in DELTA mode, otherwise all
|
||||
# tokens are sent.
|
||||
new_token_ids = self.detokenizer.output_token_ids[
|
||||
self.sent_tokens_offset :
|
||||
]
|
||||
self.sent_tokens_offset = len(self.detokenizer.output_token_ids)
|
||||
|
||||
request_id = self.request_id
|
||||
if pooling_output is not None:
|
||||
return self._new_request_output(
|
||||
@ -310,9 +340,12 @@ class RequestState:
|
||||
class OutputProcessor:
|
||||
"""Process EngineCoreOutputs into RequestOutputs."""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
|
||||
def __init__(
|
||||
self, tokenizer: AnyTokenizer, log_stats: bool, stream_interval: int = 1
|
||||
):
|
||||
self.log_stats = log_stats
|
||||
self.tokenizer = tokenizer
|
||||
self.stream_interval = stream_interval
|
||||
self.request_states: dict[str, RequestState] = {}
|
||||
self.parent_requests: dict[str, ParentRequest] = {}
|
||||
self.lora_states = LoRARequestStates(log_stats)
|
||||
@ -385,6 +418,7 @@ class OutputProcessor:
|
||||
request_index=request_index,
|
||||
queue=queue,
|
||||
log_stats=self.log_stats,
|
||||
stream_interval=self.stream_interval,
|
||||
)
|
||||
self.request_states[request_id] = req_state
|
||||
if parent_req:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user