[V1] Optimize handling of sampling metadata and req_ids list (#13244)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-02-18 12:15:33 -08:00 committed by GitHub
parent a4d577b379
commit 30172b4947
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 254 additions and 297 deletions

View File

@ -26,17 +26,13 @@ def create_logits_tensor(token_ids: List[int],
def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
batch_size = len(spec_tokens) batch_size = len(spec_tokens)
return SamplingMetadata( return SamplingMetadata(
temperature=0.0, temperature=torch.tensor([]),
all_greedy=True, all_greedy=True,
all_random=False, all_random=False,
rejection_sampling=True,
spec_token_ids=spec_tokens, spec_token_ids=spec_tokens,
top_p=None, top_p=None,
top_k=None, top_k=None,
no_top_p=False,
no_top_k=False,
min_p=torch.empty(batch_size, ), min_p=torch.empty(batch_size, ),
no_min_p=True,
generators={}, generators={},
max_num_logprobs=0, max_num_logprobs=0,
no_penalties=False, no_penalties=False,
@ -45,8 +41,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
presence_penalties=torch.tensor([]), presence_penalties=torch.tensor([]),
repetition_penalties=torch.tensor([]), repetition_penalties=torch.tensor([]),
output_token_ids=[], output_token_ids=[],
min_tokens=[], min_tokens={},
stop_token_ids=[],
logit_bias=[None] * batch_size, logit_bias=[None] * batch_size,
) )

View File

@ -77,25 +77,20 @@ def _create_default_sampling_metadata(
temperature=torch.full((batch_size, ), 0.0), temperature=torch.full((batch_size, ), 0.0),
all_greedy=True, all_greedy=True,
all_random=False, all_random=False,
rejection_sampling=False, top_p=None,
top_p=torch.empty(batch_size, ), top_k=None,
top_k=torch.empty(batch_size, ), min_p=None,
no_top_p=True,
no_top_k=True,
min_p=torch.empty(batch_size, ),
no_min_p=True,
generators={}, generators={},
max_num_logprobs=0, max_num_logprobs=0,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device), vocab_size, device),
output_token_ids=output_token_ids, output_token_ids=output_token_ids,
spec_token_ids=[], spec_token_ids=None,
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
no_penalties=True, no_penalties=True,
min_tokens=[], min_tokens={},
stop_token_ids=[],
logit_bias=[None] * batch_size, logit_bias=[None] * batch_size,
) )
return fake_sampling_metadata return fake_sampling_metadata
@ -104,10 +99,10 @@ def _create_default_sampling_metadata(
def _generate_min_token_penalties_and_stop_tokens( def _generate_min_token_penalties_and_stop_tokens(
num_output_tokens: int, batch_size: int, vocab_size: int, num_output_tokens: int, batch_size: int, vocab_size: int,
batch_indices_for_min_token_penalty: List[int] batch_indices_for_min_token_penalty: List[int]
) -> Tuple[List[int], List[Set[int]]]: ) -> Dict[int, Tuple[int, Set[int]]]:
""" """
Generates and returns a list of minimum token penalties (`min_tokens`) Generates and returns a dict of minimum token penalties and
and a corresponding list of stop token IDs (`stop_token_ids`) for each corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
batch. batch.
If a batch index is included in `batch_indices_for_min_token_penalty`, If a batch index is included in `batch_indices_for_min_token_penalty`,
@ -115,22 +110,19 @@ def _generate_min_token_penalties_and_stop_tokens(
and a random set of stop token IDs is created. Otherwise, a lower and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty. `min_tokens` value is assigned, and the stop token IDs set is empty.
""" """
stop_token_ids: List[Set[int]] = [] min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
min_tokens: List[int] = []
for index in range(batch_size): for index in range(batch_size):
if index in batch_indices_for_min_token_penalty: if index in batch_indices_for_min_token_penalty:
min_tokens.append( min_tokens[index] = (
np.random.randint(num_output_tokens + 1, np.random.randint(num_output_tokens + 1,
2 * num_output_tokens)) 2 * num_output_tokens),
stop_token_ids.append(
set( set(
np.random.randint(0, vocab_size - 1) np.random.randint(0, vocab_size - 1)
for _ in range(np.random.randint(0, vocab_size)))) for _ in range(np.random.randint(0, vocab_size))))
else: else:
min_tokens.append(np.random.randint(0, num_output_tokens)) min_tokens[index] = (np.random.randint(0,
stop_token_ids.append(set()) num_output_tokens), set())
return (min_tokens, stop_token_ids) return min_tokens
def _create_weighted_output_token_list( def _create_weighted_output_token_list(
@ -165,7 +157,7 @@ def _create_weighted_output_token_list(
output_token_ids_for_batch.extend( output_token_ids_for_batch.extend(
[token_id for _ in range(index + 1)]) [token_id for _ in range(index + 1)])
output_token_ids.append(output_token_ids_for_batch) output_token_ids.append(output_token_ids_for_batch)
return (output_token_ids, sorted_token_ids_in_output) return output_token_ids, sorted_token_ids_in_output
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@ -182,17 +174,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
batch_indices_for_min_token_penalty = np.random.randint( batch_indices_for_min_token_penalty = np.random.randint(
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist() 0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens( min_tokens = _generate_min_token_penalties_and_stop_tokens(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
batch_indices_for_min_token_penalty) batch_indices_for_min_token_penalty)
sampling_metadata.min_tokens = min_tokens sampling_metadata.min_tokens = min_tokens
sampling_metadata.stop_token_ids = stop_token_ids
sampler = Sampler() sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata) logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu() logits = logits.cpu()
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
for token_id in range(VOCAB_SIZE): for token_id in range(VOCAB_SIZE):
if token_id in stop_token_ids[batch_idx]: _, stop_token_ids = min_tokens.get(batch_idx, (0, set()))
if token_id in stop_token_ids:
assert logits[batch_idx][token_id] == -float("inf") assert logits[batch_idx][token_id] == -float("inf")
else: else:
assert logits[batch_idx][token_id] != -float("inf") assert logits[batch_idx][token_id] != -float("inf")

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
import numpy as np import numpy as np
import pytest import pytest
@ -41,7 +41,7 @@ def _remove_requests(
for index in req_indices_to_remove: for index in req_indices_to_remove:
input_batch.remove_request(reqs[index].req_id) input_batch.remove_request(reqs[index].req_id)
req_ids_to_remove.add(reqs[index].req_id) req_ids_to_remove.add(reqs[index].req_id)
return (req_ids_to_remove, req_indices_to_remove_list) return req_ids_to_remove, req_indices_to_remove_list
def _construct_expected_sampling_metadata( def _construct_expected_sampling_metadata(
@ -64,8 +64,7 @@ def _construct_expected_sampling_metadata(
top_p = [0.0 for _ in range(num_reqs)] top_p = [0.0 for _ in range(num_reqs)]
min_p = [0.0 for _ in range(num_reqs)] min_p = [0.0 for _ in range(num_reqs)]
temperature = [0.0 for _ in range(num_reqs)] temperature = [0.0 for _ in range(num_reqs)]
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)] min_tokens = {}
min_tokens = [0 for _ in range(num_reqs)]
logit_bias = [None] * num_reqs logit_bias = [None] * num_reqs
for req in reqs: for req in reqs:
if req.req_id not in req_ids_retained: if req.req_id not in req_ids_retained:
@ -83,22 +82,21 @@ def _construct_expected_sampling_metadata(
top_p[index_in_input_batch] = req.sampling_params.top_p top_p[index_in_input_batch] = req.sampling_params.top_p
min_p[index_in_input_batch] = req.sampling_params.min_p min_p[index_in_input_batch] = req.sampling_params.min_p
temperature[index_in_input_batch] = req.sampling_params.temperature temperature[index_in_input_batch] = req.sampling_params.temperature
stop_token_ids[ min_tokens[index_in_input_batch] = (
index_in_input_batch] = req.sampling_params.all_stop_token_ids req.sampling_params.min_tokens,
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens req.sampling_params.all_stop_token_ids)
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
return SamplingMetadata( return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float, temperature=torch.tensor(temperature, dtype=torch.float,
device=device), device=device),
all_greedy=False, all_greedy=False,
all_random=True, all_random=True,
rejection_sampling=False, top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
top_p=torch.tensor(top_p, dtype=torch.float, device=device), top_p, dtype=torch.float, device=device),
top_k=torch.tensor(top_k, dtype=torch.int, device=device), top_k=None if all(x == 0 for x in top_k) else torch.tensor(
no_top_p=all(x == 1.0 for x in top_p), top_k, dtype=torch.int, device=device),
no_top_k=all(x == 0 for x in top_k), min_p=None if all(x == 0.0 for x in min_p) else torch.tensor(
min_p=torch.tensor(min_p, dtype=torch.float, device=device), min_p, dtype=torch.float, device=device),
no_min_p=all(x == 0.0 for x in min_p),
generators={}, generators={},
max_num_logprobs=0, max_num_logprobs=0,
prompt_token_ids=make_tensor_with_pad( prompt_token_ids=make_tensor_with_pad(
@ -117,9 +115,8 @@ def _construct_expected_sampling_metadata(
dtype=torch.float, dtype=torch.float,
device=device), device=device),
output_token_ids=output_token_ids, output_token_ids=output_token_ids,
spec_token_ids=[], spec_token_ids=None,
min_tokens=min_tokens, min_tokens=min_tokens,
stop_token_ids=stop_token_ids,
no_penalties=(all(x == 0 for x in presence_penalties) no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties) and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)), and all(x == 1 for x in repetition_penalties)),
@ -206,8 +203,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch.condense(req_indices_to_remove) input_batch.condense(req_indices_to_remove)
# Generate the sampling metadata # Generate the sampling metadata
sampling_metadata = input_batch.make_sampling_metadata( sampling_metadata = input_batch._make_sampling_metadata()
req_id_output_token_ids, req_id_to_spec_token_ids={}, skip_copy=False)
# Create expected output. # Create expected output.
expected_sampling_metadata = _construct_expected_sampling_metadata( expected_sampling_metadata = _construct_expected_sampling_metadata(
@ -216,13 +212,16 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch.req_id_to_index, input_batch.req_id_to_index,
device=torch.device(device)) device=torch.device(device))
def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
return (t1 is None
and t2 is None) or (t1 is not None and t2 is not None
and torch.allclose(t1, t2))
# Assert the actual and expected output. # Assert the actual and expected output.
assert torch.allclose(expected_sampling_metadata.temperature, assert torch.allclose(expected_sampling_metadata.temperature,
sampling_metadata.temperature) sampling_metadata.temperature)
assert torch.allclose(expected_sampling_metadata.top_p, assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
sampling_metadata.top_p) assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
assert torch.allclose(expected_sampling_metadata.top_k,
sampling_metadata.top_k)
assert torch.allclose( assert torch.allclose(
expected_sampling_metadata.frequency_penalties, expected_sampling_metadata.frequency_penalties,
sampling_metadata.frequency_penalties, sampling_metadata.frequency_penalties,
@ -240,10 +239,6 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
assert (expected_sampling_metadata.output_token_ids == assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids) sampling_metadata.output_token_ids)
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
assert expected_sampling_metadata.stop_token_ids == \
sampling_metadata.stop_token_ids
assert expected_sampling_metadata.no_penalties == \ assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties sampling_metadata.no_penalties
assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p
assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias

View File

@ -5,6 +5,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData, from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
SchedulerOutput) SchedulerOutput)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
@ -82,14 +83,21 @@ def _is_req_added(model_runner, req_id: str) -> bool:
return req_id in model_runner.requests return req_id in model_runner.requests
def _is_sampling_metadata_changed(model_runner,
sampling_metadata_before: SamplingMetadata):
return model_runner.input_batch.sampling_metadata is not (
sampling_metadata_before)
def test_update_states_new_request(model_runner): def test_update_states_new_request(model_runner):
req_id = "req_0" req_id = "req_0"
# new req # new req
scheduler_output = _schedule_new_request(req_id) scheduler_output = _schedule_new_request(req_id)
batch_changed = model_runner._update_states(scheduler_output) metadata_before = model_runner.input_batch.sampling_metadata
assert batch_changed is True model_runner._update_states(scheduler_output)
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id) assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id)
@ -117,8 +125,9 @@ def test_update_states_request_finished(model_runner):
free_encoder_input_ids=[], free_encoder_input_ids=[],
) )
batch_changed = model_runner._update_states(scheduler_output) metadata_before = model_runner.input_batch.sampling_metadata
assert batch_changed is True model_runner._update_states(scheduler_output)
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert not _is_req_added(model_runner, req_id) assert not _is_req_added(model_runner, req_id)
assert not _is_req_scheduled(model_runner, req_id) assert not _is_req_scheduled(model_runner, req_id)
@ -142,7 +151,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids={}, finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_input_ids=[],
) )
@ -171,8 +180,9 @@ def test_update_states_request_resumed(model_runner):
free_encoder_input_ids=[], free_encoder_input_ids=[],
) )
batch_changed = model_runner._update_states(scheduler_output) metadata_before = model_runner.input_batch.sampling_metadata
assert batch_changed is True model_runner._update_states(scheduler_output)
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id) assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id)
@ -200,8 +210,9 @@ def test_update_states_no_changes(model_runner):
free_encoder_input_ids=[], free_encoder_input_ids=[],
) )
batch_changed = model_runner._update_states(scheduler_output) metadata_before = model_runner.input_batch.sampling_metadata
assert batch_changed is False model_runner._update_states(scheduler_output)
assert not _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id) assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id)
@ -233,8 +244,8 @@ def test_update_states_request_unscheduled(model_runner):
free_encoder_input_ids=[], free_encoder_input_ids=[],
) )
batch_changed = model_runner._update_states(scheduler_output) metadata_before = model_runner._update_states(scheduler_output)
assert batch_changed is True assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_ids[0]) assert _is_req_added(model_runner, req_ids[0])
assert _is_req_scheduled(model_runner, req_ids[0]) assert _is_req_scheduled(model_runner, req_ids[0])

View File

@ -45,7 +45,7 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
vocab_size, num_seqs) vocab_size, num_seqs)
output_bin_counts, output_mask = get_token_bin_counts_and_mask( output_bin_counts, output_mask = get_token_bin_counts_and_mask(
output_tokens_tensor, vocab_size, num_seqs) output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat( repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, vocab_size) 1, vocab_size)
logits[logits > 0] /= torch.where(prompt_mask | output_mask, logits[logits > 0] /= torch.where(prompt_mask | output_mask,
repetition_penalties, 1.0)[logits > 0] repetition_penalties, 1.0)[logits > 0]
@ -53,6 +53,6 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
repetition_penalties, 1.0)[logits <= 0] repetition_penalties, 1.0)[logits <= 0]
# We follow the definition in OpenAI API. # We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details # Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask logits -= presence_penalties.unsqueeze(dim=1) * output_mask
return logits return logits

View File

@ -195,8 +195,10 @@ class Scheduler:
request.num_computed_tokens - request.num_computed_tokens -
request.num_tokens) request.num_tokens)
if num_scheduled_spec_tokens > 0: if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = ( scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids[:num_scheduled_spec_tokens]) request.spec_token_ids)
# Encoder-related. # Encoder-related.
if encoder_inputs_to_schedule: if encoder_inputs_to_schedule:
@ -567,7 +569,7 @@ class Scheduler:
outputs.append( outputs.append(
EngineCoreOutput( EngineCoreOutput(
request_id=req_id, request_id=req_id,
new_token_ids=new_token_ids or [], new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(), finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs, new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors, new_prompt_logprobs_tensors=prompt_logprobs_tensors,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Set from typing import Dict, List, Optional, Set, Tuple
import torch import torch
@ -12,15 +12,13 @@ class SamplingMetadata:
temperature: torch.Tensor temperature: torch.Tensor
all_greedy: bool all_greedy: bool
all_random: bool all_random: bool
rejection_sampling: bool
spec_token_ids: List[List[int]]
top_p: torch.Tensor # None when there are no speculated tokens.
top_k: torch.Tensor spec_token_ids: Optional[List[List[int]]]
no_top_p: bool
no_top_k: bool top_p: Optional[torch.Tensor]
min_p: torch.Tensor top_k: Optional[torch.Tensor]
no_min_p: bool min_p: Optional[torch.Tensor]
generators: Dict[int, torch.Generator] generators: Dict[int, torch.Generator]
@ -34,7 +32,8 @@ class SamplingMetadata:
repetition_penalties: torch.Tensor repetition_penalties: torch.Tensor
output_token_ids: List[List[int]] output_token_ids: List[List[int]]
min_tokens: List[int]
stop_token_ids: List[Set[int]] # req_index -> (min_tokens, stop_token_ids)
min_tokens: Dict[int, Tuple[int, Set[int]]]
logit_bias: List[Optional[Dict[int, float]]] logit_bias: List[Optional[Dict[int, float]]]

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Set, Tuple from typing import Dict, List, Set, Tuple
import torch import torch
@ -8,18 +8,17 @@ from vllm.model_executor.layers.utils import apply_penalties
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available, make_tensor_with_pad
def apply_min_token_penalties(logits: torch.Tensor, def apply_min_token_penalties(
output_token_ids: List[List[int]], logits: torch.Tensor, output_token_ids: List[List[int]],
stop_token_ids: List[Set[int]], min_tokens: Dict[int, Tuple[int, Set[int]]]) -> None:
min_tokens: List[int]) -> None:
""" """
Applies minimum token penalty by setting the logits of the stop tokens Applies minimum token penalty by setting the logits of the stop tokens
to -inf. to -inf.
""" """
min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
for index, min_token in enumerate(min_tokens): for index, (min_token, stop_token_ids) in min_tokens.items():
if len(output_token_ids[index]) < min_token: if len(output_token_ids[index]) < min_token:
for stop_token_id in stop_token_ids[index]: for stop_token_id in stop_token_ids:
min_tokens_logits_to_penalize.append((index, stop_token_id)) min_tokens_logits_to_penalize.append((index, stop_token_id))
if min_tokens_logits_to_penalize: if min_tokens_logits_to_penalize:
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Dict from typing import Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -55,13 +55,11 @@ class TopKTopPSampler(nn.Module):
self, self,
logits: torch.Tensor, logits: torch.Tensor,
generators: Dict[int, torch.Generator], generators: Dict[int, torch.Generator],
no_top_k: bool, k: Optional[torch.Tensor],
k: torch.Tensor, p: Optional[torch.Tensor],
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""PyTorch-native implementation of top-k and top-p sampling.""" """PyTorch-native implementation of top-k and top-p sampling."""
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p) logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators) return random_sample(probs, generators)
@ -69,37 +67,33 @@ class TopKTopPSampler(nn.Module):
self, self,
logits: torch.Tensor, logits: torch.Tensor,
generators: Dict[int, torch.Generator], generators: Dict[int, torch.Generator],
no_top_k: bool, k: Optional[torch.Tensor],
k: torch.Tensor, p: Optional[torch.Tensor],
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling.""" """More optimized implementation for top-k and top-p sampling."""
probs = logits.softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32)
if no_top_k and no_top_p: if k is None and p is None:
# We prefer `random_sample` over `flashinfer_sample` when sorting is # We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require # not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does. # CPU-GPU synchronization while `flashinfer_sample` does.
return random_sample(probs, generators) return random_sample(probs, generators)
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators) return flashinfer_sample(probs, k, p, generators)
def apply_top_k_top_p( def apply_top_k_top_p(
logits: torch.Tensor, logits: torch.Tensor,
no_top_k: bool, k: Optional[torch.Tensor],
k: torch.Tensor, p: Optional[torch.Tensor],
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits. """Apply top-k and top-p masks to the logits.
This function sorts the logits tensor, which can be slow for large batches. This function sorts the logits tensor, which can be slow for large batches.
""" """
if no_top_k and no_top_p: if k is None and p is None:
return logits return logits
logits_sort, logits_idx = logits.sort(dim=-1, descending=False) logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if not no_top_k: if k is not None:
# Apply top-k. # Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long) top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values. # Get all the top_k values.
@ -107,7 +101,7 @@ def apply_top_k_top_p(
top_k_mask = logits_sort < top_k_mask top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf")) logits_sort.masked_fill_(top_k_mask, -float("inf"))
if not no_top_p: if p is not None:
# Apply top-p. # Apply top-p.
probs_sort = logits_sort.softmax(dim=-1) probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1) probs_sum = probs_sort.cumsum(dim=-1)
@ -147,10 +141,8 @@ def random_sample(
def flashinfer_sample( def flashinfer_sample(
probs: torch.Tensor, probs: torch.Tensor,
no_top_k: bool, k: Optional[torch.Tensor],
k: torch.Tensor, p: Optional[torch.Tensor],
no_top_p: bool,
p: torch.Tensor,
generators: Dict[int, torch.Generator], generators: Dict[int, torch.Generator],
) -> torch.Tensor: ) -> torch.Tensor:
"""Sample from the probabilities using FlashInfer. """Sample from the probabilities using FlashInfer.
@ -167,7 +159,7 @@ def flashinfer_sample(
does not. Call this function at the end of the forward pass to minimize does not. Call this function at the end of the forward pass to minimize
the synchronization overhead. the synchronization overhead.
""" """
assert not (no_top_k and no_top_p) assert not (k is None and p is None)
max_top_k_round = 32 max_top_k_round = 32
batch_size = probs.shape[0] batch_size = probs.shape[0]
uniform_samples = torch.empty((max_top_k_round, batch_size), uniform_samples = torch.empty((max_top_k_round, batch_size),
@ -178,11 +170,11 @@ def flashinfer_sample(
for i, generator in generators.items(): for i, generator in generators.items():
uniform_samples[:, i].uniform_(generator=generator) uniform_samples[:, i].uniform_(generator=generator)
if no_top_k: if k is None:
# Top-p only. # Top-p only.
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
probs, uniform_samples, p, deterministic=True) probs, uniform_samples, p, deterministic=True)
elif no_top_p: elif p is None:
# Top-k only. # Top-k only.
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs( next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
probs, uniform_samples, k, deterministic=True) probs, uniform_samples, k, deterministic=True)
@ -194,9 +186,9 @@ def flashinfer_sample(
# NOTE: CPU-GPU synchronization happens here. # NOTE: CPU-GPU synchronization happens here.
if not success.all(): if not success.all():
if not no_top_k: if k is not None:
probs = flashinfer.sampling.top_k_renorm_prob(probs, k) probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
if not no_top_p: if p is not None:
probs = flashinfer.sampling.top_p_renorm_prob(probs, p) probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
next_token_ids = flashinfer.sampling.sampling_from_probs( next_token_ids = flashinfer.sampling.sampling_from_probs(
probs, uniform_samples[0], deterministic=True) probs, uniform_samples[0], deterministic=True)

View File

@ -68,6 +68,7 @@ class RejectionSampler(nn.Module):
# NOTE: The following input preparationg can be moved # NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better # to the model runner with a persistent manner for better
# performance. # performance.
assert sampling_metadata.spec_token_ids is not None
spec_token_ids = sampling_metadata.spec_token_ids spec_token_ids = sampling_metadata.spec_token_ids
max_spec_len = max(len(s) for s in spec_token_ids) max_spec_len = max(len(s) for s in spec_token_ids)
batch_size = len(spec_token_ids) batch_size = len(spec_token_ids)
@ -119,6 +120,7 @@ class RejectionSampler(nn.Module):
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
assert sampling_metadata.spec_token_ids is not None
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids] spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
# Add 1 to include the 'bonus' token. # Add 1 to include the 'bonus' token.
sample_lens = [x + 1 for x in spec_lens] sample_lens = [x + 1 for x in spec_lens]

View File

@ -26,7 +26,7 @@ class Sampler(nn.Module):
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
if sampling_metadata.rejection_sampling: if sampling_metadata.spec_token_ids:
if sampling_metadata.max_num_logprobs: if sampling_metadata.max_num_logprobs:
raise NotImplementedError( raise NotImplementedError(
"Rejection sampling does not support logprobs.") "Rejection sampling does not support logprobs.")
@ -104,16 +104,14 @@ class Sampler(nn.Module):
logits = self.apply_temperature(logits, sampling_metadata.temperature) logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Apply min_p. # Apply min_p.
if not sampling_metadata.no_min_p: if sampling_metadata.min_p is not None:
logits = self.apply_min_p(logits, sampling_metadata.min_p) logits = self.apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p. # Apply top_k and/or top_p.
random_sampled = self.topk_topp_sampler( random_sampled = self.topk_topp_sampler(
logits, logits,
sampling_metadata.generators, sampling_metadata.generators,
sampling_metadata.no_top_k,
sampling_metadata.top_k, sampling_metadata.top_k,
sampling_metadata.no_top_p,
sampling_metadata.top_p, sampling_metadata.top_p,
) )
@ -179,8 +177,9 @@ class Sampler(nn.Module):
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
apply_min_token_penalties(logits, sampling_metadata.output_token_ids, if sampling_metadata.min_tokens:
sampling_metadata.stop_token_ids, apply_min_token_penalties(logits,
sampling_metadata.output_token_ids,
sampling_metadata.min_tokens) sampling_metadata.min_tokens)
if not sampling_metadata.no_penalties: if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None assert sampling_metadata.prompt_token_ids is not None

View File

@ -188,3 +188,14 @@ def bind_kv_cache(
for layer_name, kv_cache in kv_caches.items(): for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine. # NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache] forward_context[layer_name].kv_cache = [kv_cache]
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
length: int) -> None:
"""
Copy the first length elements of a tensor into another tensor in a
non-blocking manner.
Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
"""
to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)

View File

@ -1,9 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# Datastructures defining an input batch # Datastructures defining an input batch
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast
import numpy as np import numpy as np
import torch import torch
@ -12,6 +11,7 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
@ -63,7 +63,7 @@ class InputBatch:
self.pin_memory = pin_memory self.pin_memory = pin_memory
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.req_ids: List[Optional[str]] = [None] * max_num_reqs self._req_ids: List[Optional[str]] = []
self.req_id_to_index: Dict[str, int] = {} self.req_id_to_index: Dict[str, int] = {}
# TODO(woosuk): This buffer could be too large if max_model_len is big. # TODO(woosuk): This buffer could be too large if max_model_len is big.
@ -171,11 +171,8 @@ class InputBatch:
self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: Set[str] = set() self.repetition_penalties_reqs: Set[str] = set()
self.min_tokens: List[int] = [0] * max_num_reqs # req_index -> (min_tokens, stop_token_ids)
self.stop_token_ids: List[Set[int]] = [ self.min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
set() for _ in range(max_num_reqs)
]
self.prompt_token_ids: Optional[torch.Tensor] = None
# lora related # lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ), self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
@ -196,6 +193,17 @@ class InputBatch:
self.logit_bias: List[Optional[Dict[int, self.logit_bias: List[Optional[Dict[int,
float]]] = [None] * max_num_reqs float]]] = [None] * max_num_reqs
self.req_output_token_ids: List[Optional[List[int]]] = []
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
@property
def req_ids(self) -> List[str]:
# None elements should only be present transiently
# while performing state updates to the batch.
return cast(List[str], self._req_ids)
def add_request( def add_request(
self, self,
request: "CachedRequestState", request: "CachedRequestState",
@ -206,7 +214,13 @@ class InputBatch:
assert req_index < self.max_num_reqs assert req_index < self.max_num_reqs
req_id = request.req_id req_id = request.req_id
self.req_ids[req_index] = req_id if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.req_id_to_index[req_id] = req_index self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids. # Copy the prompt token ids and output token ids.
@ -255,8 +269,9 @@ class InputBatch:
req_index] = sampling_params.repetition_penalty req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0: if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id) self.repetition_penalties_reqs.add(req_id)
self.min_tokens[req_index] = sampling_params.min_tokens if sampling_params.min_tokens:
self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids self.min_tokens[req_index] = (sampling_params.min_tokens,
sampling_params.all_stop_token_ids)
# NOTE(woosuk): self.generators should not include the requests that # NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator. # do not have their own generator.
@ -284,16 +299,20 @@ class InputBatch:
self.request_lora_mapping[req_index] = 0 self.request_lora_mapping[req_index] = 0
def remove_request(self, req_id: str) -> Optional[int]: def remove_request(self, req_id: str) -> Optional[int]:
"""This method must always be followed by a call to condense()."""
req_index = self.req_id_to_index.pop(req_id, None) req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None: if req_index is None:
return None return None
self.req_ids[req_index] = None self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
self.greedy_reqs.discard(req_id) self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id) self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id) self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id) self.top_k_reqs.discard(req_id)
self.min_p_reqs.discard(req_id) self.min_p_reqs.discard(req_id)
self.min_tokens.pop(req_index, None)
self.frequency_penalties_reqs.discard(req_id) self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id) self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id) self.repetition_penalties_reqs.discard(req_id)
@ -313,33 +332,17 @@ class InputBatch:
self.logit_bias[req_index] = None self.logit_bias[req_index] = None
return req_index return req_index
def clear(self) -> None:
self.req_ids = [None] * self.max_num_reqs
self.req_id_to_index.clear()
self.greedy_reqs.clear()
self.random_reqs.clear()
self.top_p_reqs.clear()
self.top_k_reqs.clear()
self.min_p_reqs.clear()
self.frequency_penalties_reqs.clear()
self.presence_penalties_reqs.clear()
self.repetition_penalties_reqs.clear()
self.generators.clear()
self.num_logprobs.clear()
self.num_prompt_logprobs.clear()
self.request_lora_mapping.fill(0)
self.lora_id_to_lora_request.clear()
self.lora_id_to_request_ids.clear()
self.logit_bias = [None] * self.max_num_reqs
def condense(self, empty_req_indices: List[int]) -> None: def condense(self, empty_req_indices: List[int]) -> None:
if self.num_reqs == 0: num_reqs = self.num_reqs
if num_reqs == 0:
# The batched states are empty. # The batched states are empty.
self._req_ids.clear()
self.req_output_token_ids.clear()
return return
# NOTE(woosuk): This function assumes that the empty_req_indices # NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order. # is sorted in descending order.
last_req_index = self.num_reqs + len(empty_req_indices) - 1 last_req_index = num_reqs + len(empty_req_indices) - 1
while empty_req_indices: while empty_req_indices:
# Find the largest non-empty index. # Find the largest non-empty index.
while last_req_index in empty_req_indices: while last_req_index in empty_req_indices:
@ -351,10 +354,13 @@ class InputBatch:
break break
# Swap the states. # Swap the states.
req_id = self.req_ids[last_req_index] req_id = self._req_ids[last_req_index]
output_token_ids = self.req_output_token_ids[last_req_index]
assert req_id is not None assert req_id is not None
self.req_ids[empty_index] = req_id self._req_ids[empty_index] = req_id
self.req_ids[last_req_index] = None self._req_ids[last_req_index] = None
self.req_output_token_ids[empty_index] = output_token_ids
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index self.req_id_to_index[req_id] = empty_index
num_tokens = self.num_tokens[last_req_index] num_tokens = self.num_tokens[last_req_index]
@ -379,13 +385,14 @@ class InputBatch:
self.repetition_penalties_cpu[ self.repetition_penalties_cpu[
empty_index] = self.repetition_penalties_cpu[last_req_index] empty_index] = self.repetition_penalties_cpu[last_req_index]
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index] self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
self.min_tokens[empty_index] = self.min_tokens[last_req_index]
self.stop_token_ids[empty_index] = self.stop_token_ids[
last_req_index]
generator = self.generators.pop(last_req_index, None) generator = self.generators.pop(last_req_index, None)
if generator is not None: if generator is not None:
self.generators[empty_index] = generator self.generators[empty_index] = generator
min_token = self.min_tokens.pop(last_req_index, None)
if min_token is not None:
self.min_tokens[empty_index] = min_token
self.request_lora_mapping[empty_index] = self.request_lora_mapping[ self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index] last_req_index]
@ -394,87 +401,71 @@ class InputBatch:
# Decrement last_req_index since it is now empty. # Decrement last_req_index since it is now empty.
last_req_index -= 1 last_req_index -= 1
def make_sampling_metadata( # Trim lists to the batch size.
self, del self._req_ids[self.num_reqs:]
req_id_output_token_ids: Dict[str, List[int]], del self.req_output_token_ids[self.num_reqs:]
req_id_to_spec_token_ids: Dict[str, List[int]],
skip_copy: bool = False, def refresh_sampling_metadata(self):
) -> SamplingMetadata: self.sampling_metadata = self._make_sampling_metadata()
if not skip_copy:
self.temperature[:self.num_reqs].copy_( def _make_sampling_metadata(self) -> SamplingMetadata:
self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) num_reqs = self.num_reqs
self.top_p[:self.num_reqs].copy_( copy_slice(self.temperature_cpu_tensor, self.temperature, num_reqs)
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) if not self.no_top_p:
self.top_k[:self.num_reqs].copy_( copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) if not self.no_top_k:
self.min_p[:self.num_reqs].copy_( copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
self.min_p_cpu_tensor[:self.num_reqs], non_blocking=True) if not self.no_min_p:
copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
if not self.no_penalties: if not self.no_penalties:
# Since syncing these tensors is expensive only copy them # Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require # if necessary i.e. if there are requests which require
# penalties to be applied during sampling. # penalties to be applied during sampling.
self.frequency_penalties[:self.num_reqs].copy_( copy_slice(self.frequency_penalties_cpu_tensor,
self.frequency_penalties_cpu_tensor[:self.num_reqs], self.frequency_penalties, num_reqs)
non_blocking=True, copy_slice(self.presence_penalties_cpu_tensor,
) self.presence_penalties, num_reqs)
self.presence_penalties[:self.num_reqs].copy_( copy_slice(self.repetition_penalties_cpu_tensor,
self.presence_penalties_cpu_tensor[:self.num_reqs], self.repetition_penalties, num_reqs)
non_blocking=True,
)
self.repetition_penalties[:self.num_reqs].copy_(
self.repetition_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True,
)
# The prompt tokens are used only for applying penalties during # The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when # the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied. # there are requests which need penalties to be applied.
self.prompt_token_ids = self._make_prompt_token_ids_tensor() prompt_token_ids = self._make_prompt_token_ids_tensor()
else:
output_token_ids: List[List[int]] = [] prompt_token_ids = None
spec_token_ids: List[List[int]] = []
rejection_sampling = False
for req_id in self.req_ids[:self.num_reqs]:
assert req_id is not None
# Currently we create a tensor for output_token_ids from scratch
# at each step. However, for the penalties computation what we
# need is stats about the token ids present in the output. This
# stats can be maintained incrementally instead of computing it
# from scratch at each step.
# TODO - Replace this with incremental update to output token
# statistics.
output_token_ids.append(req_id_output_token_ids[req_id])
req_spec_token_ids = req_id_to_spec_token_ids.get(req_id, [])
spec_token_ids.append(req_spec_token_ids)
if req_spec_token_ids:
# If any of the requests require speculative decoding, set the
# flag to True.
rejection_sampling = True
return SamplingMetadata( return SamplingMetadata(
temperature=self.temperature[:self.num_reqs], temperature=self.temperature[:num_reqs],
all_greedy=self.all_greedy, all_greedy=self.all_greedy,
all_random=self.all_random, all_random=self.all_random,
rejection_sampling=rejection_sampling, top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_p=self.top_p[:self.num_reqs], top_k=None if self.no_top_k else self.top_k[:num_reqs],
top_k=self.top_k[:self.num_reqs], min_p=None if self.no_min_p else self.min_p[:num_reqs],
min_p=self.min_p[:self.num_reqs],
no_min_p=self.no_min_p,
no_top_p=self.no_top_p,
no_top_k=self.no_top_k,
generators=self.generators, generators=self.generators,
max_num_logprobs=self.max_num_logprobs, max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=self.prompt_token_ids, prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:self.num_reqs], frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:self.num_reqs], presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:self.num_reqs], repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=output_token_ids, output_token_ids=cast(List[List[int]], self.req_output_token_ids),
spec_token_ids=spec_token_ids, spec_token_ids=None,
min_tokens=self.min_tokens[:self.num_reqs], min_tokens=self.min_tokens,
stop_token_ids=self.stop_token_ids[:self.num_reqs],
no_penalties=self.no_penalties, no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:self.num_reqs], logit_bias=self.logit_bias[:num_reqs],
) )
def get_sampling_metadata(
self,
req_id_to_spec_token_ids: Dict[str, List[int]],
) -> SamplingMetadata:
# Set the new spec token ids in the cached sampling metadata.
self.sampling_metadata.spec_token_ids = [
req_id_to_spec_token_ids.get(req_id, []) for req_id in self.req_ids
] if req_id_to_spec_token_ids else None
return self.sampling_metadata
def _make_prompt_token_ids_tensor(self) -> torch.Tensor: def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty( prompt_token_ids_cpu_tensor = torch.empty(

View File

@ -31,7 +31,6 @@ from vllm.v1.engine.mm_input_cache import MMInputCacheClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
@ -224,16 +223,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pin_memory=self.pin_memory) pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy() self.seq_lens_np = self.seq_lens_cpu.numpy()
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler """Update the cached states and the persistent batch with the scheduler
output. output.
The updated states are used by the `_prepare_inputs` function to create The updated states are used by the `_prepare_inputs` function to create
the input GPU tensors for the model. the input GPU tensors for the model.
Returns: The SamplingMetadata is updated and copied to the GPU if there is a
True if there is a new/resumed/paused/finished request in the batch. new/resumed/paused/finished request in the batch.
If False, we can skip copying SamplingMetadata to the GPU.
""" """
# Remove finished requests from the cached states. # Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
@ -344,9 +342,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_new_tokens = (num_computed_tokens + num_new_tokens = (num_computed_tokens +
len(req_data.new_token_ids) - len(req_data.new_token_ids) -
req_state.num_tokens) req_state.num_tokens)
new_token_ids = (req_data.new_token_ids[-num_new_tokens:] if num_new_tokens == 1:
if num_new_tokens > 0 else []) # Avoid slicing list in most common case.
req_state.output_token_ids.extend(new_token_ids) req_state.output_token_ids.append(req_data.new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
req_data.new_token_ids[-num_new_tokens:])
# Update the block IDs. # Update the block IDs.
if not req_data.resumed_from_preemption: if not req_data.resumed_from_preemption:
# Append the new blocks to the existing block IDs. # Append the new blocks to the existing block IDs.
@ -380,7 +381,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.num_tokens_no_spec[req_index] = end_token_index self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu. # Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, []) req_id, ())
if spec_token_ids: if spec_token_ids:
start_index = end_token_index start_index = end_token_index
end_token_index += len(spec_token_ids) end_token_index += len(spec_token_ids)
@ -410,7 +411,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if removed_req_indices: if removed_req_indices:
self.input_batch.condense(removed_req_indices) self.input_batch.condense(removed_req_indices)
return batch_changed if batch_changed:
self.input_batch.refresh_sampling_metadata()
def _prepare_inputs( def _prepare_inputs(
self, self,
@ -429,8 +431,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO: The Python loop can be slow. Optimize. # TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
max_num_scheduled_tokens = 0 max_num_scheduled_tokens = 0
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): for i, req_id in enumerate(self.input_batch.req_ids):
assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens[i] = num_tokens num_scheduled_tokens[i] = num_tokens
max_num_scheduled_tokens = max(max_num_scheduled_tokens, max_num_scheduled_tokens = max(max_num_scheduled_tokens,
@ -669,10 +670,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0 mrope_pos_ptr = 0
num_reqs = self.input_batch.num_reqs for index, req_id in enumerate(self.input_batch.req_ids):
for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
assert req_id is not None
req = self.requests[req_id] req = self.requests[req_id]
assert req.mrope_positions is not None assert req.mrope_positions is not None
@ -726,12 +724,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
cu_num_tokens: np.ndarray, cu_num_tokens: np.ndarray,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
# Get the number of spec decode tokens for each request. # Get the number of spec decode tokens for each request.
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32) num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32)
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): for i, req_id in enumerate(self.input_batch.req_ids):
assert req_id is not None
num_spec_decode_tokens[i] = len( num_spec_decode_tokens[i] = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
@ -769,22 +766,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return torch.from_numpy(spec_decode_logits_indices).to( return torch.from_numpy(spec_decode_logits_indices).to(
self.device, non_blocking=True) self.device, non_blocking=True)
def _prepare_sampling(
self,
batch_changed: bool,
req_to_spec_token_ids: Dict[str, List[int]],
) -> SamplingMetadata:
# Create the sampling metadata.
req_id_output_token_ids: Dict[str, List[int]] = \
{req_id: req.output_token_ids \
for req_id, req in self.requests.items()}
sampling_metadata = self.input_batch.make_sampling_metadata(
req_id_output_token_ids,
req_to_spec_token_ids,
skip_copy=not batch_changed)
return sampling_metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"): def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs: if not scheduled_encoder_inputs:
@ -838,9 +819,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
encoder_outputs: List[torch.Tensor] = [] encoder_outputs: List[torch.Tensor] = []
num_reqs = self.input_batch.num_reqs for req_id in self.input_batch.req_ids:
for req_id in self.input_batch.req_ids[:num_reqs]:
assert req_id is not None
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id] req_id]
req_state = self.requests[req_id] req_state = self.requests[req_id]
@ -882,7 +861,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]: ) -> Union[ModelRunnerOutput, torch.Tensor]:
batch_changed = self._update_states(scheduler_output) self._update_states(scheduler_output)
if self.is_multimodal_model: if self.is_multimodal_model:
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
@ -964,8 +943,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
# Sample the next token and get logprobs if needed. # Sample the next token and get logprobs if needed.
sampling_metadata = self._prepare_sampling( sampling_metadata = self.input_batch.get_sampling_metadata(
batch_changed, scheduler_output.scheduled_spec_decode_tokens) scheduler_output.scheduled_spec_decode_tokens)
sampler_output = self.model.sample( sampler_output = self.model.sample(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
@ -973,14 +952,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): The following loop can be slow since it iterates over # TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize. # the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs for i, req_id in enumerate(self.input_batch.req_ids):
req_ids: List[str] = []
# Because `input_batch.req_ids` is a list of length `max_num_reqs`,
# we need to stop at `num_reqs`.
# FIXME(woosuk): This is hacky. Refactor.
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
req_ids.append(req_id)
req_state = self.requests[req_id] req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens + seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id]) scheduler_output.num_scheduled_tokens[req_id])
@ -1027,7 +999,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
valid_sampled_token_ids) valid_sampled_token_ids)
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(
req_ids=req_ids, req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids, sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids, spec_token_ids=spec_token_ids,
@ -1041,19 +1013,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampled_token_ids: List[List[int]], sampled_token_ids: List[List[int]],
) -> List[List[int]]: ) -> List[List[int]]:
# TODO(woosuk): Optimize. # TODO(woosuk): Optimize.
num_reqs = len(sampled_token_ids)
draft_token_ids: List[List[int]] = [] draft_token_ids: List[List[int]] = []
for i in range(num_reqs): for i, sampled_ids in enumerate(sampled_token_ids):
if len(sampled_token_ids[i]) == 0: num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding. # Skip speculative decoding.
draft_token_ids.append([]) draft_token_ids.append([])
continue continue
# Add sampled_token_ids to token_ids_cpu. # Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i] start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + len(sampled_token_ids[i]) end_idx = start_idx + num_sampled_ids
self.input_batch.token_ids_cpu[ self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
i, start_idx:end_idx] = sampled_token_ids[i]
drafter_output = self.drafter.propose( drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx], self.input_batch.token_ids_cpu[i, :end_idx],
self.speculative_config.ngram_prompt_lookup_min, self.speculative_config.ngram_prompt_lookup_min,
@ -1204,7 +1175,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# multiplying the list, to avoid Dynamo from treating them as # multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing. # tensor aliasing.
dummy_kv_caches = [ dummy_kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device) torch.tensor((), dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers) for _ in range(self.num_attn_layers)
] ]

View File

@ -1048,8 +1048,6 @@ def swap_positions(b: InputBatch, id_1, id_2):
b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[ b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[
id_1] id_1]
b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[
id_2], b.stop_token_ids[id_1]
gen_1 = b.generators.pop(id_1, None) gen_1 = b.generators.pop(id_1, None)
gen_2 = b.generators.pop(id_2, None) gen_2 = b.generators.pop(id_2, None)