Add output streaming support to multi-step + async while ensuring RequestOutput obj reuse (#8335)

This commit is contained in:
Alexander Matveev 2024-09-23 18:38:04 -04:00 committed by GitHub
parent 5f7bb58427
commit 1a2aef3e59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 142 additions and 42 deletions

View File

@ -19,7 +19,11 @@ FILTER = "exact_match,strict-match"
RTOL = 0.03 RTOL = 0.03
EXPECTED_VALUE = 0.58 EXPECTED_VALUE = 0.58
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]] MORE_ARGS_LIST = [
["--enable-chunked-prefill"], # Chunked
["--num-scheduler-steps", "8"], # MS
["--num-scheduler-steps", "8", "--multi-step-stream-outputs"] # MS+Stream
]
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST) @pytest.mark.parametrize("more_args", MORE_ARGS_LIST)

View File

@ -960,6 +960,7 @@ class SchedulerConfig:
is_multimodal_model: bool = False, is_multimodal_model: bool = False,
preemption_mode: Optional[str] = None, preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1, num_scheduler_steps: int = 1,
multi_step_stream_outputs: bool = False,
send_delta_data: bool = False) -> None: send_delta_data: bool = False) -> None:
if max_num_batched_tokens is None: if max_num_batched_tokens is None:
if enable_chunked_prefill: if enable_chunked_prefill:
@ -1000,6 +1001,7 @@ class SchedulerConfig:
self.embedding_mode = embedding_mode self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode self.preemption_mode = preemption_mode
self.num_scheduler_steps = num_scheduler_steps self.num_scheduler_steps = num_scheduler_steps
self.multi_step_stream_outputs = multi_step_stream_outputs
self.send_delta_data = send_delta_data self.send_delta_data = send_delta_data
self._verify_args() self._verify_args()

View File

@ -145,6 +145,7 @@ class EngineArgs:
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
device: str = 'auto' device: str = 'auto'
num_scheduler_steps: int = 1 num_scheduler_steps: int = 1
multi_step_stream_outputs: bool = False
ray_workers_use_nsight: bool = False ray_workers_use_nsight: bool = False
num_gpu_blocks_override: Optional[int] = None num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0 num_lookahead_slots: int = 0
@ -595,6 +596,10 @@ class EngineArgs:
help=('Maximum number of forward steps per ' help=('Maximum number of forward steps per '
'scheduler call.')) 'scheduler call.'))
parser.add_argument(
'--multi-step-stream-outputs',
action='store_true',
help='If True, then multi-step will stream outputs for every step')
parser.add_argument( parser.add_argument(
'--scheduler-delay-factor', '--scheduler-delay-factor',
type=float, type=float,
@ -999,6 +1004,7 @@ class EngineArgs:
is_multimodal_model=model_config.is_multimodal_model, is_multimodal_model=model_config.is_multimodal_model,
preemption_mode=self.preemption_mode, preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps, num_scheduler_steps=self.num_scheduler_steps,
multi_step_stream_outputs=self.multi_step_stream_outputs,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray), and parallel_config.use_ray),
) )

View File

@ -95,7 +95,7 @@ class OutputData(NamedTuple):
class SchedulerContext: class SchedulerContext:
def __init__(self): def __init__(self, multi_step_stream_outputs: bool = False):
self.output_queue: Deque[OutputData] = deque() self.output_queue: Deque[OutputData] = deque()
self.request_outputs: List[Union[RequestOutput, self.request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = [] EmbeddingRequestOutput]] = []
@ -103,6 +103,8 @@ class SchedulerContext:
List[SequenceGroupMetadata]] = None List[SequenceGroupMetadata]] = None
self.scheduler_outputs: Optional[SchedulerOutputs] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None
self.multi_step_stream_outputs: bool = multi_step_stream_outputs
def append_output(self, outputs: List[SamplerOutput], def append_output(self, outputs: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
scheduler_outputs: SchedulerOutputs, is_async: bool, scheduler_outputs: SchedulerOutputs, is_async: bool,
@ -219,6 +221,7 @@ class LLMEngine:
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
) -> None: ) -> None:
logger.info( logger.info(
"Initializing an LLM engine (v%s) with config: " "Initializing an LLM engine (v%s) with config: "
@ -234,8 +237,9 @@ class LLMEngine:
"quantization_param_path=%s, device_config=%s, " "quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, " "decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, " "num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
"use_async_output_proc=%s, mm_processor_kwargs=%s)", "enable_prefix_caching=%s, use_async_output_proc=%s, "
"use_cached_outputs=%s, mm_processor_kwargs=%s)",
VLLM_VERSION, VLLM_VERSION,
model_config.model, model_config.model,
speculative_config, speculative_config,
@ -266,8 +270,10 @@ class LLMEngine:
model_config.served_model_name, model_config.served_model_name,
scheduler_config.use_v2_block_manager, scheduler_config.use_v2_block_manager,
scheduler_config.num_scheduler_steps, scheduler_config.num_scheduler_steps,
scheduler_config.multi_step_stream_outputs,
cache_config.enable_prefix_caching, cache_config.enable_prefix_caching,
model_config.use_async_output_proc, model_config.use_async_output_proc,
use_cached_outputs,
model_config.mm_processor_kwargs, model_config.mm_processor_kwargs,
) )
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
@ -287,6 +293,7 @@ class LLMEngine:
self.observability_config = observability_config or ObservabilityConfig( self.observability_config = observability_config or ObservabilityConfig(
) )
self.log_stats = log_stats self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs
if not self.model_config.skip_tokenizer_init: if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer() self.tokenizer = self._init_tokenizer()
@ -379,7 +386,8 @@ class LLMEngine:
] ]
self.scheduler_contexts = [ self.scheduler_contexts = [
SchedulerContext() SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
multi_step_stream_outputs)
for _ in range(self.parallel_config.pipeline_parallel_size) for _ in range(self.parallel_config.pipeline_parallel_size)
] ]
@ -998,7 +1006,8 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group) request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
if request_output: if request_output:
ctx.request_outputs.append(request_output) ctx.request_outputs.append(request_output)
@ -1019,8 +1028,8 @@ class LLMEngine:
for scheduler in self.scheduler: for scheduler in self.scheduler:
scheduler.free_finished_seq_groups() scheduler.free_finished_seq_groups()
# For multi-step, do not create outputs each iteration # For multi-step without streaming, don't create outputs each iteration
if not is_last_step: if not is_last_step and not ctx.multi_step_stream_outputs:
# Immediately process request outputs here (if callback is given) # Immediately process request outputs here (if callback is given)
if (finished_now if (finished_now
and self.process_request_outputs_callback is not None): and self.process_request_outputs_callback is not None):
@ -1037,17 +1046,27 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group) request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
if request_output: if request_output:
ctx.request_outputs.append(request_output) ctx.request_outputs.append(request_output)
# For multi-step with streaming, create outputs each iteration
if not is_last_step and ctx.multi_step_stream_outputs:
# Immediately process request outputs here (if callback is given)
if self.process_request_outputs_callback is not None:
self.process_request_outputs_callback(ctx.request_outputs)
ctx.request_outputs.clear()
return
for seq_group in scheduler_outputs.ignored_seq_groups: for seq_group in scheduler_outputs.ignored_seq_groups:
params = seq_group.sampling_params params = seq_group.sampling_params
if params is not None and params.output_kind == ( if params is not None and params.output_kind == (
RequestOutputKind.DELTA) and not seq_group.is_finished(): RequestOutputKind.DELTA) and not seq_group.is_finished():
continue continue
request_output = RequestOutputFactory.create(seq_group) request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
if request_output: if request_output:
ctx.request_outputs.append(request_output) ctx.request_outputs.append(request_output)

View File

@ -66,7 +66,14 @@ class MQLLMEngine:
*args, *args,
log_requests: bool = True, log_requests: bool = True,
**kwargs) -> None: **kwargs) -> None:
self.engine = LLMEngine(*args, **kwargs) # For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
use_cached_outputs = True
self.engine = LLMEngine(*args,
**kwargs,
use_cached_outputs=use_cached_outputs)
self.log_requests = log_requests self.log_requests = log_requests
self.use_async_sockets = use_async_sockets self.use_async_sockets = use_async_sockets

View File

@ -114,17 +114,28 @@ class RequestOutput:
self.encoder_prompt_token_ids = encoder_prompt_token_ids self.encoder_prompt_token_ids = encoder_prompt_token_ids
@classmethod @classmethod
def from_seq_group(cls, def from_seq_group(cls, seq_group: SequenceGroup,
seq_group: SequenceGroup) -> Optional["RequestOutput"]: use_cache: bool) -> Optional["RequestOutput"]:
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
if sampling_params is None: if sampling_params is None:
raise ValueError( raise ValueError(
"Sampling parameters are missing for a CompletionRequest.") "Sampling parameters are missing for a CompletionRequest.")
finished = seq_group.is_finished() finished = seq_group.is_finished()
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
not finished): not finished):
return None return None
# Init cache (if needed)
if use_cache and seq_group.cached_request_output is None:
seq_group.cached_request_output = RequestOutput( # type: ignore
request_id="",
prompt=None,
prompt_token_ids=[],
prompt_logprobs=None,
outputs=[],
finished=False)
seqs = seq_group.get_seqs() seqs = seq_group.get_seqs()
if len(seqs) == 1: if len(seqs) == 1:
top_n_seqs = seqs top_n_seqs = seqs
@ -149,29 +160,66 @@ class RequestOutput:
outputs = [] outputs = []
include_prompt = True include_prompt = True
for seq in top_n_seqs: for i, seq in enumerate(top_n_seqs):
output_text = seq.get_output_text_to_return( output_text = seq.get_output_text_to_return(
text_buffer_length, delta) text_buffer_length, delta)
output_token_ids = seq.get_output_token_ids_to_return(delta) output_token_ids = seq.get_output_token_ids_to_return(delta)
num_output_tokens = 1 if isinstance(output_token_ids,
int) else len(output_token_ids)
output_logprobs = seq.output_logprobs if include_logprobs else None output_logprobs = seq.output_logprobs if include_logprobs else None
if delta: if delta:
# Slice logprobs delta if applicable # Slice logprobs delta if applicable
if output_logprobs: if output_logprobs:
output_logprobs = output_logprobs[-len(output_token_ids):] output_logprobs = output_logprobs[-num_output_tokens:]
# Don't include prompt if this is after the first output # Don't include prompt if this is after the first output
# containing decode token ids # containing decode token ids
if include_prompt and seq.get_output_len() > len( if include_prompt and seq.get_output_len() > num_output_tokens:
output_token_ids):
include_prompt = False include_prompt = False
outputs.append( if use_cache:
CompletionOutput( # Get cached output object
seqs.index(seq), output_text, output_token_ids, cached_outputs = seq_group.cached_request_output.outputs # type: ignore
if i >= len(cached_outputs):
cached_outputs.append(
CompletionOutput(index=i,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None))
output = cached_outputs[i]
# Init cached output object
assert output.index == i
output.text = output_text
if isinstance(output_token_ids, int):
output.token_ids.clear()
output.token_ids.append(output_token_ids)
else:
output.token_ids = output_token_ids
output.cumulative_logprob = seq.get_cumulative_logprob() \
if include_logprobs else None
output.logprobs = output_logprobs
output.finish_reason = SequenceStatus.get_finished_reason(
seq.status)
output.stop_reason = seq.stop_reason
else:
output = CompletionOutput(
seqs.index(seq), output_text, [output_token_ids]
if isinstance(output_token_ids, int) else output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None, seq.get_cumulative_logprob() if include_logprobs else None,
output_logprobs, output_logprobs,
SequenceStatus.get_finished_reason(seq.status), SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason)) seq.stop_reason)
outputs.append(output)
# Every sequence in the sequence group should have the same prompt. # Every sequence in the sequence group should have the same prompt.
if include_prompt: if include_prompt:
@ -188,16 +236,20 @@ class RequestOutput:
prompt_logprobs = None prompt_logprobs = None
finished_time = time.time() if finished else None finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time) seq_group.set_finished_time(finished_time)
return cls(seq_group.request_id,
prompt, init_args = (seq_group.request_id, prompt, prompt_token_ids,
prompt_token_ids, prompt_logprobs, outputs, finished, seq_group.metrics,
prompt_logprobs, seq_group.lora_request, encoder_prompt,
outputs, encoder_prompt_token_ids)
finished,
seq_group.metrics, if use_cache:
lora_request=seq_group.lora_request, request_output = seq_group.cached_request_output
encoder_prompt=encoder_prompt, request_output.__init__(*init_args) # type: ignore
encoder_prompt_token_ids=encoder_prompt_token_ids)
else:
request_output = cls(*init_args)
return request_output
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, " return (f"RequestOutput(request_id={self.request_id}, "
@ -261,10 +313,10 @@ class EmbeddingRequestOutput:
class RequestOutputFactory: class RequestOutputFactory:
@staticmethod @staticmethod
def create(seq_group): def create(seq_group: SequenceGroup, use_cache: bool = False):
# Determine the type based on a condition, for example: # Determine the type based on a condition, for example:
if hasattr(seq_group, if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None: 'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group) return EmbeddingRequestOutput.from_seq_group(seq_group)
else: else:
return RequestOutput.from_seq_group(seq_group) return RequestOutput.from_seq_group(seq_group, use_cache)

View File

@ -436,7 +436,7 @@ class Sequence:
self.stop_reason: Union[int, str, None] = None self.stop_reason: Union[int, str, None] = None
# These are used to keep track of delta outputs # These are used to keep track of delta outputs
self._last_token_ids_offset: int = 0 self._last_output_token_ids_offset: int = 0
self._last_output_text_offset: int = 0 self._last_output_text_offset: int = 0
# Used for incremental detokenization # Used for incremental detokenization
@ -499,18 +499,26 @@ class Sequence:
return self.output_text[last_offset:length] return self.output_text[last_offset:length]
return "" return ""
def get_output_token_ids_to_return(self, def get_output_token_ids_to_return(
delta: bool) -> GenericSequence[int]: self, delta: bool) -> Union[GenericSequence[int], int]:
"""If delta is True, only new tokens since the last call to """If delta is True, only new tokens since the last call to
this method are returned""" this method are returned"""
if not delta: if not delta:
return self.get_output_token_ids() return self.get_output_token_ids()
length = self.get_output_len()
last_offset = self._last_token_ids_offset output_len = self.get_output_len()
if last_offset < length:
self._last_token_ids_offset = length # Get the number of new tokens
return self.data._output_token_ids[last_offset:] num_new_tokens = output_len - self._last_output_token_ids_offset
return () self._last_output_token_ids_offset = output_len
# Return new tokens
if num_new_tokens == 1:
# Optimization for single decode token case
# (which is what we have most of the time)
return self.data._cached_all_token_ids[-1]
return self.data._cached_all_token_ids[-num_new_tokens:]
def hash_of_block(self, logical_idx: int) -> int: def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size # TODO This can produce incorrect hash when block size > prompt size
@ -671,6 +679,8 @@ class SequenceGroup:
self.encoder_seq = encoder_seq self.encoder_seq = encoder_seq
self.trace_headers = trace_headers self.trace_headers = trace_headers
self.cached_request_output = None
@property @property
def prompt(self) -> Optional[str]: def prompt(self) -> Optional[str]:
# All sequences in the group should have the same prompt. # All sequences in the group should have the same prompt.