mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:45:00 +08:00
[V1] Optimize handling of sampling metadata and req_ids list (#13244)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
a4d577b379
commit
30172b4947
@ -26,17 +26,13 @@ def create_logits_tensor(token_ids: List[int],
|
||||
def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
|
||||
batch_size = len(spec_tokens)
|
||||
return SamplingMetadata(
|
||||
temperature=0.0,
|
||||
temperature=torch.tensor([]),
|
||||
all_greedy=True,
|
||||
all_random=False,
|
||||
rejection_sampling=True,
|
||||
spec_token_ids=spec_tokens,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
no_top_p=False,
|
||||
no_top_k=False,
|
||||
min_p=torch.empty(batch_size, ),
|
||||
no_min_p=True,
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
no_penalties=False,
|
||||
@ -45,8 +41,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
|
||||
presence_penalties=torch.tensor([]),
|
||||
repetition_penalties=torch.tensor([]),
|
||||
output_token_ids=[],
|
||||
min_tokens=[],
|
||||
stop_token_ids=[],
|
||||
min_tokens={},
|
||||
logit_bias=[None] * batch_size,
|
||||
)
|
||||
|
||||
|
||||
@ -77,25 +77,20 @@ def _create_default_sampling_metadata(
|
||||
temperature=torch.full((batch_size, ), 0.0),
|
||||
all_greedy=True,
|
||||
all_random=False,
|
||||
rejection_sampling=False,
|
||||
top_p=torch.empty(batch_size, ),
|
||||
top_k=torch.empty(batch_size, ),
|
||||
no_top_p=True,
|
||||
no_top_k=True,
|
||||
min_p=torch.empty(batch_size, ),
|
||||
no_min_p=True,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
min_p=None,
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
|
||||
vocab_size, device),
|
||||
output_token_ids=output_token_ids,
|
||||
spec_token_ids=[],
|
||||
spec_token_ids=None,
|
||||
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
|
||||
no_penalties=True,
|
||||
min_tokens=[],
|
||||
stop_token_ids=[],
|
||||
min_tokens={},
|
||||
logit_bias=[None] * batch_size,
|
||||
)
|
||||
return fake_sampling_metadata
|
||||
@ -104,10 +99,10 @@ def _create_default_sampling_metadata(
|
||||
def _generate_min_token_penalties_and_stop_tokens(
|
||||
num_output_tokens: int, batch_size: int, vocab_size: int,
|
||||
batch_indices_for_min_token_penalty: List[int]
|
||||
) -> Tuple[List[int], List[Set[int]]]:
|
||||
) -> Dict[int, Tuple[int, Set[int]]]:
|
||||
"""
|
||||
Generates and returns a list of minimum token penalties (`min_tokens`)
|
||||
and a corresponding list of stop token IDs (`stop_token_ids`) for each
|
||||
Generates and returns a dict of minimum token penalties and
|
||||
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
|
||||
batch.
|
||||
|
||||
If a batch index is included in `batch_indices_for_min_token_penalty`,
|
||||
@ -115,22 +110,19 @@ def _generate_min_token_penalties_and_stop_tokens(
|
||||
and a random set of stop token IDs is created. Otherwise, a lower
|
||||
`min_tokens` value is assigned, and the stop token IDs set is empty.
|
||||
"""
|
||||
stop_token_ids: List[Set[int]] = []
|
||||
min_tokens: List[int] = []
|
||||
min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
|
||||
for index in range(batch_size):
|
||||
if index in batch_indices_for_min_token_penalty:
|
||||
min_tokens.append(
|
||||
min_tokens[index] = (
|
||||
np.random.randint(num_output_tokens + 1,
|
||||
2 * num_output_tokens))
|
||||
stop_token_ids.append(
|
||||
2 * num_output_tokens),
|
||||
set(
|
||||
np.random.randint(0, vocab_size - 1)
|
||||
for _ in range(np.random.randint(0, vocab_size))))
|
||||
|
||||
else:
|
||||
min_tokens.append(np.random.randint(0, num_output_tokens))
|
||||
stop_token_ids.append(set())
|
||||
return (min_tokens, stop_token_ids)
|
||||
min_tokens[index] = (np.random.randint(0,
|
||||
num_output_tokens), set())
|
||||
return min_tokens
|
||||
|
||||
|
||||
def _create_weighted_output_token_list(
|
||||
@ -165,7 +157,7 @@ def _create_weighted_output_token_list(
|
||||
output_token_ids_for_batch.extend(
|
||||
[token_id for _ in range(index + 1)])
|
||||
output_token_ids.append(output_token_ids_for_batch)
|
||||
return (output_token_ids, sorted_token_ids_in_output)
|
||||
return output_token_ids, sorted_token_ids_in_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@ -182,17 +174,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
batch_indices_for_min_token_penalty = np.random.randint(
|
||||
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
|
||||
min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens(
|
||||
min_tokens = _generate_min_token_penalties_and_stop_tokens(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
|
||||
batch_indices_for_min_token_penalty)
|
||||
sampling_metadata.min_tokens = min_tokens
|
||||
sampling_metadata.stop_token_ids = stop_token_ids
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
if token_id in stop_token_ids[batch_idx]:
|
||||
_, stop_token_ids = min_tokens.get(batch_idx, (0, set()))
|
||||
if token_id in stop_token_ids:
|
||||
assert logits[batch_idx][token_id] == -float("inf")
|
||||
else:
|
||||
assert logits[batch_idx][token_id] != -float("inf")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -41,7 +41,7 @@ def _remove_requests(
|
||||
for index in req_indices_to_remove:
|
||||
input_batch.remove_request(reqs[index].req_id)
|
||||
req_ids_to_remove.add(reqs[index].req_id)
|
||||
return (req_ids_to_remove, req_indices_to_remove_list)
|
||||
return req_ids_to_remove, req_indices_to_remove_list
|
||||
|
||||
|
||||
def _construct_expected_sampling_metadata(
|
||||
@ -64,8 +64,7 @@ def _construct_expected_sampling_metadata(
|
||||
top_p = [0.0 for _ in range(num_reqs)]
|
||||
min_p = [0.0 for _ in range(num_reqs)]
|
||||
temperature = [0.0 for _ in range(num_reqs)]
|
||||
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)]
|
||||
min_tokens = [0 for _ in range(num_reqs)]
|
||||
min_tokens = {}
|
||||
logit_bias = [None] * num_reqs
|
||||
for req in reqs:
|
||||
if req.req_id not in req_ids_retained:
|
||||
@ -83,22 +82,21 @@ def _construct_expected_sampling_metadata(
|
||||
top_p[index_in_input_batch] = req.sampling_params.top_p
|
||||
min_p[index_in_input_batch] = req.sampling_params.min_p
|
||||
temperature[index_in_input_batch] = req.sampling_params.temperature
|
||||
stop_token_ids[
|
||||
index_in_input_batch] = req.sampling_params.all_stop_token_ids
|
||||
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens
|
||||
min_tokens[index_in_input_batch] = (
|
||||
req.sampling_params.min_tokens,
|
||||
req.sampling_params.all_stop_token_ids)
|
||||
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
|
||||
return SamplingMetadata(
|
||||
temperature=torch.tensor(temperature, dtype=torch.float,
|
||||
device=device),
|
||||
all_greedy=False,
|
||||
all_random=True,
|
||||
rejection_sampling=False,
|
||||
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
|
||||
top_k=torch.tensor(top_k, dtype=torch.int, device=device),
|
||||
no_top_p=all(x == 1.0 for x in top_p),
|
||||
no_top_k=all(x == 0 for x in top_k),
|
||||
min_p=torch.tensor(min_p, dtype=torch.float, device=device),
|
||||
no_min_p=all(x == 0.0 for x in min_p),
|
||||
top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
|
||||
top_p, dtype=torch.float, device=device),
|
||||
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
|
||||
top_k, dtype=torch.int, device=device),
|
||||
min_p=None if all(x == 0.0 for x in min_p) else torch.tensor(
|
||||
min_p, dtype=torch.float, device=device),
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=make_tensor_with_pad(
|
||||
@ -117,9 +115,8 @@ def _construct_expected_sampling_metadata(
|
||||
dtype=torch.float,
|
||||
device=device),
|
||||
output_token_ids=output_token_ids,
|
||||
spec_token_ids=[],
|
||||
spec_token_ids=None,
|
||||
min_tokens=min_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
no_penalties=(all(x == 0 for x in presence_penalties)
|
||||
and all(x == 0 for x in frequency_penalties)
|
||||
and all(x == 1 for x in repetition_penalties)),
|
||||
@ -206,8 +203,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
input_batch.condense(req_indices_to_remove)
|
||||
|
||||
# Generate the sampling metadata
|
||||
sampling_metadata = input_batch.make_sampling_metadata(
|
||||
req_id_output_token_ids, req_id_to_spec_token_ids={}, skip_copy=False)
|
||||
sampling_metadata = input_batch._make_sampling_metadata()
|
||||
|
||||
# Create expected output.
|
||||
expected_sampling_metadata = _construct_expected_sampling_metadata(
|
||||
@ -216,13 +212,16 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
input_batch.req_id_to_index,
|
||||
device=torch.device(device))
|
||||
|
||||
def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
|
||||
return (t1 is None
|
||||
and t2 is None) or (t1 is not None and t2 is not None
|
||||
and torch.allclose(t1, t2))
|
||||
|
||||
# Assert the actual and expected output.
|
||||
assert torch.allclose(expected_sampling_metadata.temperature,
|
||||
sampling_metadata.temperature)
|
||||
assert torch.allclose(expected_sampling_metadata.top_p,
|
||||
sampling_metadata.top_p)
|
||||
assert torch.allclose(expected_sampling_metadata.top_k,
|
||||
sampling_metadata.top_k)
|
||||
assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
|
||||
assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
@ -240,10 +239,6 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
assert (expected_sampling_metadata.output_token_ids ==
|
||||
sampling_metadata.output_token_ids)
|
||||
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
|
||||
assert expected_sampling_metadata.stop_token_ids == \
|
||||
sampling_metadata.stop_token_ids
|
||||
assert expected_sampling_metadata.no_penalties == \
|
||||
sampling_metadata.no_penalties
|
||||
assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p
|
||||
assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k
|
||||
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
|
||||
|
||||
@ -5,6 +5,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
@ -82,14 +83,21 @@ def _is_req_added(model_runner, req_id: str) -> bool:
|
||||
return req_id in model_runner.requests
|
||||
|
||||
|
||||
def _is_sampling_metadata_changed(model_runner,
|
||||
sampling_metadata_before: SamplingMetadata):
|
||||
return model_runner.input_batch.sampling_metadata is not (
|
||||
sampling_metadata_before)
|
||||
|
||||
|
||||
def test_update_states_new_request(model_runner):
|
||||
req_id = "req_0"
|
||||
|
||||
# new req
|
||||
scheduler_output = _schedule_new_request(req_id)
|
||||
|
||||
batch_changed = model_runner._update_states(scheduler_output)
|
||||
assert batch_changed is True
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
@ -117,8 +125,9 @@ def test_update_states_request_finished(model_runner):
|
||||
free_encoder_input_ids=[],
|
||||
)
|
||||
|
||||
batch_changed = model_runner._update_states(scheduler_output)
|
||||
assert batch_changed is True
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
assert not _is_req_added(model_runner, req_id)
|
||||
assert not _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
@ -142,7 +151,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids={},
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
)
|
||||
|
||||
@ -171,8 +180,9 @@ def test_update_states_request_resumed(model_runner):
|
||||
free_encoder_input_ids=[],
|
||||
)
|
||||
|
||||
batch_changed = model_runner._update_states(scheduler_output)
|
||||
assert batch_changed is True
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
@ -200,8 +210,9 @@ def test_update_states_no_changes(model_runner):
|
||||
free_encoder_input_ids=[],
|
||||
)
|
||||
|
||||
batch_changed = model_runner._update_states(scheduler_output)
|
||||
assert batch_changed is False
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert not _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
@ -233,8 +244,8 @@ def test_update_states_request_unscheduled(model_runner):
|
||||
free_encoder_input_ids=[],
|
||||
)
|
||||
|
||||
batch_changed = model_runner._update_states(scheduler_output)
|
||||
assert batch_changed is True
|
||||
metadata_before = model_runner._update_states(scheduler_output)
|
||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
|
||||
assert _is_req_added(model_runner, req_ids[0])
|
||||
assert _is_req_scheduled(model_runner, req_ids[0])
|
||||
|
||||
@ -45,7 +45,7 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||
vocab_size, num_seqs)
|
||||
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
|
||||
output_tokens_tensor, vocab_size, num_seqs)
|
||||
repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat(
|
||||
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
|
||||
1, vocab_size)
|
||||
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
|
||||
repetition_penalties, 1.0)[logits > 0]
|
||||
@ -53,6 +53,6 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||
repetition_penalties, 1.0)[logits <= 0]
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
|
||||
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
|
||||
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
|
||||
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
|
||||
return logits
|
||||
|
||||
@ -195,8 +195,10 @@ class Scheduler:
|
||||
request.num_computed_tokens -
|
||||
request.num_tokens)
|
||||
if num_scheduled_spec_tokens > 0:
|
||||
# Trim spec_token_ids list to num_scheduled_spec_tokens.
|
||||
del request.spec_token_ids[num_scheduled_spec_tokens:]
|
||||
scheduled_spec_decode_tokens[request.request_id] = (
|
||||
request.spec_token_ids[:num_scheduled_spec_tokens])
|
||||
request.spec_token_ids)
|
||||
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
@ -567,7 +569,7 @@ class Scheduler:
|
||||
outputs.append(
|
||||
EngineCoreOutput(
|
||||
request_id=req_id,
|
||||
new_token_ids=new_token_ids or [],
|
||||
new_token_ids=new_token_ids,
|
||||
finish_reason=request.get_finished_reason(),
|
||||
new_logprobs=new_logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -12,15 +12,13 @@ class SamplingMetadata:
|
||||
temperature: torch.Tensor
|
||||
all_greedy: bool
|
||||
all_random: bool
|
||||
rejection_sampling: bool
|
||||
spec_token_ids: List[List[int]]
|
||||
|
||||
top_p: torch.Tensor
|
||||
top_k: torch.Tensor
|
||||
no_top_p: bool
|
||||
no_top_k: bool
|
||||
min_p: torch.Tensor
|
||||
no_min_p: bool
|
||||
# None when there are no speculated tokens.
|
||||
spec_token_ids: Optional[List[List[int]]]
|
||||
|
||||
top_p: Optional[torch.Tensor]
|
||||
top_k: Optional[torch.Tensor]
|
||||
min_p: Optional[torch.Tensor]
|
||||
|
||||
generators: Dict[int, torch.Generator]
|
||||
|
||||
@ -34,7 +32,8 @@ class SamplingMetadata:
|
||||
repetition_penalties: torch.Tensor
|
||||
|
||||
output_token_ids: List[List[int]]
|
||||
min_tokens: List[int]
|
||||
stop_token_ids: List[Set[int]]
|
||||
|
||||
# req_index -> (min_tokens, stop_token_ids)
|
||||
min_tokens: Dict[int, Tuple[int, Set[int]]]
|
||||
|
||||
logit_bias: List[Optional[Dict[int, float]]]
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Set, Tuple
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -8,18 +8,17 @@ from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
|
||||
|
||||
def apply_min_token_penalties(logits: torch.Tensor,
|
||||
output_token_ids: List[List[int]],
|
||||
stop_token_ids: List[Set[int]],
|
||||
min_tokens: List[int]) -> None:
|
||||
def apply_min_token_penalties(
|
||||
logits: torch.Tensor, output_token_ids: List[List[int]],
|
||||
min_tokens: Dict[int, Tuple[int, Set[int]]]) -> None:
|
||||
"""
|
||||
Applies minimum token penalty by setting the logits of the stop tokens
|
||||
to -inf.
|
||||
"""
|
||||
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
|
||||
for index, min_token in enumerate(min_tokens):
|
||||
for index, (min_token, stop_token_ids) in min_tokens.items():
|
||||
if len(output_token_ids[index]) < min_token:
|
||||
for stop_token_id in stop_token_ids[index]:
|
||||
for stop_token_id in stop_token_ids:
|
||||
min_tokens_logits_to_penalize.append((index, stop_token_id))
|
||||
if min_tokens_logits_to_penalize:
|
||||
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -55,13 +55,11 @@ class TopKTopPSampler(nn.Module):
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: Dict[int, torch.Generator],
|
||||
no_top_k: bool,
|
||||
k: torch.Tensor,
|
||||
no_top_p: bool,
|
||||
p: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""PyTorch-native implementation of top-k and top-p sampling."""
|
||||
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p)
|
||||
logits = apply_top_k_top_p(logits, k, p)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
|
||||
@ -69,37 +67,33 @@ class TopKTopPSampler(nn.Module):
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: Dict[int, torch.Generator],
|
||||
no_top_k: bool,
|
||||
k: torch.Tensor,
|
||||
no_top_p: bool,
|
||||
p: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""More optimized implementation for top-k and top-p sampling."""
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
if no_top_k and no_top_p:
|
||||
if k is None and p is None:
|
||||
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
||||
# not needed. This is because `random_sample` does not require
|
||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||
return random_sample(probs, generators)
|
||||
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators)
|
||||
return flashinfer_sample(probs, k, p, generators)
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
no_top_k: bool,
|
||||
k: torch.Tensor,
|
||||
no_top_p: bool,
|
||||
p: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Apply top-k and top-p masks to the logits.
|
||||
|
||||
This function sorts the logits tensor, which can be slow for large batches.
|
||||
"""
|
||||
if no_top_k and no_top_p:
|
||||
if k is None and p is None:
|
||||
return logits
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
if not no_top_k:
|
||||
if k is not None:
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
||||
# Get all the top_k values.
|
||||
@ -107,7 +101,7 @@ def apply_top_k_top_p(
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
if not no_top_p:
|
||||
if p is not None:
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = probs_sort.cumsum(dim=-1)
|
||||
@ -147,10 +141,8 @@ def random_sample(
|
||||
|
||||
def flashinfer_sample(
|
||||
probs: torch.Tensor,
|
||||
no_top_k: bool,
|
||||
k: torch.Tensor,
|
||||
no_top_p: bool,
|
||||
p: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
generators: Dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Sample from the probabilities using FlashInfer.
|
||||
@ -167,7 +159,7 @@ def flashinfer_sample(
|
||||
does not. Call this function at the end of the forward pass to minimize
|
||||
the synchronization overhead.
|
||||
"""
|
||||
assert not (no_top_k and no_top_p)
|
||||
assert not (k is None and p is None)
|
||||
max_top_k_round = 32
|
||||
batch_size = probs.shape[0]
|
||||
uniform_samples = torch.empty((max_top_k_round, batch_size),
|
||||
@ -178,11 +170,11 @@ def flashinfer_sample(
|
||||
for i, generator in generators.items():
|
||||
uniform_samples[:, i].uniform_(generator=generator)
|
||||
|
||||
if no_top_k:
|
||||
if k is None:
|
||||
# Top-p only.
|
||||
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
|
||||
probs, uniform_samples, p, deterministic=True)
|
||||
elif no_top_p:
|
||||
elif p is None:
|
||||
# Top-k only.
|
||||
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
|
||||
probs, uniform_samples, k, deterministic=True)
|
||||
@ -194,9 +186,9 @@ def flashinfer_sample(
|
||||
|
||||
# NOTE: CPU-GPU synchronization happens here.
|
||||
if not success.all():
|
||||
if not no_top_k:
|
||||
if k is not None:
|
||||
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
|
||||
if not no_top_p:
|
||||
if p is not None:
|
||||
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
|
||||
next_token_ids = flashinfer.sampling.sampling_from_probs(
|
||||
probs, uniform_samples[0], deterministic=True)
|
||||
|
||||
@ -68,6 +68,7 @@ class RejectionSampler(nn.Module):
|
||||
# NOTE: The following input preparationg can be moved
|
||||
# to the model runner with a persistent manner for better
|
||||
# performance.
|
||||
assert sampling_metadata.spec_token_ids is not None
|
||||
spec_token_ids = sampling_metadata.spec_token_ids
|
||||
max_spec_len = max(len(s) for s in spec_token_ids)
|
||||
batch_size = len(spec_token_ids)
|
||||
@ -119,6 +120,7 @@ class RejectionSampler(nn.Module):
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
assert sampling_metadata.spec_token_ids is not None
|
||||
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
|
||||
# Add 1 to include the 'bonus' token.
|
||||
sample_lens = [x + 1 for x in spec_lens]
|
||||
|
||||
@ -26,7 +26,7 @@ class Sampler(nn.Module):
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
if sampling_metadata.rejection_sampling:
|
||||
if sampling_metadata.spec_token_ids:
|
||||
if sampling_metadata.max_num_logprobs:
|
||||
raise NotImplementedError(
|
||||
"Rejection sampling does not support logprobs.")
|
||||
@ -104,16 +104,14 @@ class Sampler(nn.Module):
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
# Apply min_p.
|
||||
if not sampling_metadata.no_min_p:
|
||||
if sampling_metadata.min_p is not None:
|
||||
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled = self.topk_topp_sampler(
|
||||
logits,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.no_top_k,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.no_top_p,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
@ -179,9 +177,10 @@ class Sampler(nn.Module):
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
|
||||
sampling_metadata.stop_token_ids,
|
||||
sampling_metadata.min_tokens)
|
||||
if sampling_metadata.min_tokens:
|
||||
apply_min_token_penalties(logits,
|
||||
sampling_metadata.output_token_ids,
|
||||
sampling_metadata.min_tokens)
|
||||
if not sampling_metadata.no_penalties:
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
logits = apply_all_penalties(
|
||||
|
||||
@ -188,3 +188,14 @@ def bind_kv_cache(
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
|
||||
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
|
||||
length: int) -> None:
|
||||
"""
|
||||
Copy the first length elements of a tensor into another tensor in a
|
||||
non-blocking manner.
|
||||
|
||||
Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
|
||||
"""
|
||||
to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Datastructures defining an input batch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -12,6 +11,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
@ -63,7 +63,7 @@ class InputBatch:
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.req_ids: List[Optional[str]] = [None] * max_num_reqs
|
||||
self._req_ids: List[Optional[str]] = []
|
||||
self.req_id_to_index: Dict[str, int] = {}
|
||||
|
||||
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
||||
@ -171,11 +171,8 @@ class InputBatch:
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: Set[str] = set()
|
||||
|
||||
self.min_tokens: List[int] = [0] * max_num_reqs
|
||||
self.stop_token_ids: List[Set[int]] = [
|
||||
set() for _ in range(max_num_reqs)
|
||||
]
|
||||
self.prompt_token_ids: Optional[torch.Tensor] = None
|
||||
# req_index -> (min_tokens, stop_token_ids)
|
||||
self.min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
||||
@ -196,6 +193,17 @@ class InputBatch:
|
||||
self.logit_bias: List[Optional[Dict[int,
|
||||
float]]] = [None] * max_num_reqs
|
||||
|
||||
self.req_output_token_ids: List[Optional[List[int]]] = []
|
||||
|
||||
# This is updated each time the batch constituents change.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
|
||||
@property
|
||||
def req_ids(self) -> List[str]:
|
||||
# None elements should only be present transiently
|
||||
# while performing state updates to the batch.
|
||||
return cast(List[str], self._req_ids)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
@ -206,7 +214,13 @@ class InputBatch:
|
||||
assert req_index < self.max_num_reqs
|
||||
|
||||
req_id = request.req_id
|
||||
self.req_ids[req_index] = req_id
|
||||
if req_index == len(self._req_ids):
|
||||
self._req_ids.append(req_id)
|
||||
self.req_output_token_ids.append(request.output_token_ids)
|
||||
else:
|
||||
self._req_ids[req_index] = req_id
|
||||
self.req_output_token_ids[req_index] = request.output_token_ids
|
||||
|
||||
self.req_id_to_index[req_id] = req_index
|
||||
|
||||
# Copy the prompt token ids and output token ids.
|
||||
@ -255,8 +269,9 @@ class InputBatch:
|
||||
req_index] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
self.min_tokens[req_index] = sampling_params.min_tokens
|
||||
self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids
|
||||
if sampling_params.min_tokens:
|
||||
self.min_tokens[req_index] = (sampling_params.min_tokens,
|
||||
sampling_params.all_stop_token_ids)
|
||||
|
||||
# NOTE(woosuk): self.generators should not include the requests that
|
||||
# do not have their own generator.
|
||||
@ -284,16 +299,20 @@ class InputBatch:
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
def remove_request(self, req_id: str) -> Optional[int]:
|
||||
"""This method must always be followed by a call to condense()."""
|
||||
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
if req_index is None:
|
||||
return None
|
||||
self.req_ids[req_index] = None
|
||||
self._req_ids[req_index] = None
|
||||
self.req_output_token_ids[req_index] = None
|
||||
|
||||
self.greedy_reqs.discard(req_id)
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.min_p_reqs.discard(req_id)
|
||||
self.min_tokens.pop(req_index, None)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
@ -313,33 +332,17 @@ class InputBatch:
|
||||
self.logit_bias[req_index] = None
|
||||
return req_index
|
||||
|
||||
def clear(self) -> None:
|
||||
self.req_ids = [None] * self.max_num_reqs
|
||||
self.req_id_to_index.clear()
|
||||
self.greedy_reqs.clear()
|
||||
self.random_reqs.clear()
|
||||
self.top_p_reqs.clear()
|
||||
self.top_k_reqs.clear()
|
||||
self.min_p_reqs.clear()
|
||||
self.frequency_penalties_reqs.clear()
|
||||
self.presence_penalties_reqs.clear()
|
||||
self.repetition_penalties_reqs.clear()
|
||||
self.generators.clear()
|
||||
self.num_logprobs.clear()
|
||||
self.num_prompt_logprobs.clear()
|
||||
self.request_lora_mapping.fill(0)
|
||||
self.lora_id_to_lora_request.clear()
|
||||
self.lora_id_to_request_ids.clear()
|
||||
self.logit_bias = [None] * self.max_num_reqs
|
||||
|
||||
def condense(self, empty_req_indices: List[int]) -> None:
|
||||
if self.num_reqs == 0:
|
||||
num_reqs = self.num_reqs
|
||||
if num_reqs == 0:
|
||||
# The batched states are empty.
|
||||
self._req_ids.clear()
|
||||
self.req_output_token_ids.clear()
|
||||
return
|
||||
|
||||
# NOTE(woosuk): This function assumes that the empty_req_indices
|
||||
# is sorted in descending order.
|
||||
last_req_index = self.num_reqs + len(empty_req_indices) - 1
|
||||
last_req_index = num_reqs + len(empty_req_indices) - 1
|
||||
while empty_req_indices:
|
||||
# Find the largest non-empty index.
|
||||
while last_req_index in empty_req_indices:
|
||||
@ -351,10 +354,13 @@ class InputBatch:
|
||||
break
|
||||
|
||||
# Swap the states.
|
||||
req_id = self.req_ids[last_req_index]
|
||||
req_id = self._req_ids[last_req_index]
|
||||
output_token_ids = self.req_output_token_ids[last_req_index]
|
||||
assert req_id is not None
|
||||
self.req_ids[empty_index] = req_id
|
||||
self.req_ids[last_req_index] = None
|
||||
self._req_ids[empty_index] = req_id
|
||||
self._req_ids[last_req_index] = None
|
||||
self.req_output_token_ids[empty_index] = output_token_ids
|
||||
self.req_output_token_ids[last_req_index] = None
|
||||
self.req_id_to_index[req_id] = empty_index
|
||||
|
||||
num_tokens = self.num_tokens[last_req_index]
|
||||
@ -379,13 +385,14 @@ class InputBatch:
|
||||
self.repetition_penalties_cpu[
|
||||
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
||||
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
|
||||
self.min_tokens[empty_index] = self.min_tokens[last_req_index]
|
||||
self.stop_token_ids[empty_index] = self.stop_token_ids[
|
||||
last_req_index]
|
||||
generator = self.generators.pop(last_req_index, None)
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
|
||||
min_token = self.min_tokens.pop(last_req_index, None)
|
||||
if min_token is not None:
|
||||
self.min_tokens[empty_index] = min_token
|
||||
|
||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||
last_req_index]
|
||||
|
||||
@ -394,87 +401,71 @@ class InputBatch:
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
req_id_output_token_ids: Dict[str, List[int]],
|
||||
req_id_to_spec_token_ids: Dict[str, List[int]],
|
||||
skip_copy: bool = False,
|
||||
) -> SamplingMetadata:
|
||||
if not skip_copy:
|
||||
self.temperature[:self.num_reqs].copy_(
|
||||
self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||
self.top_p[:self.num_reqs].copy_(
|
||||
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||
self.top_k[:self.num_reqs].copy_(
|
||||
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||
self.min_p[:self.num_reqs].copy_(
|
||||
self.min_p_cpu_tensor[:self.num_reqs], non_blocking=True)
|
||||
if not self.no_penalties:
|
||||
# Since syncing these tensors is expensive only copy them
|
||||
# if necessary i.e. if there are requests which require
|
||||
# penalties to be applied during sampling.
|
||||
self.frequency_penalties[:self.num_reqs].copy_(
|
||||
self.frequency_penalties_cpu_tensor[:self.num_reqs],
|
||||
non_blocking=True,
|
||||
)
|
||||
self.presence_penalties[:self.num_reqs].copy_(
|
||||
self.presence_penalties_cpu_tensor[:self.num_reqs],
|
||||
non_blocking=True,
|
||||
)
|
||||
self.repetition_penalties[:self.num_reqs].copy_(
|
||||
self.repetition_penalties_cpu_tensor[:self.num_reqs],
|
||||
non_blocking=True,
|
||||
)
|
||||
# The prompt tokens are used only for applying penalties during
|
||||
# the sampling process. Hence copy these tensors only when
|
||||
# there are requests which need penalties to be applied.
|
||||
self.prompt_token_ids = self._make_prompt_token_ids_tensor()
|
||||
# Trim lists to the batch size.
|
||||
del self._req_ids[self.num_reqs:]
|
||||
del self.req_output_token_ids[self.num_reqs:]
|
||||
|
||||
output_token_ids: List[List[int]] = []
|
||||
spec_token_ids: List[List[int]] = []
|
||||
rejection_sampling = False
|
||||
for req_id in self.req_ids[:self.num_reqs]:
|
||||
assert req_id is not None
|
||||
# Currently we create a tensor for output_token_ids from scratch
|
||||
# at each step. However, for the penalties computation what we
|
||||
# need is stats about the token ids present in the output. This
|
||||
# stats can be maintained incrementally instead of computing it
|
||||
# from scratch at each step.
|
||||
# TODO - Replace this with incremental update to output token
|
||||
# statistics.
|
||||
output_token_ids.append(req_id_output_token_ids[req_id])
|
||||
req_spec_token_ids = req_id_to_spec_token_ids.get(req_id, [])
|
||||
spec_token_ids.append(req_spec_token_ids)
|
||||
if req_spec_token_ids:
|
||||
# If any of the requests require speculative decoding, set the
|
||||
# flag to True.
|
||||
rejection_sampling = True
|
||||
def refresh_sampling_metadata(self):
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
|
||||
def _make_sampling_metadata(self) -> SamplingMetadata:
|
||||
num_reqs = self.num_reqs
|
||||
copy_slice(self.temperature_cpu_tensor, self.temperature, num_reqs)
|
||||
if not self.no_top_p:
|
||||
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
||||
if not self.no_top_k:
|
||||
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
|
||||
if not self.no_min_p:
|
||||
copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
|
||||
|
||||
if not self.no_penalties:
|
||||
# Since syncing these tensors is expensive only copy them
|
||||
# if necessary i.e. if there are requests which require
|
||||
# penalties to be applied during sampling.
|
||||
copy_slice(self.frequency_penalties_cpu_tensor,
|
||||
self.frequency_penalties, num_reqs)
|
||||
copy_slice(self.presence_penalties_cpu_tensor,
|
||||
self.presence_penalties, num_reqs)
|
||||
copy_slice(self.repetition_penalties_cpu_tensor,
|
||||
self.repetition_penalties, num_reqs)
|
||||
|
||||
# The prompt tokens are used only for applying penalties during
|
||||
# the sampling process. Hence copy these tensors only when
|
||||
# there are requests which need penalties to be applied.
|
||||
prompt_token_ids = self._make_prompt_token_ids_tensor()
|
||||
else:
|
||||
prompt_token_ids = None
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=self.temperature[:self.num_reqs],
|
||||
temperature=self.temperature[:num_reqs],
|
||||
all_greedy=self.all_greedy,
|
||||
all_random=self.all_random,
|
||||
rejection_sampling=rejection_sampling,
|
||||
top_p=self.top_p[:self.num_reqs],
|
||||
top_k=self.top_k[:self.num_reqs],
|
||||
min_p=self.min_p[:self.num_reqs],
|
||||
no_min_p=self.no_min_p,
|
||||
no_top_p=self.no_top_p,
|
||||
no_top_k=self.no_top_k,
|
||||
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
||||
top_k=None if self.no_top_k else self.top_k[:num_reqs],
|
||||
min_p=None if self.no_min_p else self.min_p[:num_reqs],
|
||||
generators=self.generators,
|
||||
max_num_logprobs=self.max_num_logprobs,
|
||||
prompt_token_ids=self.prompt_token_ids,
|
||||
frequency_penalties=self.frequency_penalties[:self.num_reqs],
|
||||
presence_penalties=self.presence_penalties[:self.num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:self.num_reqs],
|
||||
output_token_ids=output_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
min_tokens=self.min_tokens[:self.num_reqs],
|
||||
stop_token_ids=self.stop_token_ids[:self.num_reqs],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
frequency_penalties=self.frequency_penalties[:num_reqs],
|
||||
presence_penalties=self.presence_penalties[:num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:num_reqs],
|
||||
output_token_ids=cast(List[List[int]], self.req_output_token_ids),
|
||||
spec_token_ids=None,
|
||||
min_tokens=self.min_tokens,
|
||||
no_penalties=self.no_penalties,
|
||||
logit_bias=self.logit_bias[:self.num_reqs],
|
||||
logit_bias=self.logit_bias[:num_reqs],
|
||||
)
|
||||
|
||||
def get_sampling_metadata(
|
||||
self,
|
||||
req_id_to_spec_token_ids: Dict[str, List[int]],
|
||||
) -> SamplingMetadata:
|
||||
# Set the new spec token ids in the cached sampling metadata.
|
||||
self.sampling_metadata.spec_token_ids = [
|
||||
req_id_to_spec_token_ids.get(req_id, []) for req_id in self.req_ids
|
||||
] if req_id_to_spec_token_ids else None
|
||||
return self.sampling_metadata
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
||||
prompt_token_ids_cpu_tensor = torch.empty(
|
||||
|
||||
@ -31,7 +31,6 @@ from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
@ -224,16 +223,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
output.
|
||||
|
||||
The updated states are used by the `_prepare_inputs` function to create
|
||||
the input GPU tensors for the model.
|
||||
|
||||
Returns:
|
||||
True if there is a new/resumed/paused/finished request in the batch.
|
||||
If False, we can skip copying SamplingMetadata to the GPU.
|
||||
The SamplingMetadata is updated and copied to the GPU if there is a
|
||||
new/resumed/paused/finished request in the batch.
|
||||
"""
|
||||
# Remove finished requests from the cached states.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
@ -344,9 +342,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_new_tokens = (num_computed_tokens +
|
||||
len(req_data.new_token_ids) -
|
||||
req_state.num_tokens)
|
||||
new_token_ids = (req_data.new_token_ids[-num_new_tokens:]
|
||||
if num_new_tokens > 0 else [])
|
||||
req_state.output_token_ids.extend(new_token_ids)
|
||||
if num_new_tokens == 1:
|
||||
# Avoid slicing list in most common case.
|
||||
req_state.output_token_ids.append(req_data.new_token_ids[-1])
|
||||
elif num_new_tokens > 0:
|
||||
req_state.output_token_ids.extend(
|
||||
req_data.new_token_ids[-num_new_tokens:])
|
||||
# Update the block IDs.
|
||||
if not req_data.resumed_from_preemption:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
@ -380,7 +381,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
|
||||
# Add spec_token_ids to token_ids_cpu.
|
||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||
req_id, [])
|
||||
req_id, ())
|
||||
if spec_token_ids:
|
||||
start_index = end_token_index
|
||||
end_token_index += len(spec_token_ids)
|
||||
@ -410,7 +411,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
|
||||
return batch_changed
|
||||
if batch_changed:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
@ -429,8 +431,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# TODO: The Python loop can be slow. Optimize.
|
||||
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||
max_num_scheduled_tokens = 0
|
||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||
assert req_id is not None
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_scheduled_tokens[i] = num_tokens
|
||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||
@ -669,10 +670,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
||||
mrope_pos_ptr = 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
||||
assert req_id is not None
|
||||
|
||||
for index, req_id in enumerate(self.input_batch.req_ids):
|
||||
req = self.requests[req_id]
|
||||
assert req.mrope_positions is not None
|
||||
|
||||
@ -726,12 +724,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
cu_num_tokens: np.ndarray,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
# Get the number of spec decode tokens for each request.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||
assert req_id is not None
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
num_spec_decode_tokens[i] = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
|
||||
|
||||
@ -769,22 +766,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return torch.from_numpy(spec_decode_logits_indices).to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
def _prepare_sampling(
|
||||
self,
|
||||
batch_changed: bool,
|
||||
req_to_spec_token_ids: Dict[str, List[int]],
|
||||
) -> SamplingMetadata:
|
||||
# Create the sampling metadata.
|
||||
req_id_output_token_ids: Dict[str, List[int]] = \
|
||||
{req_id: req.output_token_ids \
|
||||
for req_id, req in self.requests.items()}
|
||||
|
||||
sampling_metadata = self.input_batch.make_sampling_metadata(
|
||||
req_id_output_token_ids,
|
||||
req_to_spec_token_ids,
|
||||
skip_copy=not batch_changed)
|
||||
return sampling_metadata
|
||||
|
||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||
if not scheduled_encoder_inputs:
|
||||
@ -838,9 +819,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> List[torch.Tensor]:
|
||||
encoder_outputs: List[torch.Tensor] = []
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||
assert req_id is not None
|
||||
for req_id in self.input_batch.req_ids:
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
req_state = self.requests[req_id]
|
||||
@ -882,7 +861,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
||||
batch_changed = self._update_states(scheduler_output)
|
||||
self._update_states(scheduler_output)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# Run the multimodal encoder if any.
|
||||
@ -964,8 +943,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self._prepare_sampling(
|
||||
batch_changed, scheduler_output.scheduled_spec_decode_tokens)
|
||||
sampling_metadata = self.input_batch.get_sampling_metadata(
|
||||
scheduler_output.scheduled_spec_decode_tokens)
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
@ -973,14 +952,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
req_ids: List[str] = []
|
||||
# Because `input_batch.req_ids` is a list of length `max_num_reqs`,
|
||||
# we need to stop at `num_reqs`.
|
||||
# FIXME(woosuk): This is hacky. Refactor.
|
||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||
assert req_id is not None
|
||||
req_ids.append(req_id)
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
@ -1027,7 +999,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
valid_sampled_token_ids)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
@ -1041,19 +1013,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
sampled_token_ids: List[List[int]],
|
||||
) -> List[List[int]]:
|
||||
# TODO(woosuk): Optimize.
|
||||
num_reqs = len(sampled_token_ids)
|
||||
draft_token_ids: List[List[int]] = []
|
||||
for i in range(num_reqs):
|
||||
if len(sampled_token_ids[i]) == 0:
|
||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||
num_sampled_ids = len(sampled_ids)
|
||||
if not num_sampled_ids:
|
||||
# Skip speculative decoding.
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
# Add sampled_token_ids to token_ids_cpu.
|
||||
start_idx = self.input_batch.num_tokens_no_spec[i]
|
||||
end_idx = start_idx + len(sampled_token_ids[i])
|
||||
self.input_batch.token_ids_cpu[
|
||||
i, start_idx:end_idx] = sampled_token_ids[i]
|
||||
end_idx = start_idx + num_sampled_ids
|
||||
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
|
||||
drafter_output = self.drafter.propose(
|
||||
self.input_batch.token_ids_cpu[i, :end_idx],
|
||||
self.speculative_config.ngram_prompt_lookup_min,
|
||||
@ -1204,7 +1175,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# multiplying the list, to avoid Dynamo from treating them as
|
||||
# tensor aliasing.
|
||||
dummy_kv_caches = [
|
||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||
torch.tensor((), dtype=torch.float32, device=self.device)
|
||||
for _ in range(self.num_attn_layers)
|
||||
]
|
||||
|
||||
|
||||
@ -1048,8 +1048,6 @@ def swap_positions(b: InputBatch, id_1, id_2):
|
||||
|
||||
b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[
|
||||
id_1]
|
||||
b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[
|
||||
id_2], b.stop_token_ids[id_1]
|
||||
|
||||
gen_1 = b.generators.pop(id_1, None)
|
||||
gen_2 = b.generators.pop(id_2, None)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user