diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 8bc33e84194c..3e810e525e1c 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -26,17 +26,13 @@ def create_logits_tensor(token_ids: List[int], def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: batch_size = len(spec_tokens) return SamplingMetadata( - temperature=0.0, + temperature=torch.tensor([]), all_greedy=True, all_random=False, - rejection_sampling=True, spec_token_ids=spec_tokens, top_p=None, top_k=None, - no_top_p=False, - no_top_k=False, min_p=torch.empty(batch_size, ), - no_min_p=True, generators={}, max_num_logprobs=0, no_penalties=False, @@ -45,8 +41,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata: presence_penalties=torch.tensor([]), repetition_penalties=torch.tensor([]), output_token_ids=[], - min_tokens=[], - stop_token_ids=[], + min_tokens={}, logit_bias=[None] * batch_size, ) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index a4bd651f8224..3f6301c54267 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -77,25 +77,20 @@ def _create_default_sampling_metadata( temperature=torch.full((batch_size, ), 0.0), all_greedy=True, all_random=False, - rejection_sampling=False, - top_p=torch.empty(batch_size, ), - top_k=torch.empty(batch_size, ), - no_top_p=True, - no_top_k=True, - min_p=torch.empty(batch_size, ), - no_min_p=True, + top_p=None, + top_k=None, + min_p=None, generators={}, max_num_logprobs=0, prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, vocab_size, device), output_token_ids=output_token_ids, - spec_token_ids=[], + spec_token_ids=None, frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), no_penalties=True, - min_tokens=[], - stop_token_ids=[], + min_tokens={}, logit_bias=[None] * batch_size, ) return fake_sampling_metadata @@ -104,10 +99,10 @@ def _create_default_sampling_metadata( def _generate_min_token_penalties_and_stop_tokens( num_output_tokens: int, batch_size: int, vocab_size: int, batch_indices_for_min_token_penalty: List[int] -) -> Tuple[List[int], List[Set[int]]]: +) -> Dict[int, Tuple[int, Set[int]]]: """ - Generates and returns a list of minimum token penalties (`min_tokens`) - and a corresponding list of stop token IDs (`stop_token_ids`) for each + Generates and returns a dict of minimum token penalties and + corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each batch. If a batch index is included in `batch_indices_for_min_token_penalty`, @@ -115,22 +110,19 @@ def _generate_min_token_penalties_and_stop_tokens( and a random set of stop token IDs is created. Otherwise, a lower `min_tokens` value is assigned, and the stop token IDs set is empty. """ - stop_token_ids: List[Set[int]] = [] - min_tokens: List[int] = [] + min_tokens: Dict[int, Tuple[int, Set[int]]] = {} for index in range(batch_size): if index in batch_indices_for_min_token_penalty: - min_tokens.append( + min_tokens[index] = ( np.random.randint(num_output_tokens + 1, - 2 * num_output_tokens)) - stop_token_ids.append( + 2 * num_output_tokens), set( np.random.randint(0, vocab_size - 1) for _ in range(np.random.randint(0, vocab_size)))) - else: - min_tokens.append(np.random.randint(0, num_output_tokens)) - stop_token_ids.append(set()) - return (min_tokens, stop_token_ids) + min_tokens[index] = (np.random.randint(0, + num_output_tokens), set()) + return min_tokens def _create_weighted_output_token_list( @@ -165,7 +157,7 @@ def _create_weighted_output_token_list( output_token_ids_for_batch.extend( [token_id for _ in range(index + 1)]) output_token_ids.append(output_token_ids_for_batch) - return (output_token_ids, sorted_token_ids_in_output) + return output_token_ids, sorted_token_ids_in_output @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -182,17 +174,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int): NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) batch_indices_for_min_token_penalty = np.random.randint( 0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist() - min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens( + min_tokens = _generate_min_token_penalties_and_stop_tokens( NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, batch_indices_for_min_token_penalty) sampling_metadata.min_tokens = min_tokens - sampling_metadata.stop_token_ids = stop_token_ids sampler = Sampler() logits = sampler.apply_penalties(fake_logits, sampling_metadata) logits = logits.cpu() for batch_idx in range(batch_size): for token_id in range(VOCAB_SIZE): - if token_id in stop_token_ids[batch_idx]: + _, stop_token_ids = min_tokens.get(batch_idx, (0, set())) + if token_id in stop_token_ids: assert logits[batch_idx][token_id] == -float("inf") else: assert logits[batch_idx][token_id] != -float("inf") diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index c0ab356f5c93..cb3b3d21fbb3 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple import numpy as np import pytest @@ -41,7 +41,7 @@ def _remove_requests( for index in req_indices_to_remove: input_batch.remove_request(reqs[index].req_id) req_ids_to_remove.add(reqs[index].req_id) - return (req_ids_to_remove, req_indices_to_remove_list) + return req_ids_to_remove, req_indices_to_remove_list def _construct_expected_sampling_metadata( @@ -64,8 +64,7 @@ def _construct_expected_sampling_metadata( top_p = [0.0 for _ in range(num_reqs)] min_p = [0.0 for _ in range(num_reqs)] temperature = [0.0 for _ in range(num_reqs)] - stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)] - min_tokens = [0 for _ in range(num_reqs)] + min_tokens = {} logit_bias = [None] * num_reqs for req in reqs: if req.req_id not in req_ids_retained: @@ -83,22 +82,21 @@ def _construct_expected_sampling_metadata( top_p[index_in_input_batch] = req.sampling_params.top_p min_p[index_in_input_batch] = req.sampling_params.min_p temperature[index_in_input_batch] = req.sampling_params.temperature - stop_token_ids[ - index_in_input_batch] = req.sampling_params.all_stop_token_ids - min_tokens[index_in_input_batch] = req.sampling_params.min_tokens + min_tokens[index_in_input_batch] = ( + req.sampling_params.min_tokens, + req.sampling_params.all_stop_token_ids) logit_bias[index_in_input_batch] = req.sampling_params.logit_bias return SamplingMetadata( temperature=torch.tensor(temperature, dtype=torch.float, device=device), all_greedy=False, all_random=True, - rejection_sampling=False, - top_p=torch.tensor(top_p, dtype=torch.float, device=device), - top_k=torch.tensor(top_k, dtype=torch.int, device=device), - no_top_p=all(x == 1.0 for x in top_p), - no_top_k=all(x == 0 for x in top_k), - min_p=torch.tensor(min_p, dtype=torch.float, device=device), - no_min_p=all(x == 0.0 for x in min_p), + top_p=None if all(x == 1.0 for x in top_p) else torch.tensor( + top_p, dtype=torch.float, device=device), + top_k=None if all(x == 0 for x in top_k) else torch.tensor( + top_k, dtype=torch.int, device=device), + min_p=None if all(x == 0.0 for x in min_p) else torch.tensor( + min_p, dtype=torch.float, device=device), generators={}, max_num_logprobs=0, prompt_token_ids=make_tensor_with_pad( @@ -117,9 +115,8 @@ def _construct_expected_sampling_metadata( dtype=torch.float, device=device), output_token_ids=output_token_ids, - spec_token_ids=[], + spec_token_ids=None, min_tokens=min_tokens, - stop_token_ids=stop_token_ids, no_penalties=(all(x == 0 for x in presence_penalties) and all(x == 0 for x in frequency_penalties) and all(x == 1 for x in repetition_penalties)), @@ -206,8 +203,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): input_batch.condense(req_indices_to_remove) # Generate the sampling metadata - sampling_metadata = input_batch.make_sampling_metadata( - req_id_output_token_ids, req_id_to_spec_token_ids={}, skip_copy=False) + sampling_metadata = input_batch._make_sampling_metadata() # Create expected output. expected_sampling_metadata = _construct_expected_sampling_metadata( @@ -216,13 +212,16 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): input_batch.req_id_to_index, device=torch.device(device)) + def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool: + return (t1 is None + and t2 is None) or (t1 is not None and t2 is not None + and torch.allclose(t1, t2)) + # Assert the actual and expected output. assert torch.allclose(expected_sampling_metadata.temperature, sampling_metadata.temperature) - assert torch.allclose(expected_sampling_metadata.top_p, - sampling_metadata.top_p) - assert torch.allclose(expected_sampling_metadata.top_k, - sampling_metadata.top_k) + assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p) + assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k) assert torch.allclose( expected_sampling_metadata.frequency_penalties, sampling_metadata.frequency_penalties, @@ -240,10 +239,6 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): assert (expected_sampling_metadata.output_token_ids == sampling_metadata.output_token_ids) assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens - assert expected_sampling_metadata.stop_token_ids == \ - sampling_metadata.stop_token_ids assert expected_sampling_metadata.no_penalties == \ sampling_metadata.no_penalties - assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p - assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index c655b0fded6e..973efcbf8e50 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -5,6 +5,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.sampling_params import SamplingParams from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData, SchedulerOutput) +from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -82,14 +83,21 @@ def _is_req_added(model_runner, req_id: str) -> bool: return req_id in model_runner.requests +def _is_sampling_metadata_changed(model_runner, + sampling_metadata_before: SamplingMetadata): + return model_runner.input_batch.sampling_metadata is not ( + sampling_metadata_before) + + def test_update_states_new_request(model_runner): req_id = "req_0" # new req scheduler_output = _schedule_new_request(req_id) - batch_changed = model_runner._update_states(scheduler_output) - assert batch_changed is True + metadata_before = model_runner.input_batch.sampling_metadata + model_runner._update_states(scheduler_output) + assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) @@ -117,8 +125,9 @@ def test_update_states_request_finished(model_runner): free_encoder_input_ids=[], ) - batch_changed = model_runner._update_states(scheduler_output) - assert batch_changed is True + metadata_before = model_runner.input_batch.sampling_metadata + model_runner._update_states(scheduler_output) + assert _is_sampling_metadata_changed(model_runner, metadata_before) assert not _is_req_added(model_runner, req_id) assert not _is_req_scheduled(model_runner, req_id) @@ -142,7 +151,7 @@ def test_update_states_request_resumed(model_runner): scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, - finished_req_ids={}, + finished_req_ids=set(), free_encoder_input_ids=[], ) @@ -171,8 +180,9 @@ def test_update_states_request_resumed(model_runner): free_encoder_input_ids=[], ) - batch_changed = model_runner._update_states(scheduler_output) - assert batch_changed is True + metadata_before = model_runner.input_batch.sampling_metadata + model_runner._update_states(scheduler_output) + assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) @@ -200,8 +210,9 @@ def test_update_states_no_changes(model_runner): free_encoder_input_ids=[], ) - batch_changed = model_runner._update_states(scheduler_output) - assert batch_changed is False + metadata_before = model_runner.input_batch.sampling_metadata + model_runner._update_states(scheduler_output) + assert not _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) @@ -233,8 +244,8 @@ def test_update_states_request_unscheduled(model_runner): free_encoder_input_ids=[], ) - batch_changed = model_runner._update_states(scheduler_output) - assert batch_changed is True + metadata_before = model_runner._update_states(scheduler_output) + assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_ids[0]) assert _is_req_scheduled(model_runner, req_ids[0]) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index dfe71028c1bc..a9ef973917e1 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -45,7 +45,7 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, vocab_size, num_seqs) output_bin_counts, output_mask = get_token_bin_counts_and_mask( output_tokens_tensor, vocab_size, num_seqs) - repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat( + repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( 1, vocab_size) logits[logits > 0] /= torch.where(prompt_mask | output_mask, repetition_penalties, 1.0)[logits > 0] @@ -53,6 +53,6 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, repetition_penalties, 1.0)[logits <= 0] # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts - logits -= presence_penalties.unsqueeze_(dim=1) * output_mask + logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 8f10834251c1..535aa644c53c 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -195,8 +195,10 @@ class Scheduler: request.num_computed_tokens - request.num_tokens) if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids[:num_scheduled_spec_tokens]) + request.spec_token_ids) # Encoder-related. if encoder_inputs_to_schedule: @@ -567,7 +569,7 @@ class Scheduler: outputs.append( EngineCoreOutput( request_id=req_id, - new_token_ids=new_token_ids or [], + new_token_ids=new_token_ids, finish_reason=request.get_finished_reason(), new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index ea64181c0aeb..2184a1866ff5 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple import torch @@ -12,15 +12,13 @@ class SamplingMetadata: temperature: torch.Tensor all_greedy: bool all_random: bool - rejection_sampling: bool - spec_token_ids: List[List[int]] - top_p: torch.Tensor - top_k: torch.Tensor - no_top_p: bool - no_top_k: bool - min_p: torch.Tensor - no_min_p: bool + # None when there are no speculated tokens. + spec_token_ids: Optional[List[List[int]]] + + top_p: Optional[torch.Tensor] + top_k: Optional[torch.Tensor] + min_p: Optional[torch.Tensor] generators: Dict[int, torch.Generator] @@ -34,7 +32,8 @@ class SamplingMetadata: repetition_penalties: torch.Tensor output_token_ids: List[List[int]] - min_tokens: List[int] - stop_token_ids: List[Set[int]] + + # req_index -> (min_tokens, stop_token_ids) + min_tokens: Dict[int, Tuple[int, Set[int]]] logit_bias: List[Optional[Dict[int, float]]] diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py index ba368b44ab9c..8d9f6529fa0b 100644 --- a/vllm/v1/sample/ops/penalties.py +++ b/vllm/v1/sample/ops/penalties.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Set, Tuple +from typing import Dict, List, Set, Tuple import torch @@ -8,18 +8,17 @@ from vllm.model_executor.layers.utils import apply_penalties from vllm.utils import is_pin_memory_available, make_tensor_with_pad -def apply_min_token_penalties(logits: torch.Tensor, - output_token_ids: List[List[int]], - stop_token_ids: List[Set[int]], - min_tokens: List[int]) -> None: +def apply_min_token_penalties( + logits: torch.Tensor, output_token_ids: List[List[int]], + min_tokens: Dict[int, Tuple[int, Set[int]]]) -> None: """ Applies minimum token penalty by setting the logits of the stop tokens to -inf. """ min_tokens_logits_to_penalize: List[Tuple[int, int]] = [] - for index, min_token in enumerate(min_tokens): + for index, (min_token, stop_token_ids) in min_tokens.items(): if len(output_token_ids[index]) < min_token: - for stop_token_id in stop_token_ids[index]: + for stop_token_id in stop_token_ids: min_tokens_logits_to_penalize.append((index, stop_token_id)) if min_tokens_logits_to_penalize: logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 27431001e3e7..78c88ad8b830 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict +from typing import Dict, Optional import torch import torch.nn as nn @@ -55,13 +55,11 @@ class TopKTopPSampler(nn.Module): self, logits: torch.Tensor, generators: Dict[int, torch.Generator], - no_top_k: bool, - k: torch.Tensor, - no_top_p: bool, - p: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], ) -> torch.Tensor: """PyTorch-native implementation of top-k and top-p sampling.""" - logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p) + logits = apply_top_k_top_p(logits, k, p) probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators) @@ -69,37 +67,33 @@ class TopKTopPSampler(nn.Module): self, logits: torch.Tensor, generators: Dict[int, torch.Generator], - no_top_k: bool, - k: torch.Tensor, - no_top_p: bool, - p: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], ) -> torch.Tensor: """More optimized implementation for top-k and top-p sampling.""" probs = logits.softmax(dim=-1, dtype=torch.float32) - if no_top_k and no_top_p: + if k is None and p is None: # We prefer `random_sample` over `flashinfer_sample` when sorting is # not needed. This is because `random_sample` does not require # CPU-GPU synchronization while `flashinfer_sample` does. return random_sample(probs, generators) - return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators) + return flashinfer_sample(probs, k, p, generators) def apply_top_k_top_p( logits: torch.Tensor, - no_top_k: bool, - k: torch.Tensor, - no_top_p: bool, - p: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], ) -> torch.Tensor: """Apply top-k and top-p masks to the logits. This function sorts the logits tensor, which can be slow for large batches. """ - if no_top_k and no_top_p: + if k is None and p is None: return logits logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - if not no_top_k: + if k is not None: # Apply top-k. top_k_mask = logits_sort.size(1) - k.to(torch.long) # Get all the top_k values. @@ -107,7 +101,7 @@ def apply_top_k_top_p( top_k_mask = logits_sort < top_k_mask logits_sort.masked_fill_(top_k_mask, -float("inf")) - if not no_top_p: + if p is not None: # Apply top-p. probs_sort = logits_sort.softmax(dim=-1) probs_sum = probs_sort.cumsum(dim=-1) @@ -147,10 +141,8 @@ def random_sample( def flashinfer_sample( probs: torch.Tensor, - no_top_k: bool, - k: torch.Tensor, - no_top_p: bool, - p: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], generators: Dict[int, torch.Generator], ) -> torch.Tensor: """Sample from the probabilities using FlashInfer. @@ -167,7 +159,7 @@ def flashinfer_sample( does not. Call this function at the end of the forward pass to minimize the synchronization overhead. """ - assert not (no_top_k and no_top_p) + assert not (k is None and p is None) max_top_k_round = 32 batch_size = probs.shape[0] uniform_samples = torch.empty((max_top_k_round, batch_size), @@ -178,11 +170,11 @@ def flashinfer_sample( for i, generator in generators.items(): uniform_samples[:, i].uniform_(generator=generator) - if no_top_k: + if k is None: # Top-p only. next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs( probs, uniform_samples, p, deterministic=True) - elif no_top_p: + elif p is None: # Top-k only. next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs( probs, uniform_samples, k, deterministic=True) @@ -194,9 +186,9 @@ def flashinfer_sample( # NOTE: CPU-GPU synchronization happens here. if not success.all(): - if not no_top_k: + if k is not None: probs = flashinfer.sampling.top_k_renorm_prob(probs, k) - if not no_top_p: + if p is not None: probs = flashinfer.sampling.top_p_renorm_prob(probs, p) next_token_ids = flashinfer.sampling.sampling_from_probs( probs, uniform_samples[0], deterministic=True) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index df1da8930211..580ad44297aa 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -68,6 +68,7 @@ class RejectionSampler(nn.Module): # NOTE: The following input preparationg can be moved # to the model runner with a persistent manner for better # performance. + assert sampling_metadata.spec_token_ids is not None spec_token_ids = sampling_metadata.spec_token_ids max_spec_len = max(len(s) for s in spec_token_ids) batch_size = len(spec_token_ids) @@ -119,6 +120,7 @@ class RejectionSampler(nn.Module): logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: + assert sampling_metadata.spec_token_ids is not None spec_lens = [len(x) for x in sampling_metadata.spec_token_ids] # Add 1 to include the 'bonus' token. sample_lens = [x + 1 for x in spec_lens] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index ec6374d12b17..8e2533eefab0 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -26,7 +26,7 @@ class Sampler(nn.Module): logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - if sampling_metadata.rejection_sampling: + if sampling_metadata.spec_token_ids: if sampling_metadata.max_num_logprobs: raise NotImplementedError( "Rejection sampling does not support logprobs.") @@ -104,16 +104,14 @@ class Sampler(nn.Module): logits = self.apply_temperature(logits, sampling_metadata.temperature) # Apply min_p. - if not sampling_metadata.no_min_p: + if sampling_metadata.min_p is not None: logits = self.apply_min_p(logits, sampling_metadata.min_p) # Apply top_k and/or top_p. random_sampled = self.topk_topp_sampler( logits, sampling_metadata.generators, - sampling_metadata.no_top_k, sampling_metadata.top_k, - sampling_metadata.no_top_p, sampling_metadata.top_p, ) @@ -179,9 +177,10 @@ class Sampler(nn.Module): logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - apply_min_token_penalties(logits, sampling_metadata.output_token_ids, - sampling_metadata.stop_token_ids, - sampling_metadata.min_tokens) + if sampling_metadata.min_tokens: + apply_min_token_penalties(logits, + sampling_metadata.output_token_ids, + sampling_metadata.min_tokens) if not sampling_metadata.no_penalties: assert sampling_metadata.prompt_token_ids is not None logits = apply_all_penalties( diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 5494542c181d..5be465014242 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -188,3 +188,14 @@ def bind_kv_cache( for layer_name, kv_cache in kv_caches.items(): # NOTE: Use list because of v0 PP virtual engine. forward_context[layer_name].kv_cache = [kv_cache] + + +def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, + length: int) -> None: + """ + Copy the first length elements of a tensor into another tensor in a + non-blocking manner. + + Used to copy pinned CPU tensor data to pre-allocated GPU tensors. + """ + to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index cb7411a44e2f..ccafc325b53f 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 - # Datastructures defining an input batch from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast import numpy as np import torch @@ -12,6 +11,7 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.utils import copy_slice from vllm.v1.worker.block_table import BlockTable _SAMPLING_EPS = 1e-5 @@ -63,7 +63,7 @@ class InputBatch: self.pin_memory = pin_memory self.vocab_size = vocab_size - self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self._req_ids: List[Optional[str]] = [] self.req_id_to_index: Dict[str, int] = {} # TODO(woosuk): This buffer could be too large if max_model_len is big. @@ -171,11 +171,8 @@ class InputBatch: self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: Set[str] = set() - self.min_tokens: List[int] = [0] * max_num_reqs - self.stop_token_ids: List[Set[int]] = [ - set() for _ in range(max_num_reqs) - ] - self.prompt_token_ids: Optional[torch.Tensor] = None + # req_index -> (min_tokens, stop_token_ids) + self.min_tokens: Dict[int, Tuple[int, Set[int]]] = {} # lora related self.request_lora_mapping = np.zeros((self.max_num_reqs, ), @@ -196,6 +193,17 @@ class InputBatch: self.logit_bias: List[Optional[Dict[int, float]]] = [None] * max_num_reqs + self.req_output_token_ids: List[Optional[List[int]]] = [] + + # This is updated each time the batch constituents change. + self.sampling_metadata = self._make_sampling_metadata() + + @property + def req_ids(self) -> List[str]: + # None elements should only be present transiently + # while performing state updates to the batch. + return cast(List[str], self._req_ids) + def add_request( self, request: "CachedRequestState", @@ -206,7 +214,13 @@ class InputBatch: assert req_index < self.max_num_reqs req_id = request.req_id - self.req_ids[req_index] = req_id + if req_index == len(self._req_ids): + self._req_ids.append(req_id) + self.req_output_token_ids.append(request.output_token_ids) + else: + self._req_ids[req_index] = req_id + self.req_output_token_ids[req_index] = request.output_token_ids + self.req_id_to_index[req_id] = req_index # Copy the prompt token ids and output token ids. @@ -255,8 +269,9 @@ class InputBatch: req_index] = sampling_params.repetition_penalty if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) - self.min_tokens[req_index] = sampling_params.min_tokens - self.stop_token_ids[req_index] = sampling_params.all_stop_token_ids + if sampling_params.min_tokens: + self.min_tokens[req_index] = (sampling_params.min_tokens, + sampling_params.all_stop_token_ids) # NOTE(woosuk): self.generators should not include the requests that # do not have their own generator. @@ -284,16 +299,20 @@ class InputBatch: self.request_lora_mapping[req_index] = 0 def remove_request(self, req_id: str) -> Optional[int]: + """This method must always be followed by a call to condense().""" + req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: return None - self.req_ids[req_index] = None + self._req_ids[req_index] = None + self.req_output_token_ids[req_index] = None self.greedy_reqs.discard(req_id) self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) self.top_k_reqs.discard(req_id) self.min_p_reqs.discard(req_id) + self.min_tokens.pop(req_index, None) self.frequency_penalties_reqs.discard(req_id) self.presence_penalties_reqs.discard(req_id) self.repetition_penalties_reqs.discard(req_id) @@ -313,33 +332,17 @@ class InputBatch: self.logit_bias[req_index] = None return req_index - def clear(self) -> None: - self.req_ids = [None] * self.max_num_reqs - self.req_id_to_index.clear() - self.greedy_reqs.clear() - self.random_reqs.clear() - self.top_p_reqs.clear() - self.top_k_reqs.clear() - self.min_p_reqs.clear() - self.frequency_penalties_reqs.clear() - self.presence_penalties_reqs.clear() - self.repetition_penalties_reqs.clear() - self.generators.clear() - self.num_logprobs.clear() - self.num_prompt_logprobs.clear() - self.request_lora_mapping.fill(0) - self.lora_id_to_lora_request.clear() - self.lora_id_to_request_ids.clear() - self.logit_bias = [None] * self.max_num_reqs - def condense(self, empty_req_indices: List[int]) -> None: - if self.num_reqs == 0: + num_reqs = self.num_reqs + if num_reqs == 0: # The batched states are empty. + self._req_ids.clear() + self.req_output_token_ids.clear() return # NOTE(woosuk): This function assumes that the empty_req_indices # is sorted in descending order. - last_req_index = self.num_reqs + len(empty_req_indices) - 1 + last_req_index = num_reqs + len(empty_req_indices) - 1 while empty_req_indices: # Find the largest non-empty index. while last_req_index in empty_req_indices: @@ -351,10 +354,13 @@ class InputBatch: break # Swap the states. - req_id = self.req_ids[last_req_index] + req_id = self._req_ids[last_req_index] + output_token_ids = self.req_output_token_ids[last_req_index] assert req_id is not None - self.req_ids[empty_index] = req_id - self.req_ids[last_req_index] = None + self._req_ids[empty_index] = req_id + self._req_ids[last_req_index] = None + self.req_output_token_ids[empty_index] = output_token_ids + self.req_output_token_ids[last_req_index] = None self.req_id_to_index[req_id] = empty_index num_tokens = self.num_tokens[last_req_index] @@ -379,13 +385,14 @@ class InputBatch: self.repetition_penalties_cpu[ empty_index] = self.repetition_penalties_cpu[last_req_index] self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index] - self.min_tokens[empty_index] = self.min_tokens[last_req_index] - self.stop_token_ids[empty_index] = self.stop_token_ids[ - last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator + min_token = self.min_tokens.pop(last_req_index, None) + if min_token is not None: + self.min_tokens[empty_index] = min_token + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ last_req_index] @@ -394,87 +401,71 @@ class InputBatch: # Decrement last_req_index since it is now empty. last_req_index -= 1 - def make_sampling_metadata( - self, - req_id_output_token_ids: Dict[str, List[int]], - req_id_to_spec_token_ids: Dict[str, List[int]], - skip_copy: bool = False, - ) -> SamplingMetadata: - if not skip_copy: - self.temperature[:self.num_reqs].copy_( - self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_p[:self.num_reqs].copy_( - self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_k[:self.num_reqs].copy_( - self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) - self.min_p[:self.num_reqs].copy_( - self.min_p_cpu_tensor[:self.num_reqs], non_blocking=True) - if not self.no_penalties: - # Since syncing these tensors is expensive only copy them - # if necessary i.e. if there are requests which require - # penalties to be applied during sampling. - self.frequency_penalties[:self.num_reqs].copy_( - self.frequency_penalties_cpu_tensor[:self.num_reqs], - non_blocking=True, - ) - self.presence_penalties[:self.num_reqs].copy_( - self.presence_penalties_cpu_tensor[:self.num_reqs], - non_blocking=True, - ) - self.repetition_penalties[:self.num_reqs].copy_( - self.repetition_penalties_cpu_tensor[:self.num_reqs], - non_blocking=True, - ) - # The prompt tokens are used only for applying penalties during - # the sampling process. Hence copy these tensors only when - # there are requests which need penalties to be applied. - self.prompt_token_ids = self._make_prompt_token_ids_tensor() + # Trim lists to the batch size. + del self._req_ids[self.num_reqs:] + del self.req_output_token_ids[self.num_reqs:] - output_token_ids: List[List[int]] = [] - spec_token_ids: List[List[int]] = [] - rejection_sampling = False - for req_id in self.req_ids[:self.num_reqs]: - assert req_id is not None - # Currently we create a tensor for output_token_ids from scratch - # at each step. However, for the penalties computation what we - # need is stats about the token ids present in the output. This - # stats can be maintained incrementally instead of computing it - # from scratch at each step. - # TODO - Replace this with incremental update to output token - # statistics. - output_token_ids.append(req_id_output_token_ids[req_id]) - req_spec_token_ids = req_id_to_spec_token_ids.get(req_id, []) - spec_token_ids.append(req_spec_token_ids) - if req_spec_token_ids: - # If any of the requests require speculative decoding, set the - # flag to True. - rejection_sampling = True + def refresh_sampling_metadata(self): + self.sampling_metadata = self._make_sampling_metadata() + + def _make_sampling_metadata(self) -> SamplingMetadata: + num_reqs = self.num_reqs + copy_slice(self.temperature_cpu_tensor, self.temperature, num_reqs) + if not self.no_top_p: + copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) + if not self.no_top_k: + copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) + if not self.no_min_p: + copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs) + + if not self.no_penalties: + # Since syncing these tensors is expensive only copy them + # if necessary i.e. if there are requests which require + # penalties to be applied during sampling. + copy_slice(self.frequency_penalties_cpu_tensor, + self.frequency_penalties, num_reqs) + copy_slice(self.presence_penalties_cpu_tensor, + self.presence_penalties, num_reqs) + copy_slice(self.repetition_penalties_cpu_tensor, + self.repetition_penalties, num_reqs) + + # The prompt tokens are used only for applying penalties during + # the sampling process. Hence copy these tensors only when + # there are requests which need penalties to be applied. + prompt_token_ids = self._make_prompt_token_ids_tensor() + else: + prompt_token_ids = None return SamplingMetadata( - temperature=self.temperature[:self.num_reqs], + temperature=self.temperature[:num_reqs], all_greedy=self.all_greedy, all_random=self.all_random, - rejection_sampling=rejection_sampling, - top_p=self.top_p[:self.num_reqs], - top_k=self.top_k[:self.num_reqs], - min_p=self.min_p[:self.num_reqs], - no_min_p=self.no_min_p, - no_top_p=self.no_top_p, - no_top_k=self.no_top_k, + top_p=None if self.no_top_p else self.top_p[:num_reqs], + top_k=None if self.no_top_k else self.top_k[:num_reqs], + min_p=None if self.no_min_p else self.min_p[:num_reqs], generators=self.generators, max_num_logprobs=self.max_num_logprobs, - prompt_token_ids=self.prompt_token_ids, - frequency_penalties=self.frequency_penalties[:self.num_reqs], - presence_penalties=self.presence_penalties[:self.num_reqs], - repetition_penalties=self.repetition_penalties[:self.num_reqs], - output_token_ids=output_token_ids, - spec_token_ids=spec_token_ids, - min_tokens=self.min_tokens[:self.num_reqs], - stop_token_ids=self.stop_token_ids[:self.num_reqs], + prompt_token_ids=prompt_token_ids, + frequency_penalties=self.frequency_penalties[:num_reqs], + presence_penalties=self.presence_penalties[:num_reqs], + repetition_penalties=self.repetition_penalties[:num_reqs], + output_token_ids=cast(List[List[int]], self.req_output_token_ids), + spec_token_ids=None, + min_tokens=self.min_tokens, no_penalties=self.no_penalties, - logit_bias=self.logit_bias[:self.num_reqs], + logit_bias=self.logit_bias[:num_reqs], ) + def get_sampling_metadata( + self, + req_id_to_spec_token_ids: Dict[str, List[int]], + ) -> SamplingMetadata: + # Set the new spec token ids in the cached sampling metadata. + self.sampling_metadata.spec_token_ids = [ + req_id_to_spec_token_ids.get(req_id, []) for req_id in self.req_ids + ] if req_id_to_spec_token_ids else None + return self.sampling_metadata + def _make_prompt_token_ids_tensor(self) -> torch.Tensor: max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() prompt_token_ids_cpu_tensor = torch.empty( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5754422cb1f7..0ecc00acc790 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -31,7 +31,6 @@ from vllm.v1.engine.mm_input_cache import MMInputCacheClient from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput -from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import bind_kv_cache @@ -224,16 +223,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() - def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. The updated states are used by the `_prepare_inputs` function to create the input GPU tensors for the model. - Returns: - True if there is a new/resumed/paused/finished request in the batch. - If False, we can skip copying SamplingMetadata to the GPU. + The SamplingMetadata is updated and copied to the GPU if there is a + new/resumed/paused/finished request in the batch. """ # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: @@ -344,9 +342,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_new_tokens = (num_computed_tokens + len(req_data.new_token_ids) - req_state.num_tokens) - new_token_ids = (req_data.new_token_ids[-num_new_tokens:] - if num_new_tokens > 0 else []) - req_state.output_token_ids.extend(new_token_ids) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(req_data.new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + req_data.new_token_ids[-num_new_tokens:]) # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. @@ -380,7 +381,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.input_batch.num_tokens_no_spec[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, []) + req_id, ()) if spec_token_ids: start_index = end_token_index end_token_index += len(spec_token_ids) @@ -410,7 +411,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): if removed_req_indices: self.input_batch.condense(removed_req_indices) - return batch_changed + if batch_changed: + self.input_batch.refresh_sampling_metadata() def _prepare_inputs( self, @@ -429,8 +431,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) max_num_scheduled_tokens = 0 - for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): - assert req_id is not None + for i, req_id in enumerate(self.input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens[i] = num_tokens max_num_scheduled_tokens = max(max_num_scheduled_tokens, @@ -669,10 +670,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr = 0 - num_reqs = self.input_batch.num_reqs - for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): - assert req_id is not None - + for index, req_id in enumerate(self.input_batch.req_ids): req = self.requests[req_id] assert req.mrope_positions is not None @@ -726,12 +724,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): self, scheduler_output: "SchedulerOutput", cu_num_tokens: np.ndarray, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: # Get the number of spec decode tokens for each request. num_reqs = self.input_batch.num_reqs num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32) - for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): - assert req_id is not None + for i, req_id in enumerate(self.input_batch.req_ids): num_spec_decode_tokens[i] = len( scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) @@ -769,22 +766,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): return torch.from_numpy(spec_decode_logits_indices).to( self.device, non_blocking=True) - def _prepare_sampling( - self, - batch_changed: bool, - req_to_spec_token_ids: Dict[str, List[int]], - ) -> SamplingMetadata: - # Create the sampling metadata. - req_id_output_token_ids: Dict[str, List[int]] = \ - {req_id: req.output_token_ids \ - for req_id, req in self.requests.items()} - - sampling_metadata = self.input_batch.make_sampling_metadata( - req_id_output_token_ids, - req_to_spec_token_ids, - skip_copy=not batch_changed) - return sampling_metadata - def _execute_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: @@ -838,9 +819,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output: "SchedulerOutput", ) -> List[torch.Tensor]: encoder_outputs: List[torch.Tensor] = [] - num_reqs = self.input_batch.num_reqs - for req_id in self.input_batch.req_ids[:num_reqs]: - assert req_id is not None + for req_id in self.input_batch.req_ids: num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] @@ -882,7 +861,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: - batch_changed = self._update_states(scheduler_output) + self._update_states(scheduler_output) if self.is_multimodal_model: # Run the multimodal encoder if any. @@ -964,8 +943,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): logits = self.model.compute_logits(sample_hidden_states, None) # Sample the next token and get logprobs if needed. - sampling_metadata = self._prepare_sampling( - batch_changed, scheduler_output.scheduled_spec_decode_tokens) + sampling_metadata = self.input_batch.get_sampling_metadata( + scheduler_output.scheduled_spec_decode_tokens) sampler_output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, @@ -973,14 +952,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. - num_reqs = self.input_batch.num_reqs - req_ids: List[str] = [] - # Because `input_batch.req_ids` is a list of length `max_num_reqs`, - # we need to stop at `num_reqs`. - # FIXME(woosuk): This is hacky. Refactor. - for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): - assert req_id is not None - req_ids.append(req_id) + for i, req_id in enumerate(self.input_batch.req_ids): req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) @@ -1027,7 +999,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): valid_sampled_token_ids) model_runner_output = ModelRunnerOutput( - req_ids=req_ids, + req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, spec_token_ids=spec_token_ids, @@ -1041,19 +1013,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): sampled_token_ids: List[List[int]], ) -> List[List[int]]: # TODO(woosuk): Optimize. - num_reqs = len(sampled_token_ids) draft_token_ids: List[List[int]] = [] - for i in range(num_reqs): - if len(sampled_token_ids[i]) == 0: + for i, sampled_ids in enumerate(sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: # Skip speculative decoding. draft_token_ids.append([]) continue # Add sampled_token_ids to token_ids_cpu. start_idx = self.input_batch.num_tokens_no_spec[i] - end_idx = start_idx + len(sampled_token_ids[i]) - self.input_batch.token_ids_cpu[ - i, start_idx:end_idx] = sampled_token_ids[i] + end_idx = start_idx + num_sampled_ids + self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids drafter_output = self.drafter.propose( self.input_batch.token_ids_cpu[i, :end_idx], self.speculative_config.ngram_prompt_lookup_min, @@ -1204,7 +1175,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # multiplying the list, to avoid Dynamo from treating them as # tensor aliasing. dummy_kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) + torch.tensor((), dtype=torch.float32, device=self.device) for _ in range(self.num_attn_layers) ] diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 4ee6853ba7ef..e60268f04527 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1048,8 +1048,6 @@ def swap_positions(b: InputBatch, id_1, id_2): b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[ id_1] - b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[ - id_2], b.stop_token_ids[id_1] gen_1 = b.generators.pop(id_1, None) gen_2 = b.generators.pop(id_2, None)