mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 04:05:16 +08:00
Implement presence and frequency penalties (#95)
This commit is contained in:
parent
9f88db35da
commit
55f8b0a5de
@ -3,11 +3,12 @@ import time
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from cacheflow.core.block_manager import BlockSpaceManager
|
from cacheflow.core.block_manager import BlockSpaceManager
|
||||||
from cacheflow.logger import init_logger
|
|
||||||
from cacheflow.core.policy import PolicyFactory
|
from cacheflow.core.policy import PolicyFactory
|
||||||
|
from cacheflow.logger import init_logger
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata,
|
from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||||
SequenceOutputs, SequenceStatus)
|
SequenceGroupMetadata, SequenceOutputs,
|
||||||
|
SequenceStatus)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -246,27 +247,17 @@ class Scheduler:
|
|||||||
group_id = seq_group.group_id
|
group_id = seq_group.group_id
|
||||||
is_prompt = group_id in prompt_group_ids
|
is_prompt = group_id in prompt_group_ids
|
||||||
|
|
||||||
input_tokens: Dict[int, List[int]] = {}
|
seq_data: Dict[int, List[SequenceData]] = {}
|
||||||
seq_logprobs: Dict[int, float] = {}
|
|
||||||
block_tables: Dict[int, List[int]] = {}
|
block_tables: Dict[int, List[int]] = {}
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
seq_id = seq.seq_id
|
seq_id = seq.seq_id
|
||||||
|
seq_data[seq_id] = seq.data
|
||||||
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
||||||
if is_prompt:
|
|
||||||
input_tokens[seq_id] = seq.get_token_ids()
|
|
||||||
else:
|
|
||||||
input_tokens[seq_id] = [seq.get_last_token_id()]
|
|
||||||
seq_logprobs[seq_id] = seq.cumulative_logprobs
|
|
||||||
# NOTE(woosuk): Sequences in the same group have the same
|
|
||||||
# sequence length
|
|
||||||
seq_len = seq.get_len()
|
|
||||||
|
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
is_prompt=is_prompt,
|
is_prompt=is_prompt,
|
||||||
input_tokens=input_tokens,
|
seq_data=seq_data,
|
||||||
context_len=seq_len,
|
|
||||||
seq_logprobs=seq_logprobs,
|
|
||||||
sampling_params=self.sampling_params[group_id],
|
sampling_params=self.sampling_params[group_id],
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -96,7 +96,7 @@ class FastAPIServer:
|
|||||||
seqs: List[Sequence] = []
|
seqs: List[Sequence] = []
|
||||||
for _ in range(sampling_params.n):
|
for _ in range(sampling_params.n):
|
||||||
seq_id = next(self.seq_counter)
|
seq_id = next(self.seq_counter)
|
||||||
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
|
seq = Sequence(seq_id, prompt, token_ids, block_size=self.block_size)
|
||||||
seqs.append(seq)
|
seqs.append(seq)
|
||||||
|
|
||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|||||||
@ -35,10 +35,11 @@ class SimpleFrontend:
|
|||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
token_ids = self.tokenizer.encode(prompt)
|
token_ids = self.tokenizer.encode(prompt)
|
||||||
self._add_query(token_ids, sampling_params)
|
self._add_query(prompt, token_ids, sampling_params)
|
||||||
|
|
||||||
def _add_query(
|
def _add_query(
|
||||||
self,
|
self,
|
||||||
|
prompt: str,
|
||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
@ -48,7 +49,7 @@ class SimpleFrontend:
|
|||||||
seqs: List[Sequence] = []
|
seqs: List[Sequence] = []
|
||||||
for _ in range(sampling_params.n):
|
for _ in range(sampling_params.n):
|
||||||
seq_id = next(self.seq_counter)
|
seq_id = next(self.seq_counter)
|
||||||
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
|
seq = Sequence(seq_id, prompt, token_ids, block_size=self.block_size)
|
||||||
seqs.append(seq)
|
seqs.append(seq)
|
||||||
|
|
||||||
group_id = next(self.seq_group_counter)
|
group_id = next(self.seq_group_counter)
|
||||||
|
|||||||
@ -1,17 +1,18 @@
|
|||||||
from typing import List, Dict, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||||
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
|
from cacheflow.sequence import SequenceData
|
||||||
|
|
||||||
|
|
||||||
class InputMetadata:
|
class InputMetadata:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
seq_groups: List[Tuple[List[int], SamplingParams]],
|
seq_groups: List[Tuple[List[int], SamplingParams]], # List of (seq_ids, sampling_params).
|
||||||
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
|
seq_data: Dict[int, SequenceData], # Seq_id -> SequenceData.
|
||||||
prompt_lens: List[int],
|
prompt_lens: List[int],
|
||||||
slot_mapping: torch.Tensor,
|
slot_mapping: torch.Tensor,
|
||||||
context_lens: torch.Tensor,
|
context_lens: torch.Tensor,
|
||||||
@ -19,7 +20,7 @@ class InputMetadata:
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.seq_groups = seq_groups
|
self.seq_groups = seq_groups
|
||||||
self.seq_logprobs = seq_logprobs
|
self.seq_data = seq_data
|
||||||
self.prompt_lens = prompt_lens
|
self.prompt_lens = prompt_lens
|
||||||
self.slot_mapping = slot_mapping
|
self.slot_mapping = slot_mapping
|
||||||
self.context_lens = context_lens
|
self.context_lens = context_lens
|
||||||
@ -39,6 +40,7 @@ class InputMetadata:
|
|||||||
assert context_lens.shape[0] == self.num_generation_tokens
|
assert context_lens.shape[0] == self.num_generation_tokens
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
# Print only useful metadata.
|
||||||
return (f'InputMetadata('
|
return (f'InputMetadata('
|
||||||
f'num_valid_tokens={self.num_valid_tokens}, '
|
f'num_valid_tokens={self.num_valid_tokens}, '
|
||||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -31,6 +32,16 @@ class Sampler(nn.Module):
|
|||||||
# Remove paddings in vocab (if any).
|
# Remove paddings in vocab (if any).
|
||||||
logits = logits[:, :self.vocab_size]
|
logits = logits[:, :self.vocab_size]
|
||||||
|
|
||||||
|
# Apply presence and frequency penalties.
|
||||||
|
output_tokens = _get_output_tokens(input_metadata)
|
||||||
|
assert len(output_tokens) == logits.shape[0]
|
||||||
|
presence_penalties, frequency_penalties = _get_penalties(input_metadata)
|
||||||
|
assert len(presence_penalties) == logits.shape[0]
|
||||||
|
assert len(frequency_penalties) == logits.shape[0]
|
||||||
|
logits = _apply_penalties(
|
||||||
|
logits, output_tokens, presence_penalties, frequency_penalties,
|
||||||
|
self.vocab_size)
|
||||||
|
|
||||||
# Apply temperature scaling.
|
# Apply temperature scaling.
|
||||||
temperatures = _get_temperatures(input_metadata)
|
temperatures = _get_temperatures(input_metadata)
|
||||||
assert len(temperatures) == logits.shape[0]
|
assert len(temperatures) == logits.shape[0]
|
||||||
@ -43,16 +54,14 @@ class Sampler(nn.Module):
|
|||||||
# We use float32 for probabilities and log probabilities.
|
# We use float32 for probabilities and log probabilities.
|
||||||
# Compute the probabilities.
|
# Compute the probabilities.
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||||
# Compute the log probabilities (before applying top-p).
|
# Compute the log probabilities (before applying top-p and top-k).
|
||||||
logprobs = torch.log(probs)
|
logprobs = torch.log(probs)
|
||||||
|
|
||||||
# Apply top-p and top-k truncation.
|
# Apply top-p and top-k truncation.
|
||||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||||
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
||||||
if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks):
|
if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks):
|
||||||
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
|
||||||
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
|
|
||||||
probs = _apply_top_p_top_k(probs, p, k)
|
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Sample the next tokens.
|
||||||
return _sample(probs, logprobs, input_metadata)
|
return _sample(probs, logprobs, input_metadata)
|
||||||
@ -72,6 +81,93 @@ def _prune_hidden_states(
|
|||||||
return hidden_states[last_token_indicies]
|
return hidden_states[last_token_indicies]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_penalties(
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> Tuple[List[float], List[float]]:
|
||||||
|
# Collect the presence and frequency penalties.
|
||||||
|
presence_penalties: List[float] = []
|
||||||
|
frequency_penalties: List[float] = []
|
||||||
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||||
|
seq_ids, sampling_params = seq_group
|
||||||
|
p = sampling_params.presence_penalty
|
||||||
|
f = sampling_params.frequency_penalty
|
||||||
|
if i < input_metadata.num_prompts:
|
||||||
|
# A prompt input.
|
||||||
|
presence_penalties.append(p)
|
||||||
|
frequency_penalties.append(f)
|
||||||
|
else:
|
||||||
|
# A generation token.
|
||||||
|
presence_penalties += [p] * len(seq_ids)
|
||||||
|
frequency_penalties += [f] * len(seq_ids)
|
||||||
|
return presence_penalties, frequency_penalties
|
||||||
|
|
||||||
|
|
||||||
|
def _get_output_tokens(
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
output_tokens: List[List[int]] = []
|
||||||
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||||
|
seq_ids, _ = seq_group
|
||||||
|
if i < input_metadata.num_prompts:
|
||||||
|
# A prompt input.
|
||||||
|
# NOTE: While the prompt input usually has no output tokens,
|
||||||
|
# it may have output tokens in the case of recomputation.
|
||||||
|
seq_id = seq_ids[0]
|
||||||
|
seq_data = input_metadata.seq_data[seq_id]
|
||||||
|
output_tokens.append(seq_data.output_token_ids)
|
||||||
|
else:
|
||||||
|
# A generation token.
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
seq_data = input_metadata.seq_data[seq_id]
|
||||||
|
output_tokens.append(seq_data.output_token_ids)
|
||||||
|
return output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_penalties(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
output_tokens: List[List[int]],
|
||||||
|
presence_penalties: List[float],
|
||||||
|
frequency_penalties: List[float],
|
||||||
|
vocab_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
num_seqs = logits.shape[0]
|
||||||
|
# Collect the indices of sequences that have non-zero penalties.
|
||||||
|
indices = []
|
||||||
|
for i in range(num_seqs):
|
||||||
|
if not output_tokens[i]:
|
||||||
|
continue
|
||||||
|
p = presence_penalties[i]
|
||||||
|
f = frequency_penalties[i]
|
||||||
|
if p == 0.0 and f == 0.0:
|
||||||
|
continue
|
||||||
|
indices.append(i)
|
||||||
|
|
||||||
|
# Return early if all sequences have zero penalties.
|
||||||
|
if not indices:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
bin_counts = []
|
||||||
|
for i in indices:
|
||||||
|
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
|
||||||
|
bin_counts = np.stack(bin_counts, axis=0)
|
||||||
|
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
|
||||||
|
device=logits.device)
|
||||||
|
|
||||||
|
frequency_penalties = [frequency_penalties[i] for i in indices]
|
||||||
|
frequency_penalties = torch.tensor(
|
||||||
|
frequency_penalties, dtype=logits.dtype, device=logits.device)
|
||||||
|
presence_penalties = [presence_penalties[i] for i in indices]
|
||||||
|
presence_penalties = torch.tensor(
|
||||||
|
presence_penalties, dtype=logits.dtype, device=logits.device)
|
||||||
|
|
||||||
|
# We follow the definition in OpenAI API.
|
||||||
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||||
|
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
||||||
|
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
|
||||||
|
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _get_temperatures(
|
def _get_temperatures(
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> List[float]:
|
) -> List[float]:
|
||||||
@ -121,10 +217,11 @@ def _get_top_p_top_k(
|
|||||||
|
|
||||||
def _apply_top_p_top_k(
|
def _apply_top_p_top_k(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
p: torch.Tensor,
|
top_ps: List[float],
|
||||||
k: torch.Tensor,
|
top_ks: List[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# TODO(woosuk): Optimize.
|
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
||||||
|
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
|
||||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||||
|
|
||||||
# Apply top-p.
|
# Apply top-p.
|
||||||
@ -286,7 +383,8 @@ def _sample(
|
|||||||
|
|
||||||
# Sample the next tokens.
|
# Sample the next tokens.
|
||||||
seq_logprobs = [
|
seq_logprobs = [
|
||||||
input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids]
|
input_metadata.seq_data[seq_id].cumulative_logprobs
|
||||||
|
for seq_id in seq_ids]
|
||||||
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
||||||
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,8 @@ class SamplingParams:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n: int,
|
n: int,
|
||||||
|
presence_penalty: float,
|
||||||
|
frequency_penalty: float,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
@ -16,6 +18,12 @@ class SamplingParams:
|
|||||||
) -> None:
|
) -> None:
|
||||||
if n < 1:
|
if n < 1:
|
||||||
raise ValueError(f"n must be at least 1, got {n}.")
|
raise ValueError(f"n must be at least 1, got {n}.")
|
||||||
|
if not -2.0 <= presence_penalty <= 2.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"presence_penalty must be in [-2, 2], got {presence_penalty}.")
|
||||||
|
if not -2.0 <= frequency_penalty <= 2.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"frequency_penalty must be in [-2, 2], got {frequency_penalty}.")
|
||||||
if temperature < 0.0:
|
if temperature < 0.0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"temperature must be non-negative, got {temperature}.")
|
f"temperature must be non-negative, got {temperature}.")
|
||||||
@ -57,6 +65,8 @@ class SamplingParams:
|
|||||||
"top_k must be -1 when using greedy sampling.")
|
"top_k must be -1 when using greedy sampling.")
|
||||||
|
|
||||||
self.n = n
|
self.n = n
|
||||||
|
self.presence_penalty = presence_penalty
|
||||||
|
self.frequency_penalty = frequency_penalty
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
@ -67,6 +77,8 @@ class SamplingParams:
|
|||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"SamplingParams(n={self.n}, "
|
return (f"SamplingParams(n={self.n}, "
|
||||||
|
f"presence_penalty={self.presence_penalty}, "
|
||||||
|
f"frequency_penalty={self.frequency_penalty}, "
|
||||||
f"temperature={self.temperature}, "
|
f"temperature={self.temperature}, "
|
||||||
f"top_p={self.top_p}, "
|
f"top_p={self.top_p}, "
|
||||||
f"top_k={self.top_k},"
|
f"top_k={self.top_k},"
|
||||||
@ -77,13 +89,18 @@ class SamplingParams:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, d: Dict) -> "SamplingParams":
|
def from_dict(cls, d: Dict) -> "SamplingParams":
|
||||||
return cls(
|
sampling_params = cls(
|
||||||
n=d.get("n", 1),
|
n=d.pop("n", 1),
|
||||||
temperature=d.get("temperature", 1.0),
|
presence_penalty=d.pop("presence_penalty", 0.0),
|
||||||
top_p=d.get("top_p", 1.0),
|
frequency_penalty=d.pop("frequency_penalty", 0.0),
|
||||||
top_k=d.get("top_k", -1),
|
temperature=d.pop("temperature", 1.0),
|
||||||
use_beam_search=d.get("use_beam_search", False),
|
top_p=d.pop("top_p", 1.0),
|
||||||
stop_token_ids=set(d.get("stop_token_ids", set())),
|
top_k=d.pop("top_k", -1),
|
||||||
max_num_steps=d.get("max_num_steps", 16),
|
use_beam_search=d.pop("use_beam_search", False),
|
||||||
num_logprobs=d.get("num_logprobs", 0),
|
stop_token_ids=set(d.pop("stop_token_ids", set())),
|
||||||
|
max_num_steps=d.pop("max_num_steps", 16),
|
||||||
|
num_logprobs=d.pop("num_logprobs", 0),
|
||||||
)
|
)
|
||||||
|
if d:
|
||||||
|
raise ValueError(f"Unrecognized keys in dict: {d.keys()}")
|
||||||
|
return sampling_params
|
||||||
|
|||||||
@ -13,26 +13,55 @@ class SequenceStatus(enum.Enum):
|
|||||||
FINISHED = enum.auto()
|
FINISHED = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceData:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt_token_ids: List[int],
|
||||||
|
) -> None:
|
||||||
|
self.prompt_token_ids = prompt_token_ids
|
||||||
|
|
||||||
|
self.output_token_ids: List[int] = []
|
||||||
|
self.cumulative_logprobs = 0.0
|
||||||
|
|
||||||
|
def get_len(self) -> int:
|
||||||
|
return len(self.output_token_ids) + len(self.prompt_token_ids)
|
||||||
|
|
||||||
|
def get_token_ids(self) -> List[int]:
|
||||||
|
return self.prompt_token_ids + self.output_token_ids
|
||||||
|
|
||||||
|
def get_last_token_id(self) -> int:
|
||||||
|
if not self.output_token_ids:
|
||||||
|
return self.prompt_token_ids[-1]
|
||||||
|
return self.output_token_ids[-1]
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"SequenceData("
|
||||||
|
f"prompt={self.prompt}, "
|
||||||
|
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||||
|
f"output_token_ids={self.output_token_ids})")
|
||||||
|
|
||||||
|
|
||||||
class Sequence:
|
class Sequence:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
seq_id: int,
|
seq_id: int,
|
||||||
|
prompt: str,
|
||||||
prompt_token_ids: List[int],
|
prompt_token_ids: List[int],
|
||||||
block_size: int,
|
block_size: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.seq_id = seq_id
|
self.seq_id = seq_id
|
||||||
|
self.prompt = prompt
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
|
|
||||||
self.prompt_len = len(prompt_token_ids)
|
self.data = SequenceData(prompt_token_ids)
|
||||||
|
self.output_logprobs: List[Dict[int, float]] = []
|
||||||
|
|
||||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||||
# Initialize the logical token blocks with the prompt token ids.
|
# Initialize the logical token blocks with the prompt token ids.
|
||||||
self._append_tokens(prompt_token_ids)
|
self._append_tokens_to_blocks(prompt_token_ids)
|
||||||
|
|
||||||
self.status = SequenceStatus.WAITING
|
self.status = SequenceStatus.WAITING
|
||||||
# Used for beam search.
|
|
||||||
self.output_logprobs: List[Dict[int, float]] = []
|
|
||||||
self.cumulative_logprobs = 0.0
|
|
||||||
|
|
||||||
def _append_logical_block(self) -> None:
|
def _append_logical_block(self) -> None:
|
||||||
block = LogicalTokenBlock(
|
block = LogicalTokenBlock(
|
||||||
@ -41,7 +70,7 @@ class Sequence:
|
|||||||
)
|
)
|
||||||
self.logical_token_blocks.append(block)
|
self.logical_token_blocks.append(block)
|
||||||
|
|
||||||
def _append_tokens(self, token_ids: List[int]) -> None:
|
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
|
||||||
while token_ids:
|
while token_ids:
|
||||||
if not self.logical_token_blocks:
|
if not self.logical_token_blocks:
|
||||||
self._append_logical_block()
|
self._append_logical_block()
|
||||||
@ -57,26 +86,24 @@ class Sequence:
|
|||||||
|
|
||||||
def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
|
def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
|
||||||
assert token_id in logprobs
|
assert token_id in logprobs
|
||||||
self._append_tokens([token_id])
|
self._append_tokens_to_blocks([token_id])
|
||||||
self.output_logprobs.append(logprobs)
|
self.output_logprobs.append(logprobs)
|
||||||
self.cumulative_logprobs += logprobs[token_id]
|
self.data.output_token_ids.append(token_id)
|
||||||
|
self.data.cumulative_logprobs += logprobs[token_id]
|
||||||
|
|
||||||
def get_len(self) -> int:
|
def get_len(self) -> int:
|
||||||
return sum(block.num_tokens for block in self.logical_token_blocks)
|
return self.data.get_len()
|
||||||
|
|
||||||
def get_token_ids(self) -> List[int]:
|
def get_token_ids(self) -> List[int]:
|
||||||
token_ids: List[int] = []
|
return self.data.get_token_ids()
|
||||||
for block in self.logical_token_blocks:
|
|
||||||
token_ids.extend(block.get_token_ids())
|
|
||||||
return token_ids
|
|
||||||
|
|
||||||
def get_last_token_id(self) -> int:
|
def get_last_token_id(self) -> int:
|
||||||
return self.logical_token_blocks[-1].get_last_token_id()
|
return self.data.get_last_token_id()
|
||||||
|
|
||||||
def fork(self, child_seq: 'Sequence') -> 'Sequence':
|
def fork(self, child_seq: 'Sequence') -> 'Sequence':
|
||||||
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
||||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||||
child_seq.cumulative_logprobs = self.cumulative_logprobs
|
child_seq.data = copy.deepcopy(self.data)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f'Sequence(seq_id={self.seq_id}, '
|
return (f'Sequence(seq_id={self.seq_id}, '
|
||||||
@ -128,17 +155,13 @@ class SequenceGroupMetadata:
|
|||||||
self,
|
self,
|
||||||
group_id: int,
|
group_id: int,
|
||||||
is_prompt: bool,
|
is_prompt: bool,
|
||||||
input_tokens: Dict[int, List[int]], # Seq id -> token ids.
|
seq_data: Dict[int, SequenceData], # Seq id -> sequence data.
|
||||||
context_len: int,
|
|
||||||
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
|
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
|
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers.
|
||||||
) -> None:
|
) -> None:
|
||||||
self.group_id = group_id
|
self.group_id = group_id
|
||||||
self.is_prompt = is_prompt
|
self.is_prompt = is_prompt
|
||||||
self.input_tokens = input_tokens
|
self.seq_data = seq_data
|
||||||
self.context_len = context_len
|
|
||||||
self.seq_logprobs = seq_logprobs
|
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
self.block_tables = block_tables
|
self.block_tables = block_tables
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Tuple, Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -8,8 +8,8 @@ from cacheflow.model_executor.parallel_utils.parallel_state import (
|
|||||||
initialize_all_reduce_launcher,
|
initialize_all_reduce_launcher,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.sequence import SequenceGroupMetadata
|
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
|
||||||
from cacheflow.sequence import SequenceOutputs
|
SequenceOutputs)
|
||||||
from cacheflow.worker.cache_engine import CacheEngine
|
from cacheflow.worker.cache_engine import CacheEngine
|
||||||
|
|
||||||
|
|
||||||
@ -72,7 +72,6 @@ class Worker:
|
|||||||
self.cache_events = self.cache_engine.events
|
self.cache_events = self.cache_engine.events
|
||||||
self.gpu_cache = self.cache_engine.gpu_cache
|
self.gpu_cache = self.cache_engine.gpu_cache
|
||||||
|
|
||||||
|
|
||||||
def init_distributed_environment(self,
|
def init_distributed_environment(self,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
rank: int,
|
rank: int,
|
||||||
@ -96,7 +95,6 @@ class Worker:
|
|||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
|
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
|
||||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||||
seq_logprobs: Dict[int, float] = {}
|
|
||||||
input_tokens: List[int] = []
|
input_tokens: List[int] = []
|
||||||
input_positions: List[int] = []
|
input_positions: List[int] = []
|
||||||
slot_mapping: List[int] = []
|
slot_mapping: List[int] = []
|
||||||
@ -107,15 +105,15 @@ class Worker:
|
|||||||
if not seq_group_metadata.is_prompt:
|
if not seq_group_metadata.is_prompt:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
seq_ids = list(seq_group_metadata.input_tokens.keys())
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
sampling_params = seq_group_metadata.sampling_params
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
seq_groups.append((seq_ids, sampling_params))
|
seq_groups.append((seq_ids, sampling_params))
|
||||||
seq_logprobs.update(seq_group_metadata.seq_logprobs)
|
|
||||||
|
|
||||||
# Use any sequence in the group.
|
# Use any sequence in the group.
|
||||||
seq_id = seq_ids[0]
|
seq_id = seq_ids[0]
|
||||||
|
|
||||||
prompt_tokens = seq_group_metadata.input_tokens[seq_id]
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
|
prompt_tokens = seq_data.get_token_ids()
|
||||||
prompt_len = len(prompt_tokens)
|
prompt_len = len(prompt_tokens)
|
||||||
prompt_lens.append(prompt_len)
|
prompt_lens.append(prompt_len)
|
||||||
|
|
||||||
@ -141,27 +139,26 @@ class Worker:
|
|||||||
if seq_group_metadata.is_prompt:
|
if seq_group_metadata.is_prompt:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
seq_ids = list(seq_group_metadata.input_tokens.keys())
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
sampling_params = seq_group_metadata.sampling_params
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
seq_groups.append((seq_ids, sampling_params))
|
seq_groups.append((seq_ids, sampling_params))
|
||||||
seq_logprobs.update(seq_group_metadata.seq_logprobs)
|
|
||||||
|
|
||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
assert len(seq_group_metadata.input_tokens[seq_id]) == 1
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
generation_token = seq_group_metadata.input_tokens[seq_id][0]
|
generation_token = seq_data.get_last_token_id()
|
||||||
input_tokens.append(generation_token)
|
input_tokens.append(generation_token)
|
||||||
|
|
||||||
position = seq_group_metadata.context_len - 1
|
context_len = seq_data.get_len()
|
||||||
|
position = context_len - 1
|
||||||
input_positions.append(position)
|
input_positions.append(position)
|
||||||
|
|
||||||
block_table = seq_group_metadata.block_tables[seq_id]
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
generation_block_tables.append(block_table)
|
generation_block_tables.append(block_table)
|
||||||
|
|
||||||
max_context_len = max(
|
max_context_len = max(max_context_len, context_len)
|
||||||
max_context_len, seq_group_metadata.context_len)
|
|
||||||
max_num_blocks_per_seq = max(
|
max_num_blocks_per_seq = max(
|
||||||
max_num_blocks_per_seq, len(block_table))
|
max_num_blocks_per_seq, len(block_table))
|
||||||
context_lens.append(seq_group_metadata.context_len)
|
context_lens.append(context_len)
|
||||||
|
|
||||||
block_number = block_table[position // self.block_size]
|
block_number = block_table[position // self.block_size]
|
||||||
block_offset = position % self.block_size
|
block_offset = position % self.block_size
|
||||||
@ -188,9 +185,13 @@ class Worker:
|
|||||||
block_tables_tensor = torch.tensor(
|
block_tables_tensor = torch.tensor(
|
||||||
padded_block_tables, dtype=torch.int, device='cuda')
|
padded_block_tables, dtype=torch.int, device='cuda')
|
||||||
|
|
||||||
|
seq_data: Dict[int, SequenceData] = {}
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
seq_data.update(seq_group_metadata.seq_data)
|
||||||
|
|
||||||
input_metadata = InputMetadata(
|
input_metadata = InputMetadata(
|
||||||
seq_groups=seq_groups,
|
seq_groups=seq_groups,
|
||||||
seq_logprobs=seq_logprobs,
|
seq_data=seq_data,
|
||||||
prompt_lens=prompt_lens,
|
prompt_lens=prompt_lens,
|
||||||
slot_mapping=slot_mapping_tensor,
|
slot_mapping=slot_mapping_tensor,
|
||||||
context_lens=context_lens_tensor,
|
context_lens=context_lens_tensor,
|
||||||
|
|||||||
@ -11,8 +11,8 @@ def main(args: argparse.Namespace):
|
|||||||
# Test the following inputs.
|
# Test the following inputs.
|
||||||
test_inputs = [
|
test_inputs = [
|
||||||
("A robot may not injure a human being", {}), # Use default parameters.
|
("A robot may not injure a human being", {}), # Use default parameters.
|
||||||
("To be or not to be,", {"temperature": 0.8, "top_k": 5}),
|
("To be or not to be,", {"temperature": 0.8, "top_k": 5, "presence_penalty": 0.2}),
|
||||||
("What is the meaning of life?", {"n": 2, "temperature": 0.8, "top_p": 0.95}),
|
("What is the meaning of life?", {"n": 2, "temperature": 0.8, "top_p": 0.95, "frequency_penalty": 0.1}),
|
||||||
("It is only with the heart that one can see rightly", {"n": 3, "use_beam_search": True, "temperature": 0.0}),
|
("It is only with the heart that one can see rightly", {"n": 3, "use_beam_search": True, "temperature": 0.0}),
|
||||||
]
|
]
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user