mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:15:31 +08:00
[Sampler] Introduce logprobs mode for logging (#21398)
Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
parent
23637dcdef
commit
accac82928
@ -12,6 +12,7 @@ from tests.v1.sample.utils import (
|
|||||||
assert_incr_detok_str_matches_non_incr_detok_str,
|
assert_incr_detok_str_matches_non_incr_detok_str,
|
||||||
compute_correct_cumulative_logprob, get_test_batch)
|
compute_correct_cumulative_logprob, get_test_batch)
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
from vllm.config import LogprobsMode
|
||||||
|
|
||||||
from ...conftest import HfRunner, VllmRunner
|
from ...conftest import HfRunner, VllmRunner
|
||||||
|
|
||||||
@ -426,3 +427,45 @@ def test_zero_logprobs(vllm_model, example_prompts,
|
|||||||
# prompt token
|
# prompt token
|
||||||
assert prompt_logprobs is not None
|
assert prompt_logprobs is not None
|
||||||
assert len(prompt_token_ids) == len(prompt_logprobs)
|
assert len(prompt_token_ids) == len(prompt_logprobs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"logprobs_mode",
|
||||||
|
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
|
||||||
|
def test_logprobs_mode(logprobs_mode: LogprobsMode,
|
||||||
|
monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Test with LLM engine with different logprobs_mode.
|
||||||
|
For logprobs, we should have non-positive values.
|
||||||
|
For logits, we should expect at least one positive values.
|
||||||
|
"""
|
||||||
|
from vllm import LLM
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
"facebook/opt-125m",
|
||||||
|
max_logprobs=5,
|
||||||
|
enable_prefix_caching=False,
|
||||||
|
# 2 other llms alive during whole session
|
||||||
|
gpu_memory_utilization=0.05,
|
||||||
|
max_model_len=16,
|
||||||
|
logprobs_mode=logprobs_mode)
|
||||||
|
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||||
|
results = llm.generate(["Hello world"],
|
||||||
|
sampling_params=vllm_sampling_params)
|
||||||
|
|
||||||
|
total_token_with_logprobs = 0
|
||||||
|
positive_values = 0
|
||||||
|
for output in results[0].outputs:
|
||||||
|
for logprobs in output.logprobs:
|
||||||
|
for token_id in logprobs:
|
||||||
|
logprob = logprobs[token_id]
|
||||||
|
if "logprobs" in logprobs_mode:
|
||||||
|
assert logprob.logprob <= 0
|
||||||
|
if logprob.logprob > 0:
|
||||||
|
positive_values = positive_values + 1
|
||||||
|
total_token_with_logprobs = total_token_with_logprobs + 1
|
||||||
|
assert total_token_with_logprobs >= len(results[0].outputs)
|
||||||
|
if "logits" in logprobs_mode:
|
||||||
|
assert positive_values > 0
|
||||||
|
del llm
|
||||||
|
|||||||
@ -219,6 +219,8 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
|
|||||||
|
|
||||||
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
|
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
|
||||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||||
|
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
|
||||||
|
"processed_logits"]
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@ -316,6 +318,13 @@ class ModelConfig:
|
|||||||
"""Maximum number of log probabilities to return when `logprobs` is
|
"""Maximum number of log probabilities to return when `logprobs` is
|
||||||
specified in `SamplingParams`. The default value comes the default for the
|
specified in `SamplingParams`. The default value comes the default for the
|
||||||
OpenAI Chat Completions API."""
|
OpenAI Chat Completions API."""
|
||||||
|
logprobs_mode: LogprobsMode = "raw_logprobs"
|
||||||
|
"""Indicates the content returned in the logprobs and prompt_logprobs.
|
||||||
|
Supported mode:
|
||||||
|
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
|
||||||
|
Raw means the values before applying logit processors, like bad words.
|
||||||
|
Processed means the values after applying such processors.
|
||||||
|
"""
|
||||||
disable_sliding_window: bool = False
|
disable_sliding_window: bool = False
|
||||||
"""Whether to disable sliding window. If True, we will disable the sliding
|
"""Whether to disable sliding window. If True, we will disable the sliding
|
||||||
window functionality of the model, capping to sliding window size. If the
|
window functionality of the model, capping to sliding window size. If the
|
||||||
|
|||||||
@ -26,13 +26,13 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
|||||||
DetailedTraceModules, Device, DeviceConfig,
|
DetailedTraceModules, Device, DeviceConfig,
|
||||||
DistributedExecutorBackend, GuidedDecodingBackend,
|
DistributedExecutorBackend, GuidedDecodingBackend,
|
||||||
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
|
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
|
||||||
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
|
KVTransferConfig, LoadConfig, LoadFormat,
|
||||||
ModelConfig, ModelDType, ModelImpl, MultiModalConfig,
|
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
|
||||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
ModelImpl, MultiModalConfig, ObservabilityConfig,
|
||||||
PrefixCachingHashAlgo, PromptAdapterConfig,
|
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
|
||||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
|
||||||
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
|
SpeculativeConfig, TaskOption, TokenizerMode,
|
||||||
get_field)
|
VllmConfig, get_attr_docs, get_field)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import CpuArchEnum, current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
from vllm.plugins import load_general_plugins
|
from vllm.plugins import load_general_plugins
|
||||||
@ -324,6 +324,7 @@ class EngineArgs:
|
|||||||
SchedulerConfig.long_prefill_token_threshold
|
SchedulerConfig.long_prefill_token_threshold
|
||||||
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
|
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
|
||||||
max_logprobs: int = ModelConfig.max_logprobs
|
max_logprobs: int = ModelConfig.max_logprobs
|
||||||
|
logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
|
||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
revision: Optional[str] = ModelConfig.revision
|
revision: Optional[str] = ModelConfig.revision
|
||||||
code_revision: Optional[str] = ModelConfig.code_revision
|
code_revision: Optional[str] = ModelConfig.code_revision
|
||||||
@ -490,6 +491,8 @@ class EngineArgs:
|
|||||||
**model_kwargs["max_seq_len_to_capture"])
|
**model_kwargs["max_seq_len_to_capture"])
|
||||||
model_group.add_argument("--max-logprobs",
|
model_group.add_argument("--max-logprobs",
|
||||||
**model_kwargs["max_logprobs"])
|
**model_kwargs["max_logprobs"])
|
||||||
|
model_group.add_argument("--logprobs-mode",
|
||||||
|
**model_kwargs["logprobs_mode"])
|
||||||
model_group.add_argument("--disable-sliding-window",
|
model_group.add_argument("--disable-sliding-window",
|
||||||
**model_kwargs["disable_sliding_window"])
|
**model_kwargs["disable_sliding_window"])
|
||||||
model_group.add_argument("--disable-cascade-attn",
|
model_group.add_argument("--disable-cascade-attn",
|
||||||
@ -892,6 +895,7 @@ class EngineArgs:
|
|||||||
enforce_eager=self.enforce_eager,
|
enforce_eager=self.enforce_eager,
|
||||||
max_seq_len_to_capture=self.max_seq_len_to_capture,
|
max_seq_len_to_capture=self.max_seq_len_to_capture,
|
||||||
max_logprobs=self.max_logprobs,
|
max_logprobs=self.max_logprobs,
|
||||||
|
logprobs_mode=self.logprobs_mode,
|
||||||
disable_sliding_window=self.disable_sliding_window,
|
disable_sliding_window=self.disable_sliding_window,
|
||||||
disable_cascade_attn=self.disable_cascade_attn,
|
disable_cascade_attn=self.disable_cascade_attn,
|
||||||
skip_tokenizer_init=self.skip_tokenizer_init,
|
skip_tokenizer_init=self.skip_tokenizer_init,
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.config import LogprobsMode
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
@ -18,10 +19,11 @@ _SAMPLING_EPS = 1e-5
|
|||||||
|
|
||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.topk_topp_sampler = TopKTopPSampler()
|
self.topk_topp_sampler = TopKTopPSampler()
|
||||||
self.pin_memory = is_pin_memory_available()
|
self.pin_memory = is_pin_memory_available()
|
||||||
|
self.logprobs_mode = logprobs_mode
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -36,7 +38,10 @@ class Sampler(nn.Module):
|
|||||||
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
|
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
|
||||||
num_logprobs = sampling_metadata.max_num_logprobs
|
num_logprobs = sampling_metadata.max_num_logprobs
|
||||||
if num_logprobs is not None:
|
if num_logprobs is not None:
|
||||||
raw_logprobs = self.compute_logprobs(logits)
|
if self.logprobs_mode == "raw_logprobs":
|
||||||
|
raw_logprobs = self.compute_logprobs(logits)
|
||||||
|
elif self.logprobs_mode == "raw_logits":
|
||||||
|
raw_logprobs = logits.clone()
|
||||||
|
|
||||||
# Use float32 for the logits.
|
# Use float32 for the logits.
|
||||||
logits = logits.to(torch.float32)
|
logits = logits.to(torch.float32)
|
||||||
@ -51,6 +56,14 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||||
logits = self.apply_penalties(logits, sampling_metadata)
|
logits = self.apply_penalties(logits, sampling_metadata)
|
||||||
|
|
||||||
|
# Get the process logprobs or logits.
|
||||||
|
if num_logprobs is not None:
|
||||||
|
if self.logprobs_mode == "processed_logprobs":
|
||||||
|
raw_logprobs = self.compute_logprobs(logits)
|
||||||
|
elif self.logprobs_mode == "processed_logits":
|
||||||
|
raw_logprobs = logits.clone()
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
sampled = self.sample(logits, sampling_metadata)
|
sampled = self.sample(logits, sampling_metadata)
|
||||||
# Convert sampled token ids to int64 (long) type to ensure compatibility
|
# Convert sampled token ids to int64 (long) type to ensure compatibility
|
||||||
|
|||||||
@ -15,6 +15,7 @@ _SAMPLING_EPS = 1e-5
|
|||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
# TODO(houseroad): Add support for logprobs_mode.
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.topk_topp_sampler = TopKTopPSampler()
|
self.topk_topp_sampler = TopKTopPSampler()
|
||||||
|
|
||||||
|
|||||||
@ -389,7 +389,7 @@ class InputBatch:
|
|||||||
|
|
||||||
def remove_request(self, req_id: str) -> Optional[int]:
|
def remove_request(self, req_id: str) -> Optional[int]:
|
||||||
"""This method must always be followed by a call to condense().
|
"""This method must always be followed by a call to condense().
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
req_id: request to remove
|
req_id: request to remove
|
||||||
|
|
||||||
@ -590,7 +590,7 @@ class InputBatch:
|
|||||||
|
|
||||||
def refresh_metadata(self):
|
def refresh_metadata(self):
|
||||||
"""Apply batch updates, reset input batch at end of step
|
"""Apply batch updates, reset input batch at end of step
|
||||||
|
|
||||||
* Apply batch add/remove/permute to logits procs' states
|
* Apply batch add/remove/permute to logits procs' states
|
||||||
* If batch state is modified, update sampling metadata
|
* If batch state is modified, update sampling metadata
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -151,7 +151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.encoder_cache_size = encoder_cache_size
|
self.encoder_cache_size = encoder_cache_size
|
||||||
|
|
||||||
# Sampler
|
# Sampler
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
||||||
|
|
||||||
self.eplb_state: Optional[EplbState] = None
|
self.eplb_state: Optional[EplbState] = None
|
||||||
"""
|
"""
|
||||||
@ -1996,7 +1996,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
|
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
|
||||||
This is to help balance expert-selection
|
This is to help balance expert-selection
|
||||||
- during profile_run
|
- during profile_run
|
||||||
- during DP rank dummy run
|
- during DP rank dummy run
|
||||||
"""
|
"""
|
||||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||||
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
|
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user