Implement presence and frequency penalties (#95)

This commit is contained in:
Woosuk Kwon 2023-05-10 23:39:12 -07:00 committed by GitHub
parent 9f88db35da
commit 55f8b0a5de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 215 additions and 82 deletions

View File

@ -3,11 +3,12 @@ import time
from typing import Dict, List, Optional, Tuple
from cacheflow.core.block_manager import BlockSpaceManager
from cacheflow.logger import init_logger
from cacheflow.core.policy import PolicyFactory
from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceOutputs, SequenceStatus)
from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs,
SequenceStatus)
logger = init_logger(__name__)
@ -246,27 +247,17 @@ class Scheduler:
group_id = seq_group.group_id
is_prompt = group_id in prompt_group_ids
input_tokens: Dict[int, List[int]] = {}
seq_logprobs: Dict[int, float] = {}
seq_data: Dict[int, List[SequenceData]] = {}
block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
if is_prompt:
input_tokens[seq_id] = seq.get_token_ids()
else:
input_tokens[seq_id] = [seq.get_last_token_id()]
seq_logprobs[seq_id] = seq.cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len = seq.get_len()
seq_group_metadata = SequenceGroupMetadata(
group_id=group_id,
is_prompt=is_prompt,
input_tokens=input_tokens,
context_len=seq_len,
seq_logprobs=seq_logprobs,
seq_data=seq_data,
sampling_params=self.sampling_params[group_id],
block_tables=block_tables,
)

View File

@ -96,7 +96,7 @@ class FastAPIServer:
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
seq = Sequence(seq_id, prompt, token_ids, block_size=self.block_size)
seqs.append(seq)
arrival_time = time.time()

View File

@ -35,10 +35,11 @@ class SimpleFrontend:
sampling_params: SamplingParams,
) -> None:
token_ids = self.tokenizer.encode(prompt)
self._add_query(token_ids, sampling_params)
self._add_query(prompt, token_ids, sampling_params)
def _add_query(
self,
prompt: str,
token_ids: List[int],
sampling_params: SamplingParams,
arrival_time: Optional[float] = None,
@ -48,7 +49,7 @@ class SimpleFrontend:
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
seq = Sequence(seq_id, prompt, token_ids, block_size=self.block_size)
seqs.append(seq)
group_id = next(self.seq_group_counter)

View File

@ -1,17 +1,18 @@
from typing import List, Dict, Tuple
from typing import Dict, List, Tuple
import torch
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceData
class InputMetadata:
def __init__(
self,
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
seq_groups: List[Tuple[List[int], SamplingParams]], # List of (seq_ids, sampling_params).
seq_data: Dict[int, SequenceData], # Seq_id -> SequenceData.
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
@ -19,7 +20,7 @@ class InputMetadata:
block_tables: torch.Tensor,
) -> None:
self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.slot_mapping = slot_mapping
self.context_lens = context_lens
@ -39,6 +40,7 @@ class InputMetadata:
assert context_lens.shape[0] == self.num_generation_tokens
def __repr__(self) -> str:
# Print only useful metadata.
return (f'InputMetadata('
f'num_valid_tokens={self.num_valid_tokens}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '

View File

@ -1,5 +1,6 @@
from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn as nn
@ -31,6 +32,16 @@ class Sampler(nn.Module):
# Remove paddings in vocab (if any).
logits = logits[:, :self.vocab_size]
# Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0]
presence_penalties, frequency_penalties = _get_penalties(input_metadata)
assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0]
logits = _apply_penalties(
logits, output_tokens, presence_penalties, frequency_penalties,
self.vocab_size)
# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)
assert len(temperatures) == logits.shape[0]
@ -43,16 +54,14 @@ class Sampler(nn.Module):
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p).
# Compute the log probabilities (before applying top-p and top-k).
logprobs = torch.log(probs)
# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == probs.shape[0]
if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks):
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
probs = _apply_top_p_top_k(probs, p, k)
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
# Sample the next tokens.
return _sample(probs, logprobs, input_metadata)
@ -72,6 +81,93 @@ def _prune_hidden_states(
return hidden_states[last_token_indicies]
def _get_penalties(
input_metadata: InputMetadata,
) -> Tuple[List[float], List[float]]:
# Collect the presence and frequency penalties.
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty
if i < input_metadata.num_prompts:
# A prompt input.
presence_penalties.append(p)
frequency_penalties.append(f)
else:
# A generation token.
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
return presence_penalties, frequency_penalties
def _get_output_tokens(
input_metadata: InputMetadata,
) -> List[List[int]]:
output_tokens: List[List[int]] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, _ = seq_group
if i < input_metadata.num_prompts:
# A prompt input.
# NOTE: While the prompt input usually has no output tokens,
# it may have output tokens in the case of recomputation.
seq_id = seq_ids[0]
seq_data = input_metadata.seq_data[seq_id]
output_tokens.append(seq_data.output_token_ids)
else:
# A generation token.
for seq_id in seq_ids:
seq_data = input_metadata.seq_data[seq_id]
output_tokens.append(seq_data.output_token_ids)
return output_tokens
def _apply_penalties(
logits: torch.Tensor,
output_tokens: List[List[int]],
presence_penalties: List[float],
frequency_penalties: List[float],
vocab_size: int,
) -> torch.Tensor:
num_seqs = logits.shape[0]
# Collect the indices of sequences that have non-zero penalties.
indices = []
for i in range(num_seqs):
if not output_tokens[i]:
continue
p = presence_penalties[i]
f = frequency_penalties[i]
if p == 0.0 and f == 0.0:
continue
indices.append(i)
# Return early if all sequences have zero penalties.
if not indices:
return logits
bin_counts = []
for i in indices:
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
bin_counts = np.stack(bin_counts, axis=0)
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
device=logits.device)
frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor(
frequency_penalties, dtype=logits.dtype, device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor(
presence_penalties, dtype=logits.dtype, device=logits.device)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
return logits
def _get_temperatures(
input_metadata: InputMetadata,
) -> List[float]:
@ -121,10 +217,11 @@ def _get_top_p_top_k(
def _apply_top_p_top_k(
probs: torch.Tensor,
p: torch.Tensor,
k: torch.Tensor,
top_ps: List[float],
top_ks: List[int],
) -> torch.Tensor:
# TODO(woosuk): Optimize.
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
# Apply top-p.
@ -286,7 +383,8 @@ def _sample(
# Sample the next tokens.
seq_logprobs = [
input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids]
input_metadata.seq_data[seq_id].cumulative_logprobs
for seq_id in seq_ids]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params)

View File

@ -6,6 +6,8 @@ class SamplingParams:
def __init__(
self,
n: int,
presence_penalty: float,
frequency_penalty: float,
temperature: float,
top_p: float,
top_k: int,
@ -16,6 +18,12 @@ class SamplingParams:
) -> None:
if n < 1:
raise ValueError(f"n must be at least 1, got {n}.")
if not -2.0 <= presence_penalty <= 2.0:
raise ValueError(
f"presence_penalty must be in [-2, 2], got {presence_penalty}.")
if not -2.0 <= frequency_penalty <= 2.0:
raise ValueError(
f"frequency_penalty must be in [-2, 2], got {frequency_penalty}.")
if temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {temperature}.")
@ -57,6 +65,8 @@ class SamplingParams:
"top_k must be -1 when using greedy sampling.")
self.n = n
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
@ -67,6 +77,8 @@ class SamplingParams:
def __repr__(self) -> str:
return (f"SamplingParams(n={self.n}, "
f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, "
f"temperature={self.temperature}, "
f"top_p={self.top_p}, "
f"top_k={self.top_k},"
@ -77,13 +89,18 @@ class SamplingParams:
@classmethod
def from_dict(cls, d: Dict) -> "SamplingParams":
return cls(
n=d.get("n", 1),
temperature=d.get("temperature", 1.0),
top_p=d.get("top_p", 1.0),
top_k=d.get("top_k", -1),
use_beam_search=d.get("use_beam_search", False),
stop_token_ids=set(d.get("stop_token_ids", set())),
max_num_steps=d.get("max_num_steps", 16),
num_logprobs=d.get("num_logprobs", 0),
sampling_params = cls(
n=d.pop("n", 1),
presence_penalty=d.pop("presence_penalty", 0.0),
frequency_penalty=d.pop("frequency_penalty", 0.0),
temperature=d.pop("temperature", 1.0),
top_p=d.pop("top_p", 1.0),
top_k=d.pop("top_k", -1),
use_beam_search=d.pop("use_beam_search", False),
stop_token_ids=set(d.pop("stop_token_ids", set())),
max_num_steps=d.pop("max_num_steps", 16),
num_logprobs=d.pop("num_logprobs", 0),
)
if d:
raise ValueError(f"Unrecognized keys in dict: {d.keys()}")
return sampling_params

View File

@ -13,26 +13,55 @@ class SequenceStatus(enum.Enum):
FINISHED = enum.auto()
class SequenceData:
def __init__(
self,
prompt_token_ids: List[int],
) -> None:
self.prompt_token_ids = prompt_token_ids
self.output_token_ids: List[int] = []
self.cumulative_logprobs = 0.0
def get_len(self) -> int:
return len(self.output_token_ids) + len(self.prompt_token_ids)
def get_token_ids(self) -> List[int]:
return self.prompt_token_ids + self.output_token_ids
def get_last_token_id(self) -> int:
if not self.output_token_ids:
return self.prompt_token_ids[-1]
return self.output_token_ids[-1]
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt={self.prompt}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"output_token_ids={self.output_token_ids})")
class Sequence:
def __init__(
self,
seq_id: int,
prompt: str,
prompt_token_ids: List[int],
block_size: int,
) -> None:
self.seq_id = seq_id
self.prompt = prompt
self.block_size = block_size
self.prompt_len = len(prompt_token_ids)
self.data = SequenceData(prompt_token_ids)
self.output_logprobs: List[Dict[int, float]] = []
self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the prompt token ids.
self._append_tokens(prompt_token_ids)
self._append_tokens_to_blocks(prompt_token_ids)
self.status = SequenceStatus.WAITING
# Used for beam search.
self.output_logprobs: List[Dict[int, float]] = []
self.cumulative_logprobs = 0.0
def _append_logical_block(self) -> None:
block = LogicalTokenBlock(
@ -41,7 +70,7 @@ class Sequence:
)
self.logical_token_blocks.append(block)
def _append_tokens(self, token_ids: List[int]) -> None:
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
while token_ids:
if not self.logical_token_blocks:
self._append_logical_block()
@ -57,26 +86,24 @@ class Sequence:
def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
assert token_id in logprobs
self._append_tokens([token_id])
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs)
self.cumulative_logprobs += logprobs[token_id]
self.data.output_token_ids.append(token_id)
self.data.cumulative_logprobs += logprobs[token_id]
def get_len(self) -> int:
return sum(block.num_tokens for block in self.logical_token_blocks)
return self.data.get_len()
def get_token_ids(self) -> List[int]:
token_ids: List[int] = []
for block in self.logical_token_blocks:
token_ids.extend(block.get_token_ids())
return token_ids
return self.data.get_token_ids()
def get_last_token_id(self) -> int:
return self.logical_token_blocks[-1].get_last_token_id()
return self.data.get_last_token_id()
def fork(self, child_seq: 'Sequence') -> 'Sequence':
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
child_seq.cumulative_logprobs = self.cumulative_logprobs
child_seq.data = copy.deepcopy(self.data)
def __repr__(self) -> str:
return (f'Sequence(seq_id={self.seq_id}, '
@ -128,17 +155,13 @@ class SequenceGroupMetadata:
self,
group_id: int,
is_prompt: bool,
input_tokens: Dict[int, List[int]], # Seq id -> token ids.
context_len: int,
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
seq_data: Dict[int, SequenceData], # Seq id -> sequence data.
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers.
) -> None:
self.group_id = group_id
self.is_prompt = is_prompt
self.input_tokens = input_tokens
self.context_len = context_len
self.seq_logprobs = seq_logprobs
self.seq_data = seq_data
self.sampling_params = sampling_params
self.block_tables = block_tables

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Tuple, Optional
from typing import Dict, List, Optional, Tuple
import torch
@ -8,8 +8,8 @@ from cacheflow.model_executor.parallel_utils.parallel_state import (
initialize_all_reduce_launcher,
get_tensor_model_parallel_world_size)
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroupMetadata
from cacheflow.sequence import SequenceOutputs
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
SequenceOutputs)
from cacheflow.worker.cache_engine import CacheEngine
@ -72,7 +72,6 @@ class Worker:
self.cache_events = self.cache_engine.events
self.gpu_cache = self.cache_engine.gpu_cache
def init_distributed_environment(self,
distributed_init_method: str,
rank: int,
@ -96,7 +95,6 @@ class Worker:
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
seq_logprobs: Dict[int, float] = {}
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
@ -107,15 +105,15 @@ class Worker:
if not seq_group_metadata.is_prompt:
continue
seq_ids = list(seq_group_metadata.input_tokens.keys())
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
seq_logprobs.update(seq_group_metadata.seq_logprobs)
# Use any sequence in the group.
seq_id = seq_ids[0]
prompt_tokens = seq_group_metadata.input_tokens[seq_id]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
@ -141,27 +139,26 @@ class Worker:
if seq_group_metadata.is_prompt:
continue
seq_ids = list(seq_group_metadata.input_tokens.keys())
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
seq_logprobs.update(seq_group_metadata.seq_logprobs)
for seq_id in seq_ids:
assert len(seq_group_metadata.input_tokens[seq_id]) == 1
generation_token = seq_group_metadata.input_tokens[seq_id][0]
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append(generation_token)
position = seq_group_metadata.context_len - 1
context_len = seq_data.get_len()
position = context_len - 1
input_positions.append(position)
block_table = seq_group_metadata.block_tables[seq_id]
generation_block_tables.append(block_table)
max_context_len = max(
max_context_len, seq_group_metadata.context_len)
max_context_len = max(max_context_len, context_len)
max_num_blocks_per_seq = max(
max_num_blocks_per_seq, len(block_table))
context_lens.append(seq_group_metadata.context_len)
context_lens.append(context_len)
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
@ -188,9 +185,13 @@ class Worker:
block_tables_tensor = torch.tensor(
padded_block_tables, dtype=torch.int, device='cuda')
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
input_metadata = InputMetadata(
seq_groups=seq_groups,
seq_logprobs=seq_logprobs,
seq_data=seq_data,
prompt_lens=prompt_lens,
slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor,

View File

@ -11,8 +11,8 @@ def main(args: argparse.Namespace):
# Test the following inputs.
test_inputs = [
("A robot may not injure a human being", {}), # Use default parameters.
("To be or not to be,", {"temperature": 0.8, "top_k": 5}),
("What is the meaning of life?", {"n": 2, "temperature": 0.8, "top_p": 0.95}),
("To be or not to be,", {"temperature": 0.8, "top_k": 5, "presence_penalty": 0.2}),
("What is the meaning of life?", {"n": 2, "temperature": 0.8, "top_p": 0.95, "frequency_penalty": 0.1}),
("It is only with the heart that one can see rightly", {"n": 3, "use_beam_search": True, "temperature": 0.0}),
]
while True: