diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 914597e7ae629..60d5388b21b0e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -850,6 +850,10 @@ class Scheduler(SchedulerInterface): pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits + if sampled_token_ids is not None: + # Optimization: Avoid a .tolist() call for each request. + sampled_token_ids = sampled_token_ids.tolist() + outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None @@ -868,11 +872,13 @@ class Scheduler(SchedulerInterface): continue req_index = model_runner_output.req_id_to_index[req_id] - num_sampled = num_sampled_tokens[req_index] - if num_sampled > 0: - generated_token_ids = sampled_token_ids[:num_sampled].tolist() - else: - generated_token_ids = [] + generated_token_ids: list[int] = [] + if sampled_token_ids is not None: + assert num_sampled_tokens is not None + num_sampled = num_sampled_tokens[req_index] + if num_sampled > 0: + generated_token_ids = sampled_token_ids[ + req_index][:num_sampled] scheduled_spec_token_ids = ( scheduler_output.scheduled_spec_decode_tokens.get(req_id)) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 2a2a498ec2453..77864beba836d 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -92,9 +92,9 @@ class ModelRunnerOutput: # [num_reqs] # Number of tokens sampled in the current step. Each request may generate # different number of tokens due to chunked prefilling and spec decoding. - num_sampled_tokens: np.ndarray + num_sampled_tokens: Optional[np.ndarray] # [num_reqs, max_num_sampled_tokens] - sampled_token_ids: np.ndarray + sampled_token_ids: Optional[np.ndarray] # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] @@ -128,8 +128,8 @@ class DraftTokenIds: EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( req_ids=[], req_id_to_index={}, - num_sampled_tokens=np.empty(0, dtype=np.int32), - sampled_token_ids=np.empty((0, 0), dtype=np.int32), + num_sampled_tokens=None, + sampled_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 3d5e59addfcfa..36d2eacc0ad38 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional +import numpy as np import torch import torch.nn as nn @@ -106,9 +107,9 @@ class RejectionSampler(nn.Module): @staticmethod def parse_output( - output_token_ids: torch.Tensor, + output_token_ids: np.ndarray, vocab_size: int, - ) -> list[list[int]]: + ) -> np.ndarray: """Parse the output of the rejection sampler. Args: @@ -119,17 +120,14 @@ class RejectionSampler(nn.Module): vocab_size: The size of the vocabulary. Returns: - A list of lists of token IDs. + A Numpy array of the number of valid sampled tokens. """ - output_token_ids_np = output_token_ids.cpu().numpy() # Create mask for valid tokens. - valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & - (output_token_ids_np < vocab_size)) - outputs = [ - row[valid_mask[i]].tolist() - for i, row in enumerate(output_token_ids_np) - ] - return outputs + valid_mask = ((output_token_ids != PLACEHOLDER_TOKEN_ID) & + (output_token_ids < vocab_size)) + # Get the number until the first valid_mask=False. + num_sampled_tokens = np.cumprod(valid_mask, axis=1).sum(axis=1) + return num_sampled_tokens def rejection_sample( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 08e13ab887bf9..99147a2a2e0ec 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1456,7 +1456,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=[], + num_sampled_tokens=None, + sampled_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, @@ -1665,23 +1666,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) - # TODO(woosuk): The following loop can be slow since it iterates over - # the requests one by one. Optimize. - discard_sampled_tokens_req_indices = [] - for i, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - if seq_len < req_state.num_tokens: + # Post-processing for chunked prefill. + num_reqs = self.input_batch.num_reqs + chunked_prefilling = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens_np + < self.input_batch.num_tokens_no_spec[:num_reqs]) + if self.input_batch.generators: + chunked_prefill_indices = np.where(chunked_prefilling)[0] + for i in chunked_prefill_indices: # Ignore the sampled token for partial prefills. # Rewind the generator state as if the token was not sampled. # This relies on cuda-specific torch-internal impl details generator = self.input_batch.generators.get(i) if generator is not None: generator.set_offset(generator.get_offset() - 4) - # Record the index of the request that should not be sampled, - # so that we could clear the sampled tokens before returning. - discard_sampled_tokens_req_indices.append(i) # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. @@ -1700,16 +1699,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: # No spec decode tokens. - valid_sampled_token_ids = self._to_list(sampled_token_ids) + sampled_token_ids_np = self._to_numpy(sampled_token_ids) + num_sampled_tokens = (~chunked_prefilling).astype(np.int32) else: # Includes spec decode tokens. - valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, + sampled_token_ids_np = sampled_token_ids.cpu().numpy() + num_sampled_tokens = self.rejection_sampler.parse_output( + sampled_token_ids_np, self.input_batch.vocab_size, ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() + num_sampled_tokens *= ~chunked_prefilling # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. @@ -1717,9 +1716,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. req_ids = self.input_batch.req_ids - for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): - if not sampled_ids: + for req_idx in range(num_reqs): + num_sampled = num_sampled_tokens[req_idx] + if num_sampled == 0: continue + sampled_ids = sampled_token_ids_np[req_idx][:num_sampled].tolist() start_idx = self.input_batch.num_tokens_no_spec[req_idx] end_idx = start_idx + len(sampled_ids) @@ -1740,7 +1741,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert spec_decode_common_attn_metadata is not None self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, - valid_sampled_token_ids, + num_sampled_tokens, + sampled_token_ids_np, sampling_metadata, hidden_states, sample_hidden_states, @@ -1754,7 +1756,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, + num_sampled_tokens=num_sampled_tokens, + sampled_token_ids=sampled_token_ids_np, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], @@ -1776,7 +1779,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", - sampled_token_ids: list[list[int]], + num_sampled_tokens: np.ndarray, + sampled_token_ids: np.ndarray, sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, @@ -1788,19 +1792,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.speculative_config.method == "ngram": assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.propose_ngram_draft_token_ids( - sampled_token_ids) + num_sampled_tokens) elif self.speculative_config.method == "medusa": assert isinstance(self.drafter, MedusaProposer) - if sample_hidden_states.shape[0] == len(sampled_token_ids): + if sample_hidden_states.shape[0] == len(num_sampled_tokens): # The input to the target model does not include draft tokens. hidden_states = sample_hidden_states else: indices = [] offset = 0 - for num_draft, tokens in zip( + for num_draft, num_sampled in zip( spec_decode_metadata.num_draft_tokens, - sampled_token_ids): - indices.append(offset + len(tokens) - 1) + num_sampled_tokens): + indices.append(offset + num_sampled - 1) offset += num_draft + 1 indices = torch.tensor(indices, device=self.device) hidden_states = sample_hidden_states[indices] @@ -1813,11 +1817,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. req_ids = self.input_batch.req_ids + num_reqs = self.input_batch.num_reqs next_token_ids: list[int] = [] - for i, token_ids in enumerate(sampled_token_ids): - if token_ids: + for i in range(num_reqs): + num_sampled = num_sampled_tokens[i] + if num_sampled > 0: # Common case. - next_token_id = token_ids[-1] + next_token_id = sampled_token_ids[i][num_sampled - 1] else: # Partial prefill (rare case). # Get the next token id from the request state. @@ -1844,13 +1850,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): target_hidden_states = hidden_states[:num_scheduled_tokens] else: # TODO(woosuk): Refactor this. - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, - dtype=torch.int32) + num_draft_tokens = np.asarray( + spec_decode_metadata.num_draft_tokens, dtype=np.int32) + num_accepted_tokens = num_sampled_tokens - 1 + num_rejected_tokens = np.clip(num_draft_tokens - + num_accepted_tokens, + a_min=0) + num_rejected_tokens_cpu = torch.from_numpy(num_rejected_tokens) common_attn_metadata, token_indices =\ self.drafter.prepare_inputs( common_attn_metadata, num_rejected_tokens_cpu) @@ -1881,13 +1887,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def propose_ngram_draft_token_ids( self, - sampled_token_ids: list[list[int]], + num_sampled_tokens: np.ndarray, ) -> list[list[int]]: # TODO(woosuk): Optimize. req_ids = self.input_batch.req_ids + num_reqs = self.input_batch.num_reqs draft_token_ids: list[list[int]] = [] - for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) + for i in range(num_reqs): + num_sampled_ids = num_sampled_tokens[i] if not num_sampled_ids: # Skip speculative decoding. draft_token_ids.append([]) @@ -3267,7 +3274,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return kv_cache_spec - def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + def _to_numpy(self, sampled_token_ids: torch.Tensor) -> np.ndarray: # This is a short term mitigation for issue mentioned in # https://github.com/vllm-project/vllm/issues/22754. # `tolist` would trigger a cuda wise stream sync, which @@ -3280,4 +3287,4 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pinned.copy_(sampled_token_ids, non_blocking=True) self.transfer_event.record() self.transfer_event.synchronize() - return pinned.tolist() + return pinned.numpy()