mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 02:09:07 +08:00
[V1] Logit processors for rejection sampler (#19482)
Signed-off-by: southfreebird <yvorott@gmail.com> Signed-off-by: Sergei Skvortsov <sergeyskv@nebius.com> Signed-off-by: Sergei Skvortsov <yvorott@gmail.com> Co-authored-by: Sergei Skvortsov <sergeyskv@nebius.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
0c824fc46f
commit
6ebaf43ee4
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
from typing import Union
|
from typing import Any, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -25,6 +25,7 @@ from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.v1.sample.logits_processor import (
|
from vllm.v1.sample.logits_processor import (
|
||||||
STR_POOLING_REJECTS_LOGITSPROCS,
|
STR_POOLING_REJECTS_LOGITSPROCS,
|
||||||
|
STR_SPEC_DEC_REJECTS_LOGITSPROCS,
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -205,6 +206,7 @@ def test_custom_logitsprocs_req(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
|
@pytest.mark.parametrize("model_scenario", ["pooling", "spec_dec"])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"logitproc_source",
|
"logitproc_source",
|
||||||
[
|
[
|
||||||
@ -213,11 +215,12 @@ def test_custom_logitsprocs_req(monkeypatch):
|
|||||||
CustomLogitprocSource.LOGITPROC_SOURCE_CLASS,
|
CustomLogitprocSource.LOGITPROC_SOURCE_CLASS,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_pooling_rejects_custom_logitsprocs(
|
def test_rejects_custom_logitsprocs(
|
||||||
monkeypatch, logitproc_source: CustomLogitprocSource
|
monkeypatch, model_scenario: str, logitproc_source: CustomLogitprocSource
|
||||||
):
|
):
|
||||||
"""Validate that vLLM engine initialization properly rejects custom
|
"""Validate that vLLM engine initialization properly rejects custom
|
||||||
logitsprocs when the model is a pooling model.
|
logitsprocs when the model is a pooling model or speculative decoding
|
||||||
|
enabled.
|
||||||
|
|
||||||
Use `LLM` entrypoint. We expect `LLM` initialization to fail before the
|
Use `LLM` entrypoint. We expect `LLM` initialization to fail before the
|
||||||
logitproc is actually loaded.
|
logitproc is actually loaded.
|
||||||
@ -241,8 +244,32 @@ def test_pooling_rejects_custom_logitsprocs(
|
|||||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||||
random.seed(40)
|
random.seed(40)
|
||||||
|
|
||||||
|
test_params: dict[str, dict[str, Any]] = {
|
||||||
|
"pooling": {
|
||||||
|
"runner": "pooling",
|
||||||
|
"model": POOLING_MODEL_NAME,
|
||||||
|
"error_message": STR_POOLING_REJECTS_LOGITSPROCS,
|
||||||
|
"speculative_config": None,
|
||||||
|
},
|
||||||
|
"spec_dec": {
|
||||||
|
"runner": "auto",
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"error_message": STR_SPEC_DEC_REJECTS_LOGITSPROCS,
|
||||||
|
"speculative_config": {"model": "ngram", "num_speculative_tokens": 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config = test_params[model_scenario]
|
||||||
|
|
||||||
|
llm_kwargs: dict[str, Any] = {
|
||||||
|
"runner": config["runner"],
|
||||||
|
"model": config["model"],
|
||||||
|
"gpu_memory_utilization": 0.1,
|
||||||
|
"speculative_config": config["speculative_config"],
|
||||||
|
}
|
||||||
|
|
||||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT:
|
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT:
|
||||||
# Scenario: vLLM loads a pooling model and ignores a logitproc that is
|
# Scenario: vLLM loads a model and ignores a logitproc that is
|
||||||
# available at a preconfigured entrypoint
|
# available at a preconfigured entrypoint
|
||||||
|
|
||||||
# Patch in dummy logitproc entrypoint
|
# Patch in dummy logitproc entrypoint
|
||||||
@ -254,30 +281,20 @@ def test_pooling_rejects_custom_logitsprocs(
|
|||||||
# although they should ignore the entrypoint patch anyway
|
# although they should ignore the entrypoint patch anyway
|
||||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")
|
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(**llm_kwargs)
|
||||||
runner="pooling",
|
|
||||||
model=POOLING_MODEL_NAME,
|
|
||||||
gpu_memory_utilization=0.1,
|
|
||||||
)
|
|
||||||
# Require that no logitsprocs have been loaded
|
# Require that no logitsprocs have been loaded
|
||||||
worker = llm.llm_engine.model_executor.driver_worker.worker
|
worker = llm.llm_engine.model_executor.driver_worker.worker
|
||||||
assert sum([1 for _ in worker.model_runner.input_batch.logitsprocs.all]) == 0
|
assert sum([1 for _ in worker.model_runner.input_batch.logitsprocs.all]) == 0
|
||||||
return
|
return
|
||||||
|
|
||||||
kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {}
|
|
||||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
|
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
|
||||||
# Scenario: load logitproc based on fully-qualified class name (FQCN)
|
# Scenario: load logitproc based on fully-qualified class name (FQCN)
|
||||||
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
|
llm_kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
|
||||||
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
|
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
|
||||||
# Scenario: load logitproc from provided class object
|
# Scenario: load logitproc from provided class object
|
||||||
kwargs["logits_processors"] = [DummyLogitsProcessor]
|
llm_kwargs["logits_processors"] = [DummyLogitsProcessor]
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=STR_POOLING_REJECTS_LOGITSPROCS):
|
with pytest.raises(ValueError, match=config["error_message"]):
|
||||||
# Require that loading a pooling model alongside the logitproc raises
|
# Require that loading a model alongside the logitproc raises
|
||||||
# the appropriate exception.
|
# the appropriate exception.
|
||||||
LLM(
|
LLM(**llm_kwargs)
|
||||||
runner="pooling",
|
|
||||||
model=POOLING_MODEL_NAME,
|
|
||||||
gpu_memory_utilization=0.1,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from tests.v1.sample.utils import create_allowed_token_ids
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
@ -21,7 +22,9 @@ def rejection_sampler():
|
|||||||
|
|
||||||
|
|
||||||
def create_logits_tensor(
|
def create_logits_tensor(
|
||||||
output_token_ids: list[list[int]], vocab_size: int = 100
|
output_token_ids: list[list[int]],
|
||||||
|
vocab_size: int = 100,
|
||||||
|
token_idx_to_override: Optional[int] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Helper function to create logits tensor that
|
"""Helper function to create logits tensor that
|
||||||
will produce desired token ids on argmax"""
|
will produce desired token ids on argmax"""
|
||||||
@ -33,15 +36,25 @@ def create_logits_tensor(
|
|||||||
for j, token_id in enumerate(tokens):
|
for j, token_id in enumerate(tokens):
|
||||||
logits[start_loc + j, token_id] = 100.0
|
logits[start_loc + j, token_id] = 100.0
|
||||||
start_loc += len(tokens)
|
start_loc += len(tokens)
|
||||||
|
if token_idx_to_override:
|
||||||
|
logits[:, token_idx_to_override] = 99.0
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def create_sampling_metadata(
|
def create_sampling_metadata(
|
||||||
all_greedy: bool,
|
all_greedy: bool,
|
||||||
|
output_token_ids: Optional[list[list[int]]] = None,
|
||||||
|
prompt_token_ids: Optional[torch.Tensor] = None,
|
||||||
|
spec_token_ids: Optional[torch.Tensor] = None,
|
||||||
temperature: Optional[torch.Tensor] = None,
|
temperature: Optional[torch.Tensor] = None,
|
||||||
top_k: Optional[torch.Tensor] = None,
|
top_k: Optional[torch.Tensor] = None,
|
||||||
top_p: Optional[torch.Tensor] = None,
|
top_p: Optional[torch.Tensor] = None,
|
||||||
generators: Optional[dict[int, Any]] = None,
|
generators: Optional[dict[int, Any]] = None,
|
||||||
|
frequency_penalties: Optional[list[float]] = None,
|
||||||
|
presence_penalties: Optional[list[float]] = None,
|
||||||
|
repetition_penalties: Optional[list[float]] = None,
|
||||||
|
bad_words_token_ids: Optional[dict[int, list[list[int]]]] = None,
|
||||||
|
allowed_token_ids_mask: Optional[torch.Tensor] = None,
|
||||||
) -> SamplingMetadata:
|
) -> SamplingMetadata:
|
||||||
"""Create a v1 sampling metadata object with all_greedy set
|
"""Create a v1 sampling metadata object with all_greedy set
|
||||||
to the given value. Either all greedy or all random sampling
|
to the given value. Either all greedy or all random sampling
|
||||||
@ -53,6 +66,21 @@ def create_sampling_metadata(
|
|||||||
else:
|
else:
|
||||||
assert temperature is not None
|
assert temperature is not None
|
||||||
|
|
||||||
|
if any([frequency_penalties, presence_penalties, repetition_penalties]):
|
||||||
|
no_penalties = False
|
||||||
|
|
||||||
|
assert output_token_ids
|
||||||
|
assert len(output_token_ids) > 0
|
||||||
|
|
||||||
|
frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE)
|
||||||
|
presence_penalties = torch.tensor(presence_penalties, device=DEVICE)
|
||||||
|
repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE)
|
||||||
|
else:
|
||||||
|
no_penalties = True
|
||||||
|
frequency_penalties = torch.tensor([])
|
||||||
|
presence_penalties = torch.tensor([])
|
||||||
|
repetition_penalties = torch.tensor([])
|
||||||
|
|
||||||
return SamplingMetadata(
|
return SamplingMetadata(
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
all_greedy=all_greedy,
|
all_greedy=all_greedy,
|
||||||
@ -61,14 +89,15 @@ def create_sampling_metadata(
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
generators=generators,
|
generators=generators,
|
||||||
max_num_logprobs=0,
|
max_num_logprobs=0,
|
||||||
no_penalties=False,
|
no_penalties=no_penalties,
|
||||||
prompt_token_ids=None,
|
prompt_token_ids=prompt_token_ids,
|
||||||
frequency_penalties=torch.tensor([]),
|
frequency_penalties=frequency_penalties,
|
||||||
presence_penalties=torch.tensor([]),
|
presence_penalties=presence_penalties,
|
||||||
repetition_penalties=torch.tensor([]),
|
repetition_penalties=repetition_penalties,
|
||||||
output_token_ids=[],
|
output_token_ids=[] if output_token_ids is None else output_token_ids,
|
||||||
allowed_token_ids_mask=None,
|
spec_token_ids=[] if spec_token_ids is None else spec_token_ids,
|
||||||
bad_words_token_ids={},
|
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||||
|
bad_words_token_ids={} if bad_words_token_ids is None else bad_words_token_ids,
|
||||||
logitsprocs=LogitsProcessors(),
|
logitsprocs=LogitsProcessors(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -611,3 +640,136 @@ def test_top_p(rejection_sampler, top_p):
|
|||||||
unmasked_indices=top_p_indices,
|
unmasked_indices=top_p_indices,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
########################### Tests for Logit Processors ###################
|
||||||
|
def test_frequency_penalties(rejection_sampler):
|
||||||
|
"""Test rejection sampling with frequency penalties"""
|
||||||
|
spec_tokens = [[1, 1, 1], [], [1, 1, 1]]
|
||||||
|
output_tokens = [[1, 1, 1, 1], [7], [1, 1, 1, 1]] # 1, 7 and 1 are the bonus tokens
|
||||||
|
|
||||||
|
num_requsts = len(spec_tokens)
|
||||||
|
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
|
||||||
|
metadata = create_sampling_metadata(
|
||||||
|
all_greedy=True,
|
||||||
|
output_token_ids=[[2], [3], [4]],
|
||||||
|
spec_token_ids=spec_tokens,
|
||||||
|
prompt_token_ids=torch.tensor([[5, 6, 7], [6, 7, 8], [7, 8, 9]], device=DEVICE),
|
||||||
|
frequency_penalties=[1.5, 1.5, 0.7],
|
||||||
|
presence_penalties=[0.0] * num_requsts,
|
||||||
|
repetition_penalties=[1.0] * num_requsts,
|
||||||
|
)
|
||||||
|
bonus_token_tensor = torch.tensor(
|
||||||
|
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
|
||||||
|
)
|
||||||
|
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||||
|
spec_tokens, device=logits.device
|
||||||
|
)
|
||||||
|
output = rejection_sampler(
|
||||||
|
spec_decode_metadata,
|
||||||
|
draft_probs=None,
|
||||||
|
target_logits=logits,
|
||||||
|
bonus_token_ids=bonus_token_tensor,
|
||||||
|
sampling_metadata=metadata,
|
||||||
|
)
|
||||||
|
expected = torch.tensor(
|
||||||
|
[[1, 15, -1, -1], [7, -1, -1, -1], [1, 1, 15, -1]],
|
||||||
|
dtype=torch.int,
|
||||||
|
device=logits.device,
|
||||||
|
)
|
||||||
|
assert torch.equal(output, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bad_words(rejection_sampler):
|
||||||
|
"""Test rejection sampling with bad words constraints"""
|
||||||
|
spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]]
|
||||||
|
output_tokens = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]
|
||||||
|
|
||||||
|
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
|
||||||
|
metadata = create_sampling_metadata(
|
||||||
|
all_greedy=True,
|
||||||
|
output_token_ids=[[2], [3], [4]],
|
||||||
|
spec_token_ids=spec_tokens,
|
||||||
|
bad_words_token_ids={
|
||||||
|
0: [
|
||||||
|
[
|
||||||
|
2,
|
||||||
|
]
|
||||||
|
],
|
||||||
|
1: [
|
||||||
|
[
|
||||||
|
2,
|
||||||
|
]
|
||||||
|
],
|
||||||
|
# Do not apply bad words to the last request
|
||||||
|
},
|
||||||
|
)
|
||||||
|
bonus_token_tensor = torch.tensor(
|
||||||
|
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
|
||||||
|
)
|
||||||
|
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||||
|
spec_tokens, device=logits.device
|
||||||
|
)
|
||||||
|
output = rejection_sampler(
|
||||||
|
spec_decode_metadata,
|
||||||
|
draft_probs=None,
|
||||||
|
target_logits=logits,
|
||||||
|
bonus_token_ids=bonus_token_tensor,
|
||||||
|
sampling_metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = torch.tensor(
|
||||||
|
[[1, 15, -1, -1], [1, 15, 3, 4], [1, 2, 3, 4]],
|
||||||
|
dtype=torch.int,
|
||||||
|
device=logits.device,
|
||||||
|
)
|
||||||
|
assert torch.equal(output, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowed_token_ids(rejection_sampler):
|
||||||
|
"""Test rejection sampling with allowed token ids"""
|
||||||
|
spec_tokens = [[1, 2, 10], [10, 5, 3], [7, 10, 12]]
|
||||||
|
output_tokens = [[1, 2, 10, 5], [10, 5, 10, 5], [7, 10, 12, 5]]
|
||||||
|
# Not allowed tokens:
|
||||||
|
# 0: 0-4
|
||||||
|
# 1: 1-5
|
||||||
|
# 2: 2-6
|
||||||
|
num_allowed_token_ids = 5
|
||||||
|
|
||||||
|
# Use the token 15 as the sampler choose if a token rejected
|
||||||
|
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
|
||||||
|
|
||||||
|
batch_size = len(output_tokens)
|
||||||
|
_, vocab_size = logits.size()
|
||||||
|
mask = create_allowed_token_ids(
|
||||||
|
batch_size=batch_size,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
num_allowed_token_ids=num_allowed_token_ids,
|
||||||
|
device=logits.device,
|
||||||
|
)
|
||||||
|
metadata = create_sampling_metadata(
|
||||||
|
all_greedy=True,
|
||||||
|
output_token_ids=[[], [], []],
|
||||||
|
spec_token_ids=spec_tokens,
|
||||||
|
allowed_token_ids_mask=mask,
|
||||||
|
)
|
||||||
|
bonus_token_tensor = torch.tensor(
|
||||||
|
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
|
||||||
|
)
|
||||||
|
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||||
|
spec_tokens, device=logits.device
|
||||||
|
)
|
||||||
|
output = rejection_sampler(
|
||||||
|
spec_decode_metadata,
|
||||||
|
draft_probs=None,
|
||||||
|
target_logits=logits,
|
||||||
|
bonus_token_ids=bonus_token_tensor,
|
||||||
|
sampling_metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = torch.tensor(
|
||||||
|
[[15, -1, -1, -1], [10, 5, 10, -1], [7, 10, 12, 5]],
|
||||||
|
dtype=torch.int,
|
||||||
|
device=logits.device,
|
||||||
|
)
|
||||||
|
assert torch.equal(output, expected)
|
||||||
|
|||||||
@ -1,12 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.v1.sample.utils import create_allowed_token_ids
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||||
@ -51,26 +50,6 @@ def _create_prompt_tokens_tensor(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _create_allowed_token_ids(
|
|
||||||
batch_size: int,
|
|
||||||
vocab_size: int,
|
|
||||||
num_allowed_token_ids: int,
|
|
||||||
device: torch.device,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
mask: Optional[torch.Tensor] = None
|
|
||||||
for i in range(batch_size):
|
|
||||||
if i % 2 == 1:
|
|
||||||
continue
|
|
||||||
if mask is None:
|
|
||||||
mask = torch.zeros(
|
|
||||||
(batch_size, vocab_size), dtype=torch.bool, device=device
|
|
||||||
)
|
|
||||||
start = min(i, vocab_size - 1)
|
|
||||||
end = min(i + num_allowed_token_ids, vocab_size - 1)
|
|
||||||
mask[i, start:end] = True
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def _create_bad_words_token_ids(
|
def _create_bad_words_token_ids(
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
@ -173,6 +152,7 @@ def _create_default_sampling_metadata(
|
|||||||
prompt_token_ids, vocab_size, device
|
prompt_token_ids, vocab_size, device
|
||||||
),
|
),
|
||||||
output_token_ids=output_token_ids,
|
output_token_ids=output_token_ids,
|
||||||
|
spec_token_ids=[[] for _ in range(batch_size)],
|
||||||
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||||
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||||
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
|
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
|
||||||
@ -241,7 +221,9 @@ def test_sampler_presence_penalty(
|
|||||||
)
|
)
|
||||||
sampling_metadata.no_penalties = False
|
sampling_metadata.no_penalties = False
|
||||||
sampler = Sampler()
|
sampler = Sampler()
|
||||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
logits = sampler.apply_penalties(
|
||||||
|
fake_logits, sampling_metadata, sampling_metadata.output_token_ids
|
||||||
|
)
|
||||||
logits = logits.cpu()
|
logits = logits.cpu()
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
# Since all tokens initially have the same logits, the non-penalized
|
# Since all tokens initially have the same logits, the non-penalized
|
||||||
@ -293,7 +275,9 @@ def test_sampler_frequency_penalty(
|
|||||||
sampling_metadata.output_token_ids = output_token_ids
|
sampling_metadata.output_token_ids = output_token_ids
|
||||||
sampling_metadata.no_penalties = False
|
sampling_metadata.no_penalties = False
|
||||||
sampler = Sampler()
|
sampler = Sampler()
|
||||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
logits = sampler.apply_penalties(
|
||||||
|
fake_logits, sampling_metadata, sampling_metadata.output_token_ids
|
||||||
|
)
|
||||||
logits = logits.cpu()
|
logits = logits.cpu()
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||||
@ -343,7 +327,9 @@ def test_sampler_repetition_penalty(
|
|||||||
)
|
)
|
||||||
sampling_metadata.no_penalties = False
|
sampling_metadata.no_penalties = False
|
||||||
sampler = Sampler()
|
sampler = Sampler()
|
||||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
logits = sampler.apply_penalties(
|
||||||
|
fake_logits, sampling_metadata, sampling_metadata.output_token_ids
|
||||||
|
)
|
||||||
logits = logits.cpu()
|
logits = logits.cpu()
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||||
@ -394,7 +380,7 @@ def test_sampler_allowed_token_ids(
|
|||||||
sampling_metadata = _create_default_sampling_metadata(
|
sampling_metadata = _create_default_sampling_metadata(
|
||||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
|
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
|
||||||
)
|
)
|
||||||
mask = _create_allowed_token_ids(
|
mask = create_allowed_token_ids(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
vocab_size=VOCAB_SIZE,
|
vocab_size=VOCAB_SIZE,
|
||||||
num_allowed_token_ids=num_allowed_token_ids,
|
num_allowed_token_ids=num_allowed_token_ids,
|
||||||
@ -402,7 +388,9 @@ def test_sampler_allowed_token_ids(
|
|||||||
)
|
)
|
||||||
sampling_metadata.allowed_token_ids_mask = mask
|
sampling_metadata.allowed_token_ids_mask = mask
|
||||||
sampler = Sampler()
|
sampler = Sampler()
|
||||||
logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata)
|
logits = sampler.apply_logits_processors(
|
||||||
|
fake_logits, sampling_metadata, predict_bonus_token=False
|
||||||
|
)
|
||||||
logits = logits.cpu()
|
logits = logits.cpu()
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
logits_for_req = logits[batch_idx]
|
logits_for_req = logits[batch_idx]
|
||||||
@ -444,7 +432,9 @@ def test_sampler_bad_words(
|
|||||||
sampling_metadata, VOCAB_SIZE
|
sampling_metadata, VOCAB_SIZE
|
||||||
)
|
)
|
||||||
sampler = Sampler()
|
sampler = Sampler()
|
||||||
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
|
logits = sampler.apply_logits_processors(
|
||||||
|
fake_logits, sampling_metadata, predict_bonus_token=False
|
||||||
|
)
|
||||||
logits = logits.cpu()
|
logits = logits.cpu()
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
logits_for_req = logits[batch_idx]
|
logits_for_req = logits[batch_idx]
|
||||||
|
|||||||
@ -215,3 +215,23 @@ def fake_apply_logitsprocs(
|
|||||||
for processor in test_fakes.get_logitsprocs():
|
for processor in test_fakes.get_logitsprocs():
|
||||||
logits = processor.apply(logits)
|
logits = processor.apply(logits)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def create_allowed_token_ids(
|
||||||
|
batch_size: int,
|
||||||
|
vocab_size: int,
|
||||||
|
num_allowed_token_ids: int,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
mask: Optional[torch.Tensor] = None
|
||||||
|
for i in range(batch_size):
|
||||||
|
if i % 2 == 1:
|
||||||
|
continue
|
||||||
|
if mask is None:
|
||||||
|
mask = torch.zeros(
|
||||||
|
(batch_size, vocab_size), dtype=torch.bool, device=device
|
||||||
|
)
|
||||||
|
start = min(i, vocab_size - 1)
|
||||||
|
end = min(i + num_allowed_token_ids, vocab_size - 1)
|
||||||
|
mask[i, start:end] = True
|
||||||
|
return mask
|
||||||
|
|||||||
@ -170,6 +170,7 @@ def _construct_expected_sampling_metadata(
|
|||||||
repetition_penalties, dtype=torch.float, device=device
|
repetition_penalties, dtype=torch.float, device=device
|
||||||
),
|
),
|
||||||
output_token_ids=output_token_ids,
|
output_token_ids=output_token_ids,
|
||||||
|
spec_token_ids=[[] for _ in range(len(output_token_ids))],
|
||||||
no_penalties=(
|
no_penalties=(
|
||||||
all(x == 0 for x in presence_penalties)
|
all(x == 0 for x in presence_penalties)
|
||||||
and all(x == 0 for x in frequency_penalties)
|
and all(x == 0 for x in frequency_penalties)
|
||||||
|
|||||||
@ -37,6 +37,12 @@ STR_POOLING_REJECTS_LOGITSPROCS = (
|
|||||||
"Pooling models do not support custom logits processors."
|
"Pooling models do not support custom logits processors."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Error message when the user tries to initialize vLLM with a speculative
|
||||||
|
# decoding enabled and custom logitsproces
|
||||||
|
STR_SPEC_DEC_REJECTS_LOGITSPROCS = (
|
||||||
|
"Custom logits processors are not supportedwhen speculative decoding is enabled."
|
||||||
|
)
|
||||||
|
|
||||||
LOGITSPROCS_GROUP = "vllm.logits_processors"
|
LOGITSPROCS_GROUP = "vllm.logits_processors"
|
||||||
|
|
||||||
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
|
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
|
||||||
@ -185,6 +191,17 @@ def build_logitsprocs(
|
|||||||
" do not support logits processors."
|
" do not support logits processors."
|
||||||
)
|
)
|
||||||
return LogitsProcessors()
|
return LogitsProcessors()
|
||||||
|
|
||||||
|
# Check if speculative decoding is enabled.
|
||||||
|
if vllm_config.speculative_config:
|
||||||
|
if custom_logitsprocs:
|
||||||
|
raise ValueError(STR_SPEC_DEC_REJECTS_LOGITSPROCS)
|
||||||
|
logger.warning(
|
||||||
|
"min_p, logit_bias, and min_tokens parameters won't currently work "
|
||||||
|
"with speculative decoding enabled."
|
||||||
|
)
|
||||||
|
return LogitsProcessors()
|
||||||
|
|
||||||
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
|
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
|
||||||
return LogitsProcessors(
|
return LogitsProcessors(
|
||||||
ctor(vllm_config, device, is_pin_memory)
|
ctor(vllm_config, device, is_pin_memory)
|
||||||
|
|||||||
@ -40,3 +40,6 @@ class SamplingMetadata:
|
|||||||
|
|
||||||
# Loaded logits processors
|
# Loaded logits processors
|
||||||
logitsprocs: LogitsProcessors
|
logitsprocs: LogitsProcessors
|
||||||
|
|
||||||
|
# Speculative token ids
|
||||||
|
spec_token_ids: Optional[list[list[int]]] = None
|
||||||
|
|||||||
@ -33,3 +33,20 @@ def apply_bad_words(
|
|||||||
) -> None:
|
) -> None:
|
||||||
for i, bad_words_ids in bad_words_token_ids.items():
|
for i, bad_words_ids in bad_words_token_ids.items():
|
||||||
_apply_bad_words_single_batch(logits[i], bad_words_ids, past_tokens_ids[i])
|
_apply_bad_words_single_batch(logits[i], bad_words_ids, past_tokens_ids[i])
|
||||||
|
|
||||||
|
|
||||||
|
def apply_bad_words_with_drafts(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
bad_words_token_ids: dict[int, list[list[int]]],
|
||||||
|
past_tokens_ids: list[list[int]],
|
||||||
|
num_draft_tokens: list[int],
|
||||||
|
) -> None:
|
||||||
|
start_idx = 0
|
||||||
|
for i, bad_words_ids in bad_words_token_ids.items():
|
||||||
|
for draft_idx in range(num_draft_tokens[i]):
|
||||||
|
_apply_bad_words_single_batch(
|
||||||
|
logits[start_idx + draft_idx],
|
||||||
|
bad_words_ids,
|
||||||
|
past_tokens_ids[start_idx + draft_idx],
|
||||||
|
)
|
||||||
|
start_idx += num_draft_tokens[i]
|
||||||
|
|||||||
@ -8,6 +8,8 @@ import torch.nn as nn
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
|
||||||
|
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
|
|
||||||
@ -83,6 +85,14 @@ class RejectionSampler(nn.Module):
|
|||||||
A tensor containing the final output token IDs.
|
A tensor containing the final output token IDs.
|
||||||
"""
|
"""
|
||||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||||
|
|
||||||
|
# Use float32 for the target_logits.
|
||||||
|
target_logits = target_logits.to(torch.float32)
|
||||||
|
|
||||||
|
target_logits = self.apply_logits_processors(
|
||||||
|
target_logits, sampling_metadata, metadata
|
||||||
|
)
|
||||||
|
|
||||||
# [num_tokens, vocab_size]
|
# [num_tokens, vocab_size]
|
||||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||||
# `compute_probs` function.
|
# `compute_probs` function.
|
||||||
@ -131,6 +141,100 @@ class RejectionSampler(nn.Module):
|
|||||||
]
|
]
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def apply_logits_processors(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
metadata: SpecDecodeMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
any_penalties_or_bad_words = (
|
||||||
|
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
|
||||||
|
)
|
||||||
|
|
||||||
|
output_token_ids = sampling_metadata.output_token_ids
|
||||||
|
if any_penalties_or_bad_words:
|
||||||
|
output_token_ids = self._combine_outputs_with_spec_tokens(
|
||||||
|
sampling_metadata.output_token_ids,
|
||||||
|
sampling_metadata.spec_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate indices of target logits.
|
||||||
|
if (
|
||||||
|
sampling_metadata.allowed_token_ids_mask is not None
|
||||||
|
or not sampling_metadata.no_penalties
|
||||||
|
):
|
||||||
|
num_requests = len(sampling_metadata.output_token_ids)
|
||||||
|
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
|
||||||
|
original_indices = torch.arange(num_requests, device="cpu")
|
||||||
|
repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens)
|
||||||
|
repeat_indices = repeat_indices_cpu.to(
|
||||||
|
device=logits.device, non_blocking=True
|
||||||
|
)
|
||||||
|
logits = self.apply_penalties(
|
||||||
|
logits, sampling_metadata, metadata, repeat_indices, output_token_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply allowed token ids.
|
||||||
|
if sampling_metadata.allowed_token_ids_mask is not None:
|
||||||
|
token_mask = sampling_metadata.allowed_token_ids_mask[repeat_indices]
|
||||||
|
logits.masked_fill_(token_mask, float("-inf"))
|
||||||
|
|
||||||
|
# Apply bad words exclusion.
|
||||||
|
if sampling_metadata.bad_words_token_ids:
|
||||||
|
apply_bad_words_with_drafts(
|
||||||
|
logits,
|
||||||
|
sampling_metadata.bad_words_token_ids,
|
||||||
|
output_token_ids,
|
||||||
|
metadata.num_draft_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def apply_penalties(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
metadata: SpecDecodeMetadata,
|
||||||
|
repeat_indices: torch.Tensor,
|
||||||
|
output_token_ids: list[list[int]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if sampling_metadata.no_penalties:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
assert sampling_metadata.prompt_token_ids is not None
|
||||||
|
|
||||||
|
prompt_token_ids = sampling_metadata.prompt_token_ids[repeat_indices]
|
||||||
|
presence_penalties = sampling_metadata.presence_penalties[repeat_indices]
|
||||||
|
frequency_penalties = sampling_metadata.frequency_penalties[repeat_indices]
|
||||||
|
repetition_penalties = sampling_metadata.repetition_penalties[repeat_indices]
|
||||||
|
|
||||||
|
logits = apply_all_penalties(
|
||||||
|
logits,
|
||||||
|
prompt_token_ids,
|
||||||
|
presence_penalties,
|
||||||
|
frequency_penalties,
|
||||||
|
repetition_penalties,
|
||||||
|
output_token_ids,
|
||||||
|
)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def _combine_outputs_with_spec_tokens(
|
||||||
|
self,
|
||||||
|
output_token_ids: list[list[int]],
|
||||||
|
spec_token_ids: Optional[list[list[int]]] = None,
|
||||||
|
) -> list[list[int]]:
|
||||||
|
if spec_token_ids is None:
|
||||||
|
return output_token_ids
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for out, spec in zip(output_token_ids, spec_token_ids):
|
||||||
|
if len(spec) == 0:
|
||||||
|
continue
|
||||||
|
result.append(out)
|
||||||
|
for i in range(len(spec) - 1):
|
||||||
|
result.append([*result[-1], spec[i]])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def rejection_sample(
|
def rejection_sample(
|
||||||
# [num_tokens]
|
# [num_tokens]
|
||||||
|
|||||||
@ -70,6 +70,7 @@ class Sampler(nn.Module):
|
|||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
|
predict_bonus_token: bool = False,
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||||
# temperature scaling) for the top-k logprobs.
|
# temperature scaling) for the top-k logprobs.
|
||||||
@ -84,18 +85,10 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
# Use float32 for the logits.
|
# Use float32 for the logits.
|
||||||
logits = logits.to(torch.float32)
|
logits = logits.to(torch.float32)
|
||||||
# Apply allowed token ids.
|
|
||||||
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
|
|
||||||
# Apply bad words exclusion.
|
|
||||||
logits = self.apply_bad_words(logits, sampling_metadata)
|
|
||||||
|
|
||||||
# Apply logits processors which can impact greedy sampling
|
|
||||||
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
|
||||||
logits = processor.apply(logits)
|
|
||||||
|
|
||||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
|
||||||
logits = self.apply_penalties(logits, sampling_metadata)
|
|
||||||
|
|
||||||
|
logits = self.apply_logits_processors(
|
||||||
|
logits, sampling_metadata, predict_bonus_token
|
||||||
|
)
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
sampled, processed_logprobs = self.sample(logits, sampling_metadata)
|
sampled, processed_logprobs = self.sample(logits, sampling_metadata)
|
||||||
if processed_logprobs is not None:
|
if processed_logprobs is not None:
|
||||||
@ -245,10 +238,65 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
return LogprobsTensors(indices, logprobs, token_ranks)
|
return LogprobsTensors(indices, logprobs, token_ranks)
|
||||||
|
|
||||||
|
def _combine_outputs_with_spec_tokens(
|
||||||
|
self,
|
||||||
|
output_token_ids: list[list[int]],
|
||||||
|
spec_token_ids: Optional[list[list[int]]] = None,
|
||||||
|
) -> list[list[int]]:
|
||||||
|
if spec_token_ids is None:
|
||||||
|
return output_token_ids
|
||||||
|
|
||||||
|
return [
|
||||||
|
[*out, *spec] if spec else out
|
||||||
|
for out, spec in zip(output_token_ids, spec_token_ids)
|
||||||
|
]
|
||||||
|
|
||||||
|
def apply_logits_processors(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
predict_bonus_token: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
any_penalties_or_bad_words = (
|
||||||
|
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
|
||||||
|
)
|
||||||
|
|
||||||
|
output_token_ids = sampling_metadata.output_token_ids
|
||||||
|
if predict_bonus_token and any_penalties_or_bad_words:
|
||||||
|
# Combine base outputs with spec tokens when speculative decoding
|
||||||
|
# is enabled.
|
||||||
|
output_token_ids = self._combine_outputs_with_spec_tokens(
|
||||||
|
sampling_metadata.output_token_ids,
|
||||||
|
sampling_metadata.spec_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply allowed token ids.
|
||||||
|
if sampling_metadata.allowed_token_ids_mask is not None:
|
||||||
|
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
|
||||||
|
|
||||||
|
# Apply bad words exclusion.
|
||||||
|
if sampling_metadata.bad_words_token_ids:
|
||||||
|
apply_bad_words(
|
||||||
|
logits,
|
||||||
|
sampling_metadata.bad_words_token_ids,
|
||||||
|
output_token_ids
|
||||||
|
if output_token_ids is not None
|
||||||
|
else sampling_metadata.output_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply logits processors which can impact greedy sampling.
|
||||||
|
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
||||||
|
logits = processor.apply(logits)
|
||||||
|
|
||||||
|
# Apply penalties (e.g., freq_penalties).
|
||||||
|
logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
|
||||||
|
return logits
|
||||||
|
|
||||||
def apply_penalties(
|
def apply_penalties(
|
||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
|
output_token_ids: Optional[list[list[int]]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if not sampling_metadata.no_penalties:
|
if not sampling_metadata.no_penalties:
|
||||||
assert sampling_metadata.prompt_token_ids is not None
|
assert sampling_metadata.prompt_token_ids is not None
|
||||||
@ -258,28 +306,8 @@ class Sampler(nn.Module):
|
|||||||
sampling_metadata.presence_penalties,
|
sampling_metadata.presence_penalties,
|
||||||
sampling_metadata.frequency_penalties,
|
sampling_metadata.frequency_penalties,
|
||||||
sampling_metadata.repetition_penalties,
|
sampling_metadata.repetition_penalties,
|
||||||
sampling_metadata.output_token_ids,
|
output_token_ids
|
||||||
)
|
if output_token_ids is not None
|
||||||
return logits
|
else sampling_metadata.output_token_ids,
|
||||||
|
|
||||||
def apply_allowed_token_ids(
|
|
||||||
self,
|
|
||||||
logits: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if sampling_metadata.allowed_token_ids_mask is not None:
|
|
||||||
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def apply_bad_words(
|
|
||||||
self,
|
|
||||||
logits: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if sampling_metadata.bad_words_token_ids:
|
|
||||||
apply_bad_words(
|
|
||||||
logits,
|
|
||||||
sampling_metadata.bad_words_token_ids,
|
|
||||||
sampling_metadata.output_token_ids,
|
|
||||||
)
|
)
|
||||||
return logits
|
return logits
|
||||||
|
|||||||
@ -240,6 +240,9 @@ class InputBatch:
|
|||||||
# data structure
|
# data structure
|
||||||
self.logitsprocs = logitsprocs or LogitsProcessors()
|
self.logitsprocs = logitsprocs or LogitsProcessors()
|
||||||
|
|
||||||
|
# Store last speculative tokens for sampler.
|
||||||
|
self.spec_token_ids: list[Optional[list[int]]] = []
|
||||||
|
|
||||||
# This is updated each time the batch constituents change.
|
# This is updated each time the batch constituents change.
|
||||||
self.sampling_metadata = self._make_sampling_metadata()
|
self.sampling_metadata = self._make_sampling_metadata()
|
||||||
|
|
||||||
@ -292,9 +295,11 @@ class InputBatch:
|
|||||||
if req_index == len(self._req_ids):
|
if req_index == len(self._req_ids):
|
||||||
self._req_ids.append(req_id)
|
self._req_ids.append(req_id)
|
||||||
self.req_output_token_ids.append(request.output_token_ids)
|
self.req_output_token_ids.append(request.output_token_ids)
|
||||||
|
self.spec_token_ids.append([])
|
||||||
else:
|
else:
|
||||||
self._req_ids[req_index] = req_id
|
self._req_ids[req_index] = req_id
|
||||||
self.req_output_token_ids[req_index] = request.output_token_ids
|
self.req_output_token_ids[req_index] = request.output_token_ids
|
||||||
|
self.spec_token_ids[req_index] = []
|
||||||
|
|
||||||
self.req_id_to_index[req_id] = req_index
|
self.req_id_to_index[req_id] = req_index
|
||||||
|
|
||||||
@ -443,6 +448,7 @@ class InputBatch:
|
|||||||
self.batch_update_builder.removed_append(req_index)
|
self.batch_update_builder.removed_append(req_index)
|
||||||
self._req_ids[req_index] = None
|
self._req_ids[req_index] = None
|
||||||
self.req_output_token_ids[req_index] = None
|
self.req_output_token_ids[req_index] = None
|
||||||
|
self.spec_token_ids[req_index] = None
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
lora_id = self.request_lora_mapping[req_index]
|
lora_id = self.request_lora_mapping[req_index]
|
||||||
@ -486,6 +492,10 @@ class InputBatch:
|
|||||||
self.req_output_token_ids[i2],
|
self.req_output_token_ids[i2],
|
||||||
self.req_output_token_ids[i1],
|
self.req_output_token_ids[i1],
|
||||||
)
|
)
|
||||||
|
self.spec_token_ids[i1], self.spec_token_ids[i2] = (
|
||||||
|
self.spec_token_ids[i2],
|
||||||
|
self.spec_token_ids[i1],
|
||||||
|
)
|
||||||
assert old_id_i1 is not None and old_id_i2 is not None
|
assert old_id_i1 is not None and old_id_i2 is not None
|
||||||
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
|
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = (
|
||||||
self.req_id_to_index[old_id_i2],
|
self.req_id_to_index[old_id_i2],
|
||||||
@ -601,6 +611,7 @@ class InputBatch:
|
|||||||
# The batched states are empty.
|
# The batched states are empty.
|
||||||
self._req_ids.clear()
|
self._req_ids.clear()
|
||||||
self.req_output_token_ids.clear()
|
self.req_output_token_ids.clear()
|
||||||
|
self.spec_token_ids.clear()
|
||||||
return
|
return
|
||||||
|
|
||||||
# NOTE(woosuk): This function assumes that the empty_req_indices
|
# NOTE(woosuk): This function assumes that the empty_req_indices
|
||||||
@ -629,6 +640,10 @@ class InputBatch:
|
|||||||
self.req_output_token_ids[last_req_index] = None
|
self.req_output_token_ids[last_req_index] = None
|
||||||
self.req_id_to_index[req_id] = empty_index
|
self.req_id_to_index[req_id] = empty_index
|
||||||
|
|
||||||
|
spec_token_ids = self.spec_token_ids[last_req_index]
|
||||||
|
self.spec_token_ids[empty_index] = spec_token_ids
|
||||||
|
self.spec_token_ids[last_req_index] = None
|
||||||
|
|
||||||
num_tokens = self.num_tokens[last_req_index]
|
num_tokens = self.num_tokens[last_req_index]
|
||||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||||
last_req_index, :num_tokens
|
last_req_index, :num_tokens
|
||||||
@ -700,6 +715,7 @@ class InputBatch:
|
|||||||
# Trim lists to the batch size.
|
# Trim lists to the batch size.
|
||||||
del self._req_ids[num_reqs:]
|
del self._req_ids[num_reqs:]
|
||||||
del self.req_output_token_ids[num_reqs:]
|
del self.req_output_token_ids[num_reqs:]
|
||||||
|
del self.spec_token_ids[num_reqs:]
|
||||||
|
|
||||||
def refresh_metadata(self):
|
def refresh_metadata(self):
|
||||||
"""Apply any batch updates to sampling metadata."""
|
"""Apply any batch updates to sampling metadata."""
|
||||||
@ -784,6 +800,7 @@ class InputBatch:
|
|||||||
presence_penalties=self.presence_penalties[:num_reqs],
|
presence_penalties=self.presence_penalties[:num_reqs],
|
||||||
repetition_penalties=self.repetition_penalties[:num_reqs],
|
repetition_penalties=self.repetition_penalties[:num_reqs],
|
||||||
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
||||||
|
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
|
||||||
no_penalties=self.no_penalties,
|
no_penalties=self.no_penalties,
|
||||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||||
bad_words_token_ids=self.bad_words_token_ids,
|
bad_words_token_ids=self.bad_words_token_ids,
|
||||||
|
|||||||
@ -783,6 +783,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
] = spec_token_ids
|
] = spec_token_ids
|
||||||
# NOTE(woosuk): `num_tokens` here may include spec tokens.
|
# NOTE(woosuk): `num_tokens` here may include spec tokens.
|
||||||
self.input_batch.num_tokens[req_index] += num_spec_tokens
|
self.input_batch.num_tokens[req_index] += num_spec_tokens
|
||||||
|
self.input_batch.spec_token_ids[req_index] = spec_token_ids
|
||||||
|
|
||||||
# Add the new or resumed requests to the persistent batch.
|
# Add the new or resumed requests to the persistent batch.
|
||||||
# The smaller empty indices are filled first.
|
# The smaller empty indices are filled first.
|
||||||
@ -2197,6 +2198,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
sampler_output = self.sampler(
|
sampler_output = self.sampler(
|
||||||
logits=bonus_logits,
|
logits=bonus_logits,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
|
predict_bonus_token=True,
|
||||||
)
|
)
|
||||||
bonus_token_ids = sampler_output.sampled_token_ids
|
bonus_token_ids = sampler_output.sampled_token_ids
|
||||||
|
|
||||||
@ -3491,6 +3493,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
presence_penalties=dummy_tensors(0.1),
|
presence_penalties=dummy_tensors(0.1),
|
||||||
repetition_penalties=dummy_tensors(0.1),
|
repetition_penalties=dummy_tensors(0.1),
|
||||||
output_token_ids=[[] for _ in range(num_reqs)],
|
output_token_ids=[[] for _ in range(num_reqs)],
|
||||||
|
spec_token_ids=[[] for _ in range(num_reqs)],
|
||||||
allowed_token_ids_mask=None,
|
allowed_token_ids_mask=None,
|
||||||
bad_words_token_ids={},
|
bad_words_token_ids={},
|
||||||
logitsprocs=LogitsProcessors(),
|
logitsprocs=LogitsProcessors(),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user