mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-05 04:17:28 +08:00
Support logit_bias in v1 Sampler (#13079)
This commit is contained in:
parent
085b7b2d6c
commit
6224a9f620
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Set, Tuple
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor(
|
||||
)
|
||||
|
||||
|
||||
def _create_logit_bias(
|
||||
batch_size: int,
|
||||
vocab_size: int,
|
||||
bias_value: float,
|
||||
) -> List[Optional[Dict[int, float]]]:
|
||||
res: List[Optional[Dict[int, float]]] = []
|
||||
for i in range(batch_size):
|
||||
logit_bias = {min(i, vocab_size - 1): bias_value}
|
||||
res.append(logit_bias)
|
||||
return res
|
||||
|
||||
|
||||
def _create_default_sampling_metadata(
|
||||
num_output_tokens: int,
|
||||
batch_size: int,
|
||||
@ -80,6 +92,7 @@ def _create_default_sampling_metadata(
|
||||
no_penalties=True,
|
||||
min_tokens=[],
|
||||
stop_token_ids=[],
|
||||
logit_bias=[None] * batch_size,
|
||||
)
|
||||
return fake_sampling_metadata
|
||||
|
||||
@ -89,14 +102,14 @@ def _generate_min_token_penalties_and_stop_tokens(
|
||||
batch_indices_for_min_token_penalty: List[int]
|
||||
) -> Tuple[List[int], List[Set[int]]]:
|
||||
"""
|
||||
Generates and returns a list of minimum token penalties (`min_tokens`)
|
||||
and a corresponding list of stop token IDs (`stop_token_ids`) for each
|
||||
Generates and returns a list of minimum token penalties (`min_tokens`)
|
||||
and a corresponding list of stop token IDs (`stop_token_ids`) for each
|
||||
batch.
|
||||
|
||||
If a batch index is included in `batch_indices_for_min_token_penalty`,
|
||||
a higher `min_tokens` value is assigned (within a randomized range),
|
||||
and a random set of stop token IDs is created. Otherwise, a lower
|
||||
`min_tokens` value is assigned, and the stop token IDs set is empty.
|
||||
If a batch index is included in `batch_indices_for_min_token_penalty`,
|
||||
a higher `min_tokens` value is assigned (within a randomized range),
|
||||
and a random set of stop token IDs is created. Otherwise, a lower
|
||||
`min_tokens` value is assigned, and the stop token IDs set is empty.
|
||||
"""
|
||||
stop_token_ids: List[Set[int]] = []
|
||||
min_tokens: List[int] = []
|
||||
@ -120,7 +133,7 @@ def _create_weighted_output_token_list(
|
||||
batch_size: int,
|
||||
vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]:
|
||||
"""
|
||||
Creates an output token list where each token occurs a distinct
|
||||
Creates an output token list where each token occurs a distinct
|
||||
number of times.
|
||||
|
||||
For each batch, a random subset of token IDs is selected from the
|
||||
@ -129,8 +142,8 @@ def _create_weighted_output_token_list(
|
||||
|
||||
Returns:
|
||||
Tuple[List[List[int]], List[List[int]]]:
|
||||
- The first element is the output token list, where each sublist
|
||||
corresponds to a batch and contains tokens with weighted
|
||||
- The first element is the output token list, where each sublist
|
||||
corresponds to a batch and contains tokens with weighted
|
||||
frequencies.
|
||||
- The second element is a list of distinct token IDs for each
|
||||
batch, ordered by their frequency in the corresponding output
|
||||
@ -155,7 +168,7 @@ def _create_weighted_output_token_list(
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
def test_sampler_min_tokens_penalty(device: str, batch_size: int):
|
||||
"""
|
||||
Tests that if the number of output tokens is less than
|
||||
Tests that if the number of output tokens is less than
|
||||
SamplingParams.min_tokens then we will set the logits for
|
||||
the stop token ids to -inf.
|
||||
"""
|
||||
@ -283,7 +296,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
def test_sampler_repetition_penalty(device: str, batch_size: int,
|
||||
repetition_penalty: float):
|
||||
"""
|
||||
Test to verify that when the repetition penalty is enabled, tokens
|
||||
Test to verify that when the repetition penalty is enabled, tokens
|
||||
are penalized based on their presence in the prompt or the existing
|
||||
output.
|
||||
"""
|
||||
@ -321,3 +334,37 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
|
||||
penalized_token_id not in output_tokens)
|
||||
assert (non_penalized_token_id in prompt_tokens or \
|
||||
non_penalized_token_id in output_tokens)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
|
||||
def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
|
||||
"""
|
||||
Test to verify that when the repetition penalty is enabled, tokens
|
||||
are penalized based on their presence in the prompt or the existing
|
||||
output.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
# Create fake logits where each token is assigned the same
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
sampling_metadata.logit_bias = _create_logit_bias(
|
||||
batch_size=batch_size,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
bias_value=bias_value,
|
||||
)
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_logits_bias(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
logits_for_req = logits[batch_idx]
|
||||
biased_index = min(batch_idx, VOCAB_SIZE - 1)
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
if biased_index == token_id:
|
||||
assert logits_for_req[token_id] == pytest.approx(bias_value +
|
||||
1e-2)
|
||||
else:
|
||||
assert logits_for_req[token_id] == pytest.approx(1e-2)
|
||||
|
||||
@ -45,9 +45,11 @@ def _remove_requests(
|
||||
|
||||
|
||||
def _construct_expected_sampling_metadata(
|
||||
reqs: List[CachedRequestState], req_ids_retained: Set[int],
|
||||
req_id_index_in_input_batch: Dict[str, int],
|
||||
device: torch.device) -> SamplingMetadata:
|
||||
reqs: List[CachedRequestState],
|
||||
req_ids_retained: Set[int],
|
||||
req_id_index_in_input_batch: Dict[str, int],
|
||||
device: torch.device,
|
||||
) -> SamplingMetadata:
|
||||
"""
|
||||
Constructs and returns the expected SamplingMetadata for this
|
||||
batch.
|
||||
@ -63,6 +65,7 @@ def _construct_expected_sampling_metadata(
|
||||
temperature = [0.0 for _ in range(num_reqs)]
|
||||
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)]
|
||||
min_tokens = [0 for _ in range(num_reqs)]
|
||||
logit_bias = [None] * num_reqs
|
||||
for req in reqs:
|
||||
if req.req_id not in req_ids_retained:
|
||||
continue
|
||||
@ -71,20 +74,21 @@ def _construct_expected_sampling_metadata(
|
||||
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
|
||||
presence_penalties[
|
||||
index_in_input_batch] = req.sampling_params.presence_penalty
|
||||
frequency_penalties[
|
||||
index_in_input_batch] = req.sampling_params.frequency_penalty
|
||||
repetition_penalties[
|
||||
index_in_input_batch] = req.sampling_params.repetition_penalty
|
||||
frequency_penalties[index_in_input_batch] = (
|
||||
req.sampling_params.frequency_penalty)
|
||||
repetition_penalties[index_in_input_batch] = (
|
||||
req.sampling_params.repetition_penalty)
|
||||
top_k[index_in_input_batch] = req.sampling_params.top_k
|
||||
top_p[index_in_input_batch] = req.sampling_params.top_p
|
||||
temperature[index_in_input_batch] = req.sampling_params.temperature
|
||||
stop_token_ids[
|
||||
index_in_input_batch] = req.sampling_params.all_stop_token_ids
|
||||
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens
|
||||
|
||||
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=torch.tensor(temperature, dtype=torch.float, device=device),
|
||||
temperature=torch.tensor(temperature, dtype=torch.float,
|
||||
device=device),
|
||||
all_greedy=False,
|
||||
all_random=True,
|
||||
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
|
||||
@ -93,41 +97,45 @@ def _construct_expected_sampling_metadata(
|
||||
no_top_k=all(x == 0 for x in top_k),
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids= make_tensor_with_pad(
|
||||
prompt_token_ids=make_tensor_with_pad(
|
||||
prompt_token_ids,
|
||||
pad=VOCAB_SIZE,
|
||||
device=torch.device(device),
|
||||
dtype=torch.int64,
|
||||
),
|
||||
frequency_penalties=torch.tensor(
|
||||
frequency_penalties, dtype=torch.float,
|
||||
device=device),
|
||||
presence_penalties=torch.tensor(
|
||||
presence_penalties, dtype=torch.float,
|
||||
device=device),
|
||||
repetition_penalties=torch.tensor(
|
||||
repetition_penalties, dtype=torch.float,
|
||||
device=device),
|
||||
frequency_penalties=torch.tensor(frequency_penalties,
|
||||
dtype=torch.float,
|
||||
device=device),
|
||||
presence_penalties=torch.tensor(presence_penalties,
|
||||
dtype=torch.float,
|
||||
device=device),
|
||||
repetition_penalties=torch.tensor(repetition_penalties,
|
||||
dtype=torch.float,
|
||||
device=device),
|
||||
output_token_ids=output_token_ids,
|
||||
min_tokens=min_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
no_penalties=(all(x ==0 for x in presence_penalties) and \
|
||||
all(x ==0 for x in frequency_penalties) and \
|
||||
all(x ==1 for x in repetition_penalties))
|
||||
no_penalties=(all(x == 0 for x in presence_penalties)
|
||||
and all(x == 0 for x in frequency_penalties)
|
||||
and all(x == 1 for x in repetition_penalties)),
|
||||
logit_bias=logit_bias,
|
||||
)
|
||||
|
||||
|
||||
def _create_sampling_params():
|
||||
return SamplingParams(top_k=np.random.randint(1, 10),
|
||||
top_p=np.random.uniform(0.0, 1.0),
|
||||
presence_penalty=np.random.uniform(-2.0, 2.0),
|
||||
repetition_penalty=np.random.uniform(0.0, 2.0),
|
||||
frequency_penalty=np.random.uniform(-2.0, 2.0),
|
||||
min_tokens=np.random.randint(1, 10),
|
||||
stop_token_ids=[
|
||||
np.random.randint(0, VOCAB_SIZE)
|
||||
for _ in range(np.random.randint(10))
|
||||
])
|
||||
return SamplingParams(
|
||||
top_k=np.random.randint(1, 10),
|
||||
top_p=np.random.uniform(0.0, 1.0),
|
||||
presence_penalty=np.random.uniform(-2.0, 2.0),
|
||||
repetition_penalty=np.random.uniform(0.0, 2.0),
|
||||
frequency_penalty=np.random.uniform(-2.0, 2.0),
|
||||
min_tokens=np.random.randint(1, 10),
|
||||
stop_token_ids=[
|
||||
np.random.randint(0, VOCAB_SIZE)
|
||||
for _ in range(np.random.randint(10))
|
||||
],
|
||||
logit_bias={0: np.random.uniform(-3.0, 3.0)},
|
||||
)
|
||||
|
||||
|
||||
def _construct_cached_request_state(req_id_suffix: int):
|
||||
@ -139,16 +147,18 @@ def _construct_cached_request_state(req_id_suffix: int):
|
||||
np.random.randint(0, VOCAB_SIZE)
|
||||
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
|
||||
]
|
||||
return CachedRequestState(req_id=f"req_id_{req_id_suffix}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt=None,
|
||||
sampling_params=_create_sampling_params(),
|
||||
mm_inputs=[],
|
||||
mm_positions=[],
|
||||
block_ids=[],
|
||||
generator=None,
|
||||
num_computed_tokens=len(output_token_ids),
|
||||
output_token_ids=output_token_ids)
|
||||
return CachedRequestState(
|
||||
req_id=f"req_id_{req_id_suffix}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt=None,
|
||||
sampling_params=_create_sampling_params(),
|
||||
mm_inputs=[],
|
||||
mm_positions=[],
|
||||
block_ids=[],
|
||||
generator=None,
|
||||
num_computed_tokens=len(output_token_ids),
|
||||
output_token_ids=output_token_ids,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@ -163,12 +173,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
output of `make_sampling_metadata` is then compared against the expected
|
||||
results to ensure correctness.
|
||||
"""
|
||||
input_batch: InputBatch = InputBatch(max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_blocks_per_req=10,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024)
|
||||
input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_blocks_per_req=10,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
)
|
||||
reqs: List[CachedRequestState] = []
|
||||
req_id_reqs = {}
|
||||
req_id_output_token_ids = {}
|
||||
@ -206,21 +218,27 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
sampling_metadata.top_p)
|
||||
assert torch.allclose(expected_sampling_metadata.top_k,
|
||||
sampling_metadata.top_k)
|
||||
assert torch.allclose(expected_sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.frequency_penalties)
|
||||
assert torch.allclose(expected_sampling_metadata.presence_penalties,
|
||||
sampling_metadata.presence_penalties)
|
||||
assert torch.allclose(expected_sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.repetition_penalties)
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
)
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.presence_penalties,
|
||||
sampling_metadata.presence_penalties,
|
||||
)
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.repetition_penalties,
|
||||
)
|
||||
assert torch.allclose(expected_sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.prompt_token_ids)
|
||||
assert (expected_sampling_metadata.output_token_ids ==
|
||||
sampling_metadata.output_token_ids)
|
||||
assert (
|
||||
expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens)
|
||||
assert (expected_sampling_metadata.stop_token_ids ==
|
||||
sampling_metadata.stop_token_ids)
|
||||
assert (expected_sampling_metadata.no_penalties ==
|
||||
sampling_metadata.no_penalties)
|
||||
assert (expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p)
|
||||
assert (expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k)
|
||||
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
|
||||
assert expected_sampling_metadata.stop_token_ids == \
|
||||
sampling_metadata.stop_token_ids
|
||||
assert expected_sampling_metadata.no_penalties == \
|
||||
sampling_metadata.no_penalties
|
||||
assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p
|
||||
assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k
|
||||
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
|
||||
|
||||
@ -243,8 +243,10 @@ class SamplingParams(
|
||||
allowed_token_ids: Optional[List[int]] = None,
|
||||
) -> "SamplingParams":
|
||||
if logit_bias is not None:
|
||||
# Convert token_id to integer
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
logit_bias = {
|
||||
int(token): bias
|
||||
int(token): min(100.0, max(-100.0, bias))
|
||||
for token, bias in logit_bias.items()
|
||||
}
|
||||
|
||||
|
||||
@ -32,3 +32,5 @@ class SamplingMetadata:
|
||||
output_token_ids: List[List[int]]
|
||||
min_tokens: List[int]
|
||||
stop_token_ids: List[Set[int]]
|
||||
|
||||
logit_bias: List[Optional[Dict[int, float]]]
|
||||
|
||||
@ -37,6 +37,8 @@ class Sampler(nn.Module):
|
||||
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
# Apply logits bias.
|
||||
logits = self.apply_logits_bias(logits, sampling_metadata)
|
||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||
logits = self.apply_penalties(logits, sampling_metadata)
|
||||
# Apply temperature.
|
||||
@ -166,3 +168,17 @@ class Sampler(nn.Module):
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.output_token_ids)
|
||||
return logits
|
||||
|
||||
def apply_logits_bias(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
# TODO(houseroad): this implementation is extremely inefficient.
|
||||
# One idea is implement this as a PyTorch C++ op, and we may
|
||||
# even optimize the logit_bias layout.
|
||||
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
|
||||
if logit_bias:
|
||||
for token_id, bias in logit_bias.items():
|
||||
logits[i, token_id] += bias
|
||||
return logits
|
||||
|
||||
@ -130,7 +130,7 @@ class InputBatch:
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.frequency_penalties_cpu = \
|
||||
self.frequency_penalties_cpu_tensor.numpy()
|
||||
self.frequency_penalties_cpu_tensor.numpy()
|
||||
self.frequency_penalties_reqs: Set[str] = set()
|
||||
|
||||
# Presence penalty related data structures
|
||||
@ -141,8 +141,8 @@ class InputBatch:
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.presence_penalties_cpu = \
|
||||
self.presence_penalties_cpu_tensor.numpy()
|
||||
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
|
||||
)
|
||||
self.presence_penalties_reqs: Set[str] = set()
|
||||
|
||||
# Repetition penalty related data structures
|
||||
@ -155,7 +155,7 @@ class InputBatch:
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.repetition_penalties_cpu = \
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: Set[str] = set()
|
||||
|
||||
self.min_tokens: List[int] = [0] * max_num_reqs
|
||||
@ -180,6 +180,9 @@ class InputBatch:
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: Dict[str, int] = {}
|
||||
|
||||
self.logit_bias: List[Optional[Dict[int,
|
||||
float]]] = [None] * max_num_reqs
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
@ -220,16 +223,16 @@ class InputBatch:
|
||||
self.top_k_cpu[req_index] = sampling_params.top_k
|
||||
if sampling_params.top_k > 0:
|
||||
self.top_k_reqs.add(req_id)
|
||||
self.frequency_penalties_cpu[req_index] = \
|
||||
sampling_params.frequency_penalty
|
||||
self.frequency_penalties_cpu[
|
||||
req_index] = sampling_params.frequency_penalty
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[req_index] = \
|
||||
sampling_params.presence_penalty
|
||||
self.presence_penalties_cpu[
|
||||
req_index] = sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties_cpu[req_index] = \
|
||||
sampling_params.repetition_penalty
|
||||
self.repetition_penalties_cpu[
|
||||
req_index] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
self.min_tokens[req_index] = sampling_params.min_tokens
|
||||
@ -244,6 +247,8 @@ class InputBatch:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
|
||||
if sampling_params.logit_bias is not None:
|
||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
@ -284,6 +289,7 @@ class InputBatch:
|
||||
self.lora_id_to_lora_request.pop(lora_id)
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
self.logit_bias[req_index] = None
|
||||
return req_index
|
||||
|
||||
def clear(self) -> None:
|
||||
@ -302,6 +308,7 @@ class InputBatch:
|
||||
self.request_lora_mapping.fill(0)
|
||||
self.lora_id_to_lora_request.clear()
|
||||
self.lora_id_to_request_ids.clear()
|
||||
self.logit_bias = [None] * self.max_num_reqs
|
||||
|
||||
def condense(self, empty_req_indices: List[int]) -> None:
|
||||
if self.num_reqs == 0:
|
||||
@ -332,8 +339,8 @@ class InputBatch:
|
||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||
last_req_index, :num_tokens]
|
||||
self.num_tokens[empty_index] = num_tokens
|
||||
self.num_prompt_tokens[empty_index] = \
|
||||
self.num_prompt_tokens[last_req_index]
|
||||
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
||||
last_req_index]
|
||||
self.num_computed_tokens_cpu[
|
||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||
self.block_table.move_row(last_req_index, empty_index)
|
||||
@ -341,15 +348,15 @@ class InputBatch:
|
||||
last_req_index]
|
||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
||||
self.frequency_penalties_cpu[empty_index] = \
|
||||
self.frequency_penalties_cpu[last_req_index]
|
||||
self.presence_penalties_cpu[empty_index] = \
|
||||
self.presence_penalties_cpu[last_req_index]
|
||||
self.repetition_penalties_cpu[empty_index] = \
|
||||
self.repetition_penalties_cpu[last_req_index]
|
||||
self.frequency_penalties_cpu[
|
||||
empty_index] = self.frequency_penalties_cpu[last_req_index]
|
||||
self.presence_penalties_cpu[
|
||||
empty_index] = self.presence_penalties_cpu[last_req_index]
|
||||
self.repetition_penalties_cpu[
|
||||
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
||||
self.min_tokens[empty_index] = self.min_tokens[last_req_index]
|
||||
self.stop_token_ids[empty_index] = \
|
||||
self.stop_token_ids[last_req_index]
|
||||
self.stop_token_ids[empty_index] = self.stop_token_ids[
|
||||
last_req_index]
|
||||
generator = self.generators.pop(last_req_index, None)
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
@ -357,6 +364,8 @@ class InputBatch:
|
||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||
last_req_index]
|
||||
|
||||
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
||||
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
@ -378,13 +387,16 @@ class InputBatch:
|
||||
# penalties to be applied during sampling.
|
||||
self.frequency_penalties[:self.num_reqs].copy_(
|
||||
self.frequency_penalties_cpu_tensor[:self.num_reqs],
|
||||
non_blocking=True)
|
||||
non_blocking=True,
|
||||
)
|
||||
self.presence_penalties[:self.num_reqs].copy_(
|
||||
self.presence_penalties_cpu_tensor[:self.num_reqs],
|
||||
non_blocking=True)
|
||||
non_blocking=True,
|
||||
)
|
||||
self.repetition_penalties[:self.num_reqs].copy_(
|
||||
self.repetition_penalties_cpu_tensor[:self.num_reqs],
|
||||
non_blocking=True)
|
||||
non_blocking=True,
|
||||
)
|
||||
# The prompt tokens are used only for applying penalties during
|
||||
# the sampling process. Hence copy these tensors only when
|
||||
# there are requests which need penalties to be applied.
|
||||
@ -421,6 +433,7 @@ class InputBatch:
|
||||
min_tokens=self.min_tokens[:self.num_reqs],
|
||||
stop_token_ids=self.stop_token_ids[:self.num_reqs],
|
||||
no_penalties=self.no_penalties,
|
||||
logit_bias=self.logit_bias[:self.num_reqs],
|
||||
)
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
@ -429,10 +442,11 @@ class InputBatch:
|
||||
(self.num_reqs, max_prompt_len),
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory)
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
||||
prompt_token_ids[:] = (
|
||||
self.token_ids_cpu[:self.num_reqs, :max_prompt_len])
|
||||
prompt_token_ids[:] = self.token_ids_cpu[:self.
|
||||
num_reqs, :max_prompt_len]
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
for i in range(self.num_reqs):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user