Implement stop strings and best_of (#114)

This commit is contained in:
Woosuk Kwon 2023-05-21 11:18:00 -07:00 committed by GitHub
parent c3442c1f6f
commit f746ced08d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 162 additions and 116 deletions

View File

@ -80,7 +80,7 @@ class BlockSpaceManager:
def can_allocate(self, seq_group: SequenceGroup) -> bool:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.seqs[0]
seq = seq_group.get_seqs()[0]
num_required_blocks = len(seq.logical_token_blocks)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction.
@ -88,7 +88,7 @@ class BlockSpaceManager:
def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same prompt.
seq = seq_group.seqs[0]
seq = seq_group.get_seqs()[0]
# Allocate new physical token blocks that will store the prompt tokens.
block_table: BlockTable = []
@ -99,7 +99,7 @@ class BlockSpaceManager:
block_table.append(block)
# Assign the block table for each sequence.
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
self.block_tables[seq.seq_id] = block_table.copy()
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
@ -147,7 +147,7 @@ class BlockSpaceManager:
# NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set()
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
if seq.status == SequenceStatus.FINISHED:
continue
block_table = self.block_tables[seq.seq_id]
@ -168,7 +168,7 @@ class BlockSpaceManager:
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
# CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
if seq.status == SequenceStatus.FINISHED:
continue
new_block_table: BlockTable = []
@ -199,7 +199,7 @@ class BlockSpaceManager:
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
# GPU block -> CPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
if seq.status == SequenceStatus.FINISHED:
continue
new_block_table: BlockTable = []

View File

@ -73,8 +73,6 @@ class Scheduler:
self.waiting: List[SequenceGroup] = []
# Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = []
# Mapping: request_id -> num_steps.
self.num_steps: Dict[str, int] = {}
# Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = []
@ -84,7 +82,6 @@ class Scheduler:
def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
assert seq_group.request_id not in self.num_steps
self.waiting.append(seq_group)
def has_unfinished_seqs(self) -> bool:
@ -178,7 +175,7 @@ class Scheduler:
break
# If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens = seq_group.seqs[0].get_len()
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if (num_batched_tokens + num_prompt_tokens
> self.scheduler_config.max_num_batched_tokens):
break
@ -278,15 +275,8 @@ class Scheduler:
) -> List[SequenceGroup]:
# Update the running sequences and free blocks.
for seq_group in self.running:
request_id = seq_group.request_id
self.num_steps[request_id] += 1
stop_token_ids = seq_group.sampling_params.stop_token_ids
# Process beam search results before processing the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
# Process beam search results before processing the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam search).
@ -297,43 +287,27 @@ class Scheduler:
parent_seq.fork(seq)
self.block_manager.fork(parent_seq, seq)
# Process the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
# Process the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append_token(output.output_token, output.logprobs)
return self.running.copy()
# Check if the sequence has generated a stop token.
if output.output_token in stop_token_ids:
self._free_seq(seq)
continue
def free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq)
# Check if the sequence has reached the maximum number of steps.
max_num_steps = seq_group.sampling_params.max_tokens
if self.num_steps[request_id] == max_num_steps:
self._free_seq(seq)
continue
# Update the running sequences.
updated = self.running.copy()
running: List[SequenceGroup] = []
for seq_group in self.running:
if seq_group.is_finished():
self._free_seq_group(seq_group)
else:
running.append(seq_group)
self.running = running
return updated
def free_finished_seq_groups(self) -> None:
self.running = [
seq_group for seq_group in self.running
if not seq_group.is_finished()
]
def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.seqs:
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.RUNNING
if seq_group.request_id not in self.num_steps:
self.num_steps[seq_group.request_id] = 0
def _append_slot(
self,
@ -403,13 +377,6 @@ class Scheduler:
self._swap_out(seq_group, blocks_to_swap_out)
self.swapped.append(seq_group)
def _free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq)
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
del self.num_steps[seq_group.request_id]
def _swap_in(
self,
seq_group: SequenceGroup,

View File

@ -123,6 +123,7 @@ if __name__ == "__main__":
parallel_config = server_configs[2]
distributed_init_method, stage_devices = initialize_cluster(parallel_config)
server = FastAPIServer(
args.use_ray, *server_configs, distributed_init_method, stage_devices)
server = FastAPIServer(args.use_ray, *server_configs,
distributed_init_method, stage_devices,
log_stats=not args.disable_log_stats)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

View File

@ -283,20 +283,20 @@ def _sample_from_prompt(
) -> List[int]:
if sampling_params.use_beam_search:
# Beam search.
beam_width = sampling_params.n
beam_width = sampling_params.best_of
_, next_token_ids = torch.topk(prob, beam_width)
next_token_ids = next_token_ids.tolist()
elif sampling_params.temperature == 0.0:
# Greedy sampling.
assert sampling_params.n == 1
assert sampling_params.best_of == 1
next_token_id = torch.argmax(prob)
next_token_ids = [next_token_id.item()]
else:
# Random sampling.
# Sample n tokens for the prompt.
n = sampling_params.n
# Sample `best_of` tokens for the prompt.
num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial(
prob, num_samples=n, replacement=True)
prob, num_samples=num_seqs, replacement=True)
next_token_ids = next_token_ids.tolist()
return next_token_ids
@ -308,7 +308,7 @@ def _sample_from_generation_tokens(
seq_logprobs: List[float],
sampling_params: SamplingParams,
) -> Tuple[List[int], List[int]]:
# NOTE(woosuk): sampling_params.n can be greater than
# NOTE(woosuk): sampling_params.best_of can be greater than
# len(seq_ids) because some sequences in the group might have
# been already terminated.
if sampling_params.use_beam_search:
@ -372,7 +372,7 @@ def _sample(
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
# Generate the next tokens for a prompt input.
assert len(seq_ids) == sampling_params.n
assert len(seq_ids) == sampling_params.best_of
prob = probs[idx]
logprob = logprobs[idx]
idx += 1
@ -397,7 +397,7 @@ def _sample(
# Sample the next tokens.
seq_logprobs = [
input_metadata.seq_data[seq_id].cumulative_logprobs
input_metadata.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params)

View File

@ -1,6 +1,4 @@
from typing import Dict, List, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing import Dict, List
from cacheflow.sequence import SequenceGroup
@ -9,20 +7,23 @@ class CompletionOutput:
def __init__(
self,
index: int,
text: str,
token_ids: List[int],
cumulative_logprobs: float,
cumulative_logprob: float,
logprobs: List[Dict[int, float]],
) -> None:
self.index = index
self.text = text
self.token_ids = token_ids
self.cumulative_logprobs = cumulative_logprobs
self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs
def __repr__(self) -> str:
return (f"CompletionOutput(output={self.text!r}, "
return (f"CompletionOutput(index={self.index}, "
f"text={self.text!r}, "
f"token_ids={self.token_ids}, "
f"cumulative_logprobs={self.cumulative_logprobs}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"logprobs={self.logprobs})")
@ -43,31 +44,32 @@ class RequestOutput:
self.done = done
@staticmethod
def from_seq_group(
seq_group: SequenceGroup,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> "RequestOutput":
outputs: List[CompletionOutput] = []
def from_seq_group(seq_group: SequenceGroup) -> "RequestOutput":
# Get the top-n sequences.
n = seq_group.sampling_params.n
seqs = seq_group.get_seqs()
for seq in seqs:
output_token_ids = seq.data.output_token_ids
output_str = tokenizer.decode(output_token_ids,
skip_special_tokens=True)
seq_logprobs = seq.data.cumulative_logprobs
assert n <= len(seqs)
sorted_seqs = sorted(
seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True)
top_n_seqs = sorted_seqs[:n]
# Create the outputs.
outputs: List[CompletionOutput] = []
for seq in top_n_seqs:
logprobs = seq.output_logprobs
if seq_group.sampling_params.logprobs == 0:
# NOTE: We need to take care of this case because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
logprobs = {}
output = CompletionOutput(output_str, output_token_ids,
seq_logprobs, logprobs)
output = CompletionOutput(seqs.index(seq), seq.output_text,
seq.get_output_token_ids(),
seq.get_cumulative_logprob(), logprobs)
outputs.append(output)
# Every sequence in the sequence group should have the same prompt.
prompt = seqs[0].prompt
prompt_token_ids = seqs[0].data.prompt_token_ids
prompt = top_n_seqs[0].prompt
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
return RequestOutput(seq_group.request_id, prompt, prompt_token_ids,
outputs, seq_group.is_finished())

View File

@ -1,5 +1,5 @@
"""Sampling parameters for text generation."""
from typing import Set
from typing import List, Optional, Union
class SamplingParams:
@ -10,8 +10,12 @@ class SamplingParams:
In addition, we support beam search, which is not supported by OpenAI.
Args:
n: Number of output sequences to generate from the given prompt. This is
regarded as the beam width when using beam search.
n: Number of output sequences to return for the given prompt.
best_of: Number of output sequences that are generated from the prompt.
From these `best_of` sequences, the top `n` sequences are returned.
`best_of` must be greater than or equal to `n`. This is treated as
the beam width when `use_beam_search` is True. By default, `best_of`
is set to `n`.
presence_penalty: Float that penalizes new tokens based on whether they
appear in the generated text so far. Values > 0 encourage the model
to use new tokens, while values < 0 encourage the model to repeat
@ -28,7 +32,10 @@ class SamplingParams:
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
use_beam_search: Whether to use beam search instead of sampling.
stop_token_ids: Set of token IDs that indicate the end of a sequence.
stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
logprobs: Number of log probabilities to return per output token.
"""
@ -36,24 +43,28 @@ class SamplingParams:
def __init__(
self,
n: int = 1,
best_of: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
use_beam_search: bool = False,
stop_token_ids: Set[int] = set(),
stop: Union[str, List[str]] = [],
ignore_eos: bool = False,
max_tokens: int = 16,
logprobs: int = 0,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.use_beam_search = use_beam_search
self.stop_token_ids = stop_token_ids
self.stop = [stop] if isinstance(stop, str) else list(stop)
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.logprobs = logprobs
@ -67,6 +78,9 @@ class SamplingParams:
def _verify_args(self) -> None:
if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.")
if self.best_of < self.n:
raise ValueError(f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
if not -2.0 <= self.presence_penalty <= 2.0:
raise ValueError("presence_penalty must be in [-2, 2], got "
f"{self.presence_penalty}.")
@ -89,8 +103,9 @@ class SamplingParams:
f"logprobs must be non-negative, got {self.logprobs}.")
def _verity_beam_search(self) -> None:
if self.n == 1:
raise ValueError("n must be greater than 1 when using beam search.")
if self.best_of == 1:
raise ValueError("best_of must be greater than 1 when using beam "
f"search. Got {self.best_of}.")
if self.temperature > 0.0:
raise ValueError("temperature must be 0 when using beam search.")
if self.top_p < 1.0:
@ -99,8 +114,9 @@ class SamplingParams:
raise ValueError("top_k must be -1 when using beam search.")
def _verify_greedy_sampling(self) -> None:
if self.n > 1:
raise ValueError("n must be 1 when using greedy sampling.")
if self.best_of > 1:
raise ValueError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.")
if self.top_p < 1.0:
raise ValueError("top_p must be 1 when using greedy sampling.")
if self.top_k != -1:
@ -108,12 +124,14 @@ class SamplingParams:
def __repr__(self) -> str:
return (f"SamplingParams(n={self.n}, "
f"best_of={self.best_of}, "
f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, "
f"temperature={self.temperature}, "
f"top_p={self.top_p}, "
f"top_k={self.top_k},"
f"use_beam_search={self.use_beam_search}, "
f"stop_token_ids={self.stop_token_ids}, "
f"stop={self.stop}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs})")

View File

@ -22,11 +22,18 @@ class SequenceData:
self.prompt_token_ids = prompt_token_ids
self.output_token_ids: List[int] = []
self.cumulative_logprobs = 0.0
self.cumulative_logprob = 0.0
def append_token(self, token_id: int, logprob: float) -> None:
self.output_token_ids.append(token_id)
self.cumulative_logprob += logprob
def get_len(self) -> int:
return len(self.output_token_ids) + len(self.prompt_token_ids)
def get_output_len(self) -> int:
return len(self.output_token_ids)
def get_token_ids(self) -> List[int]:
return self.prompt_token_ids + self.output_token_ids
@ -37,9 +44,9 @@ class SequenceData:
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt={self.prompt}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"output_token_ids={self.output_token_ids})")
f"output_token_ids={self.output_token_ids}, "
f"cumulative_logprob={self.cumulative_logprob})")
class Sequence:
@ -57,6 +64,7 @@ class Sequence:
self.data = SequenceData(prompt_token_ids)
self.output_logprobs: List[Dict[int, float]] = []
self.output_text = ""
self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the prompt token ids.
@ -88,18 +96,26 @@ class Sequence:
assert token_id in logprobs
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs)
self.data.output_token_ids.append(token_id)
self.data.cumulative_logprobs += logprobs[token_id]
self.data.append_token(token_id, logprobs[token_id])
def get_len(self) -> int:
return self.data.get_len()
def get_output_len(self) -> int:
return self.data.get_output_len()
def get_token_ids(self) -> List[int]:
return self.data.get_token_ids()
def get_last_token_id(self) -> int:
return self.data.get_last_token_id()
def get_output_token_ids(self) -> List[int]:
return self.data.output_token_ids
def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob
def fork(self, child_seq: 'Sequence') -> 'Sequence':
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)

View File

@ -13,7 +13,7 @@ from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
from cacheflow.utils import Counter
from cacheflow.worker.worker import Worker
@ -49,7 +49,6 @@ class LLMServer:
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats
self._verify_args()
self.tokenizer = get_tokenizer(model_config.model)
@ -124,15 +123,11 @@ class LLMServer:
# Create the sequences.
block_size = self.cache_config.block_size
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
for _ in range(sampling_params.best_of):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seqs.append(seq)
# FIXME(woosuk)
# Add the EOS token to the stop token list.
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
# Create the sequence group.
seq_group = SequenceGroup(request_id, seqs, sampling_params,
arrival_time)
@ -157,18 +152,65 @@ class LLMServer:
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)
# Update the scheduler.
updated_seq_groups = self.scheduler.update(output)
# Update the scheduler with the model outputs.
seq_groups = self.scheduler.update(output)
# Decode the sequences.
self._decode_sequences(seq_groups)
# Stop the sequences that meet the stopping criteria.
self._stop_sequences(seq_groups)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in updated_seq_groups:
# TODO(woosuk): Batch-decode the outputs for speedup.
request_output = RequestOutput.from_seq_group(seq_group,
self.tokenizer)
for seq_group in seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
return request_outputs
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Batch-decode the sequence outputs.
seqs: List[Sequence] = []
for seq_group in seq_groups:
seqs.extend(seq_group.get_seqs(status=SequenceStatus.RUNNING))
output_tokens_per_seq = []
for seq in seqs:
output_tokens_per_seq.append(seq.get_output_token_ids())
output_texts = self.tokenizer.batch_decode(output_tokens_per_seq,
skip_special_tokens=True)
# Update the sequences with the output texts.
for seq, output_text in zip(seqs, output_texts):
seq.output_text = output_text
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Stop the sequences.
for seq_group in seq_groups:
sampling_params = seq_group.sampling_params
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Check if the sequence has generated a stop string.
stopped = False
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
self.scheduler.free_seq(seq)
stopped = True
break
if stopped:
continue
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
self.scheduler.free_seq(seq)
continue
# Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos:
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
self.scheduler.free_seq(seq)
continue
def _run_workers(
self,
method: str,

View File

@ -15,9 +15,9 @@ def main(args: argparse.Namespace):
("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?",
SamplingParams(n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1)),
SamplingParams(n=2, best_of=5, temperature=0.8, top_p=0.95, frequency_penalty=0.1)),
("It is only with the heart that one can see rightly",
SamplingParams(n=3, use_beam_search=True, temperature=0.0)),
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
]
# Run the server.