mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:35:01 +08:00
Implement presence and frequency penalties (#95)
This commit is contained in:
parent
9f88db35da
commit
55f8b0a5de
@ -3,11 +3,12 @@ import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from cacheflow.core.block_manager import BlockSpaceManager
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.core.policy import PolicyFactory
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||
SequenceOutputs, SequenceStatus)
|
||||
from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceOutputs,
|
||||
SequenceStatus)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -246,27 +247,17 @@ class Scheduler:
|
||||
group_id = seq_group.group_id
|
||||
is_prompt = group_id in prompt_group_ids
|
||||
|
||||
input_tokens: Dict[int, List[int]] = {}
|
||||
seq_logprobs: Dict[int, float] = {}
|
||||
seq_data: Dict[int, List[SequenceData]] = {}
|
||||
block_tables: Dict[int, List[int]] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
seq_id = seq.seq_id
|
||||
seq_data[seq_id] = seq.data
|
||||
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
||||
if is_prompt:
|
||||
input_tokens[seq_id] = seq.get_token_ids()
|
||||
else:
|
||||
input_tokens[seq_id] = [seq.get_last_token_id()]
|
||||
seq_logprobs[seq_id] = seq.cumulative_logprobs
|
||||
# NOTE(woosuk): Sequences in the same group have the same
|
||||
# sequence length
|
||||
seq_len = seq.get_len()
|
||||
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
group_id=group_id,
|
||||
is_prompt=is_prompt,
|
||||
input_tokens=input_tokens,
|
||||
context_len=seq_len,
|
||||
seq_logprobs=seq_logprobs,
|
||||
seq_data=seq_data,
|
||||
sampling_params=self.sampling_params[group_id],
|
||||
block_tables=block_tables,
|
||||
)
|
||||
|
||||
@ -96,7 +96,7 @@ class FastAPIServer:
|
||||
seqs: List[Sequence] = []
|
||||
for _ in range(sampling_params.n):
|
||||
seq_id = next(self.seq_counter)
|
||||
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
|
||||
seq = Sequence(seq_id, prompt, token_ids, block_size=self.block_size)
|
||||
seqs.append(seq)
|
||||
|
||||
arrival_time = time.time()
|
||||
|
||||
@ -35,10 +35,11 @@ class SimpleFrontend:
|
||||
sampling_params: SamplingParams,
|
||||
) -> None:
|
||||
token_ids = self.tokenizer.encode(prompt)
|
||||
self._add_query(token_ids, sampling_params)
|
||||
self._add_query(prompt, token_ids, sampling_params)
|
||||
|
||||
def _add_query(
|
||||
self,
|
||||
prompt: str,
|
||||
token_ids: List[int],
|
||||
sampling_params: SamplingParams,
|
||||
arrival_time: Optional[float] = None,
|
||||
@ -48,7 +49,7 @@ class SimpleFrontend:
|
||||
seqs: List[Sequence] = []
|
||||
for _ in range(sampling_params.n):
|
||||
seq_id = next(self.seq_counter)
|
||||
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
|
||||
seq = Sequence(seq_id, prompt, token_ids, block_size=self.block_size)
|
||||
seqs.append(seq)
|
||||
|
||||
group_id = next(self.seq_group_counter)
|
||||
|
||||
@ -1,17 +1,18 @@
|
||||
from typing import List, Dict, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import SequenceData
|
||||
|
||||
|
||||
class InputMetadata:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]], # List of (seq_ids, sampling_params).
|
||||
seq_data: Dict[int, SequenceData], # Seq_id -> SequenceData.
|
||||
prompt_lens: List[int],
|
||||
slot_mapping: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
@ -19,7 +20,7 @@ class InputMetadata:
|
||||
block_tables: torch.Tensor,
|
||||
) -> None:
|
||||
self.seq_groups = seq_groups
|
||||
self.seq_logprobs = seq_logprobs
|
||||
self.seq_data = seq_data
|
||||
self.prompt_lens = prompt_lens
|
||||
self.slot_mapping = slot_mapping
|
||||
self.context_lens = context_lens
|
||||
@ -39,6 +40,7 @@ class InputMetadata:
|
||||
assert context_lens.shape[0] == self.num_generation_tokens
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Print only useful metadata.
|
||||
return (f'InputMetadata('
|
||||
f'num_valid_tokens={self.num_valid_tokens}, '
|
||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -31,6 +32,16 @@ class Sampler(nn.Module):
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, :self.vocab_size]
|
||||
|
||||
# Apply presence and frequency penalties.
|
||||
output_tokens = _get_output_tokens(input_metadata)
|
||||
assert len(output_tokens) == logits.shape[0]
|
||||
presence_penalties, frequency_penalties = _get_penalties(input_metadata)
|
||||
assert len(presence_penalties) == logits.shape[0]
|
||||
assert len(frequency_penalties) == logits.shape[0]
|
||||
logits = _apply_penalties(
|
||||
logits, output_tokens, presence_penalties, frequency_penalties,
|
||||
self.vocab_size)
|
||||
|
||||
# Apply temperature scaling.
|
||||
temperatures = _get_temperatures(input_metadata)
|
||||
assert len(temperatures) == logits.shape[0]
|
||||
@ -43,16 +54,14 @@ class Sampler(nn.Module):
|
||||
# We use float32 for probabilities and log probabilities.
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities (before applying top-p).
|
||||
# Compute the log probabilities (before applying top-p and top-k).
|
||||
logprobs = torch.log(probs)
|
||||
|
||||
# Apply top-p and top-k truncation.
|
||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
||||
if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks):
|
||||
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
||||
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
|
||||
probs = _apply_top_p_top_k(probs, p, k)
|
||||
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
|
||||
|
||||
# Sample the next tokens.
|
||||
return _sample(probs, logprobs, input_metadata)
|
||||
@ -72,6 +81,93 @@ def _prune_hidden_states(
|
||||
return hidden_states[last_token_indicies]
|
||||
|
||||
|
||||
def _get_penalties(
|
||||
input_metadata: InputMetadata,
|
||||
) -> Tuple[List[float], List[float]]:
|
||||
# Collect the presence and frequency penalties.
|
||||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
p = sampling_params.presence_penalty
|
||||
f = sampling_params.frequency_penalty
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
presence_penalties.append(p)
|
||||
frequency_penalties.append(f)
|
||||
else:
|
||||
# A generation token.
|
||||
presence_penalties += [p] * len(seq_ids)
|
||||
frequency_penalties += [f] * len(seq_ids)
|
||||
return presence_penalties, frequency_penalties
|
||||
|
||||
|
||||
def _get_output_tokens(
|
||||
input_metadata: InputMetadata,
|
||||
) -> List[List[int]]:
|
||||
output_tokens: List[List[int]] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, _ = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
# NOTE: While the prompt input usually has no output tokens,
|
||||
# it may have output tokens in the case of recomputation.
|
||||
seq_id = seq_ids[0]
|
||||
seq_data = input_metadata.seq_data[seq_id]
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
else:
|
||||
# A generation token.
|
||||
for seq_id in seq_ids:
|
||||
seq_data = input_metadata.seq_data[seq_id]
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
output_tokens: List[List[int]],
|
||||
presence_penalties: List[float],
|
||||
frequency_penalties: List[float],
|
||||
vocab_size: int,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = logits.shape[0]
|
||||
# Collect the indices of sequences that have non-zero penalties.
|
||||
indices = []
|
||||
for i in range(num_seqs):
|
||||
if not output_tokens[i]:
|
||||
continue
|
||||
p = presence_penalties[i]
|
||||
f = frequency_penalties[i]
|
||||
if p == 0.0 and f == 0.0:
|
||||
continue
|
||||
indices.append(i)
|
||||
|
||||
# Return early if all sequences have zero penalties.
|
||||
if not indices:
|
||||
return logits
|
||||
|
||||
bin_counts = []
|
||||
for i in indices:
|
||||
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
|
||||
bin_counts = np.stack(bin_counts, axis=0)
|
||||
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
|
||||
frequency_penalties = [frequency_penalties[i] for i in indices]
|
||||
frequency_penalties = torch.tensor(
|
||||
frequency_penalties, dtype=logits.dtype, device=logits.device)
|
||||
presence_penalties = [presence_penalties[i] for i in indices]
|
||||
presence_penalties = torch.tensor(
|
||||
presence_penalties, dtype=logits.dtype, device=logits.device)
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
||||
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
|
||||
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
|
||||
return logits
|
||||
|
||||
|
||||
def _get_temperatures(
|
||||
input_metadata: InputMetadata,
|
||||
) -> List[float]:
|
||||
@ -121,10 +217,11 @@ def _get_top_p_top_k(
|
||||
|
||||
def _apply_top_p_top_k(
|
||||
probs: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
top_ps: List[float],
|
||||
top_ks: List[int],
|
||||
) -> torch.Tensor:
|
||||
# TODO(woosuk): Optimize.
|
||||
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
||||
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
|
||||
# Apply top-p.
|
||||
@ -286,7 +383,8 @@ def _sample(
|
||||
|
||||
# Sample the next tokens.
|
||||
seq_logprobs = [
|
||||
input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids]
|
||||
input_metadata.seq_data[seq_id].cumulative_logprobs
|
||||
for seq_id in seq_ids]
|
||||
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
||||
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
||||
|
||||
|
||||
@ -6,6 +6,8 @@ class SamplingParams:
|
||||
def __init__(
|
||||
self,
|
||||
n: int,
|
||||
presence_penalty: float,
|
||||
frequency_penalty: float,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
@ -16,6 +18,12 @@ class SamplingParams:
|
||||
) -> None:
|
||||
if n < 1:
|
||||
raise ValueError(f"n must be at least 1, got {n}.")
|
||||
if not -2.0 <= presence_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
f"presence_penalty must be in [-2, 2], got {presence_penalty}.")
|
||||
if not -2.0 <= frequency_penalty <= 2.0:
|
||||
raise ValueError(
|
||||
f"frequency_penalty must be in [-2, 2], got {frequency_penalty}.")
|
||||
if temperature < 0.0:
|
||||
raise ValueError(
|
||||
f"temperature must be non-negative, got {temperature}.")
|
||||
@ -57,6 +65,8 @@ class SamplingParams:
|
||||
"top_k must be -1 when using greedy sampling.")
|
||||
|
||||
self.n = n
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
@ -67,6 +77,8 @@ class SamplingParams:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SamplingParams(n={self.n}, "
|
||||
f"presence_penalty={self.presence_penalty}, "
|
||||
f"frequency_penalty={self.frequency_penalty}, "
|
||||
f"temperature={self.temperature}, "
|
||||
f"top_p={self.top_p}, "
|
||||
f"top_k={self.top_k},"
|
||||
@ -77,13 +89,18 @@ class SamplingParams:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: Dict) -> "SamplingParams":
|
||||
return cls(
|
||||
n=d.get("n", 1),
|
||||
temperature=d.get("temperature", 1.0),
|
||||
top_p=d.get("top_p", 1.0),
|
||||
top_k=d.get("top_k", -1),
|
||||
use_beam_search=d.get("use_beam_search", False),
|
||||
stop_token_ids=set(d.get("stop_token_ids", set())),
|
||||
max_num_steps=d.get("max_num_steps", 16),
|
||||
num_logprobs=d.get("num_logprobs", 0),
|
||||
sampling_params = cls(
|
||||
n=d.pop("n", 1),
|
||||
presence_penalty=d.pop("presence_penalty", 0.0),
|
||||
frequency_penalty=d.pop("frequency_penalty", 0.0),
|
||||
temperature=d.pop("temperature", 1.0),
|
||||
top_p=d.pop("top_p", 1.0),
|
||||
top_k=d.pop("top_k", -1),
|
||||
use_beam_search=d.pop("use_beam_search", False),
|
||||
stop_token_ids=set(d.pop("stop_token_ids", set())),
|
||||
max_num_steps=d.pop("max_num_steps", 16),
|
||||
num_logprobs=d.pop("num_logprobs", 0),
|
||||
)
|
||||
if d:
|
||||
raise ValueError(f"Unrecognized keys in dict: {d.keys()}")
|
||||
return sampling_params
|
||||
|
||||
@ -13,26 +13,55 @@ class SequenceStatus(enum.Enum):
|
||||
FINISHED = enum.auto()
|
||||
|
||||
|
||||
class SequenceData:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_token_ids: List[int],
|
||||
) -> None:
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
|
||||
self.output_token_ids: List[int] = []
|
||||
self.cumulative_logprobs = 0.0
|
||||
|
||||
def get_len(self) -> int:
|
||||
return len(self.output_token_ids) + len(self.prompt_token_ids)
|
||||
|
||||
def get_token_ids(self) -> List[int]:
|
||||
return self.prompt_token_ids + self.output_token_ids
|
||||
|
||||
def get_last_token_id(self) -> int:
|
||||
if not self.output_token_ids:
|
||||
return self.prompt_token_ids[-1]
|
||||
return self.output_token_ids[-1]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceData("
|
||||
f"prompt={self.prompt}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"output_token_ids={self.output_token_ids})")
|
||||
|
||||
|
||||
class Sequence:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
prompt: str,
|
||||
prompt_token_ids: List[int],
|
||||
block_size: int,
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.prompt = prompt
|
||||
self.block_size = block_size
|
||||
|
||||
self.prompt_len = len(prompt_token_ids)
|
||||
self.data = SequenceData(prompt_token_ids)
|
||||
self.output_logprobs: List[Dict[int, float]] = []
|
||||
|
||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||
# Initialize the logical token blocks with the prompt token ids.
|
||||
self._append_tokens(prompt_token_ids)
|
||||
|
||||
self._append_tokens_to_blocks(prompt_token_ids)
|
||||
self.status = SequenceStatus.WAITING
|
||||
# Used for beam search.
|
||||
self.output_logprobs: List[Dict[int, float]] = []
|
||||
self.cumulative_logprobs = 0.0
|
||||
|
||||
def _append_logical_block(self) -> None:
|
||||
block = LogicalTokenBlock(
|
||||
@ -41,7 +70,7 @@ class Sequence:
|
||||
)
|
||||
self.logical_token_blocks.append(block)
|
||||
|
||||
def _append_tokens(self, token_ids: List[int]) -> None:
|
||||
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
|
||||
while token_ids:
|
||||
if not self.logical_token_blocks:
|
||||
self._append_logical_block()
|
||||
@ -57,26 +86,24 @@ class Sequence:
|
||||
|
||||
def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
|
||||
assert token_id in logprobs
|
||||
self._append_tokens([token_id])
|
||||
self._append_tokens_to_blocks([token_id])
|
||||
self.output_logprobs.append(logprobs)
|
||||
self.cumulative_logprobs += logprobs[token_id]
|
||||
self.data.output_token_ids.append(token_id)
|
||||
self.data.cumulative_logprobs += logprobs[token_id]
|
||||
|
||||
def get_len(self) -> int:
|
||||
return sum(block.num_tokens for block in self.logical_token_blocks)
|
||||
return self.data.get_len()
|
||||
|
||||
def get_token_ids(self) -> List[int]:
|
||||
token_ids: List[int] = []
|
||||
for block in self.logical_token_blocks:
|
||||
token_ids.extend(block.get_token_ids())
|
||||
return token_ids
|
||||
return self.data.get_token_ids()
|
||||
|
||||
def get_last_token_id(self) -> int:
|
||||
return self.logical_token_blocks[-1].get_last_token_id()
|
||||
return self.data.get_last_token_id()
|
||||
|
||||
def fork(self, child_seq: 'Sequence') -> 'Sequence':
|
||||
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||
child_seq.cumulative_logprobs = self.cumulative_logprobs
|
||||
child_seq.data = copy.deepcopy(self.data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'Sequence(seq_id={self.seq_id}, '
|
||||
@ -128,17 +155,13 @@ class SequenceGroupMetadata:
|
||||
self,
|
||||
group_id: int,
|
||||
is_prompt: bool,
|
||||
input_tokens: Dict[int, List[int]], # Seq id -> token ids.
|
||||
context_len: int,
|
||||
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
|
||||
seq_data: Dict[int, SequenceData], # Seq id -> sequence data.
|
||||
sampling_params: SamplingParams,
|
||||
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
|
||||
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers.
|
||||
) -> None:
|
||||
self.group_id = group_id
|
||||
self.is_prompt = is_prompt
|
||||
self.input_tokens = input_tokens
|
||||
self.context_len = context_len
|
||||
self.seq_logprobs = seq_logprobs
|
||||
self.seq_data = seq_data
|
||||
self.sampling_params = sampling_params
|
||||
self.block_tables = block_tables
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -8,8 +8,8 @@ from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
initialize_all_reduce_launcher,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import SequenceGroupMetadata
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
|
||||
SequenceOutputs)
|
||||
from cacheflow.worker.cache_engine import CacheEngine
|
||||
|
||||
|
||||
@ -72,7 +72,6 @@ class Worker:
|
||||
self.cache_events = self.cache_engine.events
|
||||
self.gpu_cache = self.cache_engine.gpu_cache
|
||||
|
||||
|
||||
def init_distributed_environment(self,
|
||||
distributed_init_method: str,
|
||||
rank: int,
|
||||
@ -96,7 +95,6 @@ class Worker:
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||
seq_logprobs: Dict[int, float] = {}
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
@ -107,15 +105,15 @@ class Worker:
|
||||
if not seq_group_metadata.is_prompt:
|
||||
continue
|
||||
|
||||
seq_ids = list(seq_group_metadata.input_tokens.keys())
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
seq_logprobs.update(seq_group_metadata.seq_logprobs)
|
||||
|
||||
# Use any sequence in the group.
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
prompt_tokens = seq_group_metadata.input_tokens[seq_id]
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
prompt_len = len(prompt_tokens)
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
@ -141,27 +139,26 @@ class Worker:
|
||||
if seq_group_metadata.is_prompt:
|
||||
continue
|
||||
|
||||
seq_ids = list(seq_group_metadata.input_tokens.keys())
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_groups.append((seq_ids, sampling_params))
|
||||
seq_logprobs.update(seq_group_metadata.seq_logprobs)
|
||||
|
||||
for seq_id in seq_ids:
|
||||
assert len(seq_group_metadata.input_tokens[seq_id]) == 1
|
||||
generation_token = seq_group_metadata.input_tokens[seq_id][0]
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append(generation_token)
|
||||
|
||||
position = seq_group_metadata.context_len - 1
|
||||
context_len = seq_data.get_len()
|
||||
position = context_len - 1
|
||||
input_positions.append(position)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
generation_block_tables.append(block_table)
|
||||
|
||||
max_context_len = max(
|
||||
max_context_len, seq_group_metadata.context_len)
|
||||
max_context_len = max(max_context_len, context_len)
|
||||
max_num_blocks_per_seq = max(
|
||||
max_num_blocks_per_seq, len(block_table))
|
||||
context_lens.append(seq_group_metadata.context_len)
|
||||
context_lens.append(context_len)
|
||||
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
@ -188,9 +185,13 @@ class Worker:
|
||||
block_tables_tensor = torch.tensor(
|
||||
padded_block_tables, dtype=torch.int, device='cuda')
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data.update(seq_group_metadata.seq_data)
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_logprobs=seq_logprobs,
|
||||
seq_data=seq_data,
|
||||
prompt_lens=prompt_lens,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
context_lens=context_lens_tensor,
|
||||
|
||||
@ -11,8 +11,8 @@ def main(args: argparse.Namespace):
|
||||
# Test the following inputs.
|
||||
test_inputs = [
|
||||
("A robot may not injure a human being", {}), # Use default parameters.
|
||||
("To be or not to be,", {"temperature": 0.8, "top_k": 5}),
|
||||
("What is the meaning of life?", {"n": 2, "temperature": 0.8, "top_p": 0.95}),
|
||||
("To be or not to be,", {"temperature": 0.8, "top_k": 5, "presence_penalty": 0.2}),
|
||||
("What is the meaning of life?", {"n": 2, "temperature": 0.8, "top_p": 0.95, "frequency_penalty": 0.1}),
|
||||
("It is only with the heart that one can see rightly", {"n": 3, "use_beam_search": True, "temperature": 0.0}),
|
||||
]
|
||||
while True:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user