mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:55:01 +08:00
[Sampler] Introduce logprobs mode for logging (#21398)
Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
parent
23637dcdef
commit
accac82928
@ -12,6 +12,7 @@ from tests.v1.sample.utils import (
|
||||
assert_incr_detok_str_matches_non_incr_detok_str,
|
||||
compute_correct_cumulative_logprob, get_test_batch)
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import LogprobsMode
|
||||
|
||||
from ...conftest import HfRunner, VllmRunner
|
||||
|
||||
@ -426,3 +427,45 @@ def test_zero_logprobs(vllm_model, example_prompts,
|
||||
# prompt token
|
||||
assert prompt_logprobs is not None
|
||||
assert len(prompt_token_ids) == len(prompt_logprobs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"logprobs_mode",
|
||||
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
|
||||
def test_logprobs_mode(logprobs_mode: LogprobsMode,
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test with LLM engine with different logprobs_mode.
|
||||
For logprobs, we should have non-positive values.
|
||||
For logits, we should expect at least one positive values.
|
||||
"""
|
||||
from vllm import LLM
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
llm = LLM(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=5,
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.05,
|
||||
max_model_len=16,
|
||||
logprobs_mode=logprobs_mode)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
results = llm.generate(["Hello world"],
|
||||
sampling_params=vllm_sampling_params)
|
||||
|
||||
total_token_with_logprobs = 0
|
||||
positive_values = 0
|
||||
for output in results[0].outputs:
|
||||
for logprobs in output.logprobs:
|
||||
for token_id in logprobs:
|
||||
logprob = logprobs[token_id]
|
||||
if "logprobs" in logprobs_mode:
|
||||
assert logprob.logprob <= 0
|
||||
if logprob.logprob > 0:
|
||||
positive_values = positive_values + 1
|
||||
total_token_with_logprobs = total_token_with_logprobs + 1
|
||||
assert total_token_with_logprobs >= len(results[0].outputs)
|
||||
if "logits" in logprobs_mode:
|
||||
assert positive_values > 0
|
||||
del llm
|
||||
|
||||
@ -219,6 +219,8 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
|
||||
|
||||
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
|
||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
|
||||
"processed_logits"]
|
||||
|
||||
|
||||
@config
|
||||
@ -316,6 +318,13 @@ class ModelConfig:
|
||||
"""Maximum number of log probabilities to return when `logprobs` is
|
||||
specified in `SamplingParams`. The default value comes the default for the
|
||||
OpenAI Chat Completions API."""
|
||||
logprobs_mode: LogprobsMode = "raw_logprobs"
|
||||
"""Indicates the content returned in the logprobs and prompt_logprobs.
|
||||
Supported mode:
|
||||
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
|
||||
Raw means the values before applying logit processors, like bad words.
|
||||
Processed means the values after applying such processors.
|
||||
"""
|
||||
disable_sliding_window: bool = False
|
||||
"""Whether to disable sliding window. If True, we will disable the sliding
|
||||
window functionality of the model, capping to sliding window size. If the
|
||||
|
||||
@ -26,13 +26,13 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
DetailedTraceModules, Device, DeviceConfig,
|
||||
DistributedExecutorBackend, GuidedDecodingBackend,
|
||||
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
|
||||
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||
ModelConfig, ModelDType, ModelImpl, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||
PrefixCachingHashAlgo, PromptAdapterConfig,
|
||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
|
||||
get_field)
|
||||
KVTransferConfig, LoadConfig, LoadFormat,
|
||||
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
|
||||
ModelImpl, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
|
||||
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
|
||||
SpeculativeConfig, TaskOption, TokenizerMode,
|
||||
VllmConfig, get_attr_docs, get_field)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.plugins import load_general_plugins
|
||||
@ -324,6 +324,7 @@ class EngineArgs:
|
||||
SchedulerConfig.long_prefill_token_threshold
|
||||
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
|
||||
max_logprobs: int = ModelConfig.max_logprobs
|
||||
logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
|
||||
disable_log_stats: bool = False
|
||||
revision: Optional[str] = ModelConfig.revision
|
||||
code_revision: Optional[str] = ModelConfig.code_revision
|
||||
@ -490,6 +491,8 @@ class EngineArgs:
|
||||
**model_kwargs["max_seq_len_to_capture"])
|
||||
model_group.add_argument("--max-logprobs",
|
||||
**model_kwargs["max_logprobs"])
|
||||
model_group.add_argument("--logprobs-mode",
|
||||
**model_kwargs["logprobs_mode"])
|
||||
model_group.add_argument("--disable-sliding-window",
|
||||
**model_kwargs["disable_sliding_window"])
|
||||
model_group.add_argument("--disable-cascade-attn",
|
||||
@ -892,6 +895,7 @@ class EngineArgs:
|
||||
enforce_eager=self.enforce_eager,
|
||||
max_seq_len_to_capture=self.max_seq_len_to_capture,
|
||||
max_logprobs=self.max_logprobs,
|
||||
logprobs_mode=self.logprobs_mode,
|
||||
disable_sliding_window=self.disable_sliding_window,
|
||||
disable_cascade_attn=self.disable_cascade_attn,
|
||||
skip_tokenizer_init=self.skip_tokenizer_init,
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import LogprobsMode
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
@ -18,10 +19,11 @@ _SAMPLING_EPS = 1e-5
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler()
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.logprobs_mode = logprobs_mode
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -36,7 +38,10 @@ class Sampler(nn.Module):
|
||||
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
if self.logprobs_mode == "raw_logprobs":
|
||||
raw_logprobs = self.compute_logprobs(logits)
|
||||
elif self.logprobs_mode == "raw_logits":
|
||||
raw_logprobs = logits.clone()
|
||||
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
@ -51,6 +56,14 @@ class Sampler(nn.Module):
|
||||
|
||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||
logits = self.apply_penalties(logits, sampling_metadata)
|
||||
|
||||
# Get the process logprobs or logits.
|
||||
if num_logprobs is not None:
|
||||
if self.logprobs_mode == "processed_logprobs":
|
||||
raw_logprobs = self.compute_logprobs(logits)
|
||||
elif self.logprobs_mode == "processed_logits":
|
||||
raw_logprobs = logits.clone()
|
||||
|
||||
# Sample the next token.
|
||||
sampled = self.sample(logits, sampling_metadata)
|
||||
# Convert sampled token ids to int64 (long) type to ensure compatibility
|
||||
|
||||
@ -15,6 +15,7 @@ _SAMPLING_EPS = 1e-5
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
# TODO(houseroad): Add support for logprobs_mode.
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler()
|
||||
|
||||
|
||||
@ -151,7 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.encoder_cache_size = encoder_cache_size
|
||||
|
||||
# Sampler
|
||||
self.sampler = Sampler()
|
||||
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
||||
|
||||
self.eplb_state: Optional[EplbState] = None
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user