mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[V1] Adding min tokens/repetition/presence/frequence penalties to V1 sampler (#10681)
Signed-off-by: Sourashis Roy <sroy@roblox.com> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
7492a36207
commit
dcb1a944d4
@ -139,3 +139,41 @@ def test_engine_core(monkeypatch):
|
||||
engine_core.abort_requests([req2.request_id, req0.request_id])
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
|
||||
def test_engine_core_advanced_sampling(monkeypatch):
|
||||
"""
|
||||
A basic end-to-end test to verify that the engine functions correctly
|
||||
when additional sampling parameters, such as min_tokens and
|
||||
presence_penalty, are set.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
"""Setup the EngineCore."""
|
||||
engine_args = EngineArgs(model=MODEL_NAME)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class = AsyncLLM._get_executor_cls(vllm_config)
|
||||
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
"""Test basic request lifecycle."""
|
||||
# First request.
|
||||
request: EngineCoreRequest = make_request()
|
||||
request.sampling_params = SamplingParams(
|
||||
min_tokens=4,
|
||||
presence_penalty=1.0,
|
||||
frequency_penalty=1.0,
|
||||
repetition_penalty=0.1,
|
||||
stop_token_ids=[1001, 1002],
|
||||
)
|
||||
engine_core.add_request(request)
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
# Loop through until they are all done.
|
||||
while len(engine_core.step()) > 0:
|
||||
pass
|
||||
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
0
tests/v1/sample/__init__.py
Normal file
0
tests/v1/sample/__init__.py
Normal file
331
tests/v1/sample/test_sampler.py
Normal file
331
tests/v1/sample/test_sampler.py
Normal file
@ -0,0 +1,331 @@
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
VOCAB_SIZE = 1024
|
||||
NUM_OUTPUT_TOKENS = 20
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
MAX_NUM_PROMPT_TOKENS = 64
|
||||
|
||||
|
||||
def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
|
||||
fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=torch.float)
|
||||
return fake_logits
|
||||
|
||||
|
||||
def _create_penalty_tensor(batch_size: int, penalty_value: float,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
return torch.full((batch_size, ),
|
||||
fill_value=penalty_value,
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
|
||||
|
||||
def _create_prompt_tokens_tensor(
|
||||
prompt_token_ids: List[List[int]],
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
return make_tensor_with_pad(
|
||||
prompt_token_ids,
|
||||
pad=vocab_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
|
||||
def _create_default_sampling_metadata(
|
||||
num_output_tokens: int,
|
||||
batch_size: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
) -> SamplingMetadata:
|
||||
output_token_ids: List[List[int]] = []
|
||||
prompt_token_ids: List[List[int]] = []
|
||||
for _ in range(batch_size):
|
||||
output_token_ids.append(
|
||||
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
|
||||
prompt_token_ids.append(
|
||||
np.random.randint(0,
|
||||
vocab_size,
|
||||
size=np.random.randint(
|
||||
1, MAX_NUM_PROMPT_TOKENS)).tolist())
|
||||
fake_sampling_metadata = SamplingMetadata(
|
||||
temperature=torch.full((batch_size, ), 0.0),
|
||||
all_greedy=True,
|
||||
all_random=False,
|
||||
top_p=torch.empty(batch_size, ),
|
||||
top_k=torch.empty(batch_size, ),
|
||||
no_top_p=True,
|
||||
no_top_k=True,
|
||||
generators={},
|
||||
max_num_logprobs=VOCAB_SIZE,
|
||||
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
|
||||
vocab_size, device),
|
||||
output_token_ids=output_token_ids,
|
||||
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
|
||||
no_penalties=True,
|
||||
min_tokens=[],
|
||||
stop_token_ids=[],
|
||||
)
|
||||
return fake_sampling_metadata
|
||||
|
||||
|
||||
def _generate_min_token_penalties_and_stop_tokens(
|
||||
num_output_tokens: int, batch_size: int, vocab_size: int,
|
||||
batch_indices_for_min_token_penalty: List[int]
|
||||
) -> Tuple[List[int], List[Set[int]]]:
|
||||
"""
|
||||
Generates and returns a list of minimum token penalties (`min_tokens`)
|
||||
and a corresponding list of stop token IDs (`stop_token_ids`) for each
|
||||
batch.
|
||||
|
||||
If a batch index is included in `batch_indices_for_min_token_penalty`,
|
||||
a higher `min_tokens` value is assigned (within a randomized range),
|
||||
and a random set of stop token IDs is created. Otherwise, a lower
|
||||
`min_tokens` value is assigned, and the stop token IDs set is empty.
|
||||
"""
|
||||
stop_token_ids: List[Set[int]] = []
|
||||
min_tokens: List[int] = []
|
||||
for index in range(batch_size):
|
||||
if index in batch_indices_for_min_token_penalty:
|
||||
min_tokens.append(
|
||||
np.random.randint(num_output_tokens + 1,
|
||||
2 * num_output_tokens))
|
||||
stop_token_ids.append(
|
||||
set(
|
||||
np.random.randint(0, vocab_size - 1)
|
||||
for _ in range(np.random.randint(0, vocab_size))))
|
||||
|
||||
else:
|
||||
min_tokens.append(np.random.randint(0, num_output_tokens))
|
||||
stop_token_ids.append(set())
|
||||
return (min_tokens, stop_token_ids)
|
||||
|
||||
|
||||
def _create_weighted_output_token_list(
|
||||
batch_size: int,
|
||||
vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]:
|
||||
"""
|
||||
Creates an output token list where each token occurs a distinct
|
||||
number of times.
|
||||
|
||||
For each batch, a random subset of token IDs is selected from the
|
||||
vocabulary. The selected tokens are then added to the output token
|
||||
list, each with a different frequency.
|
||||
|
||||
Returns:
|
||||
Tuple[List[List[int]], List[List[int]]]:
|
||||
- The first element is the output token list, where each sublist
|
||||
corresponds to a batch and contains tokens with weighted
|
||||
frequencies.
|
||||
- The second element is a list of distinct token IDs for each
|
||||
batch, ordered by their frequency in the corresponding output
|
||||
list.
|
||||
"""
|
||||
output_token_ids: List[List[int]] = []
|
||||
sorted_token_ids_in_output: List[List[int]] = []
|
||||
for _ in range(batch_size):
|
||||
distinct_token_ids = np.random.choice(vocab_size,
|
||||
size=np.random.randint(1, 10),
|
||||
replace=False).tolist()
|
||||
sorted_token_ids_in_output.append(distinct_token_ids)
|
||||
output_token_ids_for_batch = []
|
||||
for index, token_id in enumerate(distinct_token_ids):
|
||||
output_token_ids_for_batch.extend(
|
||||
[token_id for _ in range(index + 1)])
|
||||
output_token_ids.append(output_token_ids_for_batch)
|
||||
return (output_token_ids, sorted_token_ids_in_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
def test_sampler_min_tokens_penalty(device: str, batch_size: int):
|
||||
"""
|
||||
Tests that if the number of output tokens is less than
|
||||
SamplingParams.min_tokens then we will set the logits for
|
||||
the stop token ids to -inf.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
batch_indices_for_min_token_penalty = np.random.randint(
|
||||
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
|
||||
min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
|
||||
batch_indices_for_min_token_penalty)
|
||||
sampling_metadata.min_tokens = min_tokens
|
||||
sampling_metadata.stop_token_ids = stop_token_ids
|
||||
sampler = Sampler()
|
||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
||||
for batch_idx in range(batch_size):
|
||||
for vocab in range(VOCAB_SIZE):
|
||||
# Verify that the logprobs for stop token ids is set
|
||||
# to -inf.
|
||||
logprob_index = torch.where(
|
||||
sampler_output.logprob_token_ids[batch_idx] ==
|
||||
vocab)[0].item()
|
||||
if vocab in stop_token_ids[batch_idx]:
|
||||
assert sampler_output.logprobs[batch_idx][
|
||||
logprob_index] == -float("inf")
|
||||
else:
|
||||
assert sampler_output.logprobs[batch_idx][
|
||||
logprob_index] != -float("inf")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0])
|
||||
def test_sampler_presence_penalty(device: str, batch_size: int,
|
||||
presence_penalty: float):
|
||||
"""
|
||||
Test to verify that if presence penalty is enabled then tokens
|
||||
are penalized as per their presence in the existing output.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
# Create fake logits where each token is assigned the same
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
output_token_ids = sampling_metadata.output_token_ids
|
||||
sampling_metadata.presence_penalties = _create_penalty_tensor(
|
||||
batch_size, presence_penalty, torch.device(device))
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
||||
for batch_idx in range(batch_size):
|
||||
# The logprobs in the SamplerOutput are arranged in descending order.
|
||||
# Since all tokens initially have the same logprobs, the non-penalized
|
||||
# tokens will appear at the beginning, while the penalized tokens
|
||||
# will appear at the end of the list.
|
||||
penalized_token_id = sampler_output.logprob_token_ids[batch_idx][
|
||||
VOCAB_SIZE - 1]
|
||||
penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1]
|
||||
non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0]
|
||||
non_penalized_log_prod = sampler_output.logprobs[batch_idx][0]
|
||||
assert non_penalized_log_prod > penalized_log_prod
|
||||
if presence_penalty > 0:
|
||||
# If `presence_penalty` is set to a value greater than 0, it
|
||||
# indicates a preference for new tokens over those already
|
||||
# present in the output.
|
||||
# Verify that the penalized token ID exists in the output, while the
|
||||
# non-penalized token ID does not.
|
||||
assert penalized_token_id in output_token_ids[batch_idx]
|
||||
assert non_penalized_token_id not in output_token_ids[batch_idx]
|
||||
elif presence_penalty < 0:
|
||||
# If `presence_penalty` is set to a value less than 0, it indicates
|
||||
# a preference for existing tokens over new ones. Verify that the
|
||||
# non-penalized token ID exists in the output, while the penalized
|
||||
# token ID does not.
|
||||
assert non_penalized_token_id in output_token_ids[batch_idx]
|
||||
assert penalized_token_id not in output_token_ids[batch_idx]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0])
|
||||
def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
frequency_penalty: float):
|
||||
"""
|
||||
Test to verify that if frequency penalty is enabled then tokens are
|
||||
penalized as per their frequency of occurrence.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
# Create fake logits where each token is assigned the same
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
sampling_metadata.frequency_penalties = _create_penalty_tensor(
|
||||
batch_size, frequency_penalty, torch.device(device))
|
||||
output_token_ids, sorted_token_ids_in_output = \
|
||||
_create_weighted_output_token_list(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata.output_token_ids = output_token_ids
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
||||
for batch_idx in range(batch_size):
|
||||
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
|
||||
non_penalized_token_id = logprobs_token_ids[0]
|
||||
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
|
||||
distinct_sorted_token_ids_in_output = \
|
||||
sorted_token_ids_in_output[batch_idx]
|
||||
most_frequent_token_id = distinct_sorted_token_ids_in_output[
|
||||
len(distinct_sorted_token_ids_in_output) - 1]
|
||||
if frequency_penalty > 0:
|
||||
# If `frequency_penalty` is set to > 0, it indicates
|
||||
# a preference for new tokens over existing ones. Verify that the
|
||||
# non-penalized token ID is not present in the output, while the
|
||||
# most penalized token is the one that occurs most frequently in
|
||||
# the output.
|
||||
assert non_penalized_token_id \
|
||||
not in distinct_sorted_token_ids_in_output
|
||||
assert penalized_token_id == most_frequent_token_id
|
||||
elif frequency_penalty < 0:
|
||||
# If `frequency_penalty` is set to < 0, it indicates
|
||||
# a preference for existing tokens over new ones. Verify that the
|
||||
# non-penalized token ID is the one that occurs most frequently
|
||||
# in the output, while the penalized token ID is one that has not
|
||||
# yet appeared.
|
||||
assert non_penalized_token_id == most_frequent_token_id
|
||||
assert penalized_token_id \
|
||||
not in distinct_sorted_token_ids_in_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("repetition_penalty", [0.1, 1.9])
|
||||
def test_sampler_repetition_penalty(device: str, batch_size: int,
|
||||
repetition_penalty: float):
|
||||
"""
|
||||
Test to verify that when the repetition penalty is enabled, tokens
|
||||
are penalized based on their presence in the prompt or the existing
|
||||
output.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
# Create fake logits where each token is assigned the same
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
sampling_metadata.repetition_penalties = _create_penalty_tensor(
|
||||
batch_size, repetition_penalty, torch.device(device))
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
||||
for batch_idx in range(batch_size):
|
||||
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
|
||||
non_penalized_token_id = logprobs_token_ids[0]
|
||||
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
|
||||
prompt_tokens = sampling_metadata.prompt_token_ids[
|
||||
batch_idx][:].tolist()
|
||||
output_tokens = sampling_metadata.output_token_ids[batch_idx]
|
||||
if repetition_penalty > 1.0:
|
||||
# If `repetition_penalty` > 1.0, verify that the non-penalized
|
||||
# token ID has not been seen before, while the penalized token ID
|
||||
# exists either in the prompt or the output.
|
||||
assert (non_penalized_token_id not in prompt_tokens and \
|
||||
non_penalized_token_id not in output_tokens)
|
||||
assert (penalized_token_id in prompt_tokens or \
|
||||
penalized_token_id in output_tokens)
|
||||
elif repetition_penalty < 1.0:
|
||||
# If `repetition_penalty` < 1.0, verify that the penalized
|
||||
# token ID has not been seen before, while the non-penalized
|
||||
# token ID exists either in the prompt or the output.
|
||||
assert (penalized_token_id not in prompt_tokens and \
|
||||
penalized_token_id not in output_tokens)
|
||||
assert (non_penalized_token_id in prompt_tokens or \
|
||||
non_penalized_token_id in output_tokens)
|
||||
0
tests/v1/worker/__init__.py
Normal file
0
tests/v1/worker/__init__.py
Normal file
224
tests/v1/worker/test_gpu_input_batch.py
Normal file
224
tests/v1/worker/test_gpu_input_batch.py
Normal file
@ -0,0 +1,224 @@
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
VOCAB_SIZE = 1024
|
||||
NUM_OUTPUT_TOKENS = 20
|
||||
MAX_PROMPT_SIZE = 100
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
MAX_NUM_PROMPT_TOKENS = 64
|
||||
|
||||
|
||||
def _remove_requests(
|
||||
input_batch: InputBatch, batch_size: int,
|
||||
reqs: List[CachedRequestState]) -> Tuple[Set[str], List[int]]:
|
||||
"""
|
||||
Remove some requests randomly from the batch and returns a Tuple
|
||||
of 1) set of request removed 2) indices of the requests removed
|
||||
ordered in descending order
|
||||
"""
|
||||
|
||||
num_reqs_to_remove = np.random.randint(0, batch_size)
|
||||
req_indices_to_remove: Set[int] = set()
|
||||
for _ in range(num_reqs_to_remove):
|
||||
req_index_to_remove = np.random.randint(0, batch_size)
|
||||
req_indices_to_remove.add(req_index_to_remove)
|
||||
|
||||
req_indices_to_remove_list = list(req_indices_to_remove)
|
||||
req_indices_to_remove_list.sort(reverse=True)
|
||||
req_ids_to_remove: Set[str] = set()
|
||||
for index in req_indices_to_remove:
|
||||
input_batch.remove_request(reqs[index].req_id)
|
||||
req_ids_to_remove.add(reqs[index].req_id)
|
||||
return (req_ids_to_remove, req_indices_to_remove_list)
|
||||
|
||||
|
||||
def _construct_expected_sampling_metadata(
|
||||
reqs: List[CachedRequestState], req_ids_retained: Set[int],
|
||||
req_id_index_in_input_batch: Dict[str, int],
|
||||
device: torch.device) -> SamplingMetadata:
|
||||
"""
|
||||
Constructs and returns the expected SamplingMetadata for this
|
||||
batch.
|
||||
"""
|
||||
num_reqs = len(req_ids_retained)
|
||||
output_token_ids: List[List[int]] = [list() for _ in range(num_reqs)]
|
||||
prompt_token_ids: List[List[int]] = [list() for _ in range(num_reqs)]
|
||||
presence_penalties = [0.0 for _ in range(num_reqs)]
|
||||
frequency_penalties = [0.0 for _ in range(num_reqs)]
|
||||
repetition_penalties = [1.0 for _ in range(num_reqs)]
|
||||
top_k = [0 for _ in range(num_reqs)]
|
||||
top_p = [0.0 for _ in range(num_reqs)]
|
||||
temperature = [0.0 for _ in range(num_reqs)]
|
||||
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)]
|
||||
min_tokens = [0 for _ in range(num_reqs)]
|
||||
for req in reqs:
|
||||
if req.req_id not in req_ids_retained:
|
||||
continue
|
||||
index_in_input_batch = req_id_index_in_input_batch[req.req_id]
|
||||
output_token_ids[index_in_input_batch] = req.output_token_ids
|
||||
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
|
||||
presence_penalties[
|
||||
index_in_input_batch] = req.sampling_params.presence_penalty
|
||||
frequency_penalties[
|
||||
index_in_input_batch] = req.sampling_params.frequency_penalty
|
||||
repetition_penalties[
|
||||
index_in_input_batch] = req.sampling_params.repetition_penalty
|
||||
top_k[index_in_input_batch] = req.sampling_params.top_k
|
||||
top_p[index_in_input_batch] = req.sampling_params.top_p
|
||||
temperature[index_in_input_batch] = req.sampling_params.temperature
|
||||
stop_token_ids[
|
||||
index_in_input_batch] = req.sampling_params.all_stop_token_ids
|
||||
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens
|
||||
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=torch.tensor(temperature, dtype=torch.float, device=device),
|
||||
all_greedy=False,
|
||||
all_random=True,
|
||||
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
|
||||
top_k=torch.tensor(top_k, dtype=torch.int, device=device),
|
||||
no_top_p=all(x == 1.0 for x in top_p),
|
||||
no_top_k=all(x == 0 for x in top_k),
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids= make_tensor_with_pad(
|
||||
prompt_token_ids,
|
||||
pad=VOCAB_SIZE,
|
||||
device=torch.device(device),
|
||||
dtype=torch.int64,
|
||||
),
|
||||
frequency_penalties=torch.tensor(
|
||||
frequency_penalties, dtype=torch.float,
|
||||
device=device),
|
||||
presence_penalties=torch.tensor(
|
||||
presence_penalties, dtype=torch.float,
|
||||
device=device),
|
||||
repetition_penalties=torch.tensor(
|
||||
repetition_penalties, dtype=torch.float,
|
||||
device=device),
|
||||
output_token_ids=output_token_ids,
|
||||
min_tokens=min_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
no_penalties=(all(x ==0 for x in presence_penalties) and \
|
||||
all(x ==0 for x in frequency_penalties) and \
|
||||
all(x ==1 for x in repetition_penalties))
|
||||
)
|
||||
|
||||
|
||||
def _create_sampling_params():
|
||||
return SamplingParams(top_k=np.random.randint(1, 10),
|
||||
top_p=np.random.uniform(0.0, 1.0),
|
||||
presence_penalty=np.random.uniform(-2.0, 2.0),
|
||||
repetition_penalty=np.random.uniform(0.0, 2.0),
|
||||
frequency_penalty=np.random.uniform(-2.0, 2.0),
|
||||
min_tokens=np.random.randint(1, 10),
|
||||
stop_token_ids=[
|
||||
np.random.randint(0, VOCAB_SIZE)
|
||||
for _ in range(np.random.randint(10))
|
||||
])
|
||||
|
||||
|
||||
def _construct_cached_request_state(req_id_suffix: int):
|
||||
prompt_token_ids = [
|
||||
np.random.randint(0, VOCAB_SIZE)
|
||||
for _ in range(np.random.randint(0, MAX_PROMPT_SIZE))
|
||||
]
|
||||
output_token_ids = [
|
||||
np.random.randint(0, VOCAB_SIZE)
|
||||
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
|
||||
]
|
||||
return CachedRequestState(req_id=f"req_id_{req_id_suffix}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt=None,
|
||||
sampling_params=_create_sampling_params(),
|
||||
mm_inputs=[],
|
||||
mm_positions=[],
|
||||
block_ids=[],
|
||||
generator=None,
|
||||
num_computed_tokens=len(output_token_ids),
|
||||
output_token_ids=output_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32, 64])
|
||||
def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
"""
|
||||
Tests the logic for managing sampling metadata in the InputBatch.
|
||||
|
||||
This test involves adding a set of requests to the InputBatch,
|
||||
followed by removing a subset of them. Afterward, the batch is compacted,
|
||||
and the `make_sampling_metadata` method is invoked on the batch. The
|
||||
output of `make_sampling_metadata` is then compared against the expected
|
||||
results to ensure correctness.
|
||||
"""
|
||||
input_batch: InputBatch = InputBatch(max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_blocks_per_req=10,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024)
|
||||
reqs: List[CachedRequestState] = []
|
||||
req_id_reqs = {}
|
||||
req_id_output_token_ids = {}
|
||||
# Add requests
|
||||
for req_index in range(batch_size):
|
||||
req: CachedRequestState = _construct_cached_request_state(req_index)
|
||||
input_batch.add_request(req, req_index)
|
||||
reqs.append(req)
|
||||
req_id_reqs[req.req_id] = req
|
||||
req_id_output_token_ids[req.req_id] = req.output_token_ids
|
||||
|
||||
# Remove some requests
|
||||
req_ids_to_remove, req_indices_to_remove = _remove_requests(
|
||||
input_batch, batch_size, reqs)
|
||||
req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove
|
||||
|
||||
# Compact the input batch
|
||||
input_batch.condense(req_indices_to_remove)
|
||||
|
||||
# Generate the sampling metadata
|
||||
sampling_metadata = input_batch.make_sampling_metadata(
|
||||
req_id_output_token_ids, skip_copy=False)
|
||||
|
||||
# Create expected output.
|
||||
expected_sampling_metadata = _construct_expected_sampling_metadata(
|
||||
reqs,
|
||||
req_ids_retained,
|
||||
input_batch.req_id_to_index,
|
||||
device=torch.device(device))
|
||||
|
||||
# Assert the actual and expected output.
|
||||
assert torch.allclose(expected_sampling_metadata.temperature,
|
||||
sampling_metadata.temperature)
|
||||
assert torch.allclose(expected_sampling_metadata.top_p,
|
||||
sampling_metadata.top_p)
|
||||
assert torch.allclose(expected_sampling_metadata.top_k,
|
||||
sampling_metadata.top_k)
|
||||
assert torch.allclose(expected_sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.frequency_penalties)
|
||||
assert torch.allclose(expected_sampling_metadata.presence_penalties,
|
||||
sampling_metadata.presence_penalties)
|
||||
assert torch.allclose(expected_sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.repetition_penalties)
|
||||
assert torch.allclose(expected_sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.prompt_token_ids)
|
||||
assert (expected_sampling_metadata.output_token_ids ==
|
||||
sampling_metadata.output_token_ids)
|
||||
assert (
|
||||
expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens)
|
||||
assert (expected_sampling_metadata.stop_token_ids ==
|
||||
sampling_metadata.stop_token_ids)
|
||||
assert (expected_sampling_metadata.no_penalties ==
|
||||
sampling_metadata.no_penalties)
|
||||
assert (expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p)
|
||||
assert (expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k)
|
||||
@ -11,6 +11,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||
SamplingTensors,
|
||||
SequenceGroupToSample)
|
||||
@ -258,11 +259,11 @@ class Sampler(nn.Module):
|
||||
|
||||
# Apply presence and frequency penalties.
|
||||
if do_penalties:
|
||||
logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
|
||||
sampling_tensors.output_tokens,
|
||||
sampling_tensors.presence_penalties,
|
||||
sampling_tensors.frequency_penalties,
|
||||
sampling_tensors.repetition_penalties)
|
||||
logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
|
||||
sampling_tensors.output_tokens,
|
||||
sampling_tensors.presence_penalties,
|
||||
sampling_tensors.frequency_penalties,
|
||||
sampling_tensors.repetition_penalties)
|
||||
|
||||
# Use float32 to apply temperature scaling.
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
@ -336,23 +337,6 @@ class Sampler(nn.Module):
|
||||
return self.should_modify_greedy_probs_inplace
|
||||
|
||||
|
||||
def _get_bin_counts_and_mask(
|
||||
tokens: torch.Tensor,
|
||||
vocab_size: int,
|
||||
num_seqs: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Compute the bin counts for the tokens.
|
||||
# vocab_size + 1 for padding.
|
||||
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
||||
dtype=torch.long,
|
||||
device=tokens.device)
|
||||
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
||||
bin_counts = bin_counts[:, :vocab_size]
|
||||
mask = bin_counts > 0
|
||||
|
||||
return bin_counts, mask
|
||||
|
||||
|
||||
def _apply_min_tokens_penalty(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
@ -400,29 +384,6 @@ def _apply_min_tokens_penalty(
|
||||
return logits
|
||||
|
||||
|
||||
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||
output_tokens_tensor: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
||||
num_seqs, vocab_size = logits.shape
|
||||
_, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
|
||||
num_seqs)
|
||||
output_bin_counts, output_mask = _get_bin_counts_and_mask(
|
||||
output_tokens_tensor, vocab_size, num_seqs)
|
||||
|
||||
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
|
||||
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
|
||||
logits = torch.where(logits > 0, logits / repetition_penalties,
|
||||
logits * repetition_penalties)
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
|
||||
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
|
||||
return logits
|
||||
|
||||
|
||||
def _apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
|
||||
57
vllm/model_executor/layers/utils.py
Normal file
57
vllm/model_executor/layers/utils.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""Utility methods for model layers."""
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_token_bin_counts_and_mask(
|
||||
tokens: torch.Tensor,
|
||||
vocab_size: int,
|
||||
num_seqs: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Compute the bin counts for the tokens.
|
||||
# vocab_size + 1 for padding.
|
||||
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
||||
dtype=torch.long,
|
||||
device=tokens.device)
|
||||
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
||||
bin_counts = bin_counts[:, :vocab_size]
|
||||
mask = bin_counts > 0
|
||||
|
||||
return bin_counts, mask
|
||||
|
||||
|
||||
def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||
output_tokens_tensor: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies penalties in place to the logits tensor
|
||||
logits : The input logits tensor of shape [num_seqs, vocab_size]
|
||||
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
|
||||
are padded to the maximum prompt length within the batch using
|
||||
`vocab_size` as the padding value. The value `vocab_size` is used
|
||||
for padding because it does not correspond to any valid token ID
|
||||
in the vocabulary.
|
||||
output_tokens_tensor: The output tokens tensor.
|
||||
presence_penalties: The presence penalties of shape (num_seqs, )
|
||||
frequency_penalties: The frequency penalties of shape (num_seqs, )
|
||||
repetition_penalties: The repetition penalties of shape (num_seqs, )
|
||||
"""
|
||||
num_seqs, vocab_size = logits.shape
|
||||
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
|
||||
vocab_size, num_seqs)
|
||||
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
|
||||
output_tokens_tensor, vocab_size, num_seqs)
|
||||
repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat(
|
||||
1, vocab_size)
|
||||
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
|
||||
repetition_penalties, 1.0)[logits > 0]
|
||||
logits[logits <= 0] *= torch.where(prompt_mask | output_mask,
|
||||
repetition_penalties, 1.0)[logits <= 0]
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
|
||||
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
|
||||
return logits
|
||||
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
@ -19,3 +19,13 @@ class SamplingMetadata:
|
||||
generators: Dict[int, torch.Generator]
|
||||
|
||||
max_num_logprobs: int
|
||||
|
||||
no_penalties: bool
|
||||
prompt_token_ids: Optional[torch.Tensor]
|
||||
frequency_penalties: torch.Tensor
|
||||
presence_penalties: torch.Tensor
|
||||
repetition_penalties: torch.Tensor
|
||||
|
||||
output_token_ids: List[List[int]]
|
||||
min_tokens: List[int]
|
||||
stop_token_ids: List[Set[int]]
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
@ -17,9 +19,18 @@ class Sampler(nn.Module):
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
_apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
|
||||
sampling_metadata.stop_token_ids,
|
||||
sampling_metadata.min_tokens)
|
||||
if not sampling_metadata.no_penalties:
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
_apply_penalties(logits, sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.output_token_ids)
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
logits = self.apply_top_k_top_p(logits, sampling_metadata)
|
||||
|
||||
probs = self.get_probs(logits)
|
||||
sampled = self.sample(probs, sampling_metadata)
|
||||
# Use int32 to reduce the tensor size.
|
||||
@ -157,3 +168,53 @@ def _apply_top_k_top_p(
|
||||
# Re-sort the probabilities.
|
||||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
|
||||
|
||||
def _apply_min_token_penalties(logits: torch.Tensor,
|
||||
output_token_ids: List[List[int]],
|
||||
stop_token_ids: List[Set[int]],
|
||||
min_tokens: List[int]):
|
||||
"""
|
||||
Applies minimum token penalty by setting the logits of the stop tokens
|
||||
to -inf.
|
||||
"""
|
||||
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
|
||||
for index, min_token in enumerate(min_tokens):
|
||||
if (len(output_token_ids[index]) < min_token):
|
||||
for stop_token_id in stop_token_ids[index]:
|
||||
min_tokens_logits_to_penalize.append((index, stop_token_id))
|
||||
if min_tokens_logits_to_penalize:
|
||||
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
|
||||
|
||||
|
||||
def _apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor,
|
||||
output_token_ids: List[List[int]]):
|
||||
"""
|
||||
Applies presence, frequency and repetition penalties to the logits.
|
||||
"""
|
||||
_, vocab_size = logits.shape
|
||||
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
|
||||
logits.device)
|
||||
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
|
||||
presence_penalties, frequency_penalties,
|
||||
repetition_penalties)
|
||||
|
||||
|
||||
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Convert the different list data structures to tensors.
|
||||
"""
|
||||
output_tokens_tensor = make_tensor_with_pad(
|
||||
output_token_ids,
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
pad=vocab_size,
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
return output_tokens_tensor.to(device, non_blocking=True)
|
||||
|
||||
@ -43,12 +43,14 @@ class InputBatch:
|
||||
max_num_blocks_per_req: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.req_ids: List[Optional[str]] = [None] * max_num_reqs
|
||||
self.req_id_to_index: Dict[str, int] = {}
|
||||
@ -63,6 +65,7 @@ class InputBatch:
|
||||
)
|
||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
# Attention-related.
|
||||
self.block_table = torch.zeros(
|
||||
@ -110,6 +113,50 @@ class InputBatch:
|
||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||
self.top_k_reqs: Set[str] = set()
|
||||
|
||||
# Frequency penalty related data structures
|
||||
self.frequency_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.frequency_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.frequency_penalties_cpu = \
|
||||
self.frequency_penalties_cpu_tensor.numpy()
|
||||
self.frequency_penalties_reqs: Set[str] = set()
|
||||
|
||||
# Presence penalty related data structures
|
||||
self.presence_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.presence_penalties_cpu = \
|
||||
self.presence_penalties_cpu_tensor.numpy()
|
||||
self.presence_penalties_reqs: Set[str] = set()
|
||||
|
||||
# Repetition penalty related data structures
|
||||
self.repetition_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.repetition_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.repetition_penalties_cpu = \
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: Set[str] = set()
|
||||
|
||||
self.min_tokens: List[int] = [0] * max_num_reqs
|
||||
self.stop_token_ids: List[Set[int]] = [
|
||||
set() for _ in range(max_num_reqs)
|
||||
]
|
||||
self.prompt_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# req_index -> generator
|
||||
# NOTE(woosuk): The indices of the requests that do not have their own
|
||||
# generator should not be included in the dictionary.
|
||||
@ -133,6 +180,7 @@ class InputBatch:
|
||||
|
||||
# Copy the prompt token ids and output token ids.
|
||||
num_prompt_tokens = len(request.prompt_token_ids)
|
||||
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
||||
self.token_ids_cpu[
|
||||
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
||||
start_idx = num_prompt_tokens
|
||||
@ -157,6 +205,20 @@ class InputBatch:
|
||||
self.top_k_cpu[req_index] = sampling_params.top_k
|
||||
if sampling_params.top_k > 0:
|
||||
self.top_k_reqs.add(req_id)
|
||||
self.frequency_penalties_cpu[req_index] = \
|
||||
sampling_params.frequency_penalty
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[req_index] = \
|
||||
sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties_cpu[req_index] = \
|
||||
sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
self.min_tokens[req_index] = sampling_params.min_tokens
|
||||
self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids
|
||||
|
||||
# NOTE(woosuk): self.generators should not include the requests that
|
||||
# do not have their own generator.
|
||||
@ -179,6 +241,9 @@ class InputBatch:
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.prompt_logprob_reqs.discard(req_id)
|
||||
@ -191,6 +256,9 @@ class InputBatch:
|
||||
self.random_reqs.clear()
|
||||
self.top_p_reqs.clear()
|
||||
self.top_k_reqs.clear()
|
||||
self.frequency_penalties_reqs.clear()
|
||||
self.presence_penalties_reqs.clear()
|
||||
self.repetition_penalties_reqs.clear()
|
||||
self.generators.clear()
|
||||
self.num_logprobs.clear()
|
||||
self.prompt_logprob_reqs.clear()
|
||||
@ -224,6 +292,8 @@ class InputBatch:
|
||||
# block_table_cpu.
|
||||
self.token_ids_cpu[empty_index] = self.token_ids_cpu[
|
||||
last_req_index]
|
||||
self.num_prompt_tokens[empty_index] = \
|
||||
self.num_prompt_tokens[last_req_index]
|
||||
self.num_computed_tokens_cpu[
|
||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||
self.block_table_cpu[empty_index] = self.block_table_cpu[
|
||||
@ -232,6 +302,15 @@ class InputBatch:
|
||||
last_req_index]
|
||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
||||
self.frequency_penalties_cpu[empty_index] = \
|
||||
self.frequency_penalties_cpu[last_req_index]
|
||||
self.presence_penalties_cpu[empty_index] = \
|
||||
self.presence_penalties_cpu[last_req_index]
|
||||
self.repetition_penalties_cpu[empty_index] = \
|
||||
self.repetition_penalties_cpu[last_req_index]
|
||||
self.min_tokens[empty_index] = self.min_tokens[last_req_index]
|
||||
self.stop_token_ids[empty_index] = \
|
||||
self.stop_token_ids[last_req_index]
|
||||
generator = self.generators.pop(last_req_index, None)
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
@ -241,6 +320,7 @@ class InputBatch:
|
||||
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
req_id_output_token_ids: Dict[str, List[int]],
|
||||
skip_copy: bool = False,
|
||||
) -> SamplingMetadata:
|
||||
if not skip_copy:
|
||||
@ -250,6 +330,37 @@ class InputBatch:
|
||||
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||
self.top_k[:self.num_reqs].copy_(
|
||||
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||
if not self.no_penalties:
|
||||
# Since syncing these tensors is expensive only copy them
|
||||
# if necessary i.e. if there are requests which require
|
||||
# penalties to be applied during sampling.
|
||||
self.frequency_penalties[:self.num_reqs].copy_(
|
||||
self.frequency_penalties_cpu_tensor[:self.num_reqs],
|
||||
non_blocking=True)
|
||||
self.presence_penalties[:self.num_reqs].copy_(
|
||||
self.presence_penalties_cpu_tensor[:self.num_reqs],
|
||||
non_blocking=True)
|
||||
self.repetition_penalties[:self.num_reqs].copy_(
|
||||
self.repetition_penalties_cpu_tensor[:self.num_reqs],
|
||||
non_blocking=True)
|
||||
# The prompt tokens are used only for applying penalties during
|
||||
# the sampling process. Hence copy these tensors only when
|
||||
# there are requests which need penalties to be applied.
|
||||
self.prompt_token_ids = self._make_prompt_token_ids_tensor()
|
||||
|
||||
output_token_ids: List[List[int]] = []
|
||||
|
||||
for req_id in self.req_ids[:self.num_reqs]:
|
||||
assert req_id is not None
|
||||
# Currently we create a tensor for output_token_ids from scratch
|
||||
# at each step. However, for the penalties computation what we
|
||||
# need is stats about the token ids present in the output. This
|
||||
# stats can be maintained incrementally instead of computing it
|
||||
# from scratch at each step.
|
||||
# TODO - Replace this with incremental update to output token
|
||||
# statistics.
|
||||
output_token_ids.append(req_id_output_token_ids[req_id])
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=self.temperature[:self.num_reqs],
|
||||
all_greedy=self.all_greedy,
|
||||
@ -260,8 +371,33 @@ class InputBatch:
|
||||
no_top_k=self.no_top_k,
|
||||
generators=self.generators,
|
||||
max_num_logprobs=self.max_num_logprobs,
|
||||
prompt_token_ids=self.prompt_token_ids,
|
||||
frequency_penalties=self.frequency_penalties[:self.num_reqs],
|
||||
presence_penalties=self.presence_penalties[:self.num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:self.num_reqs],
|
||||
output_token_ids=output_token_ids,
|
||||
min_tokens=self.min_tokens[:self.num_reqs],
|
||||
stop_token_ids=self.stop_token_ids[:self.num_reqs],
|
||||
no_penalties=self.no_penalties,
|
||||
)
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
||||
prompt_token_ids_cpu_tensor = torch.empty(
|
||||
(self.num_reqs, max_prompt_len),
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory)
|
||||
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
||||
prompt_token_ids[:] = (
|
||||
self.token_ids_cpu[:self.num_reqs, :max_prompt_len])
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
for i in range(self.num_reqs):
|
||||
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
|
||||
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
||||
non_blocking=True)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
@ -282,6 +418,12 @@ class InputBatch:
|
||||
def no_top_k(self) -> bool:
|
||||
return len(self.top_k_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_penalties(self) -> bool:
|
||||
return (len(self.presence_penalties_reqs) == 0
|
||||
and len(self.frequency_penalties_reqs) == 0
|
||||
and len(self.repetition_penalties_reqs) == 0)
|
||||
|
||||
@property
|
||||
def max_num_logprobs(self) -> int:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else 0
|
||||
|
||||
@ -105,6 +105,7 @@ class GPUModelRunner:
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=model_config.get_vocab_size(),
|
||||
)
|
||||
|
||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||
@ -383,7 +384,12 @@ class GPUModelRunner:
|
||||
or scheduler_output.scheduled_resumed_reqs):
|
||||
skip_copy = False
|
||||
# Create the sampling metadata.
|
||||
sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy)
|
||||
req_id_output_token_ids: Dict[str, List[int]] = \
|
||||
{req_id: req.output_token_ids \
|
||||
for req_id, req in self.requests.items()}
|
||||
|
||||
sampling_metadata = self.input_batch.make_sampling_metadata(
|
||||
req_id_output_token_ids, skip_copy)
|
||||
return sampling_metadata
|
||||
|
||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user