diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 6829fed33e455..295df1fbc1179 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1474,7 +1474,7 @@ class Scheduler(SchedulerInterface): affected_req_ids.add(request.request_id) - return (affected_req_ids, total_affected_tokens) + return affected_req_ids, total_affected_tokens def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: total_requests_to_reschedule = 0 diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py index c5989b37d22d7..5906a73382a2d 100644 --- a/vllm/v1/core/sched/utils.py +++ b/vllm/v1/core/sched/utils.py @@ -59,8 +59,7 @@ def check_stop( sampling_params = request.sampling_params assert sampling_params is not None - min_tokens = sampling_params.min_tokens - if request.num_output_tokens < min_tokens: + if request.num_output_tokens < sampling_params.min_tokens: return False last_token_id = request.output_token_ids[-1] diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 76555a8666857..72cee8c73969a 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -147,22 +147,20 @@ class RejectionSampler(nn.Module): sampling_metadata: SamplingMetadata, metadata: SpecDecodeMetadata, ) -> torch.Tensor: + has_penalties = not sampling_metadata.no_penalties any_penalties_or_bad_words = ( - sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties + sampling_metadata.bad_words_token_ids or has_penalties ) output_token_ids = sampling_metadata.output_token_ids if any_penalties_or_bad_words: output_token_ids = self._combine_outputs_with_spec_tokens( - sampling_metadata.output_token_ids, + output_token_ids, sampling_metadata.spec_token_ids, ) # Calculate indices of target logits. - if ( - sampling_metadata.allowed_token_ids_mask is not None - or not sampling_metadata.no_penalties - ): + if sampling_metadata.allowed_token_ids_mask is not None or has_penalties: num_requests = len(sampling_metadata.output_token_ids) num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu") original_indices = torch.arange(num_requests, device="cpu") @@ -180,18 +178,15 @@ class RejectionSampler(nn.Module): logits.masked_fill_(token_mask, float("-inf")) # Apply bad words exclusion. - if sampling_metadata.bad_words_token_ids: + if bad_words_token_ids := sampling_metadata.bad_words_token_ids: apply_bad_words_with_drafts( - logits, - sampling_metadata.bad_words_token_ids, - output_token_ids, - metadata.num_draft_tokens, + logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens ) return logits + @staticmethod def apply_penalties( - self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, metadata: SpecDecodeMetadata, @@ -218,8 +213,8 @@ class RejectionSampler(nn.Module): ) return logits + @staticmethod def _combine_outputs_with_spec_tokens( - self, output_token_ids: list[list[int]], spec_token_ids: Optional[list[list[int]]] = None, ) -> list[list[int]]: diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 80cc866487f53..2e076ca8e3c84 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -120,8 +120,8 @@ class Sampler(nn.Module): ) return sampler_output + @staticmethod def apply_temperature( - self, logits: torch.Tensor, temp: torch.Tensor, all_random: bool, @@ -132,7 +132,8 @@ class Sampler(nn.Module): temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) return logits.div_(temp.unsqueeze(dim=1)) - def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: + @staticmethod + def greedy_sample(logits: torch.Tensor) -> torch.Tensor: return logits.argmax(dim=-1).view(-1) def sample( @@ -191,11 +192,12 @@ class Sampler(nn.Module): ) return sampled, processed_logprobs - def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: + @staticmethod + def compute_logprobs(logits: torch.Tensor) -> torch.Tensor: return logits.log_softmax(dim=-1, dtype=torch.float32) + @staticmethod def gather_logprobs( - self, logprobs: torch.Tensor, num_logprobs: int, token_ids: torch.Tensor, @@ -238,8 +240,8 @@ class Sampler(nn.Module): return LogprobsTensors(indices, logprobs, token_ranks) + @staticmethod def _combine_outputs_with_spec_tokens( - self, output_token_ids: list[list[int]], spec_token_ids: Optional[list[list[int]]] = None, ) -> list[list[int]]: @@ -257,8 +259,9 @@ class Sampler(nn.Module): sampling_metadata: SamplingMetadata, predict_bonus_token: bool, ) -> torch.Tensor: + bad_words_token_ids = sampling_metadata.bad_words_token_ids any_penalties_or_bad_words = ( - sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties + bool(bad_words_token_ids) or not sampling_metadata.no_penalties ) output_token_ids = sampling_metadata.output_token_ids @@ -266,7 +269,7 @@ class Sampler(nn.Module): # Combine base outputs with spec tokens when speculative decoding # is enabled. output_token_ids = self._combine_outputs_with_spec_tokens( - sampling_metadata.output_token_ids, + output_token_ids, sampling_metadata.spec_token_ids, ) @@ -275,14 +278,8 @@ class Sampler(nn.Module): logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf")) # Apply bad words exclusion. - if sampling_metadata.bad_words_token_ids: - apply_bad_words( - logits, - sampling_metadata.bad_words_token_ids, - output_token_ids - if output_token_ids is not None - else sampling_metadata.output_token_ids, - ) + if bad_words_token_ids: + apply_bad_words(logits, bad_words_token_ids, output_token_ids) # Apply logits processors which can impact greedy sampling. for processor in sampling_metadata.logitsprocs.non_argmax_invariant: @@ -292,22 +289,21 @@ class Sampler(nn.Module): logits = self.apply_penalties(logits, sampling_metadata, output_token_ids) return logits + @staticmethod def apply_penalties( - self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - output_token_ids: Optional[list[list[int]]] = None, + output_token_ids: list[list[int]], ) -> torch.Tensor: - if not sampling_metadata.no_penalties: - assert sampling_metadata.prompt_token_ids is not None - logits = apply_all_penalties( - logits, - sampling_metadata.prompt_token_ids, - sampling_metadata.presence_penalties, - sampling_metadata.frequency_penalties, - sampling_metadata.repetition_penalties, - output_token_ids - if output_token_ids is not None - else sampling_metadata.output_token_ids, - ) - return logits + if sampling_metadata.no_penalties: + return logits + + assert sampling_metadata.prompt_token_ids is not None + return apply_all_penalties( + logits, + sampling_metadata.prompt_token_ids, + sampling_metadata.presence_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.repetition_penalties, + output_token_ids, + ) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 6d7473d8f44ba..5db843e99d6a7 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -62,10 +62,9 @@ class CachedRequestState: "provided via prompt_embeds, and its ID is unknown." ) return self.prompt_token_ids[idx] - elif idx - self.num_prompt_tokens < len(self.output_token_ids): + if idx - self.num_prompt_tokens < len(self.output_token_ids): return self.output_token_ids[idx - self.num_prompt_tokens] - else: - return -1 + return -1 class InputBatch: @@ -770,14 +769,13 @@ class InputBatch: not self.no_penalties or self.logits_processing_needs_token_ids[:num_reqs].any() ) - if needs_prompt_token_ids: - # The prompt tokens are used only for applying penalties or - # step pooling during the sampling/pooling process. - # Hence copy these tensors only when there are requests which - # need penalties/step_pooler to be applied. - prompt_token_ids = self._make_prompt_token_ids_tensor() - else: - prompt_token_ids = None + # The prompt tokens are used only for applying penalties or + # step pooling during the sampling/pooling process. + # Hence copy these tensors only when there are requests which + # need penalties/step_pooler to be applied. + prompt_token_ids = ( + self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None + ) allowed_token_ids_mask: Optional[torch.Tensor] = None if not self.no_allowed_token_ids: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ea3b18b447f3a..8ebfb1f2b8577 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1996,7 +1996,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Should be called after attention metadata creation. This just pads # the second ubatch slice out to the total number of tokens # (num_tokens + padding) - def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, num_total_tokens: int): + @staticmethod + def pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int): padded_second_ubatch_slice = slice( ubatch_slices[1].token_slice.start, num_total_tokens ) @@ -2085,12 +2086,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dict[str, Any], ]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + is_first_rank = get_pp_group().is_first_rank # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order if ( self.supports_mm_inputs - and get_pp_group().is_first_rank + and is_first_rank and not self.model_config.is_encoder_decoder ): # Run the multimodal encoder if any. @@ -2115,7 +2117,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): **self._init_model_kwargs(num_scheduled_tokens), **self._extract_mm_kwargs(scheduler_output), } - elif self.enable_prompt_embeds and get_pp_group().is_first_rank: + elif self.enable_prompt_embeds and is_first_rank: # Get the input embeddings for the tokens that are not input embeds, # then put them into the appropriate positions. # TODO(qthequartermasterman): Since even when prompt embeds are @@ -2155,7 +2157,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: positions = self.positions.gpu[:num_input_tokens] - if get_pp_group().is_first_rank: + if is_first_rank: intermediate_tensors = None else: intermediate_tensors = self.sync_and_slice_intermediate_tensors( @@ -2186,38 +2188,37 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: - sampler_output = self.sampler( + return self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) - else: - # When indexing with a tensor (bonus_logits_indices), PyTorch - # creates a new tensor with separate storage from the original - # logits tensor. This means any in-place operations on bonus_logits - # won't affect the original logits tensor. - assert logits is not None - bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] - sampler_output = self.sampler( - logits=bonus_logits, - sampling_metadata=sampling_metadata, - predict_bonus_token=True, - ) - bonus_token_ids = sampler_output.sampled_token_ids - # Just like `bonus_logits`, `target_logits` is a new tensor with - # separate storage from the original `logits` tensor. Therefore, - # it is safe to update `target_logits` in place. - target_logits = logits[spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( - spec_decode_metadata, - None, # draft_probs - target_logits, - bonus_token_ids, - sampling_metadata, - ) - sampler_output.sampled_token_ids = output_token_ids - self._update_states_after_model_execute(output_token_ids) + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + assert logits is not None + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + predict_bonus_token=True, + ) + bonus_token_ids = sampler_output.sampled_token_ids + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + self._update_states_after_model_execute(output_token_ids) return sampler_output def _bookkeeping_sync( @@ -3741,7 +3742,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): decode_cudagraph_batch_sizes = [ x for x in self.cudagraph_batch_sizes - if x <= max_num_tokens and x >= self.uniform_decode_query_len + if max_num_tokens >= x >= self.uniform_decode_query_len ] compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes)) self._capture_cudagraphs(