From a7347d9a6d2391734d838ab6a4f3a702e348d9fa Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Sun, 17 Dec 2023 07:03:49 -0800 Subject: [PATCH] Make sampler less blocking (#1889) --- vllm/model_executor/layers/sampler.py | 319 +++++++++-------------- vllm/model_executor/sampling_metadata.py | 187 +++++++++++++ 2 files changed, 309 insertions(+), 197 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 13da9aa38af03..fe88b0ea42936 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -6,13 +6,11 @@ import torch.nn as nn from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_gather) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) -_SAMPLING_EPS = 1e-5 - class Sampler(nn.Module): """Samples the next tokens from the model's outputs. @@ -32,6 +30,7 @@ class Sampler(nn.Module): def __init__(self, vocab_size: int) -> None: super().__init__() self.vocab_size = vocab_size + self._copy_stream: torch.cuda.Stream = torch.cuda.Stream() def forward( self, @@ -47,40 +46,38 @@ class Sampler(nn.Module): logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size) + _, vocab_size = logits.shape + # Apply logits processors (if any). logits = _apply_logits_processors(logits, sampling_metadata) + + # Prepare sampling tensors in another stream to overlap + # CPU<->GPU data transfer with GPU computation in forward pass. + with torch.cuda.stream(self._copy_stream): + (sampling_tensors, do_penalties, do_top_p_top_k, + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) + + torch.cuda.current_stream().wait_stream(self._copy_stream) + # Apply presence and frequency penalties. - presence_penalties, frequency_penalties, repetition_penalties = ( - _get_penalties(sampling_metadata)) - assert len(presence_penalties) == logits.shape[0] - assert len(frequency_penalties) == logits.shape[0] - assert len(repetition_penalties) == logits.shape[0] - logits = _apply_penalties(logits, sampling_metadata, - presence_penalties, frequency_penalties, - repetition_penalties) + if do_penalties: + logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) # Apply temperature scaling. - temperatures = _get_temperatures(sampling_metadata) - assert len(temperatures) == logits.shape[0] - if any(t != 1.0 for t in temperatures): - t = torch.tensor(temperatures, - dtype=logits.dtype, - device=logits.device) - # Use in-place division to avoid creating a new tensor. - logits.div_(t.unsqueeze(dim=1)) + # Use in-place division to avoid creating a new tensor. + logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) - # Apply top-p and top-k truncation. - top_ps, top_ks, min_ps = _get_top_p_top_k_min_p( - sampling_metadata, self.vocab_size) - assert len(top_ps) == len(top_ks) == logits.shape[0] - do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) - do_top_k = any(k != self.vocab_size for k in top_ks) - if do_top_p or do_top_k: - logits = _apply_top_p_top_k(logits, top_ps, top_ks) + if do_top_p_top_k: + logits = _apply_top_p_top_k(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) - do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps) if do_min_p: - logits = _apply_min_p(logits, min_ps) + logits = _apply_min_p(logits, sampling_tensors.min_ps) # We use float32 for probabilities and log probabilities. # Compute the probabilities. @@ -120,32 +117,6 @@ def _prune_hidden_states( sampling_metadata.selected_token_indices) -def _get_penalties( - sampling_metadata: SamplingMetadata -) -> Tuple[List[float], List[float], List[float]]: - # Collect the presence and frequency penalties. - presence_penalties: List[float] = [] - frequency_penalties: List[float] = [] - repetition_penalties: List[float] = [] - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group - p = sampling_params.presence_penalty - f = sampling_params.frequency_penalty - r = sampling_params.repetition_penalty - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - # NOTE: We do not apply presence and frequency penalties for the - # prompt token positions where we don't sample new tokens. - prompt_len = sampling_metadata.prompt_lens[i] - presence_penalties += [0] * (prompt_len - 1) - frequency_penalties += [0] * (prompt_len - 1) - repetition_penalties += [1] * (prompt_len - 1) - presence_penalties += [p] * len(seq_ids) - frequency_penalties += [f] * len(seq_ids) - repetition_penalties += [r] * len(seq_ids) - return presence_penalties, frequency_penalties, repetition_penalties - - def _get_prompt_and_output_tokens( sampling_metadata: SamplingMetadata, ) -> Tuple[List[List[int]], List[List[int]]]: @@ -168,25 +139,16 @@ def _get_prompt_and_output_tokens( def _get_bin_counts_and_mask( - logits: torch.Tensor, - tokens: List[List[int]], + tokens: torch.Tensor, vocab_size: int, num_seqs: int, ) -> Tuple[torch.Tensor, torch.Tensor]: - max_len = max(len(tokens) for tokens in tokens) - padded_tokens = [ - tokens + [vocab_size] * (max_len - len(tokens)) for tokens in tokens - ] - tokens_tensor = torch.tensor(padded_tokens, - dtype=torch.long, - device=logits.device) - # Compute the bin counts for the tokens. # vocab_size + 1 for padding. bin_counts = torch.zeros((num_seqs, vocab_size + 1), dtype=torch.long, - device=logits.device) - bin_counts.scatter_add_(1, tokens_tensor, torch.ones_like(tokens_tensor)) + device=tokens.device) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) bin_counts = bin_counts[:, :vocab_size] mask = bin_counts > 0 @@ -217,45 +179,16 @@ def _apply_logits_processors( return logits -def _apply_penalties( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - presence_penalties: List[float], - frequency_penalties: List[float], - repetition_penalties: List[float], -) -> torch.Tensor: +def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: num_seqs, vocab_size = logits.shape - for i in range(num_seqs): - p = presence_penalties[i] - f = frequency_penalties[i] - r = repetition_penalties[i] - if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS and abs( - r - 1.0) < _SAMPLING_EPS: - continue - break - else: - # Return early if all sequences have zero penalties. - return logits - - prompt_tokens, output_tokens = ( - _get_prompt_and_output_tokens(sampling_metadata)) - assert len(prompt_tokens) == logits.shape[0] - assert len(output_tokens) == logits.shape[0] - - prompt_bin_counts, prompt_mask = _get_bin_counts_and_mask( - logits, prompt_tokens, vocab_size, num_seqs) + _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, + num_seqs) output_bin_counts, output_mask = _get_bin_counts_and_mask( - logits, output_tokens, vocab_size, num_seqs) - - repetition_penalties = torch.tensor(repetition_penalties, - dtype=logits.dtype, - device=logits.device) - frequency_penalties = torch.tensor(frequency_penalties, - dtype=logits.dtype, - device=logits.device) - presence_penalties = torch.tensor(presence_penalties, - dtype=logits.dtype, - device=logits.device) + output_tokens_tensor, vocab_size, num_seqs) repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) repetition_penalties[~(prompt_mask | output_mask)] = 1.0 @@ -264,109 +197,65 @@ def _apply_penalties( # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts - logits -= presence_penalties.unsqueeze(dim=1) * output_mask + logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze_(dim=1) * output_mask return logits -def _get_temperatures(sampling_metadata: SamplingMetadata) -> List[float]: - # Collect the temperatures for the logits. - temperatures: List[float] = [] - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group - temperature = sampling_params.temperature - if temperature < _SAMPLING_EPS: - # NOTE: Zero temperature means deterministic sampling - # (i.e., greedy sampling or beam search). - # Set the temperature to 1 to avoid division by zero. - temperature = 1.0 - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - prompt_len = sampling_metadata.prompt_lens[i] - temperatures += [temperature] * (prompt_len - 1) - temperatures += [temperature] * len(seq_ids) - return temperatures - - -def _get_top_p_top_k_min_p( - sampling_metadata: SamplingMetadata, - vocab_size: int, -) -> Tuple[List[float], List[int], List[float]]: - top_ps: List[float] = [] - top_ks: List[int] = [] - min_ps: List[float] = [] - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group - top_p = sampling_params.top_p - min_p = sampling_params.min_p - # k should not be greater than the vocab size. - top_k = min(sampling_params.top_k, vocab_size) - # k=-1 means no truncation. - top_k = vocab_size if top_k == -1 else top_k - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - prompt_len = sampling_metadata.prompt_lens[i] - top_ps += [top_p] * (prompt_len - 1) - top_ks += [top_k] * (prompt_len - 1) - min_ps += [min_p] * (prompt_len - 1) - top_ps += [top_p] * len(seq_ids) - top_ks += [top_k] * len(seq_ids) - min_ps += [min_p] * len(seq_ids) - return top_ps, top_ks, min_ps - - def _apply_top_p_top_k( logits: torch.Tensor, - top_ps: List[float], - top_ks: List[int], + p: torch.Tensor, + k: torch.Tensor, ) -> torch.Tensor: - p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) - k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) logits_sort, logits_idx = logits.sort(dim=-1, descending=True) # Apply top-p. probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1) - top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) - logits_sort[top_p_mask] = -float("inf") + probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort) + top_p_mask = probs_sum > p.unsqueeze_(dim=1) # Apply top-k. # Create a mask for the top-k elements. top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device) top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1) - top_k_mask = top_k_mask >= k.unsqueeze(dim=1) - logits_sort[top_k_mask] = -float("inf") + top_k_mask = top_k_mask >= k.unsqueeze_(dim=1) + + # Final mask. + mask = (top_p_mask | top_k_mask) + logits_sort.masked_fill_(mask, -float("inf")) # Re-sort the probabilities. - logits = torch.gather(logits_sort, - dim=-1, - index=torch.argsort(logits_idx, dim=-1)) + src = torch.arange(logits_idx.shape[-1], + device=logits_idx.device).expand_as(logits_idx) + logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1, + index=logits_idx, + src=src) + logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv) return logits def _apply_min_p( logits: torch.Tensor, - min_ps: List[float], + min_p: torch.Tensor, ) -> torch.Tensor: """ Adapted from https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17 """ - min_p = torch.tensor(min_ps, dtype=logits.dtype, device=logits.device) probs = torch.softmax(logits, dim=-1) top_probs, _ = probs.max(dim=-1, keepdim=True) - scaled_min_p = min_p.unsqueeze(dim=1) * top_probs + scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs tokens_to_remove = probs < scaled_min_p - logits = logits.masked_fill(tokens_to_remove, -float("inf")) + logits = logits.masked_fill_(tokens_to_remove, -float("inf")) return logits def _greedy_sample( selected_seq_groups: List[Tuple[List[int], SamplingParams]], - logprobs: torch.Tensor, + samples: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: - samples = torch.argmax(logprobs, dim=-1).cpu() + samples = samples.tolist() sample_idx = 0 results = [] for seq_group in selected_seq_groups: @@ -375,27 +264,19 @@ def _greedy_sample( assert num_parent_seqs == 1, ( "Greedy sampling should have only one seq.") parent_ids = list(range(num_parent_seqs)) - next_token_ids = [samples[sample_idx].item()] + next_token_ids = [samples[sample_idx]] results.append((next_token_ids, parent_ids)) sample_idx += num_parent_seqs - assert sample_idx == logprobs.size(0) return results def _random_sample( selected_seq_groups: List[Tuple[List[int], SamplingParams]], is_prompts: List[bool], - probs: torch.Tensor, + random_samples: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: # Find the maximum best_of value of the prompt phase requests. - max_best_of = 1 - for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): - if is_prompt: - seq_ids, sampling_params = seq_group - max_best_of = max(max_best_of, sampling_params.best_of) - random_samples = torch.multinomial(probs, - num_samples=max_best_of, - replacement=True).cpu() + random_samples = random_samples.cpu() sample_idx = 0 results = [] for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): @@ -403,8 +284,6 @@ def _random_sample( num_parent_seqs = len(seq_ids) if is_prompt: # Prompt phase. - assert num_parent_seqs == 1, ( - "Prompt input should have only one seq.") parent_ids = [0] * sampling_params.best_of next_token_ids = random_samples[ sample_idx, :sampling_params.best_of].tolist() @@ -415,7 +294,6 @@ def _random_sample( num_parent_seqs, 0].tolist() results.append((next_token_ids, parent_ids)) sample_idx += num_parent_seqs - assert sample_idx == probs.size(0) return results @@ -472,6 +350,28 @@ def _beam_search_sample( return results +# torch.multinomial forces a GPU<->CPU sync. +# Therefore, we use an optimized implementation instead. +# Note that we always sample with replacement. +# probs will be modified in place, but this is fine, as we pass +# in a copy already. +def _multinomial( + probs: torch.Tensor, + num_samples: int, +): + if num_samples > 1: + # This is equivalent to torch.repeat_interleaved (which also + # forces a GPU<->CPU sync). + # This allows us to do sampling with replacement by creating + # num_samples copies of each row in the tensor, and then + # batch sampling the resulting tensor. + probs = probs[:, None, :].expand(probs.shape[0], num_samples, + probs.shape[1]).contiguous().view( + -1, probs.shape[1]) + q = torch.empty_like(probs).exponential_(1) + return probs.div_(q).argmax(dim=1).view(-1, num_samples) + + def _sample( probs: torch.Tensor, logprobs: torch.Tensor, @@ -485,28 +385,51 @@ def _sample( categorized_seq_group_ids[sampling_type].append(i) sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} + sample_metadata = {} + + # Counterintiutively, having two loops here is actually faster. + # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: - seq_group_ids = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] - is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] sample_indices = categorized_sample_indices[sampling_type] num_tokens = len(sample_indices) if num_tokens == 0: continue + seq_group_ids = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] + is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] + sample_metadata[sampling_type] = (seq_group_ids, seq_groups, + is_prompts, sample_indices) if sampling_type == SamplingType.GREEDY: - category_logprobs = logprobs[sample_indices] - sample_results = _greedy_sample(seq_groups, category_logprobs) + greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1) elif sampling_type == SamplingType.RANDOM: - category_probs = probs[sample_indices] - sample_results = _random_sample(seq_groups, is_prompts, - category_probs) + max_best_of = 1 + for seq_group, is_prompt in zip(seq_groups, is_prompts): + if is_prompt: + _, sampling_params = seq_group + max_best_of = max(max_best_of, sampling_params.best_of) + multinomial_samples = _multinomial(probs[sample_indices], + max_best_of) elif sampling_type == SamplingType.BEAM: - category_logprobs = logprobs[sample_indices] - sample_results = _beam_search_sample(seq_groups, is_prompts, - sampling_metadata.seq_data, - category_logprobs) + beam_search_logprobs = logprobs[sample_indices] else: raise ValueError(f"Unsupported sampling type: {sampling_type}") + + # GPU<->CPU sync happens in the loop below. + + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[ + sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, greedy_samples) + elif sampling_type == SamplingType.RANDOM: + sample_results = _random_sample(seq_groups, is_prompts, + multinomial_samples) + elif sampling_type == SamplingType.BEAM: + sample_results = _beam_search_sample(seq_groups, is_prompts, + sampling_metadata.seq_data, + beam_search_logprobs) sample_results_dict.update(zip(seq_group_ids, sample_results)) sample_results = [ @@ -557,7 +480,7 @@ def _get_logprobs( batched_logprobs_query_result = logprobs[[ batched_logprobs_query_seq_indices, batched_logprobs_query_token_indices - ]].cpu() + ]] # Batched query for logprobs of topk tokens if largest_num_logprobs > 0: @@ -569,6 +492,8 @@ def _get_logprobs( else: top_logprobs, top_token_ids = None, None + batched_logprobs_query_result = batched_logprobs_query_result.cpu() + # Gather results result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] result_sample_logprobs: List[SampleLogprobs] = [] diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index deb779f537c69..49013ec273787 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -1,9 +1,13 @@ +from dataclasses import dataclass from typing import Dict, List, Tuple import torch from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData +from vllm.utils import in_wsl + +_SAMPLING_EPS = 1e-5 class SamplingMetadata: @@ -41,3 +45,186 @@ class SamplingMetadata: f"prompt_lens={self.prompt_lens}, " f"selected_token_indices={self.selected_token_indices}, " f"categorized_sample_indices={self.categorized_sample_indices})") + + +@dataclass +class SamplingTensors: + """Tensors for sampling.""" + + temperatures: torch.Tensor + top_ps: torch.Tensor + top_ks: torch.Tensor + min_ps: torch.Tensor + presence_penalties: torch.Tensor + frequency_penalties: torch.Tensor + repetition_penalties: torch.Tensor + prompt_tokens: torch.Tensor + output_tokens: torch.Tensor + + @classmethod + def from_sampling_metadata( + cls, sampling_metadata: "SamplingMetadata", vocab_size: int, + device: torch.device, + dtype: torch.dtype) -> Tuple["SamplingTensors", bool, bool, bool]: + prompt_tokens: List[List[int]] = [] + output_tokens: List[List[int]] = [] + top_ks: List[int] = [] + temperatures: List[float] = [] + top_ps: List[float] = [] + min_ps: List[float] = [] + presence_penalties: List[float] = [] + frequency_penalties: List[float] = [] + repetition_penalties: List[float] = [] + do_penalties = False + do_top_p_top_k = False + do_min_p = False + for i, seq_group in enumerate(sampling_metadata.seq_groups): + seq_ids, sampling_params = seq_group + temperature = sampling_params.temperature + p = sampling_params.presence_penalty + f = sampling_params.frequency_penalty + r = sampling_params.repetition_penalty + top_p = sampling_params.top_p + min_p = sampling_params.min_p + # k should not be greater than the vocab size. + top_k = min(sampling_params.top_k, vocab_size) + top_k = vocab_size if top_k == -1 else top_k + if temperature < _SAMPLING_EPS: + # NOTE: Zero temperature means deterministic sampling + # (i.e., greedy sampling or beam search). + # Set the temperature to 1 to avoid division by zero. + temperature = 1.0 + if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS + or top_k != vocab_size): + do_top_p_top_k = True + if not do_min_p and min_p > _SAMPLING_EPS: + do_min_p = True + if not do_penalties and (abs(p) >= _SAMPLING_EPS + or abs(f) >= _SAMPLING_EPS + or abs(r - 1.0) >= _SAMPLING_EPS): + do_penalties = True + if (i < sampling_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + # For tokens in the prompt that we only need to get their logprobs + prompt_len = sampling_metadata.prompt_lens[i] + temperatures += [temperature] * (prompt_len - 1) + top_ps += [top_p] * (prompt_len - 1) + top_ks += [top_k] * (prompt_len - 1) + min_ps += [min_p] * (prompt_len - 1) + presence_penalties += [0] * (prompt_len - 1) + frequency_penalties += [0] * (prompt_len - 1) + repetition_penalties += [1] * (prompt_len - 1) + prompt_tokens.extend([] for _ in range(prompt_len - 1)) + output_tokens.extend([] for _ in range(prompt_len - 1)) + for seq_id in seq_ids: + seq_data = sampling_metadata.seq_data[seq_id] + prompt_tokens.append(seq_data.prompt_token_ids) + output_tokens.append(seq_data.output_token_ids) + temperatures += [temperature] * len(seq_ids) + top_ps += [top_p] * len(seq_ids) + top_ks += [top_k] * len(seq_ids) + min_ps += [min_p] * len(seq_ids) + presence_penalties += [p] * len(seq_ids) + frequency_penalties += [f] * len(seq_ids) + repetition_penalties += [r] * len(seq_ids) + + sampling_tensors = SamplingTensors.from_lists( + temperatures, top_ps, top_ks, min_ps, presence_penalties, + frequency_penalties, repetition_penalties, prompt_tokens, + output_tokens, vocab_size, device, dtype) + return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) + + @classmethod + def from_lists(cls, temperatures: List[float], top_ps: List[float], + top_ks: List[int], min_ps: List[float], + presence_penalties: List[float], + frequency_penalties: List[float], + repetition_penalties: List[float], + prompt_tokens: List[List[int]], + output_tokens: List[List[int]], vocab_size: int, + device: torch.device, + dtype: torch.dtype) -> "SamplingTensors": + # Note that the performance will be very bad without + # pinned memory. + pin_memory = not in_wsl() + prompt_max_len = max(len(tokens) for tokens in prompt_tokens) + prompt_padded_tokens = [ + tokens + [vocab_size] * (prompt_max_len - len(tokens)) + for tokens in prompt_tokens + ] + output_max_len = max(len(tokens) for tokens in output_tokens) + output_padded_tokens = [ + tokens + [vocab_size] * (output_max_len - len(tokens)) + for tokens in output_tokens + ] + + temperatures_t = torch.tensor( + temperatures, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + top_ps_t = torch.tensor( + top_ps, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + min_ps_t = torch.tensor( + min_ps, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + presence_penalties_t = torch.tensor( + presence_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + frequency_penalties_t = torch.tensor( + frequency_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + repetition_penalties_t = torch.tensor( + repetition_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + top_ks_t = torch.tensor( + top_ks, + device="cpu", + dtype=torch.int, + pin_memory=pin_memory, + ) + prompt_tensor = torch.tensor( + prompt_padded_tokens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + output_tensor = torch.tensor( + output_padded_tokens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + # Because the memory is pinned, we can do non-blocking + # transfer to device. + return cls( + temperatures=temperatures_t.to(device=device, non_blocking=True), + top_ps=top_ps_t.to(device=device, non_blocking=True), + top_ks=top_ks_t.to(device=device, non_blocking=True), + min_ps=min_ps_t.to(device=device, non_blocking=True), + presence_penalties=presence_penalties_t.to(device=device, + non_blocking=True), + frequency_penalties=frequency_penalties_t.to(device=device, + non_blocking=True), + repetition_penalties=repetition_penalties_t.to(device=device, + non_blocking=True), + prompt_tokens=prompt_tensor.to(device=device, non_blocking=True), + output_tokens=output_tensor.to(device=device, non_blocking=True), + )