mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 20:35:40 +08:00
[Sampler] Support returning final logprobs (#22387)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
f64ee61d9e
commit
f571ff8eb6
@ -154,12 +154,15 @@ differences compared to V0:
|
|||||||
|
|
||||||
##### Logprobs Calculation
|
##### Logprobs Calculation
|
||||||
|
|
||||||
Logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e.
|
By default, logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e.
|
||||||
before applying any logits post-processing such as temperature scaling or penalty
|
before applying any logits post-processing such as temperature scaling or penalty
|
||||||
adjustments). As a result, the returned logprobs do not reflect the final adjusted
|
adjustments). As a result, the returned logprobs do not reflect the final adjusted
|
||||||
probabilities used during sampling.
|
probabilities used during sampling.
|
||||||
|
|
||||||
Support for logprobs with post-sampling adjustments is in progress and will be added in future updates.
|
You can adjust this behavior by setting the `--logprobs-mode` flag.
|
||||||
|
Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`.
|
||||||
|
Raw means the values before applying any logit processors, like bad words.
|
||||||
|
Processed means the values after applying all processors, including temperature and top_k/top_p.
|
||||||
|
|
||||||
##### Prompt Logprobs with Prefix Caching
|
##### Prompt Logprobs with Prefix Caching
|
||||||
|
|
||||||
|
|||||||
@ -456,9 +456,7 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert len(logprob) == vocab_size
|
assert len(logprob) == vocab_size
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode))
|
||||||
"logprobs_mode",
|
|
||||||
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
|
|
||||||
def test_logprobs_mode(logprobs_mode: LogprobsMode,
|
def test_logprobs_mode(logprobs_mode: LogprobsMode,
|
||||||
monkeypatch: pytest.MonkeyPatch):
|
monkeypatch: pytest.MonkeyPatch):
|
||||||
"""Test with LLM engine with different logprobs_mode.
|
"""Test with LLM engine with different logprobs_mode.
|
||||||
@ -487,12 +485,14 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode,
|
|||||||
for logprobs in output.logprobs:
|
for logprobs in output.logprobs:
|
||||||
for token_id in logprobs:
|
for token_id in logprobs:
|
||||||
logprob = logprobs[token_id]
|
logprob = logprobs[token_id]
|
||||||
if "logprobs" in logprobs_mode:
|
if logprobs_mode in (LogprobsMode.RAW_LOGPROBS,
|
||||||
|
LogprobsMode.PROCESSED_LOGPROBS):
|
||||||
assert logprob.logprob <= 0
|
assert logprob.logprob <= 0
|
||||||
if logprob.logprob > 0:
|
if logprob.logprob > 0:
|
||||||
positive_values = positive_values + 1
|
positive_values = positive_values + 1
|
||||||
total_token_with_logprobs = total_token_with_logprobs + 1
|
total_token_with_logprobs = total_token_with_logprobs + 1
|
||||||
assert total_token_with_logprobs >= len(results[0].outputs)
|
assert total_token_with_logprobs >= len(results[0].outputs)
|
||||||
if "logits" in logprobs_mode:
|
if logprobs_mode in (LogprobsMode.RAW_LOGITS,
|
||||||
|
LogprobsMode.PROCESSED_LOGITS):
|
||||||
assert positive_values > 0
|
assert positive_values > 0
|
||||||
del llm
|
del llm
|
||||||
|
|||||||
@ -257,11 +257,16 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
|
|||||||
|
|
||||||
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
|
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
|
||||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||||
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
|
|
||||||
"processed_logits"]
|
|
||||||
MMEncoderTPMode = Literal["weights", "data"]
|
MMEncoderTPMode = Literal["weights", "data"]
|
||||||
|
|
||||||
|
|
||||||
|
class LogprobsMode(enum.Enum):
|
||||||
|
RAW_LOGITS = "raw_logits"
|
||||||
|
RAW_LOGPROBS = "raw_logprobs"
|
||||||
|
PROCESSED_LOGITS = "processed_logits"
|
||||||
|
PROCESSED_LOGPROBS = "processed_logprobs"
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
@ -363,12 +368,13 @@ class ModelConfig:
|
|||||||
specified in `SamplingParams`. The default value comes the default for the
|
specified in `SamplingParams`. The default value comes the default for the
|
||||||
OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length *
|
OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length *
|
||||||
vocab_size) logprobs are allowed to be returned and it may cause OOM."""
|
vocab_size) logprobs are allowed to be returned and it may cause OOM."""
|
||||||
logprobs_mode: LogprobsMode = "raw_logprobs"
|
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS
|
||||||
"""Indicates the content returned in the logprobs and prompt_logprobs.
|
"""Indicates the content returned in the logprobs and prompt_logprobs.
|
||||||
Supported mode:
|
Supported mode:
|
||||||
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
|
1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits.
|
||||||
Raw means the values before applying logit processors, like bad words.
|
Raw means the values before applying any logit processors, like bad words.
|
||||||
Processed means the values after applying such processors.
|
Processed means the values after applying all processors, including
|
||||||
|
temperature and top_k/top_p.
|
||||||
"""
|
"""
|
||||||
disable_sliding_window: bool = False
|
disable_sliding_window: bool = False
|
||||||
"""Whether to disable sliding window. If True, we will disable the sliding
|
"""Whether to disable sliding window. If True, we will disable the sliding
|
||||||
|
|||||||
@ -516,6 +516,7 @@ class EngineArgs:
|
|||||||
model_group.add_argument("--max-logprobs",
|
model_group.add_argument("--max-logprobs",
|
||||||
**model_kwargs["max_logprobs"])
|
**model_kwargs["max_logprobs"])
|
||||||
model_group.add_argument("--logprobs-mode",
|
model_group.add_argument("--logprobs-mode",
|
||||||
|
choices=[f.value for f in LogprobsMode],
|
||||||
**model_kwargs["logprobs_mode"])
|
**model_kwargs["logprobs_mode"])
|
||||||
model_group.add_argument("--disable-sliding-window",
|
model_group.add_argument("--disable-sliding-window",
|
||||||
**model_kwargs["disable_sliding_window"])
|
**model_kwargs["disable_sliding_window"])
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import torch.nn as nn
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm.config import LogprobsMode
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@ -28,9 +29,16 @@ class TopKTopPSampler(nn.Module):
|
|||||||
Implementations may update the logits tensor in-place.
|
Implementations may update the logits tensor in-place.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(
|
||||||
|
self,
|
||||||
|
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if current_platform.is_cuda():
|
self.logprobs_mode = logprobs_mode
|
||||||
|
# flashinfer optimization does not apply if intermediate
|
||||||
|
# logprobs/logits after top_k/top_p need to be returned
|
||||||
|
if logprobs_mode not in (LogprobsMode.PROCESSED_LOGITS,
|
||||||
|
LogprobsMode.PROCESSED_LOGPROBS
|
||||||
|
) and current_platform.is_cuda():
|
||||||
if is_flashinfer_available:
|
if is_flashinfer_available:
|
||||||
flashinfer_version = flashinfer.__version__
|
flashinfer_version = flashinfer.__version__
|
||||||
if version.parse(flashinfer_version) < version.parse("0.2.3"):
|
if version.parse(flashinfer_version) < version.parse("0.2.3"):
|
||||||
@ -63,10 +71,12 @@ class TopKTopPSampler(nn.Module):
|
|||||||
"native implementation of top-p & top-k sampling. For the "
|
"native implementation of top-p & top-k sampling. For the "
|
||||||
"best performance, please install FlashInfer.")
|
"best performance, please install FlashInfer.")
|
||||||
self.forward = self.forward_native
|
self.forward = self.forward_native
|
||||||
elif current_platform.is_tpu():
|
|
||||||
self.forward = self.forward_tpu
|
|
||||||
else:
|
else:
|
||||||
self.forward = self.forward_native
|
self.forward = self.forward_native
|
||||||
|
if current_platform.is_tpu():
|
||||||
|
self.apply_top_k_top_p = apply_top_k_top_p_tpu
|
||||||
|
else:
|
||||||
|
self.apply_top_k_top_p = apply_top_k_top_p
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
@ -74,15 +84,20 @@ class TopKTopPSampler(nn.Module):
|
|||||||
generators: dict[int, torch.Generator],
|
generators: dict[int, torch.Generator],
|
||||||
k: Optional[torch.Tensor],
|
k: Optional[torch.Tensor],
|
||||||
p: Optional[torch.Tensor],
|
p: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
PyTorch-native implementation of top-k and top-p sampling.
|
PyTorch-native implementation of top-k and top-p sampling.
|
||||||
|
|
||||||
The logits tensor may be updated in-place.
|
The logits tensor may be updated in-place.
|
||||||
"""
|
"""
|
||||||
logits = apply_top_k_top_p(logits, k, p)
|
logits = self.apply_top_k_top_p(logits, k, p)
|
||||||
|
logits_to_return = None
|
||||||
|
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS:
|
||||||
|
logits_to_return = logits
|
||||||
|
elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS:
|
||||||
|
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
return random_sample(probs, generators)
|
return random_sample(probs, generators), logits_to_return
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
self,
|
self,
|
||||||
@ -90,34 +105,24 @@ class TopKTopPSampler(nn.Module):
|
|||||||
generators: dict[int, torch.Generator],
|
generators: dict[int, torch.Generator],
|
||||||
k: Optional[torch.Tensor],
|
k: Optional[torch.Tensor],
|
||||||
p: Optional[torch.Tensor],
|
p: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""More optimized implementation for top-k and top-p sampling."""
|
"""More optimized implementation for top-k and top-p sampling."""
|
||||||
if k is None and p is None:
|
|
||||||
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
||||||
# not needed. This is because `random_sample` does not require
|
# not needed. This is because `random_sample` does not require
|
||||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
if (k is None and p is None) or generators:
|
||||||
return random_sample(probs, generators)
|
|
||||||
if generators:
|
if generators:
|
||||||
logger.warning_once("FlashInfer 0.2.3+ does not support "
|
logger.warning_once("FlashInfer 0.2.3+ does not support "
|
||||||
"per-request generators. Falling back to "
|
"per-request generators. Falling back to "
|
||||||
"PyTorch-native implementation.")
|
"PyTorch-native implementation.")
|
||||||
return self.forward_native(logits, generators, k, p)
|
return self.forward_native(logits, generators, k, p)
|
||||||
|
assert self.logprobs_mode not in (
|
||||||
|
LogprobsMode.PROCESSED_LOGITS, LogprobsMode.PROCESSED_LOGPROBS
|
||||||
|
), "FlashInfer does not support returning logits/logprobs"
|
||||||
# flashinfer sampling functions expect contiguous logits.
|
# flashinfer sampling functions expect contiguous logits.
|
||||||
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
|
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
|
||||||
# because of slicing operation in logits_processor.
|
# because of slicing operation in logits_processor.
|
||||||
return flashinfer_sample(logits.contiguous(), k, p, generators)
|
return flashinfer_sample(logits.contiguous(), k, p, generators), None
|
||||||
|
|
||||||
def forward_tpu(
|
|
||||||
self,
|
|
||||||
logits: torch.Tensor,
|
|
||||||
generators: dict[int, torch.Generator],
|
|
||||||
k: Optional[torch.Tensor],
|
|
||||||
p: Optional[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
logits = apply_top_k_top_p_tpu(logits, k, p)
|
|
||||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
|
||||||
return random_sample(probs, generators)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_top_k_top_p_tpu(
|
def apply_top_k_top_p_tpu(
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""A layer that samples the next tokens from the model's outputs."""
|
"""A layer that samples the next tokens from the model's outputs."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -18,10 +20,50 @@ _SAMPLING_EPS = 1e-5
|
|||||||
|
|
||||||
|
|
||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
|
"""
|
||||||
|
A layer that samples the next tokens from the model's outputs
|
||||||
|
with the following steps in order:
|
||||||
|
|
||||||
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
|
1. If logprobs are requested:
|
||||||
|
a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
|
||||||
|
as the final logprobs to return.
|
||||||
|
b) If `logprobs_mode` is `raw_logits`, clone the logits
|
||||||
|
as the final logprobs to return.
|
||||||
|
2. Convert logits to float32.
|
||||||
|
3. Apply allowed token ids whitelist.
|
||||||
|
4. Apply bad words exclusion.
|
||||||
|
5. Apply logit processors which are not argmax-invariant,
|
||||||
|
i.e. that can impact greedy sampling.
|
||||||
|
a) Min tokens processor
|
||||||
|
b) Logit bias processor
|
||||||
|
6. Apply penalties
|
||||||
|
a) Repetition penalty
|
||||||
|
b) Frequency penalty
|
||||||
|
c) Presence penalty
|
||||||
|
7. Sample the next tokens. `sample` method performs the following steps:
|
||||||
|
a) If not `all_random`, perform greedy sampling. If `all_greedy`,
|
||||||
|
return the greedily sampled tokens and final logprobs if requested.
|
||||||
|
b) Apply temperature.
|
||||||
|
c) Apply logit processors which are argmax-invariant, by default
|
||||||
|
the min_p processor.
|
||||||
|
d) Apply top_k and/or top_p.
|
||||||
|
e) Sample the next tokens with the probability distribution.
|
||||||
|
f) If `all_random` or temperature >= epsilon (1e-5), return the
|
||||||
|
randomly sampled tokens and final logprobs if requested. Else,
|
||||||
|
return the greedily sampled tokens and logprobs if requested.
|
||||||
|
8. Gather the logprobs of the top `max_num_logprobs` and sampled token
|
||||||
|
(if requested). Note that if the sampled token is within the top
|
||||||
|
`max_num_logprobs`, the logprob will be eventually merged in
|
||||||
|
`LogprobsProcessor` during output processing. Therefore, the
|
||||||
|
final output may contain either `max_num_logprobs + 1` or
|
||||||
|
`max_num_logprobs` logprobs.
|
||||||
|
9. Return the final `SamplerOutput`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
logprobs_mode: LogprobsMode = LogprobsMode.RAW_LOGPROBS):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.topk_topp_sampler = TopKTopPSampler()
|
self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
|
||||||
self.pin_memory = is_pin_memory_available()
|
self.pin_memory = is_pin_memory_available()
|
||||||
self.logprobs_mode = logprobs_mode
|
self.logprobs_mode = logprobs_mode
|
||||||
|
|
||||||
@ -34,13 +76,11 @@ class Sampler(nn.Module):
|
|||||||
# temperature scaling) for the top-k logprobs.
|
# temperature scaling) for the top-k logprobs.
|
||||||
# This is different from the V0 sampler, which uses the logits that
|
# This is different from the V0 sampler, which uses the logits that
|
||||||
# is used for sampling (after penalties and temperature scaling).
|
# is used for sampling (after penalties and temperature scaling).
|
||||||
# TODO(rob): provide option for logprobs post sampling.
|
|
||||||
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
|
|
||||||
num_logprobs = sampling_metadata.max_num_logprobs
|
num_logprobs = sampling_metadata.max_num_logprobs
|
||||||
if num_logprobs is not None:
|
if num_logprobs is not None:
|
||||||
if self.logprobs_mode == "raw_logprobs":
|
if self.logprobs_mode == LogprobsMode.RAW_LOGPROBS:
|
||||||
raw_logprobs = self.compute_logprobs(logits)
|
raw_logprobs = self.compute_logprobs(logits)
|
||||||
elif self.logprobs_mode == "raw_logits":
|
elif self.logprobs_mode == LogprobsMode.RAW_LOGITS:
|
||||||
raw_logprobs = logits.clone()
|
raw_logprobs = logits.clone()
|
||||||
|
|
||||||
# Use float32 for the logits.
|
# Use float32 for the logits.
|
||||||
@ -57,15 +97,10 @@ class Sampler(nn.Module):
|
|||||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||||
logits = self.apply_penalties(logits, sampling_metadata)
|
logits = self.apply_penalties(logits, sampling_metadata)
|
||||||
|
|
||||||
# Get the process logprobs or logits.
|
|
||||||
if num_logprobs is not None:
|
|
||||||
if self.logprobs_mode == "processed_logprobs":
|
|
||||||
raw_logprobs = self.compute_logprobs(logits)
|
|
||||||
elif self.logprobs_mode == "processed_logits":
|
|
||||||
raw_logprobs = logits.clone()
|
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
sampled = self.sample(logits, sampling_metadata)
|
sampled, processed_logprobs = self.sample(logits, sampling_metadata)
|
||||||
|
if processed_logprobs is not None:
|
||||||
|
raw_logprobs = processed_logprobs
|
||||||
# Convert sampled token ids to int64 (long) type to ensure compatibility
|
# Convert sampled token ids to int64 (long) type to ensure compatibility
|
||||||
# with subsequent operations that may use these values as indices.
|
# with subsequent operations that may use these values as indices.
|
||||||
# This conversion is necessary because FlashInfer sampling operations
|
# This conversion is necessary because FlashInfer sampling operations
|
||||||
@ -105,7 +140,7 @@ class Sampler(nn.Module):
|
|||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""Sample logits based on sampling metadata.
|
"""Sample logits based on sampling metadata.
|
||||||
|
|
||||||
The various logits processing functions called in this method
|
The various logits processing functions called in this method
|
||||||
@ -119,7 +154,13 @@ class Sampler(nn.Module):
|
|||||||
else:
|
else:
|
||||||
greedy_sampled = self.greedy_sample(logits)
|
greedy_sampled = self.greedy_sample(logits)
|
||||||
if sampling_metadata.all_greedy:
|
if sampling_metadata.all_greedy:
|
||||||
return greedy_sampled
|
processed_logprobs = None
|
||||||
|
if sampling_metadata.max_num_logprobs is not None:
|
||||||
|
if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS:
|
||||||
|
processed_logprobs = logits
|
||||||
|
elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS:
|
||||||
|
processed_logprobs = self.compute_logprobs(logits)
|
||||||
|
return greedy_sampled, processed_logprobs
|
||||||
|
|
||||||
assert sampling_metadata.temperature is not None
|
assert sampling_metadata.temperature is not None
|
||||||
|
|
||||||
@ -132,7 +173,7 @@ class Sampler(nn.Module):
|
|||||||
logits = processor.apply(logits)
|
logits = processor.apply(logits)
|
||||||
|
|
||||||
# Apply top_k and/or top_p.
|
# Apply top_k and/or top_p.
|
||||||
random_sampled = self.topk_topp_sampler(
|
random_sampled, processed_logprobs = self.topk_topp_sampler(
|
||||||
logits,
|
logits,
|
||||||
sampling_metadata.generators,
|
sampling_metadata.generators,
|
||||||
sampling_metadata.top_k,
|
sampling_metadata.top_k,
|
||||||
@ -140,7 +181,7 @@ class Sampler(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if greedy_sampled is None:
|
if greedy_sampled is None:
|
||||||
return random_sampled
|
return random_sampled, processed_logprobs
|
||||||
|
|
||||||
sampled = torch.where(
|
sampled = torch.where(
|
||||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||||
@ -148,7 +189,7 @@ class Sampler(nn.Module):
|
|||||||
random_sampled,
|
random_sampled,
|
||||||
out=greedy_sampled, # Reuse tensor
|
out=greedy_sampled, # Reuse tensor
|
||||||
)
|
)
|
||||||
return sampled
|
return sampled, processed_logprobs
|
||||||
|
|
||||||
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||||
|
|||||||
@ -65,7 +65,7 @@ class Sampler(nn.Module):
|
|||||||
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||||
|
|
||||||
# Apply top_k and/or top_p.
|
# Apply top_k and/or top_p.
|
||||||
random_sampled = self.topk_topp_sampler(
|
random_sampled, _ = self.topk_topp_sampler(
|
||||||
logits,
|
logits,
|
||||||
sampling_metadata.generators,
|
sampling_metadata.generators,
|
||||||
sampling_metadata.top_k,
|
sampling_metadata.top_k,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user