mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:45:01 +08:00
[Perf] Support stream interval for reducing host overhead (#27869)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
f9f3b596f3
commit
5d6ce2b960
@ -49,10 +49,15 @@ def _ref_convert_id_to_token(
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
|
"request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
|
||||||
)
|
)
|
||||||
|
@pytest.mark.parametrize("stream_interval", [1, 5, 10])
|
||||||
def test_incremental_detokenization(
|
def test_incremental_detokenization(
|
||||||
request_output_kind: RequestOutputKind, dummy_test_vectors
|
request_output_kind: RequestOutputKind,
|
||||||
|
stream_interval: int,
|
||||||
|
dummy_test_vectors,
|
||||||
):
|
):
|
||||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
output_processor = OutputProcessor(
|
||||||
|
dummy_test_vectors.tokenizer, log_stats=False, stream_interval=stream_interval
|
||||||
|
)
|
||||||
engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens)
|
engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens)
|
||||||
|
|
||||||
# Make N requests.
|
# Make N requests.
|
||||||
@ -104,9 +109,18 @@ def test_incremental_detokenization(
|
|||||||
if request_id not in gen_strings:
|
if request_id not in gen_strings:
|
||||||
gen_strings[request_id] = new_text
|
gen_strings[request_id] = new_text
|
||||||
gen_tokens[request_id] = new_tokens
|
gen_tokens[request_id] = new_tokens
|
||||||
|
if request_output_kind == RequestOutputKind.DELTA:
|
||||||
|
assert len(new_tokens) == 1, f"{len(new_tokens)=}"
|
||||||
else:
|
else:
|
||||||
gen_strings[request_id] += new_text
|
gen_strings[request_id] += new_text
|
||||||
gen_tokens[request_id].extend(new_tokens)
|
gen_tokens[request_id].extend(new_tokens)
|
||||||
|
if (
|
||||||
|
request_output_kind == RequestOutputKind.DELTA
|
||||||
|
and not request_output.finished
|
||||||
|
):
|
||||||
|
assert len(new_tokens) >= stream_interval, (
|
||||||
|
f"{len(new_tokens)=}, {stream_interval=}"
|
||||||
|
)
|
||||||
|
|
||||||
# Confirmed tracked values matches what we expected.
|
# Confirmed tracked values matches what we expected.
|
||||||
for idx, (ref_gen_str, ref_gen_toks) in enumerate(
|
for idx, (ref_gen_str, ref_gen_toks) in enumerate(
|
||||||
|
|||||||
@ -142,6 +142,12 @@ class SchedulerConfig:
|
|||||||
speculative decoding and pipeline parallelism.
|
speculative decoding and pipeline parallelism.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
stream_interval: int = Field(default=1, ge=1)
|
||||||
|
"""The interval (or buffer size) for streaming in terms of token length.
|
||||||
|
A smaller value (1) makes streaming smoother by sending each token immediately,
|
||||||
|
while a larger value (e.g., 10) reduces host overhead and may increase throughput
|
||||||
|
by batching multiple tokens before sending."""
|
||||||
|
|
||||||
def get_scheduler_cls(self) -> type["SchedulerInterface"]:
|
def get_scheduler_cls(self) -> type["SchedulerInterface"]:
|
||||||
if self.scheduler_cls is None:
|
if self.scheduler_cls is None:
|
||||||
if self.async_scheduling:
|
if self.async_scheduling:
|
||||||
|
|||||||
@ -558,6 +558,8 @@ class EngineArgs:
|
|||||||
|
|
||||||
async_scheduling: bool | None = SchedulerConfig.async_scheduling
|
async_scheduling: bool | None = SchedulerConfig.async_scheduling
|
||||||
|
|
||||||
|
stream_interval: int = SchedulerConfig.stream_interval
|
||||||
|
|
||||||
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
|
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
|
||||||
|
|
||||||
kv_offloading_size: float | None = CacheConfig.kv_offloading_size
|
kv_offloading_size: float | None = CacheConfig.kv_offloading_size
|
||||||
@ -1067,6 +1069,9 @@ class EngineArgs:
|
|||||||
scheduler_group.add_argument(
|
scheduler_group.add_argument(
|
||||||
"--async-scheduling", **scheduler_kwargs["async_scheduling"]
|
"--async-scheduling", **scheduler_kwargs["async_scheduling"]
|
||||||
)
|
)
|
||||||
|
scheduler_group.add_argument(
|
||||||
|
"--stream-interval", **scheduler_kwargs["stream_interval"]
|
||||||
|
)
|
||||||
|
|
||||||
# Compilation arguments
|
# Compilation arguments
|
||||||
compilation_kwargs = get_kwargs(CompilationConfig)
|
compilation_kwargs = get_kwargs(CompilationConfig)
|
||||||
@ -1562,6 +1567,7 @@ class EngineArgs:
|
|||||||
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
||||||
disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
|
disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
|
||||||
async_scheduling=self.async_scheduling,
|
async_scheduling=self.async_scheduling,
|
||||||
|
stream_interval=self.stream_interval,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not model_config.is_multimodal_model and self.default_mm_loras:
|
if not model_config.is_multimodal_model and self.default_mm_loras:
|
||||||
|
|||||||
@ -120,8 +120,9 @@ class AsyncLLM(EngineClient):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
|
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
|
||||||
|
stream_interval = self.vllm_config.scheduler_config.stream_interval
|
||||||
self.output_processor = OutputProcessor(
|
self.output_processor = OutputProcessor(
|
||||||
self.tokenizer, log_stats=self.log_stats
|
self.tokenizer, log_stats=self.log_stats, stream_interval=stream_interval
|
||||||
)
|
)
|
||||||
endpoint = self.observability_config.otlp_traces_endpoint
|
endpoint = self.observability_config.otlp_traces_endpoint
|
||||||
if endpoint is not None:
|
if endpoint is not None:
|
||||||
|
|||||||
@ -96,8 +96,9 @@ class LLMEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
|
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
|
||||||
|
stream_interval = self.vllm_config.scheduler_config.stream_interval
|
||||||
self.output_processor = OutputProcessor(
|
self.output_processor = OutputProcessor(
|
||||||
self.tokenizer, log_stats=self.log_stats
|
self.tokenizer, log_stats=self.log_stats, stream_interval=stream_interval
|
||||||
)
|
)
|
||||||
endpoint = self.observability_config.otlp_traces_endpoint
|
endpoint = self.observability_config.otlp_traces_endpoint
|
||||||
if endpoint is not None:
|
if endpoint is not None:
|
||||||
|
|||||||
@ -104,6 +104,7 @@ class RequestState:
|
|||||||
arrival_time: float,
|
arrival_time: float,
|
||||||
queue: RequestOutputCollector | None,
|
queue: RequestOutputCollector | None,
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
|
stream_interval: int,
|
||||||
top_p: float | None = None,
|
top_p: float | None = None,
|
||||||
n: int | None = None,
|
n: int | None = None,
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
@ -131,6 +132,10 @@ class RequestState:
|
|||||||
|
|
||||||
self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
|
self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None
|
||||||
|
|
||||||
|
# Stream Interval
|
||||||
|
self.stream_interval = stream_interval
|
||||||
|
self.sent_tokens_offset = 0 # Offset of sent tokens
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_new_request(
|
def from_new_request(
|
||||||
cls,
|
cls,
|
||||||
@ -141,6 +146,7 @@ class RequestState:
|
|||||||
request_index: int,
|
request_index: int,
|
||||||
queue: RequestOutputCollector | None,
|
queue: RequestOutputCollector | None,
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
|
stream_interval: int,
|
||||||
) -> "RequestState":
|
) -> "RequestState":
|
||||||
if sampling_params := request.sampling_params:
|
if sampling_params := request.sampling_params:
|
||||||
if not sampling_params.detokenize:
|
if not sampling_params.detokenize:
|
||||||
@ -188,6 +194,7 @@ class RequestState:
|
|||||||
arrival_time=request.arrival_time,
|
arrival_time=request.arrival_time,
|
||||||
queue=queue,
|
queue=queue,
|
||||||
log_stats=log_stats,
|
log_stats=log_stats,
|
||||||
|
stream_interval=stream_interval,
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_request_output(
|
def make_request_output(
|
||||||
@ -205,6 +212,29 @@ class RequestState:
|
|||||||
# Only the final output is required in FINAL_ONLY mode.
|
# Only the final output is required in FINAL_ONLY mode.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if self.stream_interval > 1:
|
||||||
|
assert self.detokenizer is not None
|
||||||
|
|
||||||
|
# Send output request only when
|
||||||
|
# 1. It has finished, or
|
||||||
|
# 2. It is the first token, or
|
||||||
|
# 3. It has reached the stream interval number of tokens
|
||||||
|
if not (
|
||||||
|
finished
|
||||||
|
or self.sent_tokens_offset == 0
|
||||||
|
or len(self.detokenizer.output_token_ids) - self.sent_tokens_offset
|
||||||
|
>= self.stream_interval
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.output_kind == RequestOutputKind.DELTA:
|
||||||
|
# Send tokens from the offset in DELTA mode, otherwise all
|
||||||
|
# tokens are sent.
|
||||||
|
new_token_ids = self.detokenizer.output_token_ids[
|
||||||
|
self.sent_tokens_offset :
|
||||||
|
]
|
||||||
|
self.sent_tokens_offset = len(self.detokenizer.output_token_ids)
|
||||||
|
|
||||||
request_id = self.request_id
|
request_id = self.request_id
|
||||||
if pooling_output is not None:
|
if pooling_output is not None:
|
||||||
return self._new_request_output(
|
return self._new_request_output(
|
||||||
@ -310,9 +340,12 @@ class RequestState:
|
|||||||
class OutputProcessor:
|
class OutputProcessor:
|
||||||
"""Process EngineCoreOutputs into RequestOutputs."""
|
"""Process EngineCoreOutputs into RequestOutputs."""
|
||||||
|
|
||||||
def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
|
def __init__(
|
||||||
|
self, tokenizer: AnyTokenizer, log_stats: bool, stream_interval: int = 1
|
||||||
|
):
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
self.stream_interval = stream_interval
|
||||||
self.request_states: dict[str, RequestState] = {}
|
self.request_states: dict[str, RequestState] = {}
|
||||||
self.parent_requests: dict[str, ParentRequest] = {}
|
self.parent_requests: dict[str, ParentRequest] = {}
|
||||||
self.lora_states = LoRARequestStates(log_stats)
|
self.lora_states = LoRARequestStates(log_stats)
|
||||||
@ -385,6 +418,7 @@ class OutputProcessor:
|
|||||||
request_index=request_index,
|
request_index=request_index,
|
||||||
queue=queue,
|
queue=queue,
|
||||||
log_stats=self.log_stats,
|
log_stats=self.log_stats,
|
||||||
|
stream_interval=self.stream_interval,
|
||||||
)
|
)
|
||||||
self.request_states[request_id] = req_state
|
self.request_states[request_id] = req_state
|
||||||
if parent_req:
|
if parent_req:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user