mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 17:47:07 +08:00
Add output streaming support to multi-step + async while ensuring RequestOutput obj reuse (#8335)
This commit is contained in:
parent
5f7bb58427
commit
1a2aef3e59
@ -19,7 +19,11 @@ FILTER = "exact_match,strict-match"
|
|||||||
RTOL = 0.03
|
RTOL = 0.03
|
||||||
EXPECTED_VALUE = 0.58
|
EXPECTED_VALUE = 0.58
|
||||||
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
|
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
|
||||||
MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]]
|
MORE_ARGS_LIST = [
|
||||||
|
["--enable-chunked-prefill"], # Chunked
|
||||||
|
["--num-scheduler-steps", "8"], # MS
|
||||||
|
["--num-scheduler-steps", "8", "--multi-step-stream-outputs"] # MS+Stream
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
|
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
|
||||||
|
|||||||
@ -960,6 +960,7 @@ class SchedulerConfig:
|
|||||||
is_multimodal_model: bool = False,
|
is_multimodal_model: bool = False,
|
||||||
preemption_mode: Optional[str] = None,
|
preemption_mode: Optional[str] = None,
|
||||||
num_scheduler_steps: int = 1,
|
num_scheduler_steps: int = 1,
|
||||||
|
multi_step_stream_outputs: bool = False,
|
||||||
send_delta_data: bool = False) -> None:
|
send_delta_data: bool = False) -> None:
|
||||||
if max_num_batched_tokens is None:
|
if max_num_batched_tokens is None:
|
||||||
if enable_chunked_prefill:
|
if enable_chunked_prefill:
|
||||||
@ -1000,6 +1001,7 @@ class SchedulerConfig:
|
|||||||
self.embedding_mode = embedding_mode
|
self.embedding_mode = embedding_mode
|
||||||
self.preemption_mode = preemption_mode
|
self.preemption_mode = preemption_mode
|
||||||
self.num_scheduler_steps = num_scheduler_steps
|
self.num_scheduler_steps = num_scheduler_steps
|
||||||
|
self.multi_step_stream_outputs = multi_step_stream_outputs
|
||||||
self.send_delta_data = send_delta_data
|
self.send_delta_data = send_delta_data
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
|
|||||||
@ -145,6 +145,7 @@ class EngineArgs:
|
|||||||
max_cpu_loras: Optional[int] = None
|
max_cpu_loras: Optional[int] = None
|
||||||
device: str = 'auto'
|
device: str = 'auto'
|
||||||
num_scheduler_steps: int = 1
|
num_scheduler_steps: int = 1
|
||||||
|
multi_step_stream_outputs: bool = False
|
||||||
ray_workers_use_nsight: bool = False
|
ray_workers_use_nsight: bool = False
|
||||||
num_gpu_blocks_override: Optional[int] = None
|
num_gpu_blocks_override: Optional[int] = None
|
||||||
num_lookahead_slots: int = 0
|
num_lookahead_slots: int = 0
|
||||||
@ -595,6 +596,10 @@ class EngineArgs:
|
|||||||
help=('Maximum number of forward steps per '
|
help=('Maximum number of forward steps per '
|
||||||
'scheduler call.'))
|
'scheduler call.'))
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--multi-step-stream-outputs',
|
||||||
|
action='store_true',
|
||||||
|
help='If True, then multi-step will stream outputs for every step')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--scheduler-delay-factor',
|
'--scheduler-delay-factor',
|
||||||
type=float,
|
type=float,
|
||||||
@ -999,6 +1004,7 @@ class EngineArgs:
|
|||||||
is_multimodal_model=model_config.is_multimodal_model,
|
is_multimodal_model=model_config.is_multimodal_model,
|
||||||
preemption_mode=self.preemption_mode,
|
preemption_mode=self.preemption_mode,
|
||||||
num_scheduler_steps=self.num_scheduler_steps,
|
num_scheduler_steps=self.num_scheduler_steps,
|
||||||
|
multi_step_stream_outputs=self.multi_step_stream_outputs,
|
||||||
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
|
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
|
||||||
and parallel_config.use_ray),
|
and parallel_config.use_ray),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -95,7 +95,7 @@ class OutputData(NamedTuple):
|
|||||||
|
|
||||||
class SchedulerContext:
|
class SchedulerContext:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, multi_step_stream_outputs: bool = False):
|
||||||
self.output_queue: Deque[OutputData] = deque()
|
self.output_queue: Deque[OutputData] = deque()
|
||||||
self.request_outputs: List[Union[RequestOutput,
|
self.request_outputs: List[Union[RequestOutput,
|
||||||
EmbeddingRequestOutput]] = []
|
EmbeddingRequestOutput]] = []
|
||||||
@ -103,6 +103,8 @@ class SchedulerContext:
|
|||||||
List[SequenceGroupMetadata]] = None
|
List[SequenceGroupMetadata]] = None
|
||||||
self.scheduler_outputs: Optional[SchedulerOutputs] = None
|
self.scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||||
|
|
||||||
|
self.multi_step_stream_outputs: bool = multi_step_stream_outputs
|
||||||
|
|
||||||
def append_output(self, outputs: List[SamplerOutput],
|
def append_output(self, outputs: List[SamplerOutput],
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
scheduler_outputs: SchedulerOutputs, is_async: bool,
|
scheduler_outputs: SchedulerOutputs, is_async: bool,
|
||||||
@ -219,6 +221,7 @@ class LLMEngine:
|
|||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||||
|
use_cached_outputs: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Initializing an LLM engine (v%s) with config: "
|
"Initializing an LLM engine (v%s) with config: "
|
||||||
@ -234,8 +237,9 @@ class LLMEngine:
|
|||||||
"quantization_param_path=%s, device_config=%s, "
|
"quantization_param_path=%s, device_config=%s, "
|
||||||
"decoding_config=%r, observability_config=%r, "
|
"decoding_config=%r, observability_config=%r, "
|
||||||
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
|
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
|
||||||
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
|
"num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
|
||||||
"use_async_output_proc=%s, mm_processor_kwargs=%s)",
|
"enable_prefix_caching=%s, use_async_output_proc=%s, "
|
||||||
|
"use_cached_outputs=%s, mm_processor_kwargs=%s)",
|
||||||
VLLM_VERSION,
|
VLLM_VERSION,
|
||||||
model_config.model,
|
model_config.model,
|
||||||
speculative_config,
|
speculative_config,
|
||||||
@ -266,8 +270,10 @@ class LLMEngine:
|
|||||||
model_config.served_model_name,
|
model_config.served_model_name,
|
||||||
scheduler_config.use_v2_block_manager,
|
scheduler_config.use_v2_block_manager,
|
||||||
scheduler_config.num_scheduler_steps,
|
scheduler_config.num_scheduler_steps,
|
||||||
|
scheduler_config.multi_step_stream_outputs,
|
||||||
cache_config.enable_prefix_caching,
|
cache_config.enable_prefix_caching,
|
||||||
model_config.use_async_output_proc,
|
model_config.use_async_output_proc,
|
||||||
|
use_cached_outputs,
|
||||||
model_config.mm_processor_kwargs,
|
model_config.mm_processor_kwargs,
|
||||||
)
|
)
|
||||||
# TODO(woosuk): Print more configs in debug mode.
|
# TODO(woosuk): Print more configs in debug mode.
|
||||||
@ -287,6 +293,7 @@ class LLMEngine:
|
|||||||
self.observability_config = observability_config or ObservabilityConfig(
|
self.observability_config = observability_config or ObservabilityConfig(
|
||||||
)
|
)
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
|
self.use_cached_outputs = use_cached_outputs
|
||||||
|
|
||||||
if not self.model_config.skip_tokenizer_init:
|
if not self.model_config.skip_tokenizer_init:
|
||||||
self.tokenizer = self._init_tokenizer()
|
self.tokenizer = self._init_tokenizer()
|
||||||
@ -379,7 +386,8 @@ class LLMEngine:
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.scheduler_contexts = [
|
self.scheduler_contexts = [
|
||||||
SchedulerContext()
|
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
|
||||||
|
multi_step_stream_outputs)
|
||||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -998,7 +1006,8 @@ class LLMEngine:
|
|||||||
|
|
||||||
seq_group = scheduled_seq_group.seq_group
|
seq_group = scheduled_seq_group.seq_group
|
||||||
seq_group.maybe_set_first_token_time(now)
|
seq_group.maybe_set_first_token_time(now)
|
||||||
request_output = RequestOutputFactory.create(seq_group)
|
request_output = RequestOutputFactory.create(
|
||||||
|
seq_group, use_cache=self.use_cached_outputs)
|
||||||
if request_output:
|
if request_output:
|
||||||
ctx.request_outputs.append(request_output)
|
ctx.request_outputs.append(request_output)
|
||||||
|
|
||||||
@ -1019,8 +1028,8 @@ class LLMEngine:
|
|||||||
for scheduler in self.scheduler:
|
for scheduler in self.scheduler:
|
||||||
scheduler.free_finished_seq_groups()
|
scheduler.free_finished_seq_groups()
|
||||||
|
|
||||||
# For multi-step, do not create outputs each iteration
|
# For multi-step without streaming, don't create outputs each iteration
|
||||||
if not is_last_step:
|
if not is_last_step and not ctx.multi_step_stream_outputs:
|
||||||
# Immediately process request outputs here (if callback is given)
|
# Immediately process request outputs here (if callback is given)
|
||||||
if (finished_now
|
if (finished_now
|
||||||
and self.process_request_outputs_callback is not None):
|
and self.process_request_outputs_callback is not None):
|
||||||
@ -1037,17 +1046,27 @@ class LLMEngine:
|
|||||||
|
|
||||||
seq_group = scheduled_seq_group.seq_group
|
seq_group = scheduled_seq_group.seq_group
|
||||||
seq_group.maybe_set_first_token_time(now)
|
seq_group.maybe_set_first_token_time(now)
|
||||||
request_output = RequestOutputFactory.create(seq_group)
|
request_output = RequestOutputFactory.create(
|
||||||
|
seq_group, use_cache=self.use_cached_outputs)
|
||||||
if request_output:
|
if request_output:
|
||||||
ctx.request_outputs.append(request_output)
|
ctx.request_outputs.append(request_output)
|
||||||
|
|
||||||
|
# For multi-step with streaming, create outputs each iteration
|
||||||
|
if not is_last_step and ctx.multi_step_stream_outputs:
|
||||||
|
# Immediately process request outputs here (if callback is given)
|
||||||
|
if self.process_request_outputs_callback is not None:
|
||||||
|
self.process_request_outputs_callback(ctx.request_outputs)
|
||||||
|
ctx.request_outputs.clear()
|
||||||
|
return
|
||||||
|
|
||||||
for seq_group in scheduler_outputs.ignored_seq_groups:
|
for seq_group in scheduler_outputs.ignored_seq_groups:
|
||||||
params = seq_group.sampling_params
|
params = seq_group.sampling_params
|
||||||
if params is not None and params.output_kind == (
|
if params is not None and params.output_kind == (
|
||||||
RequestOutputKind.DELTA) and not seq_group.is_finished():
|
RequestOutputKind.DELTA) and not seq_group.is_finished():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
request_output = RequestOutputFactory.create(seq_group)
|
request_output = RequestOutputFactory.create(
|
||||||
|
seq_group, use_cache=self.use_cached_outputs)
|
||||||
if request_output:
|
if request_output:
|
||||||
ctx.request_outputs.append(request_output)
|
ctx.request_outputs.append(request_output)
|
||||||
|
|
||||||
|
|||||||
@ -66,7 +66,14 @@ class MQLLMEngine:
|
|||||||
*args,
|
*args,
|
||||||
log_requests: bool = True,
|
log_requests: bool = True,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
self.engine = LLMEngine(*args, **kwargs)
|
# For MQLLMEngine, we can use cached outputs, since each new request
|
||||||
|
# output is immediately pickled and send over the socket, which frees
|
||||||
|
# the python object to be reused again.
|
||||||
|
use_cached_outputs = True
|
||||||
|
|
||||||
|
self.engine = LLMEngine(*args,
|
||||||
|
**kwargs,
|
||||||
|
use_cached_outputs=use_cached_outputs)
|
||||||
self.log_requests = log_requests
|
self.log_requests = log_requests
|
||||||
|
|
||||||
self.use_async_sockets = use_async_sockets
|
self.use_async_sockets = use_async_sockets
|
||||||
|
|||||||
@ -114,17 +114,28 @@ class RequestOutput:
|
|||||||
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_seq_group(cls,
|
def from_seq_group(cls, seq_group: SequenceGroup,
|
||||||
seq_group: SequenceGroup) -> Optional["RequestOutput"]:
|
use_cache: bool) -> Optional["RequestOutput"]:
|
||||||
sampling_params = seq_group.sampling_params
|
sampling_params = seq_group.sampling_params
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Sampling parameters are missing for a CompletionRequest.")
|
"Sampling parameters are missing for a CompletionRequest.")
|
||||||
|
|
||||||
finished = seq_group.is_finished()
|
finished = seq_group.is_finished()
|
||||||
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
|
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
|
||||||
not finished):
|
not finished):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Init cache (if needed)
|
||||||
|
if use_cache and seq_group.cached_request_output is None:
|
||||||
|
seq_group.cached_request_output = RequestOutput( # type: ignore
|
||||||
|
request_id="",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=[],
|
||||||
|
prompt_logprobs=None,
|
||||||
|
outputs=[],
|
||||||
|
finished=False)
|
||||||
|
|
||||||
seqs = seq_group.get_seqs()
|
seqs = seq_group.get_seqs()
|
||||||
if len(seqs) == 1:
|
if len(seqs) == 1:
|
||||||
top_n_seqs = seqs
|
top_n_seqs = seqs
|
||||||
@ -149,29 +160,66 @@ class RequestOutput:
|
|||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
include_prompt = True
|
include_prompt = True
|
||||||
for seq in top_n_seqs:
|
for i, seq in enumerate(top_n_seqs):
|
||||||
output_text = seq.get_output_text_to_return(
|
output_text = seq.get_output_text_to_return(
|
||||||
text_buffer_length, delta)
|
text_buffer_length, delta)
|
||||||
|
|
||||||
output_token_ids = seq.get_output_token_ids_to_return(delta)
|
output_token_ids = seq.get_output_token_ids_to_return(delta)
|
||||||
|
num_output_tokens = 1 if isinstance(output_token_ids,
|
||||||
|
int) else len(output_token_ids)
|
||||||
|
|
||||||
output_logprobs = seq.output_logprobs if include_logprobs else None
|
output_logprobs = seq.output_logprobs if include_logprobs else None
|
||||||
|
|
||||||
if delta:
|
if delta:
|
||||||
# Slice logprobs delta if applicable
|
# Slice logprobs delta if applicable
|
||||||
if output_logprobs:
|
if output_logprobs:
|
||||||
output_logprobs = output_logprobs[-len(output_token_ids):]
|
output_logprobs = output_logprobs[-num_output_tokens:]
|
||||||
# Don't include prompt if this is after the first output
|
# Don't include prompt if this is after the first output
|
||||||
# containing decode token ids
|
# containing decode token ids
|
||||||
if include_prompt and seq.get_output_len() > len(
|
if include_prompt and seq.get_output_len() > num_output_tokens:
|
||||||
output_token_ids):
|
|
||||||
include_prompt = False
|
include_prompt = False
|
||||||
|
|
||||||
outputs.append(
|
if use_cache:
|
||||||
CompletionOutput(
|
# Get cached output object
|
||||||
seqs.index(seq), output_text, output_token_ids,
|
cached_outputs = seq_group.cached_request_output.outputs # type: ignore
|
||||||
|
if i >= len(cached_outputs):
|
||||||
|
cached_outputs.append(
|
||||||
|
CompletionOutput(index=i,
|
||||||
|
text="",
|
||||||
|
token_ids=[],
|
||||||
|
cumulative_logprob=None,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None,
|
||||||
|
stop_reason=None))
|
||||||
|
output = cached_outputs[i]
|
||||||
|
|
||||||
|
# Init cached output object
|
||||||
|
assert output.index == i
|
||||||
|
output.text = output_text
|
||||||
|
|
||||||
|
if isinstance(output_token_ids, int):
|
||||||
|
output.token_ids.clear()
|
||||||
|
output.token_ids.append(output_token_ids)
|
||||||
|
else:
|
||||||
|
output.token_ids = output_token_ids
|
||||||
|
|
||||||
|
output.cumulative_logprob = seq.get_cumulative_logprob() \
|
||||||
|
if include_logprobs else None
|
||||||
|
output.logprobs = output_logprobs
|
||||||
|
output.finish_reason = SequenceStatus.get_finished_reason(
|
||||||
|
seq.status)
|
||||||
|
output.stop_reason = seq.stop_reason
|
||||||
|
|
||||||
|
else:
|
||||||
|
output = CompletionOutput(
|
||||||
|
seqs.index(seq), output_text, [output_token_ids]
|
||||||
|
if isinstance(output_token_ids, int) else output_token_ids,
|
||||||
seq.get_cumulative_logprob() if include_logprobs else None,
|
seq.get_cumulative_logprob() if include_logprobs else None,
|
||||||
output_logprobs,
|
output_logprobs,
|
||||||
SequenceStatus.get_finished_reason(seq.status),
|
SequenceStatus.get_finished_reason(seq.status),
|
||||||
seq.stop_reason))
|
seq.stop_reason)
|
||||||
|
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
# Every sequence in the sequence group should have the same prompt.
|
# Every sequence in the sequence group should have the same prompt.
|
||||||
if include_prompt:
|
if include_prompt:
|
||||||
@ -188,16 +236,20 @@ class RequestOutput:
|
|||||||
prompt_logprobs = None
|
prompt_logprobs = None
|
||||||
finished_time = time.time() if finished else None
|
finished_time = time.time() if finished else None
|
||||||
seq_group.set_finished_time(finished_time)
|
seq_group.set_finished_time(finished_time)
|
||||||
return cls(seq_group.request_id,
|
|
||||||
prompt,
|
init_args = (seq_group.request_id, prompt, prompt_token_ids,
|
||||||
prompt_token_ids,
|
prompt_logprobs, outputs, finished, seq_group.metrics,
|
||||||
prompt_logprobs,
|
seq_group.lora_request, encoder_prompt,
|
||||||
outputs,
|
encoder_prompt_token_ids)
|
||||||
finished,
|
|
||||||
seq_group.metrics,
|
if use_cache:
|
||||||
lora_request=seq_group.lora_request,
|
request_output = seq_group.cached_request_output
|
||||||
encoder_prompt=encoder_prompt,
|
request_output.__init__(*init_args) # type: ignore
|
||||||
encoder_prompt_token_ids=encoder_prompt_token_ids)
|
|
||||||
|
else:
|
||||||
|
request_output = cls(*init_args)
|
||||||
|
|
||||||
|
return request_output
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"RequestOutput(request_id={self.request_id}, "
|
return (f"RequestOutput(request_id={self.request_id}, "
|
||||||
@ -261,10 +313,10 @@ class EmbeddingRequestOutput:
|
|||||||
class RequestOutputFactory:
|
class RequestOutputFactory:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(seq_group):
|
def create(seq_group: SequenceGroup, use_cache: bool = False):
|
||||||
# Determine the type based on a condition, for example:
|
# Determine the type based on a condition, for example:
|
||||||
if hasattr(seq_group,
|
if hasattr(seq_group,
|
||||||
'embeddings') and seq_group.embeddings is not None:
|
'embeddings') and seq_group.embeddings is not None:
|
||||||
return EmbeddingRequestOutput.from_seq_group(seq_group)
|
return EmbeddingRequestOutput.from_seq_group(seq_group)
|
||||||
else:
|
else:
|
||||||
return RequestOutput.from_seq_group(seq_group)
|
return RequestOutput.from_seq_group(seq_group, use_cache)
|
||||||
|
|||||||
@ -436,7 +436,7 @@ class Sequence:
|
|||||||
self.stop_reason: Union[int, str, None] = None
|
self.stop_reason: Union[int, str, None] = None
|
||||||
|
|
||||||
# These are used to keep track of delta outputs
|
# These are used to keep track of delta outputs
|
||||||
self._last_token_ids_offset: int = 0
|
self._last_output_token_ids_offset: int = 0
|
||||||
self._last_output_text_offset: int = 0
|
self._last_output_text_offset: int = 0
|
||||||
|
|
||||||
# Used for incremental detokenization
|
# Used for incremental detokenization
|
||||||
@ -499,18 +499,26 @@ class Sequence:
|
|||||||
return self.output_text[last_offset:length]
|
return self.output_text[last_offset:length]
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def get_output_token_ids_to_return(self,
|
def get_output_token_ids_to_return(
|
||||||
delta: bool) -> GenericSequence[int]:
|
self, delta: bool) -> Union[GenericSequence[int], int]:
|
||||||
"""If delta is True, only new tokens since the last call to
|
"""If delta is True, only new tokens since the last call to
|
||||||
this method are returned"""
|
this method are returned"""
|
||||||
if not delta:
|
if not delta:
|
||||||
return self.get_output_token_ids()
|
return self.get_output_token_ids()
|
||||||
length = self.get_output_len()
|
|
||||||
last_offset = self._last_token_ids_offset
|
output_len = self.get_output_len()
|
||||||
if last_offset < length:
|
|
||||||
self._last_token_ids_offset = length
|
# Get the number of new tokens
|
||||||
return self.data._output_token_ids[last_offset:]
|
num_new_tokens = output_len - self._last_output_token_ids_offset
|
||||||
return ()
|
self._last_output_token_ids_offset = output_len
|
||||||
|
|
||||||
|
# Return new tokens
|
||||||
|
if num_new_tokens == 1:
|
||||||
|
# Optimization for single decode token case
|
||||||
|
# (which is what we have most of the time)
|
||||||
|
return self.data._cached_all_token_ids[-1]
|
||||||
|
|
||||||
|
return self.data._cached_all_token_ids[-num_new_tokens:]
|
||||||
|
|
||||||
def hash_of_block(self, logical_idx: int) -> int:
|
def hash_of_block(self, logical_idx: int) -> int:
|
||||||
# TODO This can produce incorrect hash when block size > prompt size
|
# TODO This can produce incorrect hash when block size > prompt size
|
||||||
@ -671,6 +679,8 @@ class SequenceGroup:
|
|||||||
self.encoder_seq = encoder_seq
|
self.encoder_seq = encoder_seq
|
||||||
self.trace_headers = trace_headers
|
self.trace_headers = trace_headers
|
||||||
|
|
||||||
|
self.cached_request_output = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt(self) -> Optional[str]:
|
def prompt(self) -> Optional[str]:
|
||||||
# All sequences in the group should have the same prompt.
|
# All sequences in the group should have the same prompt.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user