mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 07:25:02 +08:00
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**
commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:18:24 2025 -0500
Add SPDX license headers to python source files
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
also be easily used by tools to help manage license compliance.
The Linux Foundation runs license scans against the codebase to help
ensure
we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
More information can be found on the SPDX site:
- https://spdx.dev/learn/handling-license-info/
Signed-off-by: Russell Bryant <rbryant@redhat.com>
commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:36:32 2025 -0500
Check for SPDX headers using pre-commit
Signed-off-by: Russell Bryant <rbryant@redhat.com>
---------
Signed-off-by: Russell Bryant <rbryant@redhat.com>
597 lines
22 KiB
Python
597 lines
22 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from array import array
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from vllm.sampling_params import SamplingParams, SamplingType
|
|
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
|
|
SequenceGroupMetadata)
|
|
from vllm.utils import (PyObjectCache, async_tensor_h2d,
|
|
is_pin_memory_available, make_tensor_with_pad)
|
|
|
|
_SAMPLING_EPS = 1e-5
|
|
|
|
|
|
@dataclass
|
|
class SequenceGroupToSample:
|
|
# |---------- N-1 iteration --------|
|
|
# |---------------- N iteration ---------------------|
|
|
# |- tokenA -|......................|-- newTokens ---|
|
|
# |---------- context_len ----------|
|
|
# |-------------------- seq_len ----------------------|
|
|
# |-- query_len ---|
|
|
|
|
# Sequence ids for the sequence group in a previous step.
|
|
seq_ids: List[int]
|
|
sampling_params: SamplingParams
|
|
# seq_id -> sequence data.
|
|
seq_data: Dict[int, SequenceData]
|
|
# The length of the sequence (all tokens seen in the past + new token to
|
|
# compute attention) of the sequence group. None if it is in a decode
|
|
# stage.
|
|
seq_len: Optional[int]
|
|
# The length of new query tokens to compute in the current step. None if it
|
|
# is in a decode stage. The length of query_len <= seq_len if chunked
|
|
# prefill is enabled.
|
|
query_len: Optional[int]
|
|
# A random number generator for sampling.
|
|
generator: Optional[torch.Generator]
|
|
# True if the sequence group is in prefill stage. False if it is in a
|
|
# decode stage.
|
|
is_prompt: bool
|
|
# Query token indices from logits. to compute prompt logprob. Empty if
|
|
# prompt logprob is not required.
|
|
prompt_logprob_indices: List[int]
|
|
# Sample token indices from logits. Empty if sampling is not required.
|
|
sample_indices: List[int]
|
|
|
|
@property
|
|
def do_sample(self):
|
|
return len(self.sample_indices) > 0
|
|
|
|
def __post_init__(self):
|
|
if len(self.prompt_logprob_indices) > 0:
|
|
assert self.sampling_params.prompt_logprobs is not None
|
|
if self.is_prompt:
|
|
assert self.seq_len is not None
|
|
assert self.query_len is not None
|
|
|
|
|
|
def gen_seq_group_to_sample_builder(num_seqs: int):
|
|
return lambda: SequenceGroupToSample(
|
|
seq_ids=[0] * num_seqs,
|
|
sampling_params=None,
|
|
seq_data=None, # type: ignore
|
|
seq_len=0,
|
|
query_len=0,
|
|
generator=None,
|
|
is_prompt=True,
|
|
prompt_logprob_indices=[],
|
|
sample_indices=[],
|
|
)
|
|
|
|
|
|
class SamplingMetadataCache:
|
|
"""Used to cache SamplingMetadata objects between scheduler iterations"""
|
|
|
|
def __init__(self):
|
|
self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
|
|
|
|
def get_cached_seq_group_to_sample(self, num_seqs):
|
|
if num_seqs not in self._seq_group_to_sample_cache:
|
|
self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
|
|
gen_seq_group_to_sample_builder(num_seqs))
|
|
|
|
obj = self._seq_group_to_sample_cache[num_seqs].get_object()
|
|
return obj
|
|
|
|
def reset(self):
|
|
for cache in self._seq_group_to_sample_cache.values():
|
|
cache.reset()
|
|
|
|
|
|
class SamplingMetadata:
|
|
"""Metadata for input sequences. Used in sampler.
|
|
|
|
The usage is as follow;
|
|
```
|
|
hidden_states = execute_model(...)
|
|
logits = hidden_states[sampling_metadata.selected_token_indices]
|
|
sample(logits)
|
|
|
|
def sample(logits):
|
|
# Use categorized_sample_indices for sampling....
|
|
```
|
|
|
|
Args:
|
|
seq_groups: List of batched sequence groups.
|
|
selected_token_indices: (num_query_tokens_to_logprob). Indices to find
|
|
logits from the initial model output hidden states.
|
|
categorized_sample_indices: SamplingType -> token indices to sample.
|
|
Each token indices is 2D tensor of (num_indices, num_indices) where
|
|
the first item means the sample index within the returned logit
|
|
(before pruning padding), and the second item means the sample
|
|
index after pruning using selected_token_indices.
|
|
For example, if the returned logit is [1, 2, 3], and we select
|
|
[1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
|
|
The first tuple is [1, 2] (sampled index within original logit),
|
|
and the second tuple is [0, 1] (sampled index within pruned logit).
|
|
num_prompts: Number of prompt sequence groups in seq_groups.
|
|
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
|
|
serialization of token outputs.
|
|
reuse_sampling_tensors: Indicates if we want to reuse sampling
|
|
tensors that are part of the sampler forward pass. Currently,
|
|
it is mainly used for multi-step decode.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
seq_groups: List[SequenceGroupToSample],
|
|
selected_token_indices: torch.Tensor,
|
|
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
|
|
num_prompts: int,
|
|
skip_sampler_cpu_output: bool = False,
|
|
reuse_sampling_tensors: bool = False,
|
|
) -> None:
|
|
self.seq_groups = seq_groups
|
|
self.selected_token_indices = selected_token_indices
|
|
self.categorized_sample_indices = categorized_sample_indices
|
|
self.num_prompts = num_prompts
|
|
self.skip_sampler_cpu_output = skip_sampler_cpu_output
|
|
self.reuse_sampling_tensors = reuse_sampling_tensors
|
|
|
|
@staticmethod
|
|
def prepare(
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
seq_lens: List[int],
|
|
query_lens: List[int],
|
|
device: str,
|
|
pin_memory: bool,
|
|
generators: Optional[Dict[str, torch.Generator]] = None,
|
|
cache: Optional[SamplingMetadataCache] = None,
|
|
) -> "SamplingMetadata":
|
|
(
|
|
seq_groups,
|
|
selected_token_indices,
|
|
categorized_sample_indices,
|
|
num_prompts,
|
|
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
|
|
device, generators, cache)
|
|
selected_token_indices = async_tensor_h2d(
|
|
selected_token_indices,
|
|
dtype=torch.long,
|
|
target_device=device,
|
|
pin_memory=pin_memory,
|
|
)
|
|
categorized_sample_indices = {
|
|
t:
|
|
async_tensor_h2d(
|
|
seq_ids,
|
|
dtype=torch.int,
|
|
target_device=device,
|
|
pin_memory=pin_memory,
|
|
)
|
|
for t, seq_ids in categorized_sample_indices.items()
|
|
}
|
|
|
|
sampling_metadata = SamplingMetadata(
|
|
seq_groups=seq_groups,
|
|
selected_token_indices=selected_token_indices,
|
|
categorized_sample_indices=categorized_sample_indices,
|
|
num_prompts=num_prompts,
|
|
)
|
|
return sampling_metadata
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
"SamplingMetadata("
|
|
f"seq_groups={self.seq_groups}, "
|
|
f"selected_token_indices={self.selected_token_indices}, "
|
|
f"categorized_sample_indices={self.categorized_sample_indices}), ")
|
|
|
|
|
|
def _prepare_seq_groups(
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
seq_lens: List[int],
|
|
query_lens: List[int],
|
|
device: str,
|
|
generators: Optional[Dict[str, torch.Generator]] = None,
|
|
cache: Optional[SamplingMetadataCache] = None,
|
|
) -> Tuple[
|
|
List[SequenceGroupToSample],
|
|
List[int],
|
|
Dict[SamplingType, List[int]],
|
|
int,
|
|
]:
|
|
"""Prepare sequence groups and indices for sampling.
|
|
|
|
Args:
|
|
seq_group_metadata_list: A list of sequence group to batch.
|
|
seq_lens: A list of sequence lens per sequence group.
|
|
Index of prompt len should match with seq_group_metadata_list.
|
|
query_lens: A list of query lengths. Prompt lens include the length
|
|
of entire prompt tokens, and it could be shorter.
|
|
device: A device to use for random number generators,
|
|
`SequenceGroupToSample.generator`.
|
|
generators: A store of per-request random number generators used
|
|
for seeded requests.
|
|
|
|
Returns:
|
|
seq_groups: A list of sequence group to sample.
|
|
selected_token_indices: See the definition from `SamplingMetadata`.
|
|
categorized_sample_indices: See the definition from `SamplingMetadata`.
|
|
num_prompts: Total number of prompts from `seq_group_metadata_list`.
|
|
"""
|
|
# Batched sequence groups for the current model forward stsep.
|
|
seq_groups: List[SequenceGroupToSample] = []
|
|
# A list of token indices to sample/compute logprob. It is used to
|
|
# prune the outcome logits from the model for the performance.
|
|
selected_token_indices: List[int] = []
|
|
# Used for selected_token_indices.
|
|
model_output_idx = 0
|
|
|
|
# Sampling type -> (
|
|
# indices to sample/prompt logprob within pruned output logits,
|
|
# indices to sample within pruned logits)
|
|
categorized_sample_indices: Dict[SamplingType, List[int]] = {
|
|
t: []
|
|
for t in SamplingType
|
|
}
|
|
# Index of logits to compute logprob. Logits include both prompt logprob
|
|
# and sample logprob indices.
|
|
logit_idx = 0
|
|
# Total number of prompts from given sequence groups.
|
|
num_prompts = 0
|
|
|
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
|
seq_ids = seq_group_metadata.seq_data.keys()
|
|
|
|
if cache is not None:
|
|
sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
|
|
|
|
for j, seq_id in enumerate(seq_ids):
|
|
sample_obj.seq_ids[j] = seq_id
|
|
|
|
sample_obj.prompt_logprob_indices.clear()
|
|
sample_obj.sample_indices.clear()
|
|
|
|
sampling_params = seq_group_metadata.sampling_params
|
|
is_prompt = seq_group_metadata.is_prompt
|
|
generator: Optional[torch.Generator] = None
|
|
# If the current seq group is in decode stage, it is None.
|
|
seq_len: Optional[int] = None
|
|
query_len: Optional[int] = None
|
|
prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
|
|
if cache is not None else [])
|
|
sample_indices: List[int] = (sample_obj.sample_indices
|
|
if cache is not None else [])
|
|
do_sample = seq_group_metadata.do_sample
|
|
|
|
if seq_group_metadata.is_prompt:
|
|
if sampling_params.seed is not None:
|
|
generator = torch.Generator(device=device).manual_seed(
|
|
sampling_params.seed)
|
|
if generators is not None:
|
|
generators[seq_group_metadata.request_id] = generator
|
|
|
|
num_prompts += 1
|
|
num_prefill_sample = len(seq_ids)
|
|
assert num_prefill_sample == 1
|
|
assert query_lens is not None and seq_lens is not None
|
|
query_len, seq_len = query_lens[i], seq_lens[i]
|
|
# If we need sampling, exclude num_prefill_sample tokens from
|
|
# prompt logprob.
|
|
prompt_logprob_len = (query_len - num_prefill_sample
|
|
if do_sample else query_len)
|
|
sample_len = num_prefill_sample if do_sample else 0
|
|
else:
|
|
# Decode
|
|
prompt_logprob_len = 0
|
|
query_len = query_lens[i] if query_lens is not None and len(
|
|
query_lens) > 0 else 1
|
|
sample_len = len(seq_ids) * query_len if do_sample else 0
|
|
|
|
if sampling_params.seed is not None and generators is not None:
|
|
generator = generators.get(seq_group_metadata.request_id)
|
|
|
|
# Update indices to select from the model output.
|
|
"""
|
|
This blocks computes selected_token_indices which is used in the
|
|
following way.
|
|
|
|
hidden_states = model(...)
|
|
logits = hidden_states[selected_token_indices]
|
|
"""
|
|
|
|
if sampling_params.prompt_logprobs is not None:
|
|
selected_token_indices.extend(
|
|
range(model_output_idx, model_output_idx + prompt_logprob_len))
|
|
model_output_idx += prompt_logprob_len
|
|
if do_sample:
|
|
selected_token_indices.extend(
|
|
range(model_output_idx, model_output_idx + sample_len))
|
|
model_output_idx += sample_len
|
|
|
|
# We now find indices for logprob computation and sampling.
|
|
"""
|
|
This block computes categorized_sample_indices which is used in the
|
|
following way.
|
|
|
|
hidden_states = model(...)
|
|
logits = hidden_states[selected_token_indices]
|
|
def sample(logits):
|
|
# Use categorized_sample_indices for sampling.
|
|
# prompt_logprob_indices to find prompt logprob indices.
|
|
# sample_indices to find sample indices.
|
|
"""
|
|
|
|
if sampling_params.prompt_logprobs is not None:
|
|
prompt_logprob_indices.extend(
|
|
range(logit_idx, logit_idx + prompt_logprob_len))
|
|
logit_idx += prompt_logprob_len
|
|
if do_sample:
|
|
sample_indices.extend(range(logit_idx, logit_idx + sample_len))
|
|
categorized_sample_indices[sampling_params.sampling_type].extend(
|
|
list(range(logit_idx, logit_idx + sample_len)))
|
|
logit_idx += sample_len
|
|
|
|
if cache is not None:
|
|
sample_obj.sampling_params = sampling_params
|
|
sample_obj.seq_data = seq_group_metadata.seq_data
|
|
sample_obj.seq_len = seq_len
|
|
sample_obj.query_len = query_len
|
|
sample_obj.generator = generator
|
|
sample_obj.is_prompt = is_prompt
|
|
else:
|
|
sample_obj = SequenceGroupToSample(
|
|
seq_ids=list(seq_ids),
|
|
sampling_params=sampling_params,
|
|
seq_data=seq_group_metadata.seq_data,
|
|
seq_len=seq_len,
|
|
query_len=query_len,
|
|
generator=generator,
|
|
is_prompt=is_prompt,
|
|
prompt_logprob_indices=list(prompt_logprob_indices),
|
|
sample_indices=list(sample_indices),
|
|
)
|
|
|
|
seq_groups.append(sample_obj)
|
|
|
|
if cache is not None:
|
|
cache.reset()
|
|
|
|
return (seq_groups, selected_token_indices, categorized_sample_indices,
|
|
num_prompts)
|
|
|
|
|
|
@dataclass
|
|
class SamplingTensors:
|
|
"""Tensors for sampling."""
|
|
|
|
temperatures: torch.Tensor
|
|
top_ps: torch.Tensor
|
|
top_ks: torch.Tensor
|
|
min_ps: torch.Tensor
|
|
presence_penalties: torch.Tensor
|
|
frequency_penalties: torch.Tensor
|
|
repetition_penalties: torch.Tensor
|
|
prompt_tokens: torch.Tensor
|
|
output_tokens: torch.Tensor
|
|
|
|
@classmethod
|
|
def from_sampling_metadata(
|
|
cls,
|
|
sampling_metadata: "SamplingMetadata",
|
|
vocab_size: int,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
) -> Tuple["SamplingTensors", bool, bool, bool]:
|
|
prompt_tokens: List[array] = []
|
|
output_tokens: List[array] = []
|
|
top_ks: List[int] = []
|
|
temperatures: List[float] = []
|
|
top_ps: List[float] = []
|
|
min_ps: List[float] = []
|
|
presence_penalties: List[float] = []
|
|
frequency_penalties: List[float] = []
|
|
repetition_penalties: List[float] = []
|
|
do_penalties = False
|
|
do_top_p_top_k = False
|
|
do_min_p = False
|
|
|
|
assert sampling_metadata.seq_groups is not None
|
|
for seq_group in sampling_metadata.seq_groups:
|
|
seq_ids = seq_group.seq_ids
|
|
sampling_params = seq_group.sampling_params
|
|
temperature = sampling_params.temperature
|
|
p = sampling_params.presence_penalty
|
|
f = sampling_params.frequency_penalty
|
|
r = sampling_params.repetition_penalty
|
|
top_p = sampling_params.top_p
|
|
min_p = sampling_params.min_p
|
|
|
|
# k should not be greater than the vocab size.
|
|
top_k = min(sampling_params.top_k, vocab_size)
|
|
top_k = vocab_size if top_k == -1 else top_k
|
|
if temperature < _SAMPLING_EPS:
|
|
# NOTE: Zero temperature means deterministic sampling
|
|
# (i.e., greedy sampling or beam search).
|
|
# Set the temperature to 1 to avoid division by zero.
|
|
temperature = 1.0
|
|
if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
|
|
or top_k != vocab_size):
|
|
do_top_p_top_k = True
|
|
if not do_min_p and min_p > _SAMPLING_EPS:
|
|
do_min_p = True
|
|
if not do_penalties and (abs(p) >= _SAMPLING_EPS
|
|
or abs(f) >= _SAMPLING_EPS
|
|
or abs(r - 1.0) >= _SAMPLING_EPS):
|
|
do_penalties = True
|
|
|
|
is_prompt = seq_group.is_prompt
|
|
if is_prompt and sampling_params.prompt_logprobs is not None:
|
|
# For tokens in the prompt that we only need to get
|
|
# their logprobs
|
|
query_len = seq_group.query_len
|
|
assert query_len is not None
|
|
prefill_len = len(seq_group.prompt_logprob_indices)
|
|
temperatures += [temperature] * prefill_len
|
|
top_ps += [top_p] * prefill_len
|
|
top_ks += [top_k] * prefill_len
|
|
min_ps += [min_p] * prefill_len
|
|
presence_penalties += [0] * prefill_len
|
|
frequency_penalties += [0] * prefill_len
|
|
repetition_penalties += [1] * prefill_len
|
|
|
|
if seq_group.do_sample:
|
|
sample_lens = len(seq_group.sample_indices)
|
|
assert sample_lens >= len(seq_ids)
|
|
temperatures += [temperature] * sample_lens
|
|
top_ps += [top_p] * sample_lens
|
|
top_ks += [top_k] * sample_lens
|
|
min_ps += [min_p] * sample_lens
|
|
presence_penalties += [p] * sample_lens
|
|
frequency_penalties += [f] * sample_lens
|
|
repetition_penalties += [r] * sample_lens
|
|
|
|
if do_penalties:
|
|
for seq_group in sampling_metadata.seq_groups:
|
|
seq_ids = seq_group.seq_ids
|
|
sampling_params = seq_group.sampling_params
|
|
if (seq_group.is_prompt
|
|
and sampling_params.prompt_logprobs is not None):
|
|
prefill_len = len(seq_group.prompt_logprob_indices)
|
|
prompt_tokens.extend(
|
|
array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
|
for _ in range(prefill_len))
|
|
output_tokens.extend(
|
|
array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
|
for _ in range(prefill_len))
|
|
if seq_group.do_sample:
|
|
for seq_id in seq_ids:
|
|
seq_data = seq_group.seq_data[seq_id]
|
|
prompt_tokens.append(seq_data.prompt_token_ids_array)
|
|
output_tokens.append(seq_data.output_token_ids_array)
|
|
|
|
sampling_tensors = SamplingTensors.from_lists(
|
|
temperatures,
|
|
top_ps,
|
|
top_ks,
|
|
min_ps,
|
|
presence_penalties,
|
|
frequency_penalties,
|
|
repetition_penalties,
|
|
prompt_tokens,
|
|
output_tokens,
|
|
vocab_size,
|
|
device,
|
|
dtype,
|
|
)
|
|
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
|
|
|
|
@classmethod
|
|
def from_lists(
|
|
cls,
|
|
temperatures: List[float],
|
|
top_ps: List[float],
|
|
top_ks: List[int],
|
|
min_ps: List[float],
|
|
presence_penalties: List[float],
|
|
frequency_penalties: List[float],
|
|
repetition_penalties: List[float],
|
|
prompt_tokens: List[array],
|
|
output_tokens: List[array],
|
|
vocab_size: int,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
) -> "SamplingTensors":
|
|
# Note that the performance will be very bad without
|
|
# pinned memory.
|
|
pin_memory = is_pin_memory_available()
|
|
|
|
do_penalties = prompt_tokens or output_tokens
|
|
|
|
if do_penalties:
|
|
prompt_t = make_tensor_with_pad(
|
|
prompt_tokens,
|
|
vocab_size,
|
|
device="cpu",
|
|
dtype=torch.int64,
|
|
pin_memory=pin_memory,
|
|
)
|
|
output_t = make_tensor_with_pad(
|
|
output_tokens,
|
|
vocab_size,
|
|
device="cpu",
|
|
dtype=torch.int64,
|
|
pin_memory=pin_memory,
|
|
)
|
|
else:
|
|
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
|
|
prompt_t = empty_tensor
|
|
output_t = empty_tensor
|
|
|
|
temperatures_t = torch.tensor(
|
|
temperatures,
|
|
device="cpu",
|
|
dtype=dtype,
|
|
pin_memory=pin_memory,
|
|
)
|
|
top_ps_t = torch.tensor(
|
|
top_ps,
|
|
device="cpu",
|
|
dtype=dtype,
|
|
pin_memory=pin_memory,
|
|
)
|
|
min_ps_t = torch.tensor(
|
|
min_ps,
|
|
device="cpu",
|
|
dtype=dtype,
|
|
pin_memory=pin_memory,
|
|
)
|
|
presence_penalties_t = torch.tensor(
|
|
presence_penalties,
|
|
device="cpu",
|
|
dtype=dtype,
|
|
pin_memory=pin_memory,
|
|
)
|
|
frequency_penalties_t = torch.tensor(
|
|
frequency_penalties,
|
|
device="cpu",
|
|
dtype=dtype,
|
|
pin_memory=pin_memory,
|
|
)
|
|
repetition_penalties_t = torch.tensor(
|
|
repetition_penalties,
|
|
device="cpu",
|
|
dtype=dtype,
|
|
pin_memory=pin_memory,
|
|
)
|
|
top_ks_t = torch.tensor(
|
|
top_ks,
|
|
device="cpu",
|
|
dtype=torch.int,
|
|
pin_memory=pin_memory,
|
|
)
|
|
# Because the memory is pinned, we can do non-blocking
|
|
# transfer to device.
|
|
|
|
return cls(
|
|
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
|
top_ps=top_ps_t.to(device=device, non_blocking=True),
|
|
top_ks=top_ks_t.to(device=device, non_blocking=True),
|
|
min_ps=min_ps_t.to(device=device, non_blocking=True),
|
|
presence_penalties=presence_penalties_t.to(device=device,
|
|
non_blocking=True),
|
|
frequency_penalties=frequency_penalties_t.to(device=device,
|
|
non_blocking=True),
|
|
repetition_penalties=repetition_penalties_t.to(device=device,
|
|
non_blocking=True),
|
|
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
|
|
output_tokens=output_t.to(device=device, non_blocking=True),
|
|
)
|