mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:24:54 +08:00
Implement stop strings and best_of (#114)
This commit is contained in:
parent
c3442c1f6f
commit
f746ced08d
@ -80,7 +80,7 @@ class BlockSpaceManager:
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> bool:
|
||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
seq = seq_group.seqs[0]
|
||||
seq = seq_group.get_seqs()[0]
|
||||
num_required_blocks = len(seq.logical_token_blocks)
|
||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
# Use watermark to avoid frequent cache eviction.
|
||||
@ -88,7 +88,7 @@ class BlockSpaceManager:
|
||||
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
# NOTE: Here we assume that all sequences in the group have the same prompt.
|
||||
seq = seq_group.seqs[0]
|
||||
seq = seq_group.get_seqs()[0]
|
||||
|
||||
# Allocate new physical token blocks that will store the prompt tokens.
|
||||
block_table: BlockTable = []
|
||||
@ -99,7 +99,7 @@ class BlockSpaceManager:
|
||||
block_table.append(block)
|
||||
|
||||
# Assign the block table for each sequence.
|
||||
for seq in seq_group.seqs:
|
||||
for seq in seq_group.get_seqs():
|
||||
self.block_tables[seq.seq_id] = block_table.copy()
|
||||
|
||||
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
|
||||
@ -147,7 +147,7 @@ class BlockSpaceManager:
|
||||
# NOTE: Here, we assume that the physical blocks are only shared by
|
||||
# the sequences in the same group.
|
||||
blocks: Set[PhysicalTokenBlock] = set()
|
||||
for seq in seq_group.seqs:
|
||||
for seq in seq_group.get_seqs():
|
||||
if seq.status == SequenceStatus.FINISHED:
|
||||
continue
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
@ -168,7 +168,7 @@ class BlockSpaceManager:
|
||||
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
# CPU block -> GPU block.
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.seqs:
|
||||
for seq in seq_group.get_seqs():
|
||||
if seq.status == SequenceStatus.FINISHED:
|
||||
continue
|
||||
new_block_table: BlockTable = []
|
||||
@ -199,7 +199,7 @@ class BlockSpaceManager:
|
||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
# GPU block -> CPU block.
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.seqs:
|
||||
for seq in seq_group.get_seqs():
|
||||
if seq.status == SequenceStatus.FINISHED:
|
||||
continue
|
||||
new_block_table: BlockTable = []
|
||||
|
||||
@ -73,8 +73,6 @@ class Scheduler:
|
||||
self.waiting: List[SequenceGroup] = []
|
||||
# Sequence groups in the RUNNING state.
|
||||
self.running: List[SequenceGroup] = []
|
||||
# Mapping: request_id -> num_steps.
|
||||
self.num_steps: Dict[str, int] = {}
|
||||
# Sequence groups in the SWAPPED state.
|
||||
self.swapped: List[SequenceGroup] = []
|
||||
|
||||
@ -84,7 +82,6 @@ class Scheduler:
|
||||
|
||||
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||
# Add sequence groups to the waiting queue.
|
||||
assert seq_group.request_id not in self.num_steps
|
||||
self.waiting.append(seq_group)
|
||||
|
||||
def has_unfinished_seqs(self) -> bool:
|
||||
@ -178,7 +175,7 @@ class Scheduler:
|
||||
break
|
||||
|
||||
# If the number of batched tokens exceeds the limit, stop.
|
||||
num_prompt_tokens = seq_group.seqs[0].get_len()
|
||||
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
||||
if (num_batched_tokens + num_prompt_tokens
|
||||
> self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
@ -278,15 +275,8 @@ class Scheduler:
|
||||
) -> List[SequenceGroup]:
|
||||
# Update the running sequences and free blocks.
|
||||
for seq_group in self.running:
|
||||
request_id = seq_group.request_id
|
||||
self.num_steps[request_id] += 1
|
||||
stop_token_ids = seq_group.sampling_params.stop_token_ids
|
||||
|
||||
# Process beam search results before processing the next tokens.
|
||||
for seq in seq_group.seqs:
|
||||
if seq.status == SequenceStatus.FINISHED:
|
||||
continue
|
||||
|
||||
# Process beam search results before processing the new tokens.
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
output = seq_outputs[seq.seq_id]
|
||||
if seq.seq_id != output.parent_seq_id:
|
||||
# The sequence is a fork of the parent sequence (beam search).
|
||||
@ -297,43 +287,27 @@ class Scheduler:
|
||||
parent_seq.fork(seq)
|
||||
self.block_manager.fork(parent_seq, seq)
|
||||
|
||||
# Process the next tokens.
|
||||
for seq in seq_group.seqs:
|
||||
if seq.status == SequenceStatus.FINISHED:
|
||||
continue
|
||||
|
||||
# Process the new tokens.
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
# Append a new token to the sequence.
|
||||
output = seq_outputs[seq.seq_id]
|
||||
seq.append_token(output.output_token, output.logprobs)
|
||||
return self.running.copy()
|
||||
|
||||
# Check if the sequence has generated a stop token.
|
||||
if output.output_token in stop_token_ids:
|
||||
self._free_seq(seq)
|
||||
continue
|
||||
def free_seq(self, seq: Sequence) -> None:
|
||||
seq.status = SequenceStatus.FINISHED
|
||||
self.block_manager.free(seq)
|
||||
|
||||
# Check if the sequence has reached the maximum number of steps.
|
||||
max_num_steps = seq_group.sampling_params.max_tokens
|
||||
if self.num_steps[request_id] == max_num_steps:
|
||||
self._free_seq(seq)
|
||||
continue
|
||||
|
||||
# Update the running sequences.
|
||||
updated = self.running.copy()
|
||||
running: List[SequenceGroup] = []
|
||||
for seq_group in self.running:
|
||||
if seq_group.is_finished():
|
||||
self._free_seq_group(seq_group)
|
||||
else:
|
||||
running.append(seq_group)
|
||||
self.running = running
|
||||
return updated
|
||||
def free_finished_seq_groups(self) -> None:
|
||||
self.running = [
|
||||
seq_group for seq_group in self.running
|
||||
if not seq_group.is_finished()
|
||||
]
|
||||
|
||||
def _allocate(self, seq_group: SequenceGroup) -> None:
|
||||
self.block_manager.allocate(seq_group)
|
||||
for seq in seq_group.seqs:
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
if seq_group.request_id not in self.num_steps:
|
||||
self.num_steps[seq_group.request_id] = 0
|
||||
|
||||
def _append_slot(
|
||||
self,
|
||||
@ -403,13 +377,6 @@ class Scheduler:
|
||||
self._swap_out(seq_group, blocks_to_swap_out)
|
||||
self.swapped.append(seq_group)
|
||||
|
||||
def _free_seq(self, seq: Sequence) -> None:
|
||||
seq.status = SequenceStatus.FINISHED
|
||||
self.block_manager.free(seq)
|
||||
|
||||
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||
del self.num_steps[seq_group.request_id]
|
||||
|
||||
def _swap_in(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
|
||||
@ -123,6 +123,7 @@ if __name__ == "__main__":
|
||||
parallel_config = server_configs[2]
|
||||
distributed_init_method, stage_devices = initialize_cluster(parallel_config)
|
||||
|
||||
server = FastAPIServer(
|
||||
args.use_ray, *server_configs, distributed_init_method, stage_devices)
|
||||
server = FastAPIServer(args.use_ray, *server_configs,
|
||||
distributed_init_method, stage_devices,
|
||||
log_stats=not args.disable_log_stats)
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
|
||||
@ -283,20 +283,20 @@ def _sample_from_prompt(
|
||||
) -> List[int]:
|
||||
if sampling_params.use_beam_search:
|
||||
# Beam search.
|
||||
beam_width = sampling_params.n
|
||||
beam_width = sampling_params.best_of
|
||||
_, next_token_ids = torch.topk(prob, beam_width)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
elif sampling_params.temperature == 0.0:
|
||||
# Greedy sampling.
|
||||
assert sampling_params.n == 1
|
||||
assert sampling_params.best_of == 1
|
||||
next_token_id = torch.argmax(prob)
|
||||
next_token_ids = [next_token_id.item()]
|
||||
else:
|
||||
# Random sampling.
|
||||
# Sample n tokens for the prompt.
|
||||
n = sampling_params.n
|
||||
# Sample `best_of` tokens for the prompt.
|
||||
num_seqs = sampling_params.best_of
|
||||
next_token_ids = torch.multinomial(
|
||||
prob, num_samples=n, replacement=True)
|
||||
prob, num_samples=num_seqs, replacement=True)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
return next_token_ids
|
||||
|
||||
@ -308,7 +308,7 @@ def _sample_from_generation_tokens(
|
||||
seq_logprobs: List[float],
|
||||
sampling_params: SamplingParams,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
# NOTE(woosuk): sampling_params.n can be greater than
|
||||
# NOTE(woosuk): sampling_params.best_of can be greater than
|
||||
# len(seq_ids) because some sequences in the group might have
|
||||
# been already terminated.
|
||||
if sampling_params.use_beam_search:
|
||||
@ -372,7 +372,7 @@ def _sample(
|
||||
seq_ids, sampling_params = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
# Generate the next tokens for a prompt input.
|
||||
assert len(seq_ids) == sampling_params.n
|
||||
assert len(seq_ids) == sampling_params.best_of
|
||||
prob = probs[idx]
|
||||
logprob = logprobs[idx]
|
||||
idx += 1
|
||||
@ -397,7 +397,7 @@ def _sample(
|
||||
|
||||
# Sample the next tokens.
|
||||
seq_logprobs = [
|
||||
input_metadata.seq_data[seq_id].cumulative_logprobs
|
||||
input_metadata.seq_data[seq_id].cumulative_logprob
|
||||
for seq_id in seq_ids]
|
||||
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
||||
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from typing import Dict, List
|
||||
|
||||
from cacheflow.sequence import SequenceGroup
|
||||
|
||||
@ -9,20 +7,23 @@ class CompletionOutput:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: int,
|
||||
text: str,
|
||||
token_ids: List[int],
|
||||
cumulative_logprobs: float,
|
||||
cumulative_logprob: float,
|
||||
logprobs: List[Dict[int, float]],
|
||||
) -> None:
|
||||
self.index = index
|
||||
self.text = text
|
||||
self.token_ids = token_ids
|
||||
self.cumulative_logprobs = cumulative_logprobs
|
||||
self.cumulative_logprob = cumulative_logprob
|
||||
self.logprobs = logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"CompletionOutput(output={self.text!r}, "
|
||||
return (f"CompletionOutput(index={self.index}, "
|
||||
f"text={self.text!r}, "
|
||||
f"token_ids={self.token_ids}, "
|
||||
f"cumulative_logprobs={self.cumulative_logprobs}, "
|
||||
f"cumulative_logprob={self.cumulative_logprob}, "
|
||||
f"logprobs={self.logprobs})")
|
||||
|
||||
|
||||
@ -43,31 +44,32 @@ class RequestOutput:
|
||||
self.done = done
|
||||
|
||||
@staticmethod
|
||||
def from_seq_group(
|
||||
seq_group: SequenceGroup,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
) -> "RequestOutput":
|
||||
outputs: List[CompletionOutput] = []
|
||||
def from_seq_group(seq_group: SequenceGroup) -> "RequestOutput":
|
||||
# Get the top-n sequences.
|
||||
n = seq_group.sampling_params.n
|
||||
seqs = seq_group.get_seqs()
|
||||
for seq in seqs:
|
||||
output_token_ids = seq.data.output_token_ids
|
||||
output_str = tokenizer.decode(output_token_ids,
|
||||
skip_special_tokens=True)
|
||||
seq_logprobs = seq.data.cumulative_logprobs
|
||||
assert n <= len(seqs)
|
||||
sorted_seqs = sorted(
|
||||
seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True)
|
||||
top_n_seqs = sorted_seqs[:n]
|
||||
|
||||
# Create the outputs.
|
||||
outputs: List[CompletionOutput] = []
|
||||
for seq in top_n_seqs:
|
||||
logprobs = seq.output_logprobs
|
||||
if seq_group.sampling_params.logprobs == 0:
|
||||
# NOTE: We need to take care of this case because the sequence
|
||||
# always has the logprobs of the sampled tokens even if the
|
||||
# logprobs are not requested.
|
||||
logprobs = {}
|
||||
output = CompletionOutput(output_str, output_token_ids,
|
||||
seq_logprobs, logprobs)
|
||||
output = CompletionOutput(seqs.index(seq), seq.output_text,
|
||||
seq.get_output_token_ids(),
|
||||
seq.get_cumulative_logprob(), logprobs)
|
||||
outputs.append(output)
|
||||
|
||||
# Every sequence in the sequence group should have the same prompt.
|
||||
prompt = seqs[0].prompt
|
||||
prompt_token_ids = seqs[0].data.prompt_token_ids
|
||||
prompt = top_n_seqs[0].prompt
|
||||
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
|
||||
return RequestOutput(seq_group.request_id, prompt, prompt_token_ids,
|
||||
outputs, seq_group.is_finished())
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
from typing import Set
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
@ -10,8 +10,12 @@ class SamplingParams:
|
||||
In addition, we support beam search, which is not supported by OpenAI.
|
||||
|
||||
Args:
|
||||
n: Number of output sequences to generate from the given prompt. This is
|
||||
regarded as the beam width when using beam search.
|
||||
n: Number of output sequences to return for the given prompt.
|
||||
best_of: Number of output sequences that are generated from the prompt.
|
||||
From these `best_of` sequences, the top `n` sequences are returned.
|
||||
`best_of` must be greater than or equal to `n`. This is treated as
|
||||
the beam width when `use_beam_search` is True. By default, `best_of`
|
||||
is set to `n`.
|
||||
presence_penalty: Float that penalizes new tokens based on whether they
|
||||
appear in the generated text so far. Values > 0 encourage the model
|
||||
to use new tokens, while values < 0 encourage the model to repeat
|
||||
@ -28,7 +32,10 @@ class SamplingParams:
|
||||
top_k: Integer that controls the number of top tokens to consider. Set
|
||||
to -1 to consider all tokens.
|
||||
use_beam_search: Whether to use beam search instead of sampling.
|
||||
stop_token_ids: Set of token IDs that indicate the end of a sequence.
|
||||
stop: List of strings that stop the generation when they are generated.
|
||||
The returned output will not contain the stop strings.
|
||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||
tokens after the EOS token is generated.
|
||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||
logprobs: Number of log probabilities to return per output token.
|
||||
"""
|
||||
@ -36,24 +43,28 @@ class SamplingParams:
|
||||
def __init__(
|
||||
self,
|
||||
n: int = 1,
|
||||
best_of: Optional[int] = None,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
use_beam_search: bool = False,
|
||||
stop_token_ids: Set[int] = set(),
|
||||
stop: Union[str, List[str]] = [],
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: int = 16,
|
||||
logprobs: int = 0,
|
||||
) -> None:
|
||||
self.n = n
|
||||
self.best_of = best_of if best_of is not None else n
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.use_beam_search = use_beam_search
|
||||
self.stop_token_ids = stop_token_ids
|
||||
self.stop = [stop] if isinstance(stop, str) else list(stop)
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_tokens = max_tokens
|
||||
self.logprobs = logprobs
|
||||
|
||||
@ -67,6 +78,9 @@ class SamplingParams:
|
||||
def _verify_args(self) -> None:
|
||||
if self.n < 1:
|
||||
raise ValueError(f"n must be at least 1, got {self.n}.")
|
||||
if self.best_of < self.n:
|
||||
raise ValueError(f"best_of must be greater than or equal to n, "
|
||||
f"got n={self.n} and best_of={self.best_of}.")
|
||||
if not -2.0 <= self.presence_penalty <= 2.0:
|
||||
raise ValueError("presence_penalty must be in [-2, 2], got "
|
||||
f"{self.presence_penalty}.")
|
||||
@ -89,8 +103,9 @@ class SamplingParams:
|
||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
|
||||
def _verity_beam_search(self) -> None:
|
||||
if self.n == 1:
|
||||
raise ValueError("n must be greater than 1 when using beam search.")
|
||||
if self.best_of == 1:
|
||||
raise ValueError("best_of must be greater than 1 when using beam "
|
||||
f"search. Got {self.best_of}.")
|
||||
if self.temperature > 0.0:
|
||||
raise ValueError("temperature must be 0 when using beam search.")
|
||||
if self.top_p < 1.0:
|
||||
@ -99,8 +114,9 @@ class SamplingParams:
|
||||
raise ValueError("top_k must be -1 when using beam search.")
|
||||
|
||||
def _verify_greedy_sampling(self) -> None:
|
||||
if self.n > 1:
|
||||
raise ValueError("n must be 1 when using greedy sampling.")
|
||||
if self.best_of > 1:
|
||||
raise ValueError("best_of must be 1 when using greedy sampling."
|
||||
f"Got {self.best_of}.")
|
||||
if self.top_p < 1.0:
|
||||
raise ValueError("top_p must be 1 when using greedy sampling.")
|
||||
if self.top_k != -1:
|
||||
@ -108,12 +124,14 @@ class SamplingParams:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SamplingParams(n={self.n}, "
|
||||
f"best_of={self.best_of}, "
|
||||
f"presence_penalty={self.presence_penalty}, "
|
||||
f"frequency_penalty={self.frequency_penalty}, "
|
||||
f"temperature={self.temperature}, "
|
||||
f"top_p={self.top_p}, "
|
||||
f"top_k={self.top_k},"
|
||||
f"use_beam_search={self.use_beam_search}, "
|
||||
f"stop_token_ids={self.stop_token_ids}, "
|
||||
f"stop={self.stop}, "
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
f"logprobs={self.logprobs})")
|
||||
|
||||
@ -22,11 +22,18 @@ class SequenceData:
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
|
||||
self.output_token_ids: List[int] = []
|
||||
self.cumulative_logprobs = 0.0
|
||||
self.cumulative_logprob = 0.0
|
||||
|
||||
def append_token(self, token_id: int, logprob: float) -> None:
|
||||
self.output_token_ids.append(token_id)
|
||||
self.cumulative_logprob += logprob
|
||||
|
||||
def get_len(self) -> int:
|
||||
return len(self.output_token_ids) + len(self.prompt_token_ids)
|
||||
|
||||
def get_output_len(self) -> int:
|
||||
return len(self.output_token_ids)
|
||||
|
||||
def get_token_ids(self) -> List[int]:
|
||||
return self.prompt_token_ids + self.output_token_ids
|
||||
|
||||
@ -37,9 +44,9 @@ class SequenceData:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceData("
|
||||
f"prompt={self.prompt}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"output_token_ids={self.output_token_ids})")
|
||||
f"output_token_ids={self.output_token_ids}, "
|
||||
f"cumulative_logprob={self.cumulative_logprob})")
|
||||
|
||||
|
||||
class Sequence:
|
||||
@ -57,6 +64,7 @@ class Sequence:
|
||||
|
||||
self.data = SequenceData(prompt_token_ids)
|
||||
self.output_logprobs: List[Dict[int, float]] = []
|
||||
self.output_text = ""
|
||||
|
||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||
# Initialize the logical token blocks with the prompt token ids.
|
||||
@ -88,18 +96,26 @@ class Sequence:
|
||||
assert token_id in logprobs
|
||||
self._append_tokens_to_blocks([token_id])
|
||||
self.output_logprobs.append(logprobs)
|
||||
self.data.output_token_ids.append(token_id)
|
||||
self.data.cumulative_logprobs += logprobs[token_id]
|
||||
self.data.append_token(token_id, logprobs[token_id])
|
||||
|
||||
def get_len(self) -> int:
|
||||
return self.data.get_len()
|
||||
|
||||
def get_output_len(self) -> int:
|
||||
return self.data.get_output_len()
|
||||
|
||||
def get_token_ids(self) -> List[int]:
|
||||
return self.data.get_token_ids()
|
||||
|
||||
def get_last_token_id(self) -> int:
|
||||
return self.data.get_last_token_id()
|
||||
|
||||
def get_output_token_ids(self) -> List[int]:
|
||||
return self.data.output_token_ids
|
||||
|
||||
def get_cumulative_logprob(self) -> float:
|
||||
return self.data.cumulative_logprob
|
||||
|
||||
def fork(self, child_seq: 'Sequence') -> 'Sequence':
|
||||
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||
|
||||
@ -13,7 +13,7 @@ from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.tokenizer_utils import get_tokenizer
|
||||
from cacheflow.sequence import Sequence, SequenceGroup
|
||||
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from cacheflow.utils import Counter
|
||||
from cacheflow.worker.worker import Worker
|
||||
|
||||
@ -49,7 +49,6 @@ class LLMServer:
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.log_stats = log_stats
|
||||
|
||||
self._verify_args()
|
||||
|
||||
self.tokenizer = get_tokenizer(model_config.model)
|
||||
@ -124,15 +123,11 @@ class LLMServer:
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seqs: List[Sequence] = []
|
||||
for _ in range(sampling_params.n):
|
||||
for _ in range(sampling_params.best_of):
|
||||
seq_id = next(self.seq_counter)
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||
seqs.append(seq)
|
||||
|
||||
# FIXME(woosuk)
|
||||
# Add the EOS token to the stop token list.
|
||||
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
|
||||
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id, seqs, sampling_params,
|
||||
arrival_time)
|
||||
@ -157,18 +152,65 @@ class LLMServer:
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
)
|
||||
# Update the scheduler.
|
||||
updated_seq_groups = self.scheduler.update(output)
|
||||
# Update the scheduler with the model outputs.
|
||||
seq_groups = self.scheduler.update(output)
|
||||
|
||||
# Decode the sequences.
|
||||
self._decode_sequences(seq_groups)
|
||||
# Stop the sequences that meet the stopping criteria.
|
||||
self._stop_sequences(seq_groups)
|
||||
# Free the finished sequence groups.
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
|
||||
# Create the outputs.
|
||||
request_outputs: List[RequestOutput] = []
|
||||
for seq_group in updated_seq_groups:
|
||||
# TODO(woosuk): Batch-decode the outputs for speedup.
|
||||
request_output = RequestOutput.from_seq_group(seq_group,
|
||||
self.tokenizer)
|
||||
for seq_group in seq_groups:
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
return request_outputs
|
||||
|
||||
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
# Batch-decode the sequence outputs.
|
||||
seqs: List[Sequence] = []
|
||||
for seq_group in seq_groups:
|
||||
seqs.extend(seq_group.get_seqs(status=SequenceStatus.RUNNING))
|
||||
output_tokens_per_seq = []
|
||||
for seq in seqs:
|
||||
output_tokens_per_seq.append(seq.get_output_token_ids())
|
||||
output_texts = self.tokenizer.batch_decode(output_tokens_per_seq,
|
||||
skip_special_tokens=True)
|
||||
# Update the sequences with the output texts.
|
||||
for seq, output_text in zip(seqs, output_texts):
|
||||
seq.output_text = output_text
|
||||
|
||||
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
# Stop the sequences.
|
||||
for seq_group in seq_groups:
|
||||
sampling_params = seq_group.sampling_params
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
# Check if the sequence has generated a stop string.
|
||||
stopped = False
|
||||
for stop_str in sampling_params.stop:
|
||||
if seq.output_text.endswith(stop_str):
|
||||
# Truncate the output text so that the stop string is
|
||||
# not included in the output.
|
||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||
self.scheduler.free_seq(seq)
|
||||
stopped = True
|
||||
break
|
||||
if stopped:
|
||||
continue
|
||||
|
||||
# Check if the sequence has reached max_tokens.
|
||||
if seq.get_output_len() == sampling_params.max_tokens:
|
||||
self.scheduler.free_seq(seq)
|
||||
continue
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if not sampling_params.ignore_eos:
|
||||
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
|
||||
self.scheduler.free_seq(seq)
|
||||
continue
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
|
||||
@ -15,9 +15,9 @@ def main(args: argparse.Namespace):
|
||||
("To be or not to be,",
|
||||
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
||||
("What is the meaning of life?",
|
||||
SamplingParams(n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1)),
|
||||
SamplingParams(n=2, best_of=5, temperature=0.8, top_p=0.95, frequency_penalty=0.1)),
|
||||
("It is only with the heart that one can see rightly",
|
||||
SamplingParams(n=3, use_beam_search=True, temperature=0.0)),
|
||||
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
|
||||
]
|
||||
|
||||
# Run the server.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user