diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 9bf0c5842c6be..b89768913681e 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -154,12 +154,15 @@ differences compared to V0: ##### Logprobs Calculation -Logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e. +By default, logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e. before applying any logits post-processing such as temperature scaling or penalty adjustments). As a result, the returned logprobs do not reflect the final adjusted probabilities used during sampling. -Support for logprobs with post-sampling adjustments is in progress and will be added in future updates. +You can adjust this behavior by setting the `--logprobs-mode` flag. +Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`. +Raw means the values before applying any logit processors, like bad words. +Processed means the values after applying all processors, including temperature and top_k/top_p. ##### Prompt Logprobs with Prefix Caching diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 8bd142e87b06e..e835c029634ce 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -456,9 +456,7 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): assert len(logprob) == vocab_size -@pytest.mark.parametrize( - "logprobs_mode", - ["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"]) +@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode)) def test_logprobs_mode(logprobs_mode: LogprobsMode, monkeypatch: pytest.MonkeyPatch): """Test with LLM engine with different logprobs_mode. @@ -487,12 +485,14 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode, for logprobs in output.logprobs: for token_id in logprobs: logprob = logprobs[token_id] - if "logprobs" in logprobs_mode: + if logprobs_mode in (LogprobsMode.RAW_LOGPROBS, + LogprobsMode.PROCESSED_LOGPROBS): assert logprob.logprob <= 0 if logprob.logprob > 0: positive_values = positive_values + 1 total_token_with_logprobs = total_token_with_logprobs + 1 assert total_token_with_logprobs >= len(results[0].outputs) - if "logits" in logprobs_mode: + if logprobs_mode in (LogprobsMode.RAW_LOGITS, + LogprobsMode.PROCESSED_LOGITS): assert positive_values > 0 del llm diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 959f111ced22e..2973cb92d195b 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -257,11 +257,16 @@ def is_init_field(cls: ConfigType, name: str) -> bool: TokenizerMode = Literal["auto", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] -LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs", - "processed_logits"] MMEncoderTPMode = Literal["weights", "data"] +class LogprobsMode(enum.Enum): + RAW_LOGITS = "raw_logits" + RAW_LOGPROBS = "raw_logprobs" + PROCESSED_LOGITS = "processed_logits" + PROCESSED_LOGPROBS = "processed_logprobs" + + @config @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class ModelConfig: @@ -363,12 +368,13 @@ class ModelConfig: specified in `SamplingParams`. The default value comes the default for the OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length * vocab_size) logprobs are allowed to be returned and it may cause OOM.""" - logprobs_mode: LogprobsMode = "raw_logprobs" + logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS """Indicates the content returned in the logprobs and prompt_logprobs. Supported mode: 1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits. - Raw means the values before applying logit processors, like bad words. - Processed means the values after applying such processors. + Raw means the values before applying any logit processors, like bad words. + Processed means the values after applying all processors, including + temperature and top_k/top_p. """ disable_sliding_window: bool = False """Whether to disable sliding window. If True, we will disable the sliding @@ -2586,7 +2592,7 @@ class MultiModalConfig: skip_mm_profiling: bool = False """ - When enabled, skips multimodal memory profiling and only profiles with + When enabled, skips multimodal memory profiling and only profiles with language backbone model during engine initialization. This reduces engine startup time but shifts the responsibility to users for @@ -2649,24 +2655,24 @@ class PoolerConfig: ## for embeddings models normalize: Optional[bool] = None """ - Whether to normalize the embeddings outputs. + Whether to normalize the embeddings outputs. """ dimensions: Optional[int] = None """ - Reduce the dimensions of embeddings if model + Reduce the dimensions of embeddings if model support matryoshka representation. """ ## for classification models activation: Optional[bool] = None """ - Whether to apply activation function to the classification outputs. + Whether to apply activation function to the classification outputs. """ ## for reward models softmax: Optional[bool] = None """ - Whether to apply softmax to the reward outputs. + Whether to apply softmax to the reward outputs. """ step_tag_id: Optional[int] = None """ @@ -2692,9 +2698,9 @@ class PoolerConfig: max_embed_len: Optional[int] = None """ - Maximum input length allowed for embedding generation. When set, allows + Maximum input length allowed for embedding generation. When set, allows inputs longer than max_embed_len to be accepted for embedding models. - This parameter enables accepting long inputs without requiring + This parameter enables accepting long inputs without requiring VLLM_ALLOW_LONG_MAX_MODEL_LEN environment variable. When an input exceeds max_embed_len, it will be handled according to the original max_model_len validation logic. Defaults to None (i.e. set to max_model_len). diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f3afc015f669c..b0f50b4429a82 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -516,6 +516,7 @@ class EngineArgs: model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"]) model_group.add_argument("--logprobs-mode", + choices=[f.value for f in LogprobsMode], **model_kwargs["logprobs_mode"]) model_group.add_argument("--disable-sliding-window", **model_kwargs["disable_sliding_window"]) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index e0434c8f3d713..7bd4a5a380ac0 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -8,6 +8,7 @@ import torch.nn as nn from packaging import version from vllm import envs +from vllm.config import LogprobsMode from vllm.logger import init_logger from vllm.platforms import current_platform @@ -28,9 +29,16 @@ class TopKTopPSampler(nn.Module): Implementations may update the logits tensor in-place. """ - def __init__(self): + def __init__( + self, + logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS) -> None: super().__init__() - if current_platform.is_cuda(): + self.logprobs_mode = logprobs_mode + # flashinfer optimization does not apply if intermediate + # logprobs/logits after top_k/top_p need to be returned + if logprobs_mode not in (LogprobsMode.PROCESSED_LOGITS, + LogprobsMode.PROCESSED_LOGPROBS + ) and current_platform.is_cuda(): if is_flashinfer_available: flashinfer_version = flashinfer.__version__ if version.parse(flashinfer_version) < version.parse("0.2.3"): @@ -63,10 +71,12 @@ class TopKTopPSampler(nn.Module): "native implementation of top-p & top-k sampling. For the " "best performance, please install FlashInfer.") self.forward = self.forward_native - elif current_platform.is_tpu(): - self.forward = self.forward_tpu else: self.forward = self.forward_native + if current_platform.is_tpu(): + self.apply_top_k_top_p = apply_top_k_top_p_tpu + else: + self.apply_top_k_top_p = apply_top_k_top_p def forward_native( self, @@ -74,15 +84,20 @@ class TopKTopPSampler(nn.Module): generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ PyTorch-native implementation of top-k and top-p sampling. The logits tensor may be updated in-place. """ - logits = apply_top_k_top_p(logits, k, p) + logits = self.apply_top_k_top_p(logits, k, p) + logits_to_return = None + if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: + logits_to_return = logits + elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: + logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) + return random_sample(probs, generators), logits_to_return def forward_cuda( self, @@ -90,34 +105,24 @@ class TopKTopPSampler(nn.Module): generators: dict[int, torch.Generator], k: Optional[torch.Tensor], p: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """More optimized implementation for top-k and top-p sampling.""" - if k is None and p is None: - # We prefer `random_sample` over `flashinfer_sample` when sorting is - # not needed. This is because `random_sample` does not require - # CPU-GPU synchronization while `flashinfer_sample` does. - probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) - if generators: - logger.warning_once("FlashInfer 0.2.3+ does not support " - "per-request generators. Falling back to " - "PyTorch-native implementation.") + # We prefer `random_sample` over `flashinfer_sample` when sorting is + # not needed. This is because `random_sample` does not require + # CPU-GPU synchronization while `flashinfer_sample` does. + if (k is None and p is None) or generators: + if generators: + logger.warning_once("FlashInfer 0.2.3+ does not support " + "per-request generators. Falling back to " + "PyTorch-native implementation.") return self.forward_native(logits, generators, k, p) + assert self.logprobs_mode not in ( + LogprobsMode.PROCESSED_LOGITS, LogprobsMode.PROCESSED_LOGPROBS + ), "FlashInfer does not support returning logits/logprobs" # flashinfer sampling functions expect contiguous logits. # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous # because of slicing operation in logits_processor. - return flashinfer_sample(logits.contiguous(), k, p, generators) - - def forward_tpu( - self, - logits: torch.Tensor, - generators: dict[int, torch.Generator], - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], - ) -> torch.Tensor: - logits = apply_top_k_top_p_tpu(logits, k, p) - probs = logits.softmax(dim=-1, dtype=torch.float32) - return random_sample(probs, generators) + return flashinfer_sample(logits.contiguous(), k, p, generators), None def apply_top_k_top_p_tpu( diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 82f51298f1b59..70ec8a0c26ddf 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A layer that samples the next tokens from the model's outputs.""" +from typing import Optional + import torch import torch.nn as nn @@ -18,10 +20,50 @@ _SAMPLING_EPS = 1e-5 class Sampler(nn.Module): + """ + A layer that samples the next tokens from the model's outputs + with the following steps in order: - def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"): + 1. If logprobs are requested: + a) If `logprobs_mode` is `raw_logprobs`, compute logprobs + as the final logprobs to return. + b) If `logprobs_mode` is `raw_logits`, clone the logits + as the final logprobs to return. + 2. Convert logits to float32. + 3. Apply allowed token ids whitelist. + 4. Apply bad words exclusion. + 5. Apply logit processors which are not argmax-invariant, + i.e. that can impact greedy sampling. + a) Min tokens processor + b) Logit bias processor + 6. Apply penalties + a) Repetition penalty + b) Frequency penalty + c) Presence penalty + 7. Sample the next tokens. `sample` method performs the following steps: + a) If not `all_random`, perform greedy sampling. If `all_greedy`, + return the greedily sampled tokens and final logprobs if requested. + b) Apply temperature. + c) Apply logit processors which are argmax-invariant, by default + the min_p processor. + d) Apply top_k and/or top_p. + e) Sample the next tokens with the probability distribution. + f) If `all_random` or temperature >= epsilon (1e-5), return the + randomly sampled tokens and final logprobs if requested. Else, + return the greedily sampled tokens and logprobs if requested. + 8. Gather the logprobs of the top `max_num_logprobs` and sampled token + (if requested). Note that if the sampled token is within the top + `max_num_logprobs`, the logprob will be eventually merged in + `LogprobsProcessor` during output processing. Therefore, the + final output may contain either `max_num_logprobs + 1` or + `max_num_logprobs` logprobs. + 9. Return the final `SamplerOutput`. + """ + + def __init__(self, + logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS): super().__init__() - self.topk_topp_sampler = TopKTopPSampler() + self.topk_topp_sampler = TopKTopPSampler(logprobs_mode) self.pin_memory = is_pin_memory_available() self.logprobs_mode = logprobs_mode @@ -34,13 +76,11 @@ class Sampler(nn.Module): # temperature scaling) for the top-k logprobs. # This is different from the V0 sampler, which uses the logits that # is used for sampling (after penalties and temperature scaling). - # TODO(rob): provide option for logprobs post sampling. - # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501 num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: - if self.logprobs_mode == "raw_logprobs": + if self.logprobs_mode == LogprobsMode.RAW_LOGPROBS: raw_logprobs = self.compute_logprobs(logits) - elif self.logprobs_mode == "raw_logits": + elif self.logprobs_mode == LogprobsMode.RAW_LOGITS: raw_logprobs = logits.clone() # Use float32 for the logits. @@ -57,15 +97,10 @@ class Sampler(nn.Module): # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) - # Get the process logprobs or logits. - if num_logprobs is not None: - if self.logprobs_mode == "processed_logprobs": - raw_logprobs = self.compute_logprobs(logits) - elif self.logprobs_mode == "processed_logits": - raw_logprobs = logits.clone() - # Sample the next token. - sampled = self.sample(logits, sampling_metadata) + sampled, processed_logprobs = self.sample(logits, sampling_metadata) + if processed_logprobs is not None: + raw_logprobs = processed_logprobs # Convert sampled token ids to int64 (long) type to ensure compatibility # with subsequent operations that may use these values as indices. # This conversion is necessary because FlashInfer sampling operations @@ -105,7 +140,7 @@ class Sampler(nn.Module): self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Sample logits based on sampling metadata. The various logits processing functions called in this method @@ -119,7 +154,13 @@ class Sampler(nn.Module): else: greedy_sampled = self.greedy_sample(logits) if sampling_metadata.all_greedy: - return greedy_sampled + processed_logprobs = None + if sampling_metadata.max_num_logprobs is not None: + if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: + processed_logprobs = logits + elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: + processed_logprobs = self.compute_logprobs(logits) + return greedy_sampled, processed_logprobs assert sampling_metadata.temperature is not None @@ -132,7 +173,7 @@ class Sampler(nn.Module): logits = processor.apply(logits) # Apply top_k and/or top_p. - random_sampled = self.topk_topp_sampler( + random_sampled, processed_logprobs = self.topk_topp_sampler( logits, sampling_metadata.generators, sampling_metadata.top_k, @@ -140,7 +181,7 @@ class Sampler(nn.Module): ) if greedy_sampled is None: - return random_sampled + return random_sampled, processed_logprobs sampled = torch.where( sampling_metadata.temperature < _SAMPLING_EPS, @@ -148,7 +189,7 @@ class Sampler(nn.Module): random_sampled, out=greedy_sampled, # Reuse tensor ) - return sampled + return sampled, processed_logprobs def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: return logits.log_softmax(dim=-1, dtype=torch.float32) diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py index 2c9f4892bc247..04545d587e4a9 100644 --- a/vllm/v1/sample/tpu/sampler.py +++ b/vllm/v1/sample/tpu/sampler.py @@ -65,7 +65,7 @@ class Sampler(nn.Module): logits = self.apply_min_p(logits, sampling_metadata.min_p) # Apply top_k and/or top_p. - random_sampled = self.topk_topp_sampler( + random_sampled, _ = self.topk_topp_sampler( logits, sampling_metadata.generators, sampling_metadata.top_k,