[Perf] Support stream interval for reducing host overhead (#27869)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
elvischenv 2025-11-14 02:21:25 +08:00 committed by GitHub
parent f9f3b596f3
commit 5d6ce2b960
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 67 additions and 5 deletions

View File

@ -49,10 +49,15 @@ def _ref_convert_id_to_token(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] "request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
) )
@pytest.mark.parametrize("stream_interval", [1, 5, 10])
def test_incremental_detokenization( def test_incremental_detokenization(
request_output_kind: RequestOutputKind, dummy_test_vectors request_output_kind: RequestOutputKind,
stream_interval: int,
dummy_test_vectors,
): ):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) output_processor = OutputProcessor(
dummy_test_vectors.tokenizer, log_stats=False, stream_interval=stream_interval
)
engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens) engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens)
# Make N requests. # Make N requests.
@ -104,9 +109,18 @@ def test_incremental_detokenization(
if request_id not in gen_strings: if request_id not in gen_strings:
gen_strings[request_id] = new_text gen_strings[request_id] = new_text
gen_tokens[request_id] = new_tokens gen_tokens[request_id] = new_tokens
if request_output_kind == RequestOutputKind.DELTA:
assert len(new_tokens) == 1, f"{len(new_tokens)=}"
else: else:
gen_strings[request_id] += new_text gen_strings[request_id] += new_text
gen_tokens[request_id].extend(new_tokens) gen_tokens[request_id].extend(new_tokens)
if (
request_output_kind == RequestOutputKind.DELTA
and not request_output.finished
):
assert len(new_tokens) >= stream_interval, (
f"{len(new_tokens)=}, {stream_interval=}"
)
# Confirmed tracked values matches what we expected. # Confirmed tracked values matches what we expected.
for idx, (ref_gen_str, ref_gen_toks) in enumerate( for idx, (ref_gen_str, ref_gen_toks) in enumerate(

View File

@ -142,6 +142,12 @@ class SchedulerConfig:
speculative decoding and pipeline parallelism. speculative decoding and pipeline parallelism.
""" """
stream_interval: int = Field(default=1, ge=1)
"""The interval (or buffer size) for streaming in terms of token length.
A smaller value (1) makes streaming smoother by sending each token immediately,
while a larger value (e.g., 10) reduces host overhead and may increase throughput
by batching multiple tokens before sending."""
def get_scheduler_cls(self) -> type["SchedulerInterface"]: def get_scheduler_cls(self) -> type["SchedulerInterface"]:
if self.scheduler_cls is None: if self.scheduler_cls is None:
if self.async_scheduling: if self.async_scheduling:

View File

@ -558,6 +558,8 @@ class EngineArgs:
async_scheduling: bool | None = SchedulerConfig.async_scheduling async_scheduling: bool | None = SchedulerConfig.async_scheduling
stream_interval: int = SchedulerConfig.stream_interval
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
kv_offloading_size: float | None = CacheConfig.kv_offloading_size kv_offloading_size: float | None = CacheConfig.kv_offloading_size
@ -1067,6 +1069,9 @@ class EngineArgs:
scheduler_group.add_argument( scheduler_group.add_argument(
"--async-scheduling", **scheduler_kwargs["async_scheduling"] "--async-scheduling", **scheduler_kwargs["async_scheduling"]
) )
scheduler_group.add_argument(
"--stream-interval", **scheduler_kwargs["stream_interval"]
)
# Compilation arguments # Compilation arguments
compilation_kwargs = get_kwargs(CompilationConfig) compilation_kwargs = get_kwargs(CompilationConfig)
@ -1562,6 +1567,7 @@ class EngineArgs:
long_prefill_token_threshold=self.long_prefill_token_threshold, long_prefill_token_threshold=self.long_prefill_token_threshold,
disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager, disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
async_scheduling=self.async_scheduling, async_scheduling=self.async_scheduling,
stream_interval=self.stream_interval,
) )
if not model_config.is_multimodal_model and self.default_mm_loras: if not model_config.is_multimodal_model and self.default_mm_loras:

View File

@ -120,8 +120,9 @@ class AsyncLLM(EngineClient):
) )
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput). # OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
stream_interval = self.vllm_config.scheduler_config.stream_interval
self.output_processor = OutputProcessor( self.output_processor = OutputProcessor(
self.tokenizer, log_stats=self.log_stats self.tokenizer, log_stats=self.log_stats, stream_interval=stream_interval
) )
endpoint = self.observability_config.otlp_traces_endpoint endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None: if endpoint is not None:

View File

@ -96,8 +96,9 @@ class LLMEngine:
) )
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput). # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
stream_interval = self.vllm_config.scheduler_config.stream_interval
self.output_processor = OutputProcessor( self.output_processor = OutputProcessor(
self.tokenizer, log_stats=self.log_stats self.tokenizer, log_stats=self.log_stats, stream_interval=stream_interval
) )
endpoint = self.observability_config.otlp_traces_endpoint endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None: if endpoint is not None:

View File

@ -104,6 +104,7 @@ class RequestState:
arrival_time: float, arrival_time: float,
queue: RequestOutputCollector | None, queue: RequestOutputCollector | None,
log_stats: bool, log_stats: bool,
stream_interval: int,
top_p: float | None = None, top_p: float | None = None,
n: int | None = None, n: int | None = None,
temperature: float | None = None, temperature: float | None = None,
@ -131,6 +132,10 @@ class RequestState:
self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
# Stream Interval
self.stream_interval = stream_interval
self.sent_tokens_offset = 0 # Offset of sent tokens
@classmethod @classmethod
def from_new_request( def from_new_request(
cls, cls,
@ -141,6 +146,7 @@ class RequestState:
request_index: int, request_index: int,
queue: RequestOutputCollector | None, queue: RequestOutputCollector | None,
log_stats: bool, log_stats: bool,
stream_interval: int,
) -> "RequestState": ) -> "RequestState":
if sampling_params := request.sampling_params: if sampling_params := request.sampling_params:
if not sampling_params.detokenize: if not sampling_params.detokenize:
@ -188,6 +194,7 @@ class RequestState:
arrival_time=request.arrival_time, arrival_time=request.arrival_time,
queue=queue, queue=queue,
log_stats=log_stats, log_stats=log_stats,
stream_interval=stream_interval,
) )
def make_request_output( def make_request_output(
@ -205,6 +212,29 @@ class RequestState:
# Only the final output is required in FINAL_ONLY mode. # Only the final output is required in FINAL_ONLY mode.
return None return None
if self.stream_interval > 1:
assert self.detokenizer is not None
# Send output request only when
# 1. It has finished, or
# 2. It is the first token, or
# 3. It has reached the stream interval number of tokens
if not (
finished
or self.sent_tokens_offset == 0
or len(self.detokenizer.output_token_ids) - self.sent_tokens_offset
>= self.stream_interval
):
return None
if self.output_kind == RequestOutputKind.DELTA:
# Send tokens from the offset in DELTA mode, otherwise all
# tokens are sent.
new_token_ids = self.detokenizer.output_token_ids[
self.sent_tokens_offset :
]
self.sent_tokens_offset = len(self.detokenizer.output_token_ids)
request_id = self.request_id request_id = self.request_id
if pooling_output is not None: if pooling_output is not None:
return self._new_request_output( return self._new_request_output(
@ -310,9 +340,12 @@ class RequestState:
class OutputProcessor: class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs.""" """Process EngineCoreOutputs into RequestOutputs."""
def __init__(self, tokenizer: AnyTokenizer, log_stats: bool): def __init__(
self, tokenizer: AnyTokenizer, log_stats: bool, stream_interval: int = 1
):
self.log_stats = log_stats self.log_stats = log_stats
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.stream_interval = stream_interval
self.request_states: dict[str, RequestState] = {} self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {} self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates(log_stats) self.lora_states = LoRARequestStates(log_stats)
@ -385,6 +418,7 @@ class OutputProcessor:
request_index=request_index, request_index=request_index,
queue=queue, queue=queue,
log_stats=self.log_stats, log_stats=self.log_stats,
stream_interval=self.stream_interval,
) )
self.request_states[request_id] = req_state self.request_states[request_id] = req_state
if parent_req: if parent_req: