mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 12:25:45 +08:00
[V1] Optimize handling of sampling metadata and req_ids list (#13244)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
a4d577b379
commit
30172b4947
@ -26,17 +26,13 @@ def create_logits_tensor(token_ids: List[int],
|
|||||||
def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
|
def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
|
||||||
batch_size = len(spec_tokens)
|
batch_size = len(spec_tokens)
|
||||||
return SamplingMetadata(
|
return SamplingMetadata(
|
||||||
temperature=0.0,
|
temperature=torch.tensor([]),
|
||||||
all_greedy=True,
|
all_greedy=True,
|
||||||
all_random=False,
|
all_random=False,
|
||||||
rejection_sampling=True,
|
|
||||||
spec_token_ids=spec_tokens,
|
spec_token_ids=spec_tokens,
|
||||||
top_p=None,
|
top_p=None,
|
||||||
top_k=None,
|
top_k=None,
|
||||||
no_top_p=False,
|
|
||||||
no_top_k=False,
|
|
||||||
min_p=torch.empty(batch_size, ),
|
min_p=torch.empty(batch_size, ),
|
||||||
no_min_p=True,
|
|
||||||
generators={},
|
generators={},
|
||||||
max_num_logprobs=0,
|
max_num_logprobs=0,
|
||||||
no_penalties=False,
|
no_penalties=False,
|
||||||
@ -45,8 +41,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
|
|||||||
presence_penalties=torch.tensor([]),
|
presence_penalties=torch.tensor([]),
|
||||||
repetition_penalties=torch.tensor([]),
|
repetition_penalties=torch.tensor([]),
|
||||||
output_token_ids=[],
|
output_token_ids=[],
|
||||||
min_tokens=[],
|
min_tokens={},
|
||||||
stop_token_ids=[],
|
|
||||||
logit_bias=[None] * batch_size,
|
logit_bias=[None] * batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -77,25 +77,20 @@ def _create_default_sampling_metadata(
|
|||||||
temperature=torch.full((batch_size, ), 0.0),
|
temperature=torch.full((batch_size, ), 0.0),
|
||||||
all_greedy=True,
|
all_greedy=True,
|
||||||
all_random=False,
|
all_random=False,
|
||||||
rejection_sampling=False,
|
top_p=None,
|
||||||
top_p=torch.empty(batch_size, ),
|
top_k=None,
|
||||||
top_k=torch.empty(batch_size, ),
|
min_p=None,
|
||||||
no_top_p=True,
|
|
||||||
no_top_k=True,
|
|
||||||
min_p=torch.empty(batch_size, ),
|
|
||||||
no_min_p=True,
|
|
||||||
generators={},
|
generators={},
|
||||||
max_num_logprobs=0,
|
max_num_logprobs=0,
|
||||||
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
|
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
|
||||||
vocab_size, device),
|
vocab_size, device),
|
||||||
output_token_ids=output_token_ids,
|
output_token_ids=output_token_ids,
|
||||||
spec_token_ids=[],
|
spec_token_ids=None,
|
||||||
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||||
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||||
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
|
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
|
||||||
no_penalties=True,
|
no_penalties=True,
|
||||||
min_tokens=[],
|
min_tokens={},
|
||||||
stop_token_ids=[],
|
|
||||||
logit_bias=[None] * batch_size,
|
logit_bias=[None] * batch_size,
|
||||||
)
|
)
|
||||||
return fake_sampling_metadata
|
return fake_sampling_metadata
|
||||||
@ -104,10 +99,10 @@ def _create_default_sampling_metadata(
|
|||||||
def _generate_min_token_penalties_and_stop_tokens(
|
def _generate_min_token_penalties_and_stop_tokens(
|
||||||
num_output_tokens: int, batch_size: int, vocab_size: int,
|
num_output_tokens: int, batch_size: int, vocab_size: int,
|
||||||
batch_indices_for_min_token_penalty: List[int]
|
batch_indices_for_min_token_penalty: List[int]
|
||||||
) -> Tuple[List[int], List[Set[int]]]:
|
) -> Dict[int, Tuple[int, Set[int]]]:
|
||||||
"""
|
"""
|
||||||
Generates and returns a list of minimum token penalties (`min_tokens`)
|
Generates and returns a dict of minimum token penalties and
|
||||||
and a corresponding list of stop token IDs (`stop_token_ids`) for each
|
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
|
||||||
batch.
|
batch.
|
||||||
|
|
||||||
If a batch index is included in `batch_indices_for_min_token_penalty`,
|
If a batch index is included in `batch_indices_for_min_token_penalty`,
|
||||||
@ -115,22 +110,19 @@ def _generate_min_token_penalties_and_stop_tokens(
|
|||||||
and a random set of stop token IDs is created. Otherwise, a lower
|
and a random set of stop token IDs is created. Otherwise, a lower
|
||||||
`min_tokens` value is assigned, and the stop token IDs set is empty.
|
`min_tokens` value is assigned, and the stop token IDs set is empty.
|
||||||
"""
|
"""
|
||||||
stop_token_ids: List[Set[int]] = []
|
min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
|
||||||
min_tokens: List[int] = []
|
|
||||||
for index in range(batch_size):
|
for index in range(batch_size):
|
||||||
if index in batch_indices_for_min_token_penalty:
|
if index in batch_indices_for_min_token_penalty:
|
||||||
min_tokens.append(
|
min_tokens[index] = (
|
||||||
np.random.randint(num_output_tokens + 1,
|
np.random.randint(num_output_tokens + 1,
|
||||||
2 * num_output_tokens))
|
2 * num_output_tokens),
|
||||||
stop_token_ids.append(
|
|
||||||
set(
|
set(
|
||||||
np.random.randint(0, vocab_size - 1)
|
np.random.randint(0, vocab_size - 1)
|
||||||
for _ in range(np.random.randint(0, vocab_size))))
|
for _ in range(np.random.randint(0, vocab_size))))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
min_tokens.append(np.random.randint(0, num_output_tokens))
|
min_tokens[index] = (np.random.randint(0,
|
||||||
stop_token_ids.append(set())
|
num_output_tokens), set())
|
||||||
return (min_tokens, stop_token_ids)
|
return min_tokens
|
||||||
|
|
||||||
|
|
||||||
def _create_weighted_output_token_list(
|
def _create_weighted_output_token_list(
|
||||||
@ -165,7 +157,7 @@ def _create_weighted_output_token_list(
|
|||||||
output_token_ids_for_batch.extend(
|
output_token_ids_for_batch.extend(
|
||||||
[token_id for _ in range(index + 1)])
|
[token_id for _ in range(index + 1)])
|
||||||
output_token_ids.append(output_token_ids_for_batch)
|
output_token_ids.append(output_token_ids_for_batch)
|
||||||
return (output_token_ids, sorted_token_ids_in_output)
|
return output_token_ids, sorted_token_ids_in_output
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@ -182,17 +174,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
|
|||||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||||
batch_indices_for_min_token_penalty = np.random.randint(
|
batch_indices_for_min_token_penalty = np.random.randint(
|
||||||
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
|
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
|
||||||
min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens(
|
min_tokens = _generate_min_token_penalties_and_stop_tokens(
|
||||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
|
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
|
||||||
batch_indices_for_min_token_penalty)
|
batch_indices_for_min_token_penalty)
|
||||||
sampling_metadata.min_tokens = min_tokens
|
sampling_metadata.min_tokens = min_tokens
|
||||||
sampling_metadata.stop_token_ids = stop_token_ids
|
|
||||||
sampler = Sampler()
|
sampler = Sampler()
|
||||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||||
logits = logits.cpu()
|
logits = logits.cpu()
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
for token_id in range(VOCAB_SIZE):
|
for token_id in range(VOCAB_SIZE):
|
||||||
if token_id in stop_token_ids[batch_idx]:
|
_, stop_token_ids = min_tokens.get(batch_idx, (0, set()))
|
||||||
|
if token_id in stop_token_ids:
|
||||||
assert logits[batch_idx][token_id] == -float("inf")
|
assert logits[batch_idx][token_id] == -float("inf")
|
||||||
else:
|
else:
|
||||||
assert logits[batch_idx][token_id] != -float("inf")
|
assert logits[batch_idx][token_id] != -float("inf")
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Dict, List, Set, Tuple
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -41,7 +41,7 @@ def _remove_requests(
|
|||||||
for index in req_indices_to_remove:
|
for index in req_indices_to_remove:
|
||||||
input_batch.remove_request(reqs[index].req_id)
|
input_batch.remove_request(reqs[index].req_id)
|
||||||
req_ids_to_remove.add(reqs[index].req_id)
|
req_ids_to_remove.add(reqs[index].req_id)
|
||||||
return (req_ids_to_remove, req_indices_to_remove_list)
|
return req_ids_to_remove, req_indices_to_remove_list
|
||||||
|
|
||||||
|
|
||||||
def _construct_expected_sampling_metadata(
|
def _construct_expected_sampling_metadata(
|
||||||
@ -64,8 +64,7 @@ def _construct_expected_sampling_metadata(
|
|||||||
top_p = [0.0 for _ in range(num_reqs)]
|
top_p = [0.0 for _ in range(num_reqs)]
|
||||||
min_p = [0.0 for _ in range(num_reqs)]
|
min_p = [0.0 for _ in range(num_reqs)]
|
||||||
temperature = [0.0 for _ in range(num_reqs)]
|
temperature = [0.0 for _ in range(num_reqs)]
|
||||||
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)]
|
min_tokens = {}
|
||||||
min_tokens = [0 for _ in range(num_reqs)]
|
|
||||||
logit_bias = [None] * num_reqs
|
logit_bias = [None] * num_reqs
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
if req.req_id not in req_ids_retained:
|
if req.req_id not in req_ids_retained:
|
||||||
@ -83,22 +82,21 @@ def _construct_expected_sampling_metadata(
|
|||||||
top_p[index_in_input_batch] = req.sampling_params.top_p
|
top_p[index_in_input_batch] = req.sampling_params.top_p
|
||||||
min_p[index_in_input_batch] = req.sampling_params.min_p
|
min_p[index_in_input_batch] = req.sampling_params.min_p
|
||||||
temperature[index_in_input_batch] = req.sampling_params.temperature
|
temperature[index_in_input_batch] = req.sampling_params.temperature
|
||||||
stop_token_ids[
|
min_tokens[index_in_input_batch] = (
|
||||||
index_in_input_batch] = req.sampling_params.all_stop_token_ids
|
req.sampling_params.min_tokens,
|
||||||
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens
|
req.sampling_params.all_stop_token_ids)
|
||||||
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
|
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
|
||||||
return SamplingMetadata(
|
return SamplingMetadata(
|
||||||
temperature=torch.tensor(temperature, dtype=torch.float,
|
temperature=torch.tensor(temperature, dtype=torch.float,
|
||||||
device=device),
|
device=device),
|
||||||
all_greedy=False,
|
all_greedy=False,
|
||||||
all_random=True,
|
all_random=True,
|
||||||
rejection_sampling=False,
|
top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
|
||||||
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
|
top_p, dtype=torch.float, device=device),
|
||||||
top_k=torch.tensor(top_k, dtype=torch.int, device=device),
|
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
|
||||||
no_top_p=all(x == 1.0 for x in top_p),
|
top_k, dtype=torch.int, device=device),
|
||||||
no_top_k=all(x == 0 for x in top_k),
|
min_p=None if all(x == 0.0 for x in min_p) else torch.tensor(
|
||||||
min_p=torch.tensor(min_p, dtype=torch.float, device=device),
|
min_p, dtype=torch.float, device=device),
|
||||||
no_min_p=all(x == 0.0 for x in min_p),
|
|
||||||
generators={},
|
generators={},
|
||||||
max_num_logprobs=0,
|
max_num_logprobs=0,
|
||||||
prompt_token_ids=make_tensor_with_pad(
|
prompt_token_ids=make_tensor_with_pad(
|
||||||
@ -117,9 +115,8 @@ def _construct_expected_sampling_metadata(
|
|||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device=device),
|
device=device),
|
||||||
output_token_ids=output_token_ids,
|
output_token_ids=output_token_ids,
|
||||||
spec_token_ids=[],
|
spec_token_ids=None,
|
||||||
min_tokens=min_tokens,
|
min_tokens=min_tokens,
|
||||||
stop_token_ids=stop_token_ids,
|
|
||||||
no_penalties=(all(x == 0 for x in presence_penalties)
|
no_penalties=(all(x == 0 for x in presence_penalties)
|
||||||
and all(x == 0 for x in frequency_penalties)
|
and all(x == 0 for x in frequency_penalties)
|
||||||
and all(x == 1 for x in repetition_penalties)),
|
and all(x == 1 for x in repetition_penalties)),
|
||||||
@ -206,8 +203,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
|||||||
input_batch.condense(req_indices_to_remove)
|
input_batch.condense(req_indices_to_remove)
|
||||||
|
|
||||||
# Generate the sampling metadata
|
# Generate the sampling metadata
|
||||||
sampling_metadata = input_batch.make_sampling_metadata(
|
sampling_metadata = input_batch._make_sampling_metadata()
|
||||||
req_id_output_token_ids, req_id_to_spec_token_ids={}, skip_copy=False)
|
|
||||||
|
|
||||||
# Create expected output.
|
# Create expected output.
|
||||||
expected_sampling_metadata = _construct_expected_sampling_metadata(
|
expected_sampling_metadata = _construct_expected_sampling_metadata(
|
||||||
@ -216,13 +212,16 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
|||||||
input_batch.req_id_to_index,
|
input_batch.req_id_to_index,
|
||||||
device=torch.device(device))
|
device=torch.device(device))
|
||||||
|
|
||||||
|
def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
|
||||||
|
return (t1 is None
|
||||||
|
and t2 is None) or (t1 is not None and t2 is not None
|
||||||
|
and torch.allclose(t1, t2))
|
||||||
|
|
||||||
# Assert the actual and expected output.
|
# Assert the actual and expected output.
|
||||||
assert torch.allclose(expected_sampling_metadata.temperature,
|
assert torch.allclose(expected_sampling_metadata.temperature,
|
||||||
sampling_metadata.temperature)
|
sampling_metadata.temperature)
|
||||||
assert torch.allclose(expected_sampling_metadata.top_p,
|
assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
|
||||||
sampling_metadata.top_p)
|
assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
|
||||||
assert torch.allclose(expected_sampling_metadata.top_k,
|
|
||||||
sampling_metadata.top_k)
|
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
expected_sampling_metadata.frequency_penalties,
|
expected_sampling_metadata.frequency_penalties,
|
||||||
sampling_metadata.frequency_penalties,
|
sampling_metadata.frequency_penalties,
|
||||||
@ -240,10 +239,6 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
|||||||
assert (expected_sampling_metadata.output_token_ids ==
|
assert (expected_sampling_metadata.output_token_ids ==
|
||||||
sampling_metadata.output_token_ids)
|
sampling_metadata.output_token_ids)
|
||||||
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
|
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
|
||||||
assert expected_sampling_metadata.stop_token_ids == \
|
|
||||||
sampling_metadata.stop_token_ids
|
|
||||||
assert expected_sampling_metadata.no_penalties == \
|
assert expected_sampling_metadata.no_penalties == \
|
||||||
sampling_metadata.no_penalties
|
sampling_metadata.no_penalties
|
||||||
assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p
|
|
||||||
assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k
|
|
||||||
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
|
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
|||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
|
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
|
||||||
SchedulerOutput)
|
SchedulerOutput)
|
||||||
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
|
||||||
|
|
||||||
@ -82,14 +83,21 @@ def _is_req_added(model_runner, req_id: str) -> bool:
|
|||||||
return req_id in model_runner.requests
|
return req_id in model_runner.requests
|
||||||
|
|
||||||
|
|
||||||
|
def _is_sampling_metadata_changed(model_runner,
|
||||||
|
sampling_metadata_before: SamplingMetadata):
|
||||||
|
return model_runner.input_batch.sampling_metadata is not (
|
||||||
|
sampling_metadata_before)
|
||||||
|
|
||||||
|
|
||||||
def test_update_states_new_request(model_runner):
|
def test_update_states_new_request(model_runner):
|
||||||
req_id = "req_0"
|
req_id = "req_0"
|
||||||
|
|
||||||
# new req
|
# new req
|
||||||
scheduler_output = _schedule_new_request(req_id)
|
scheduler_output = _schedule_new_request(req_id)
|
||||||
|
|
||||||
batch_changed = model_runner._update_states(scheduler_output)
|
metadata_before = model_runner.input_batch.sampling_metadata
|
||||||
assert batch_changed is True
|
model_runner._update_states(scheduler_output)
|
||||||
|
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||||
assert _is_req_added(model_runner, req_id)
|
assert _is_req_added(model_runner, req_id)
|
||||||
assert _is_req_scheduled(model_runner, req_id)
|
assert _is_req_scheduled(model_runner, req_id)
|
||||||
|
|
||||||
@ -117,8 +125,9 @@ def test_update_states_request_finished(model_runner):
|
|||||||
free_encoder_input_ids=[],
|
free_encoder_input_ids=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_changed = model_runner._update_states(scheduler_output)
|
metadata_before = model_runner.input_batch.sampling_metadata
|
||||||
assert batch_changed is True
|
model_runner._update_states(scheduler_output)
|
||||||
|
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||||
assert not _is_req_added(model_runner, req_id)
|
assert not _is_req_added(model_runner, req_id)
|
||||||
assert not _is_req_scheduled(model_runner, req_id)
|
assert not _is_req_scheduled(model_runner, req_id)
|
||||||
|
|
||||||
@ -142,7 +151,7 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids={},
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_input_ids=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -171,8 +180,9 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
free_encoder_input_ids=[],
|
free_encoder_input_ids=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_changed = model_runner._update_states(scheduler_output)
|
metadata_before = model_runner.input_batch.sampling_metadata
|
||||||
assert batch_changed is True
|
model_runner._update_states(scheduler_output)
|
||||||
|
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||||
assert _is_req_added(model_runner, req_id)
|
assert _is_req_added(model_runner, req_id)
|
||||||
assert _is_req_scheduled(model_runner, req_id)
|
assert _is_req_scheduled(model_runner, req_id)
|
||||||
|
|
||||||
@ -200,8 +210,9 @@ def test_update_states_no_changes(model_runner):
|
|||||||
free_encoder_input_ids=[],
|
free_encoder_input_ids=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_changed = model_runner._update_states(scheduler_output)
|
metadata_before = model_runner.input_batch.sampling_metadata
|
||||||
assert batch_changed is False
|
model_runner._update_states(scheduler_output)
|
||||||
|
assert not _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||||
assert _is_req_added(model_runner, req_id)
|
assert _is_req_added(model_runner, req_id)
|
||||||
assert _is_req_scheduled(model_runner, req_id)
|
assert _is_req_scheduled(model_runner, req_id)
|
||||||
|
|
||||||
@ -233,8 +244,8 @@ def test_update_states_request_unscheduled(model_runner):
|
|||||||
free_encoder_input_ids=[],
|
free_encoder_input_ids=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_changed = model_runner._update_states(scheduler_output)
|
metadata_before = model_runner._update_states(scheduler_output)
|
||||||
assert batch_changed is True
|
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||||
|
|
||||||
assert _is_req_added(model_runner, req_ids[0])
|
assert _is_req_added(model_runner, req_ids[0])
|
||||||
assert _is_req_scheduled(model_runner, req_ids[0])
|
assert _is_req_scheduled(model_runner, req_ids[0])
|
||||||
|
|||||||
@ -45,7 +45,7 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
|||||||
vocab_size, num_seqs)
|
vocab_size, num_seqs)
|
||||||
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
|
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
|
||||||
output_tokens_tensor, vocab_size, num_seqs)
|
output_tokens_tensor, vocab_size, num_seqs)
|
||||||
repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat(
|
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
|
||||||
1, vocab_size)
|
1, vocab_size)
|
||||||
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
|
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
|
||||||
repetition_penalties, 1.0)[logits > 0]
|
repetition_penalties, 1.0)[logits > 0]
|
||||||
@ -53,6 +53,6 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
|||||||
repetition_penalties, 1.0)[logits <= 0]
|
repetition_penalties, 1.0)[logits <= 0]
|
||||||
# We follow the definition in OpenAI API.
|
# We follow the definition in OpenAI API.
|
||||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||||
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
|
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
|
||||||
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
|
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
|
||||||
return logits
|
return logits
|
||||||
|
|||||||
@ -195,8 +195,10 @@ class Scheduler:
|
|||||||
request.num_computed_tokens -
|
request.num_computed_tokens -
|
||||||
request.num_tokens)
|
request.num_tokens)
|
||||||
if num_scheduled_spec_tokens > 0:
|
if num_scheduled_spec_tokens > 0:
|
||||||
|
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
||||||
|
del request.spec_token_ids[num_scheduled_spec_tokens:]
|
||||||
scheduled_spec_decode_tokens[request.request_id] = (
|
scheduled_spec_decode_tokens[request.request_id] = (
|
||||||
request.spec_token_ids[:num_scheduled_spec_tokens])
|
request.spec_token_ids)
|
||||||
|
|
||||||
# Encoder-related.
|
# Encoder-related.
|
||||||
if encoder_inputs_to_schedule:
|
if encoder_inputs_to_schedule:
|
||||||
@ -567,7 +569,7 @@ class Scheduler:
|
|||||||
outputs.append(
|
outputs.append(
|
||||||
EngineCoreOutput(
|
EngineCoreOutput(
|
||||||
request_id=req_id,
|
request_id=req_id,
|
||||||
new_token_ids=new_token_ids or [],
|
new_token_ids=new_token_ids,
|
||||||
finish_reason=request.get_finished_reason(),
|
finish_reason=request.get_finished_reason(),
|
||||||
new_logprobs=new_logprobs,
|
new_logprobs=new_logprobs,
|
||||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Set
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -12,15 +12,13 @@ class SamplingMetadata:
|
|||||||
temperature: torch.Tensor
|
temperature: torch.Tensor
|
||||||
all_greedy: bool
|
all_greedy: bool
|
||||||
all_random: bool
|
all_random: bool
|
||||||
rejection_sampling: bool
|
|
||||||
spec_token_ids: List[List[int]]
|
|
||||||
|
|
||||||
top_p: torch.Tensor
|
# None when there are no speculated tokens.
|
||||||
top_k: torch.Tensor
|
spec_token_ids: Optional[List[List[int]]]
|
||||||
no_top_p: bool
|
|
||||||
no_top_k: bool
|
top_p: Optional[torch.Tensor]
|
||||||
min_p: torch.Tensor
|
top_k: Optional[torch.Tensor]
|
||||||
no_min_p: bool
|
min_p: Optional[torch.Tensor]
|
||||||
|
|
||||||
generators: Dict[int, torch.Generator]
|
generators: Dict[int, torch.Generator]
|
||||||
|
|
||||||
@ -34,7 +32,8 @@ class SamplingMetadata:
|
|||||||
repetition_penalties: torch.Tensor
|
repetition_penalties: torch.Tensor
|
||||||
|
|
||||||
output_token_ids: List[List[int]]
|
output_token_ids: List[List[int]]
|
||||||
min_tokens: List[int]
|
|
||||||
stop_token_ids: List[Set[int]]
|
# req_index -> (min_tokens, stop_token_ids)
|
||||||
|
min_tokens: Dict[int, Tuple[int, Set[int]]]
|
||||||
|
|
||||||
logit_bias: List[Optional[Dict[int, float]]]
|
logit_bias: List[Optional[Dict[int, float]]]
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import List, Set, Tuple
|
from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -8,18 +8,17 @@ from vllm.model_executor.layers.utils import apply_penalties
|
|||||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||||
|
|
||||||
|
|
||||||
def apply_min_token_penalties(logits: torch.Tensor,
|
def apply_min_token_penalties(
|
||||||
output_token_ids: List[List[int]],
|
logits: torch.Tensor, output_token_ids: List[List[int]],
|
||||||
stop_token_ids: List[Set[int]],
|
min_tokens: Dict[int, Tuple[int, Set[int]]]) -> None:
|
||||||
min_tokens: List[int]) -> None:
|
|
||||||
"""
|
"""
|
||||||
Applies minimum token penalty by setting the logits of the stop tokens
|
Applies minimum token penalty by setting the logits of the stop tokens
|
||||||
to -inf.
|
to -inf.
|
||||||
"""
|
"""
|
||||||
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
|
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
|
||||||
for index, min_token in enumerate(min_tokens):
|
for index, (min_token, stop_token_ids) in min_tokens.items():
|
||||||
if len(output_token_ids[index]) < min_token:
|
if len(output_token_ids[index]) < min_token:
|
||||||
for stop_token_id in stop_token_ids[index]:
|
for stop_token_id in stop_token_ids:
|
||||||
min_tokens_logits_to_penalize.append((index, stop_token_id))
|
min_tokens_logits_to_penalize.append((index, stop_token_id))
|
||||||
if min_tokens_logits_to_penalize:
|
if min_tokens_logits_to_penalize:
|
||||||
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
|
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -55,13 +55,11 @@ class TopKTopPSampler(nn.Module):
|
|||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
generators: Dict[int, torch.Generator],
|
generators: Dict[int, torch.Generator],
|
||||||
no_top_k: bool,
|
k: Optional[torch.Tensor],
|
||||||
k: torch.Tensor,
|
p: Optional[torch.Tensor],
|
||||||
no_top_p: bool,
|
|
||||||
p: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""PyTorch-native implementation of top-k and top-p sampling."""
|
"""PyTorch-native implementation of top-k and top-p sampling."""
|
||||||
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p)
|
logits = apply_top_k_top_p(logits, k, p)
|
||||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
return random_sample(probs, generators)
|
return random_sample(probs, generators)
|
||||||
|
|
||||||
@ -69,37 +67,33 @@ class TopKTopPSampler(nn.Module):
|
|||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
generators: Dict[int, torch.Generator],
|
generators: Dict[int, torch.Generator],
|
||||||
no_top_k: bool,
|
k: Optional[torch.Tensor],
|
||||||
k: torch.Tensor,
|
p: Optional[torch.Tensor],
|
||||||
no_top_p: bool,
|
|
||||||
p: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""More optimized implementation for top-k and top-p sampling."""
|
"""More optimized implementation for top-k and top-p sampling."""
|
||||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
if no_top_k and no_top_p:
|
if k is None and p is None:
|
||||||
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
||||||
# not needed. This is because `random_sample` does not require
|
# not needed. This is because `random_sample` does not require
|
||||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||||
return random_sample(probs, generators)
|
return random_sample(probs, generators)
|
||||||
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators)
|
return flashinfer_sample(probs, k, p, generators)
|
||||||
|
|
||||||
|
|
||||||
def apply_top_k_top_p(
|
def apply_top_k_top_p(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
no_top_k: bool,
|
k: Optional[torch.Tensor],
|
||||||
k: torch.Tensor,
|
p: Optional[torch.Tensor],
|
||||||
no_top_p: bool,
|
|
||||||
p: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Apply top-k and top-p masks to the logits.
|
"""Apply top-k and top-p masks to the logits.
|
||||||
|
|
||||||
This function sorts the logits tensor, which can be slow for large batches.
|
This function sorts the logits tensor, which can be slow for large batches.
|
||||||
"""
|
"""
|
||||||
if no_top_k and no_top_p:
|
if k is None and p is None:
|
||||||
return logits
|
return logits
|
||||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||||
|
|
||||||
if not no_top_k:
|
if k is not None:
|
||||||
# Apply top-k.
|
# Apply top-k.
|
||||||
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
||||||
# Get all the top_k values.
|
# Get all the top_k values.
|
||||||
@ -107,7 +101,7 @@ def apply_top_k_top_p(
|
|||||||
top_k_mask = logits_sort < top_k_mask
|
top_k_mask = logits_sort < top_k_mask
|
||||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||||
|
|
||||||
if not no_top_p:
|
if p is not None:
|
||||||
# Apply top-p.
|
# Apply top-p.
|
||||||
probs_sort = logits_sort.softmax(dim=-1)
|
probs_sort = logits_sort.softmax(dim=-1)
|
||||||
probs_sum = probs_sort.cumsum(dim=-1)
|
probs_sum = probs_sort.cumsum(dim=-1)
|
||||||
@ -147,10 +141,8 @@ def random_sample(
|
|||||||
|
|
||||||
def flashinfer_sample(
|
def flashinfer_sample(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
no_top_k: bool,
|
k: Optional[torch.Tensor],
|
||||||
k: torch.Tensor,
|
p: Optional[torch.Tensor],
|
||||||
no_top_p: bool,
|
|
||||||
p: torch.Tensor,
|
|
||||||
generators: Dict[int, torch.Generator],
|
generators: Dict[int, torch.Generator],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Sample from the probabilities using FlashInfer.
|
"""Sample from the probabilities using FlashInfer.
|
||||||
@ -167,7 +159,7 @@ def flashinfer_sample(
|
|||||||
does not. Call this function at the end of the forward pass to minimize
|
does not. Call this function at the end of the forward pass to minimize
|
||||||
the synchronization overhead.
|
the synchronization overhead.
|
||||||
"""
|
"""
|
||||||
assert not (no_top_k and no_top_p)
|
assert not (k is None and p is None)
|
||||||
max_top_k_round = 32
|
max_top_k_round = 32
|
||||||
batch_size = probs.shape[0]
|
batch_size = probs.shape[0]
|
||||||
uniform_samples = torch.empty((max_top_k_round, batch_size),
|
uniform_samples = torch.empty((max_top_k_round, batch_size),
|
||||||
@ -178,11 +170,11 @@ def flashinfer_sample(
|
|||||||
for i, generator in generators.items():
|
for i, generator in generators.items():
|
||||||
uniform_samples[:, i].uniform_(generator=generator)
|
uniform_samples[:, i].uniform_(generator=generator)
|
||||||
|
|
||||||
if no_top_k:
|
if k is None:
|
||||||
# Top-p only.
|
# Top-p only.
|
||||||
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
|
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
|
||||||
probs, uniform_samples, p, deterministic=True)
|
probs, uniform_samples, p, deterministic=True)
|
||||||
elif no_top_p:
|
elif p is None:
|
||||||
# Top-k only.
|
# Top-k only.
|
||||||
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
|
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
|
||||||
probs, uniform_samples, k, deterministic=True)
|
probs, uniform_samples, k, deterministic=True)
|
||||||
@ -194,9 +186,9 @@ def flashinfer_sample(
|
|||||||
|
|
||||||
# NOTE: CPU-GPU synchronization happens here.
|
# NOTE: CPU-GPU synchronization happens here.
|
||||||
if not success.all():
|
if not success.all():
|
||||||
if not no_top_k:
|
if k is not None:
|
||||||
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
|
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
|
||||||
if not no_top_p:
|
if p is not None:
|
||||||
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
|
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
|
||||||
next_token_ids = flashinfer.sampling.sampling_from_probs(
|
next_token_ids = flashinfer.sampling.sampling_from_probs(
|
||||||
probs, uniform_samples[0], deterministic=True)
|
probs, uniform_samples[0], deterministic=True)
|
||||||
|
|||||||
@ -68,6 +68,7 @@ class RejectionSampler(nn.Module):
|
|||||||
# NOTE: The following input preparationg can be moved
|
# NOTE: The following input preparationg can be moved
|
||||||
# to the model runner with a persistent manner for better
|
# to the model runner with a persistent manner for better
|
||||||
# performance.
|
# performance.
|
||||||
|
assert sampling_metadata.spec_token_ids is not None
|
||||||
spec_token_ids = sampling_metadata.spec_token_ids
|
spec_token_ids = sampling_metadata.spec_token_ids
|
||||||
max_spec_len = max(len(s) for s in spec_token_ids)
|
max_spec_len = max(len(s) for s in spec_token_ids)
|
||||||
batch_size = len(spec_token_ids)
|
batch_size = len(spec_token_ids)
|
||||||
@ -119,6 +120,7 @@ class RejectionSampler(nn.Module):
|
|||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
|
assert sampling_metadata.spec_token_ids is not None
|
||||||
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
|
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
|
||||||
# Add 1 to include the 'bonus' token.
|
# Add 1 to include the 'bonus' token.
|
||||||
sample_lens = [x + 1 for x in spec_lens]
|
sample_lens = [x + 1 for x in spec_lens]
|
||||||
|
|||||||
@ -26,7 +26,7 @@ class Sampler(nn.Module):
|
|||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
if sampling_metadata.rejection_sampling:
|
if sampling_metadata.spec_token_ids:
|
||||||
if sampling_metadata.max_num_logprobs:
|
if sampling_metadata.max_num_logprobs:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Rejection sampling does not support logprobs.")
|
"Rejection sampling does not support logprobs.")
|
||||||
@ -104,16 +104,14 @@ class Sampler(nn.Module):
|
|||||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||||
|
|
||||||
# Apply min_p.
|
# Apply min_p.
|
||||||
if not sampling_metadata.no_min_p:
|
if sampling_metadata.min_p is not None:
|
||||||
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||||
|
|
||||||
# Apply top_k and/or top_p.
|
# Apply top_k and/or top_p.
|
||||||
random_sampled = self.topk_topp_sampler(
|
random_sampled = self.topk_topp_sampler(
|
||||||
logits,
|
logits,
|
||||||
sampling_metadata.generators,
|
sampling_metadata.generators,
|
||||||
sampling_metadata.no_top_k,
|
|
||||||
sampling_metadata.top_k,
|
sampling_metadata.top_k,
|
||||||
sampling_metadata.no_top_p,
|
|
||||||
sampling_metadata.top_p,
|
sampling_metadata.top_p,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -179,8 +177,9 @@ class Sampler(nn.Module):
|
|||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
|
if sampling_metadata.min_tokens:
|
||||||
sampling_metadata.stop_token_ids,
|
apply_min_token_penalties(logits,
|
||||||
|
sampling_metadata.output_token_ids,
|
||||||
sampling_metadata.min_tokens)
|
sampling_metadata.min_tokens)
|
||||||
if not sampling_metadata.no_penalties:
|
if not sampling_metadata.no_penalties:
|
||||||
assert sampling_metadata.prompt_token_ids is not None
|
assert sampling_metadata.prompt_token_ids is not None
|
||||||
|
|||||||
@ -188,3 +188,14 @@ def bind_kv_cache(
|
|||||||
for layer_name, kv_cache in kv_caches.items():
|
for layer_name, kv_cache in kv_caches.items():
|
||||||
# NOTE: Use list because of v0 PP virtual engine.
|
# NOTE: Use list because of v0 PP virtual engine.
|
||||||
forward_context[layer_name].kv_cache = [kv_cache]
|
forward_context[layer_name].kv_cache = [kv_cache]
|
||||||
|
|
||||||
|
|
||||||
|
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
|
||||||
|
length: int) -> None:
|
||||||
|
"""
|
||||||
|
Copy the first length elements of a tensor into another tensor in a
|
||||||
|
non-blocking manner.
|
||||||
|
|
||||||
|
Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
|
||||||
|
"""
|
||||||
|
to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
# Datastructures defining an input batch
|
# Datastructures defining an input batch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -12,6 +11,7 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.multimodal import MultiModalKwargs
|
from vllm.multimodal import MultiModalKwargs
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.utils import copy_slice
|
||||||
from vllm.v1.worker.block_table import BlockTable
|
from vllm.v1.worker.block_table import BlockTable
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
@ -63,7 +63,7 @@ class InputBatch:
|
|||||||
self.pin_memory = pin_memory
|
self.pin_memory = pin_memory
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
self.req_ids: List[Optional[str]] = [None] * max_num_reqs
|
self._req_ids: List[Optional[str]] = []
|
||||||
self.req_id_to_index: Dict[str, int] = {}
|
self.req_id_to_index: Dict[str, int] = {}
|
||||||
|
|
||||||
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
||||||
@ -171,11 +171,8 @@ class InputBatch:
|
|||||||
self.repetition_penalties_cpu_tensor.numpy()
|
self.repetition_penalties_cpu_tensor.numpy()
|
||||||
self.repetition_penalties_reqs: Set[str] = set()
|
self.repetition_penalties_reqs: Set[str] = set()
|
||||||
|
|
||||||
self.min_tokens: List[int] = [0] * max_num_reqs
|
# req_index -> (min_tokens, stop_token_ids)
|
||||||
self.stop_token_ids: List[Set[int]] = [
|
self.min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
|
||||||
set() for _ in range(max_num_reqs)
|
|
||||||
]
|
|
||||||
self.prompt_token_ids: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
# lora related
|
# lora related
|
||||||
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
||||||
@ -196,6 +193,17 @@ class InputBatch:
|
|||||||
self.logit_bias: List[Optional[Dict[int,
|
self.logit_bias: List[Optional[Dict[int,
|
||||||
float]]] = [None] * max_num_reqs
|
float]]] = [None] * max_num_reqs
|
||||||
|
|
||||||
|
self.req_output_token_ids: List[Optional[List[int]]] = []
|
||||||
|
|
||||||
|
# This is updated each time the batch constituents change.
|
||||||
|
self.sampling_metadata = self._make_sampling_metadata()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def req_ids(self) -> List[str]:
|
||||||
|
# None elements should only be present transiently
|
||||||
|
# while performing state updates to the batch.
|
||||||
|
return cast(List[str], self._req_ids)
|
||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
request: "CachedRequestState",
|
request: "CachedRequestState",
|
||||||
@ -206,7 +214,13 @@ class InputBatch:
|
|||||||
assert req_index < self.max_num_reqs
|
assert req_index < self.max_num_reqs
|
||||||
|
|
||||||
req_id = request.req_id
|
req_id = request.req_id
|
||||||
self.req_ids[req_index] = req_id
|
if req_index == len(self._req_ids):
|
||||||
|
self._req_ids.append(req_id)
|
||||||
|
self.req_output_token_ids.append(request.output_token_ids)
|
||||||
|
else:
|
||||||
|
self._req_ids[req_index] = req_id
|
||||||
|
self.req_output_token_ids[req_index] = request.output_token_ids
|
||||||
|
|
||||||
self.req_id_to_index[req_id] = req_index
|
self.req_id_to_index[req_id] = req_index
|
||||||
|
|
||||||
# Copy the prompt token ids and output token ids.
|
# Copy the prompt token ids and output token ids.
|
||||||
@ -255,8 +269,9 @@ class InputBatch:
|
|||||||
req_index] = sampling_params.repetition_penalty
|
req_index] = sampling_params.repetition_penalty
|
||||||
if sampling_params.repetition_penalty != 1.0:
|
if sampling_params.repetition_penalty != 1.0:
|
||||||
self.repetition_penalties_reqs.add(req_id)
|
self.repetition_penalties_reqs.add(req_id)
|
||||||
self.min_tokens[req_index] = sampling_params.min_tokens
|
if sampling_params.min_tokens:
|
||||||
self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids
|
self.min_tokens[req_index] = (sampling_params.min_tokens,
|
||||||
|
sampling_params.all_stop_token_ids)
|
||||||
|
|
||||||
# NOTE(woosuk): self.generators should not include the requests that
|
# NOTE(woosuk): self.generators should not include the requests that
|
||||||
# do not have their own generator.
|
# do not have their own generator.
|
||||||
@ -284,16 +299,20 @@ class InputBatch:
|
|||||||
self.request_lora_mapping[req_index] = 0
|
self.request_lora_mapping[req_index] = 0
|
||||||
|
|
||||||
def remove_request(self, req_id: str) -> Optional[int]:
|
def remove_request(self, req_id: str) -> Optional[int]:
|
||||||
|
"""This method must always be followed by a call to condense()."""
|
||||||
|
|
||||||
req_index = self.req_id_to_index.pop(req_id, None)
|
req_index = self.req_id_to_index.pop(req_id, None)
|
||||||
if req_index is None:
|
if req_index is None:
|
||||||
return None
|
return None
|
||||||
self.req_ids[req_index] = None
|
self._req_ids[req_index] = None
|
||||||
|
self.req_output_token_ids[req_index] = None
|
||||||
|
|
||||||
self.greedy_reqs.discard(req_id)
|
self.greedy_reqs.discard(req_id)
|
||||||
self.random_reqs.discard(req_id)
|
self.random_reqs.discard(req_id)
|
||||||
self.top_p_reqs.discard(req_id)
|
self.top_p_reqs.discard(req_id)
|
||||||
self.top_k_reqs.discard(req_id)
|
self.top_k_reqs.discard(req_id)
|
||||||
self.min_p_reqs.discard(req_id)
|
self.min_p_reqs.discard(req_id)
|
||||||
|
self.min_tokens.pop(req_index, None)
|
||||||
self.frequency_penalties_reqs.discard(req_id)
|
self.frequency_penalties_reqs.discard(req_id)
|
||||||
self.presence_penalties_reqs.discard(req_id)
|
self.presence_penalties_reqs.discard(req_id)
|
||||||
self.repetition_penalties_reqs.discard(req_id)
|
self.repetition_penalties_reqs.discard(req_id)
|
||||||
@ -313,33 +332,17 @@ class InputBatch:
|
|||||||
self.logit_bias[req_index] = None
|
self.logit_bias[req_index] = None
|
||||||
return req_index
|
return req_index
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
self.req_ids = [None] * self.max_num_reqs
|
|
||||||
self.req_id_to_index.clear()
|
|
||||||
self.greedy_reqs.clear()
|
|
||||||
self.random_reqs.clear()
|
|
||||||
self.top_p_reqs.clear()
|
|
||||||
self.top_k_reqs.clear()
|
|
||||||
self.min_p_reqs.clear()
|
|
||||||
self.frequency_penalties_reqs.clear()
|
|
||||||
self.presence_penalties_reqs.clear()
|
|
||||||
self.repetition_penalties_reqs.clear()
|
|
||||||
self.generators.clear()
|
|
||||||
self.num_logprobs.clear()
|
|
||||||
self.num_prompt_logprobs.clear()
|
|
||||||
self.request_lora_mapping.fill(0)
|
|
||||||
self.lora_id_to_lora_request.clear()
|
|
||||||
self.lora_id_to_request_ids.clear()
|
|
||||||
self.logit_bias = [None] * self.max_num_reqs
|
|
||||||
|
|
||||||
def condense(self, empty_req_indices: List[int]) -> None:
|
def condense(self, empty_req_indices: List[int]) -> None:
|
||||||
if self.num_reqs == 0:
|
num_reqs = self.num_reqs
|
||||||
|
if num_reqs == 0:
|
||||||
# The batched states are empty.
|
# The batched states are empty.
|
||||||
|
self._req_ids.clear()
|
||||||
|
self.req_output_token_ids.clear()
|
||||||
return
|
return
|
||||||
|
|
||||||
# NOTE(woosuk): This function assumes that the empty_req_indices
|
# NOTE(woosuk): This function assumes that the empty_req_indices
|
||||||
# is sorted in descending order.
|
# is sorted in descending order.
|
||||||
last_req_index = self.num_reqs + len(empty_req_indices) - 1
|
last_req_index = num_reqs + len(empty_req_indices) - 1
|
||||||
while empty_req_indices:
|
while empty_req_indices:
|
||||||
# Find the largest non-empty index.
|
# Find the largest non-empty index.
|
||||||
while last_req_index in empty_req_indices:
|
while last_req_index in empty_req_indices:
|
||||||
@ -351,10 +354,13 @@ class InputBatch:
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Swap the states.
|
# Swap the states.
|
||||||
req_id = self.req_ids[last_req_index]
|
req_id = self._req_ids[last_req_index]
|
||||||
|
output_token_ids = self.req_output_token_ids[last_req_index]
|
||||||
assert req_id is not None
|
assert req_id is not None
|
||||||
self.req_ids[empty_index] = req_id
|
self._req_ids[empty_index] = req_id
|
||||||
self.req_ids[last_req_index] = None
|
self._req_ids[last_req_index] = None
|
||||||
|
self.req_output_token_ids[empty_index] = output_token_ids
|
||||||
|
self.req_output_token_ids[last_req_index] = None
|
||||||
self.req_id_to_index[req_id] = empty_index
|
self.req_id_to_index[req_id] = empty_index
|
||||||
|
|
||||||
num_tokens = self.num_tokens[last_req_index]
|
num_tokens = self.num_tokens[last_req_index]
|
||||||
@ -379,13 +385,14 @@ class InputBatch:
|
|||||||
self.repetition_penalties_cpu[
|
self.repetition_penalties_cpu[
|
||||||
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
||||||
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
|
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
|
||||||
self.min_tokens[empty_index] = self.min_tokens[last_req_index]
|
|
||||||
self.stop_token_ids[empty_index] = self.stop_token_ids[
|
|
||||||
last_req_index]
|
|
||||||
generator = self.generators.pop(last_req_index, None)
|
generator = self.generators.pop(last_req_index, None)
|
||||||
if generator is not None:
|
if generator is not None:
|
||||||
self.generators[empty_index] = generator
|
self.generators[empty_index] = generator
|
||||||
|
|
||||||
|
min_token = self.min_tokens.pop(last_req_index, None)
|
||||||
|
if min_token is not None:
|
||||||
|
self.min_tokens[empty_index] = min_token
|
||||||
|
|
||||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||||
last_req_index]
|
last_req_index]
|
||||||
|
|
||||||
@ -394,87 +401,71 @@ class InputBatch:
|
|||||||
# Decrement last_req_index since it is now empty.
|
# Decrement last_req_index since it is now empty.
|
||||||
last_req_index -= 1
|
last_req_index -= 1
|
||||||
|
|
||||||
def make_sampling_metadata(
|
# Trim lists to the batch size.
|
||||||
self,
|
del self._req_ids[self.num_reqs:]
|
||||||
req_id_output_token_ids: Dict[str, List[int]],
|
del self.req_output_token_ids[self.num_reqs:]
|
||||||
req_id_to_spec_token_ids: Dict[str, List[int]],
|
|
||||||
skip_copy: bool = False,
|
def refresh_sampling_metadata(self):
|
||||||
) -> SamplingMetadata:
|
self.sampling_metadata = self._make_sampling_metadata()
|
||||||
if not skip_copy:
|
|
||||||
self.temperature[:self.num_reqs].copy_(
|
def _make_sampling_metadata(self) -> SamplingMetadata:
|
||||||
self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True)
|
num_reqs = self.num_reqs
|
||||||
self.top_p[:self.num_reqs].copy_(
|
copy_slice(self.temperature_cpu_tensor, self.temperature, num_reqs)
|
||||||
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
|
if not self.no_top_p:
|
||||||
self.top_k[:self.num_reqs].copy_(
|
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
||||||
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
|
if not self.no_top_k:
|
||||||
self.min_p[:self.num_reqs].copy_(
|
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
|
||||||
self.min_p_cpu_tensor[:self.num_reqs], non_blocking=True)
|
if not self.no_min_p:
|
||||||
|
copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
|
||||||
|
|
||||||
if not self.no_penalties:
|
if not self.no_penalties:
|
||||||
# Since syncing these tensors is expensive only copy them
|
# Since syncing these tensors is expensive only copy them
|
||||||
# if necessary i.e. if there are requests which require
|
# if necessary i.e. if there are requests which require
|
||||||
# penalties to be applied during sampling.
|
# penalties to be applied during sampling.
|
||||||
self.frequency_penalties[:self.num_reqs].copy_(
|
copy_slice(self.frequency_penalties_cpu_tensor,
|
||||||
self.frequency_penalties_cpu_tensor[:self.num_reqs],
|
self.frequency_penalties, num_reqs)
|
||||||
non_blocking=True,
|
copy_slice(self.presence_penalties_cpu_tensor,
|
||||||
)
|
self.presence_penalties, num_reqs)
|
||||||
self.presence_penalties[:self.num_reqs].copy_(
|
copy_slice(self.repetition_penalties_cpu_tensor,
|
||||||
self.presence_penalties_cpu_tensor[:self.num_reqs],
|
self.repetition_penalties, num_reqs)
|
||||||
non_blocking=True,
|
|
||||||
)
|
|
||||||
self.repetition_penalties[:self.num_reqs].copy_(
|
|
||||||
self.repetition_penalties_cpu_tensor[:self.num_reqs],
|
|
||||||
non_blocking=True,
|
|
||||||
)
|
|
||||||
# The prompt tokens are used only for applying penalties during
|
# The prompt tokens are used only for applying penalties during
|
||||||
# the sampling process. Hence copy these tensors only when
|
# the sampling process. Hence copy these tensors only when
|
||||||
# there are requests which need penalties to be applied.
|
# there are requests which need penalties to be applied.
|
||||||
self.prompt_token_ids = self._make_prompt_token_ids_tensor()
|
prompt_token_ids = self._make_prompt_token_ids_tensor()
|
||||||
|
else:
|
||||||
output_token_ids: List[List[int]] = []
|
prompt_token_ids = None
|
||||||
spec_token_ids: List[List[int]] = []
|
|
||||||
rejection_sampling = False
|
|
||||||
for req_id in self.req_ids[:self.num_reqs]:
|
|
||||||
assert req_id is not None
|
|
||||||
# Currently we create a tensor for output_token_ids from scratch
|
|
||||||
# at each step. However, for the penalties computation what we
|
|
||||||
# need is stats about the token ids present in the output. This
|
|
||||||
# stats can be maintained incrementally instead of computing it
|
|
||||||
# from scratch at each step.
|
|
||||||
# TODO - Replace this with incremental update to output token
|
|
||||||
# statistics.
|
|
||||||
output_token_ids.append(req_id_output_token_ids[req_id])
|
|
||||||
req_spec_token_ids = req_id_to_spec_token_ids.get(req_id, [])
|
|
||||||
spec_token_ids.append(req_spec_token_ids)
|
|
||||||
if req_spec_token_ids:
|
|
||||||
# If any of the requests require speculative decoding, set the
|
|
||||||
# flag to True.
|
|
||||||
rejection_sampling = True
|
|
||||||
|
|
||||||
return SamplingMetadata(
|
return SamplingMetadata(
|
||||||
temperature=self.temperature[:self.num_reqs],
|
temperature=self.temperature[:num_reqs],
|
||||||
all_greedy=self.all_greedy,
|
all_greedy=self.all_greedy,
|
||||||
all_random=self.all_random,
|
all_random=self.all_random,
|
||||||
rejection_sampling=rejection_sampling,
|
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
||||||
top_p=self.top_p[:self.num_reqs],
|
top_k=None if self.no_top_k else self.top_k[:num_reqs],
|
||||||
top_k=self.top_k[:self.num_reqs],
|
min_p=None if self.no_min_p else self.min_p[:num_reqs],
|
||||||
min_p=self.min_p[:self.num_reqs],
|
|
||||||
no_min_p=self.no_min_p,
|
|
||||||
no_top_p=self.no_top_p,
|
|
||||||
no_top_k=self.no_top_k,
|
|
||||||
generators=self.generators,
|
generators=self.generators,
|
||||||
max_num_logprobs=self.max_num_logprobs,
|
max_num_logprobs=self.max_num_logprobs,
|
||||||
prompt_token_ids=self.prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
frequency_penalties=self.frequency_penalties[:self.num_reqs],
|
frequency_penalties=self.frequency_penalties[:num_reqs],
|
||||||
presence_penalties=self.presence_penalties[:self.num_reqs],
|
presence_penalties=self.presence_penalties[:num_reqs],
|
||||||
repetition_penalties=self.repetition_penalties[:self.num_reqs],
|
repetition_penalties=self.repetition_penalties[:num_reqs],
|
||||||
output_token_ids=output_token_ids,
|
output_token_ids=cast(List[List[int]], self.req_output_token_ids),
|
||||||
spec_token_ids=spec_token_ids,
|
spec_token_ids=None,
|
||||||
min_tokens=self.min_tokens[:self.num_reqs],
|
min_tokens=self.min_tokens,
|
||||||
stop_token_ids=self.stop_token_ids[:self.num_reqs],
|
|
||||||
no_penalties=self.no_penalties,
|
no_penalties=self.no_penalties,
|
||||||
logit_bias=self.logit_bias[:self.num_reqs],
|
logit_bias=self.logit_bias[:num_reqs],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_sampling_metadata(
|
||||||
|
self,
|
||||||
|
req_id_to_spec_token_ids: Dict[str, List[int]],
|
||||||
|
) -> SamplingMetadata:
|
||||||
|
# Set the new spec token ids in the cached sampling metadata.
|
||||||
|
self.sampling_metadata.spec_token_ids = [
|
||||||
|
req_id_to_spec_token_ids.get(req_id, []) for req_id in self.req_ids
|
||||||
|
] if req_id_to_spec_token_ids else None
|
||||||
|
return self.sampling_metadata
|
||||||
|
|
||||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
||||||
prompt_token_ids_cpu_tensor = torch.empty(
|
prompt_token_ids_cpu_tensor = torch.empty(
|
||||||
|
|||||||
@ -31,7 +31,6 @@ from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
|||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
|
||||||
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
|
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
from vllm.v1.utils import bind_kv_cache
|
from vllm.v1.utils import bind_kv_cache
|
||||||
@ -224,16 +223,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
pin_memory=self.pin_memory)
|
pin_memory=self.pin_memory)
|
||||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
"""Update the cached states and the persistent batch with the scheduler
|
"""Update the cached states and the persistent batch with the scheduler
|
||||||
output.
|
output.
|
||||||
|
|
||||||
The updated states are used by the `_prepare_inputs` function to create
|
The updated states are used by the `_prepare_inputs` function to create
|
||||||
the input GPU tensors for the model.
|
the input GPU tensors for the model.
|
||||||
|
|
||||||
Returns:
|
The SamplingMetadata is updated and copied to the GPU if there is a
|
||||||
True if there is a new/resumed/paused/finished request in the batch.
|
new/resumed/paused/finished request in the batch.
|
||||||
If False, we can skip copying SamplingMetadata to the GPU.
|
|
||||||
"""
|
"""
|
||||||
# Remove finished requests from the cached states.
|
# Remove finished requests from the cached states.
|
||||||
for req_id in scheduler_output.finished_req_ids:
|
for req_id in scheduler_output.finished_req_ids:
|
||||||
@ -344,9 +342,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_new_tokens = (num_computed_tokens +
|
num_new_tokens = (num_computed_tokens +
|
||||||
len(req_data.new_token_ids) -
|
len(req_data.new_token_ids) -
|
||||||
req_state.num_tokens)
|
req_state.num_tokens)
|
||||||
new_token_ids = (req_data.new_token_ids[-num_new_tokens:]
|
if num_new_tokens == 1:
|
||||||
if num_new_tokens > 0 else [])
|
# Avoid slicing list in most common case.
|
||||||
req_state.output_token_ids.extend(new_token_ids)
|
req_state.output_token_ids.append(req_data.new_token_ids[-1])
|
||||||
|
elif num_new_tokens > 0:
|
||||||
|
req_state.output_token_ids.extend(
|
||||||
|
req_data.new_token_ids[-num_new_tokens:])
|
||||||
# Update the block IDs.
|
# Update the block IDs.
|
||||||
if not req_data.resumed_from_preemption:
|
if not req_data.resumed_from_preemption:
|
||||||
# Append the new blocks to the existing block IDs.
|
# Append the new blocks to the existing block IDs.
|
||||||
@ -380,7 +381,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
|
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
|
||||||
# Add spec_token_ids to token_ids_cpu.
|
# Add spec_token_ids to token_ids_cpu.
|
||||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||||
req_id, [])
|
req_id, ())
|
||||||
if spec_token_ids:
|
if spec_token_ids:
|
||||||
start_index = end_token_index
|
start_index = end_token_index
|
||||||
end_token_index += len(spec_token_ids)
|
end_token_index += len(spec_token_ids)
|
||||||
@ -410,7 +411,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if removed_req_indices:
|
if removed_req_indices:
|
||||||
self.input_batch.condense(removed_req_indices)
|
self.input_batch.condense(removed_req_indices)
|
||||||
|
|
||||||
return batch_changed
|
if batch_changed:
|
||||||
|
self.input_batch.refresh_sampling_metadata()
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self,
|
self,
|
||||||
@ -429,8 +431,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# TODO: The Python loop can be slow. Optimize.
|
# TODO: The Python loop can be slow. Optimize.
|
||||||
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
|
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||||
max_num_scheduled_tokens = 0
|
max_num_scheduled_tokens = 0
|
||||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||||
assert req_id is not None
|
|
||||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||||
num_scheduled_tokens[i] = num_tokens
|
num_scheduled_tokens[i] = num_tokens
|
||||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||||
@ -669,10 +670,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
||||||
mrope_pos_ptr = 0
|
mrope_pos_ptr = 0
|
||||||
num_reqs = self.input_batch.num_reqs
|
for index, req_id in enumerate(self.input_batch.req_ids):
|
||||||
for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
|
||||||
assert req_id is not None
|
|
||||||
|
|
||||||
req = self.requests[req_id]
|
req = self.requests[req_id]
|
||||||
assert req.mrope_positions is not None
|
assert req.mrope_positions is not None
|
||||||
|
|
||||||
@ -726,12 +724,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
cu_num_tokens: np.ndarray,
|
cu_num_tokens: np.ndarray,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
# Get the number of spec decode tokens for each request.
|
# Get the number of spec decode tokens for each request.
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32)
|
num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||||
assert req_id is not None
|
|
||||||
num_spec_decode_tokens[i] = len(
|
num_spec_decode_tokens[i] = len(
|
||||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
|
||||||
|
|
||||||
@ -769,22 +766,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
return torch.from_numpy(spec_decode_logits_indices).to(
|
return torch.from_numpy(spec_decode_logits_indices).to(
|
||||||
self.device, non_blocking=True)
|
self.device, non_blocking=True)
|
||||||
|
|
||||||
def _prepare_sampling(
|
|
||||||
self,
|
|
||||||
batch_changed: bool,
|
|
||||||
req_to_spec_token_ids: Dict[str, List[int]],
|
|
||||||
) -> SamplingMetadata:
|
|
||||||
# Create the sampling metadata.
|
|
||||||
req_id_output_token_ids: Dict[str, List[int]] = \
|
|
||||||
{req_id: req.output_token_ids \
|
|
||||||
for req_id, req in self.requests.items()}
|
|
||||||
|
|
||||||
sampling_metadata = self.input_batch.make_sampling_metadata(
|
|
||||||
req_id_output_token_ids,
|
|
||||||
req_to_spec_token_ids,
|
|
||||||
skip_copy=not batch_changed)
|
|
||||||
return sampling_metadata
|
|
||||||
|
|
||||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||||
if not scheduled_encoder_inputs:
|
if not scheduled_encoder_inputs:
|
||||||
@ -838,9 +819,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
encoder_outputs: List[torch.Tensor] = []
|
encoder_outputs: List[torch.Tensor] = []
|
||||||
num_reqs = self.input_batch.num_reqs
|
for req_id in self.input_batch.req_ids:
|
||||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
|
||||||
assert req_id is not None
|
|
||||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||||
req_id]
|
req_id]
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
@ -882,7 +861,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
||||||
batch_changed = self._update_states(scheduler_output)
|
self._update_states(scheduler_output)
|
||||||
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
# Run the multimodal encoder if any.
|
# Run the multimodal encoder if any.
|
||||||
@ -964,8 +943,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
|
|
||||||
# Sample the next token and get logprobs if needed.
|
# Sample the next token and get logprobs if needed.
|
||||||
sampling_metadata = self._prepare_sampling(
|
sampling_metadata = self.input_batch.get_sampling_metadata(
|
||||||
batch_changed, scheduler_output.scheduled_spec_decode_tokens)
|
scheduler_output.scheduled_spec_decode_tokens)
|
||||||
sampler_output = self.model.sample(
|
sampler_output = self.model.sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
@ -973,14 +952,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||||
# the requests one by one. Optimize.
|
# the requests one by one. Optimize.
|
||||||
num_reqs = self.input_batch.num_reqs
|
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||||
req_ids: List[str] = []
|
|
||||||
# Because `input_batch.req_ids` is a list of length `max_num_reqs`,
|
|
||||||
# we need to stop at `num_reqs`.
|
|
||||||
# FIXME(woosuk): This is hacky. Refactor.
|
|
||||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
|
||||||
assert req_id is not None
|
|
||||||
req_ids.append(req_id)
|
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
seq_len = (req_state.num_computed_tokens +
|
seq_len = (req_state.num_computed_tokens +
|
||||||
scheduler_output.num_scheduled_tokens[req_id])
|
scheduler_output.num_scheduled_tokens[req_id])
|
||||||
@ -1027,7 +999,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
valid_sampled_token_ids)
|
valid_sampled_token_ids)
|
||||||
|
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=req_ids,
|
req_ids=self.input_batch.req_ids,
|
||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
sampled_token_ids=valid_sampled_token_ids,
|
sampled_token_ids=valid_sampled_token_ids,
|
||||||
spec_token_ids=spec_token_ids,
|
spec_token_ids=spec_token_ids,
|
||||||
@ -1041,19 +1013,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
sampled_token_ids: List[List[int]],
|
sampled_token_ids: List[List[int]],
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
# TODO(woosuk): Optimize.
|
# TODO(woosuk): Optimize.
|
||||||
num_reqs = len(sampled_token_ids)
|
|
||||||
draft_token_ids: List[List[int]] = []
|
draft_token_ids: List[List[int]] = []
|
||||||
for i in range(num_reqs):
|
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||||
if len(sampled_token_ids[i]) == 0:
|
num_sampled_ids = len(sampled_ids)
|
||||||
|
if not num_sampled_ids:
|
||||||
# Skip speculative decoding.
|
# Skip speculative decoding.
|
||||||
draft_token_ids.append([])
|
draft_token_ids.append([])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Add sampled_token_ids to token_ids_cpu.
|
# Add sampled_token_ids to token_ids_cpu.
|
||||||
start_idx = self.input_batch.num_tokens_no_spec[i]
|
start_idx = self.input_batch.num_tokens_no_spec[i]
|
||||||
end_idx = start_idx + len(sampled_token_ids[i])
|
end_idx = start_idx + num_sampled_ids
|
||||||
self.input_batch.token_ids_cpu[
|
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
|
||||||
i, start_idx:end_idx] = sampled_token_ids[i]
|
|
||||||
drafter_output = self.drafter.propose(
|
drafter_output = self.drafter.propose(
|
||||||
self.input_batch.token_ids_cpu[i, :end_idx],
|
self.input_batch.token_ids_cpu[i, :end_idx],
|
||||||
self.speculative_config.ngram_prompt_lookup_min,
|
self.speculative_config.ngram_prompt_lookup_min,
|
||||||
@ -1204,7 +1175,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# multiplying the list, to avoid Dynamo from treating them as
|
# multiplying the list, to avoid Dynamo from treating them as
|
||||||
# tensor aliasing.
|
# tensor aliasing.
|
||||||
dummy_kv_caches = [
|
dummy_kv_caches = [
|
||||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
torch.tensor((), dtype=torch.float32, device=self.device)
|
||||||
for _ in range(self.num_attn_layers)
|
for _ in range(self.num_attn_layers)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1048,8 +1048,6 @@ def swap_positions(b: InputBatch, id_1, id_2):
|
|||||||
|
|
||||||
b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[
|
b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[
|
||||||
id_1]
|
id_1]
|
||||||
b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[
|
|
||||||
id_2], b.stop_token_ids[id_1]
|
|
||||||
|
|
||||||
gen_1 = b.generators.pop(id_1, None)
|
gen_1 = b.generators.pop(id_1, None)
|
||||||
gen_2 = b.generators.pop(id_2, None)
|
gen_2 = b.generators.pop(id_2, None)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user