mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:16:00 +08:00
[Performance] Optimize get_seqs (#7051)
This commit is contained in:
parent
6a11fdfbb8
commit
6ce01f3066
@ -700,5 +700,5 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
|
|
||||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||||
if self.enable_caching:
|
if self.enable_caching:
|
||||||
for seq in seq_group.seqs_dict.values():
|
for seq in seq_group.get_seqs():
|
||||||
self.compute_full_blocks_in_seq(seq)
|
self.compute_full_blocks_in_seq(seq)
|
||||||
|
|||||||
@ -444,6 +444,7 @@ class SequenceGroup:
|
|||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
|
self.seqs = seqs
|
||||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
self.metrics = RequestMetrics(arrival_time=arrival_time,
|
self.metrics = RequestMetrics(arrival_time=arrival_time,
|
||||||
@ -458,25 +459,24 @@ class SequenceGroup:
|
|||||||
self.prompt_adapter_request = prompt_adapter_request
|
self.prompt_adapter_request = prompt_adapter_request
|
||||||
self.encoder_seq = encoder_seq
|
self.encoder_seq = encoder_seq
|
||||||
self.trace_headers = trace_headers
|
self.trace_headers = trace_headers
|
||||||
self._first_seq = next(iter(self.seqs_dict.values()))
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt(self) -> Optional[str]:
|
def prompt(self) -> Optional[str]:
|
||||||
# All sequences in the group should have the same prompt.
|
# All sequences in the group should have the same prompt.
|
||||||
# We use the prompt of an arbitrary sequence.
|
# We use the prompt of an arbitrary sequence.
|
||||||
return self._first_seq.prompt
|
return self.seqs[0].prompt
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt_token_ids(self) -> List[int]:
|
def prompt_token_ids(self) -> List[int]:
|
||||||
# All sequences in the group should have the same prompt.
|
# All sequences in the group should have the same prompt.
|
||||||
# We use the prompt of an arbitrary sequence.
|
# We use the prompt of an arbitrary sequence.
|
||||||
return self._first_seq.prompt_token_ids
|
return self.seqs[0].prompt_token_ids
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||||
# All sequences in the group should have the same multi-modal data.
|
# All sequences in the group should have the same multi-modal data.
|
||||||
# We use the multi-modal data of an arbitrary sequence.
|
# We use the multi-modal data of an arbitrary sequence.
|
||||||
return self._first_seq.multi_modal_data
|
return self.seqs[0].multi_modal_data
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lora_int_id(self) -> int:
|
def lora_int_id(self) -> int:
|
||||||
@ -512,7 +512,7 @@ class SequenceGroup:
|
|||||||
# in TPOT, rather than recalculating TTFT (since from the )
|
# in TPOT, rather than recalculating TTFT (since from the )
|
||||||
# POV of the user, there is simply a long generation delay.
|
# POV of the user, there is simply a long generation delay.
|
||||||
if (self.metrics.first_token_time is None
|
if (self.metrics.first_token_time is None
|
||||||
and self.get_seqs()[0].get_output_len() == 1):
|
and self.seqs[0].get_output_len() == 1):
|
||||||
self.metrics.first_token_time = time
|
self.metrics.first_token_time = time
|
||||||
|
|
||||||
def maybe_set_first_scheduled_time(self, time: float) -> None:
|
def maybe_set_first_scheduled_time(self, time: float) -> None:
|
||||||
@ -548,9 +548,9 @@ class SequenceGroup:
|
|||||||
self,
|
self,
|
||||||
status: Optional[SequenceStatus] = None,
|
status: Optional[SequenceStatus] = None,
|
||||||
) -> List[Sequence]:
|
) -> List[Sequence]:
|
||||||
return list(self.seqs_dict.values()) if status is None else [
|
if status is None:
|
||||||
seq for seq in self.seqs_dict.values() if seq.status == status
|
return self.seqs
|
||||||
]
|
return [seq for seq in self.seqs if seq.status == status]
|
||||||
|
|
||||||
def is_encoder_decoder(self) -> bool:
|
def is_encoder_decoder(self) -> bool:
|
||||||
return self.encoder_seq is not None
|
return self.encoder_seq is not None
|
||||||
@ -559,22 +559,20 @@ class SequenceGroup:
|
|||||||
return self.encoder_seq
|
return self.encoder_seq
|
||||||
|
|
||||||
def get_unfinished_seqs(self) -> List[Sequence]:
|
def get_unfinished_seqs(self) -> List[Sequence]:
|
||||||
return [
|
return [seq for seq in self.seqs if not seq.is_finished()]
|
||||||
seq for seq in self.seqs_dict.values() if not seq.is_finished()
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_finished_seqs(self) -> List[Sequence]:
|
def get_finished_seqs(self) -> List[Sequence]:
|
||||||
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
return [seq for seq in self.seqs if seq.is_finished()]
|
||||||
|
|
||||||
def update_num_computed_tokens(self, num_new_computed_tokens: int):
|
def update_num_computed_tokens(self, num_new_computed_tokens: int):
|
||||||
"""Update number of tokens computed so far."""
|
"""Update number of tokens computed so far."""
|
||||||
for seq in self.seqs_dict.values():
|
for seq in self.seqs:
|
||||||
if not seq.is_finished():
|
if not seq.is_finished():
|
||||||
seq.data.update_num_computed_tokens(num_new_computed_tokens)
|
seq.data.update_num_computed_tokens(num_new_computed_tokens)
|
||||||
|
|
||||||
def get_num_uncomputed_tokens(self) -> int:
|
def get_num_uncomputed_tokens(self) -> int:
|
||||||
num_uncomputed_tokens = 0
|
num_uncomputed_tokens = 0
|
||||||
for seq in self.get_seqs():
|
for seq in self.seqs:
|
||||||
if not seq.is_finished():
|
if not seq.is_finished():
|
||||||
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
||||||
return num_uncomputed_tokens
|
return num_uncomputed_tokens
|
||||||
@ -583,7 +581,7 @@ class SequenceGroup:
|
|||||||
# Optimization. We don't need to call get_seqs if we don't need to
|
# Optimization. We don't need to call get_seqs if we don't need to
|
||||||
# filter by states.
|
# filter by states.
|
||||||
if status is None:
|
if status is None:
|
||||||
return len(self.seqs_dict)
|
return len(self.seqs)
|
||||||
|
|
||||||
return len(self.get_seqs(status))
|
return len(self.get_seqs(status))
|
||||||
|
|
||||||
@ -602,23 +600,25 @@ class SequenceGroup:
|
|||||||
if seq.seq_id in self.seqs_dict:
|
if seq.seq_id in self.seqs_dict:
|
||||||
raise ValueError(f"Sequence {seq.seq_id} already exists.")
|
raise ValueError(f"Sequence {seq.seq_id} already exists.")
|
||||||
self.seqs_dict[seq.seq_id] = seq
|
self.seqs_dict[seq.seq_id] = seq
|
||||||
|
self.seqs.append(seq)
|
||||||
|
|
||||||
def remove(self, seq_id: int) -> None:
|
def remove(self, seq_id: int) -> None:
|
||||||
if seq_id not in self.seqs_dict:
|
seq = self.seqs_dict.pop(seq_id, None)
|
||||||
|
if seq is None:
|
||||||
raise ValueError(f"Sequence {seq_id} not found.")
|
raise ValueError(f"Sequence {seq_id} not found.")
|
||||||
del self.seqs_dict[seq_id]
|
self.seqs.remove(seq)
|
||||||
|
|
||||||
def is_finished(self) -> bool:
|
def is_finished(self) -> bool:
|
||||||
return all(seq.is_finished() for seq in self.get_seqs())
|
return all(seq.is_finished() for seq in self.seqs)
|
||||||
|
|
||||||
def is_prefill(self) -> bool:
|
def is_prefill(self) -> bool:
|
||||||
# Every sequence should be in the same stage.
|
# Every sequence should be in the same stage.
|
||||||
return self.get_seqs()[0].is_prefill()
|
return self.seqs[0].is_prefill()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"SequenceGroup(request_id={self.request_id}, "
|
return (f"SequenceGroup(request_id={self.request_id}, "
|
||||||
f"sampling_params={self.sampling_params}, "
|
f"sampling_params={self.sampling_params}, "
|
||||||
f"num_seqs={len(self.seqs_dict)})")
|
f"num_seqs={len(self.seqs)})")
|
||||||
|
|
||||||
|
|
||||||
class SequenceGroupMetadata:
|
class SequenceGroupMetadata:
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class Detokenizer:
|
|||||||
assert prms is not None
|
assert prms is not None
|
||||||
|
|
||||||
# We can pick any sequence for the prompt.
|
# We can pick any sequence for the prompt.
|
||||||
seq = next(iter(seq_group.seqs_dict.values()))
|
seq = seq_group.get_seqs()[0]
|
||||||
# Only prompt, without the generated token.
|
# Only prompt, without the generated token.
|
||||||
all_token_ids = seq.get_token_ids()
|
all_token_ids = seq.get_token_ids()
|
||||||
prompt_token_ids = all_token_ids[:-1]
|
prompt_token_ids = all_token_ids[:-1]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user