mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 16:25:01 +08:00
107 lines
3.8 KiB
Python
107 lines
3.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""A layer that compute logits from hidden_stats."""
|
|
|
|
import torch
|
|
|
|
from vllm.distributed import (
|
|
tensor_model_parallel_all_gather,
|
|
tensor_model_parallel_gather,
|
|
)
|
|
from vllm.model_executor.custom_op import CustomOp
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
@CustomOp.register("logits_processor")
|
|
class LogitsProcessor(CustomOp):
|
|
"""Process logits and apply logits processors from sampling metadata.
|
|
|
|
This layer does the following:
|
|
1. Gather logits from model hidden_states.
|
|
2. Scale logits if needed.
|
|
3. Apply logits processors (if any).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size: int,
|
|
org_vocab_size: int | None = None,
|
|
scale: float = 1.0,
|
|
logits_as_input: bool = False,
|
|
soft_cap: float | None = None,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
scale: A scaling factor to apply to the logits.
|
|
"""
|
|
super().__init__()
|
|
self.scale = scale
|
|
self.vocab_size = vocab_size
|
|
# Whether the input is logits (default is hidden states).
|
|
self.logits_as_input = logits_as_input
|
|
# original vocabulary size (without LoRA).
|
|
self.org_vocab_size = org_vocab_size or vocab_size
|
|
# Soft cap the logits. Used in Gemma 2.
|
|
self.soft_cap = soft_cap
|
|
# Whether to use gather or all-gather to gather the logits.
|
|
self.use_all_gather = current_platform.use_all_gather()
|
|
|
|
def forward(
|
|
self,
|
|
lm_head: VocabParallelEmbedding,
|
|
hidden_states: torch.Tensor,
|
|
embedding_bias: torch.Tensor | None = None,
|
|
) -> torch.Tensor | None:
|
|
if self.logits_as_input:
|
|
logits = hidden_states
|
|
else:
|
|
# Get the logits for the next tokens.
|
|
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
|
|
if logits is not None:
|
|
if self.soft_cap is not None:
|
|
logits = logits / self.soft_cap
|
|
logits = torch.tanh(logits)
|
|
logits = logits * self.soft_cap
|
|
|
|
if self.scale != 1.0:
|
|
logits *= self.scale
|
|
return logits
|
|
|
|
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
|
|
"""gather/all-gather the logits tensor across model parallel group."""
|
|
if self.use_all_gather:
|
|
# Gather is not supported for some devices such as TPUs.
|
|
# Use all-gather instead.
|
|
# NOTE(woosuk): Here, the outputs of every device should not be None
|
|
# because XLA requires strict SPMD among all devices. Every device
|
|
# should execute the same operations after gathering the logits.
|
|
logits = tensor_model_parallel_all_gather(logits)
|
|
else:
|
|
# None may be returned for rank > 0
|
|
logits = tensor_model_parallel_gather(logits)
|
|
return logits
|
|
|
|
def _get_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
lm_head: VocabParallelEmbedding,
|
|
embedding_bias: torch.Tensor | None,
|
|
) -> torch.Tensor | None:
|
|
# Get the logits for the next tokens.
|
|
logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias)
|
|
|
|
# Gather logits for TP
|
|
logits = self._gather_logits(logits)
|
|
|
|
# Remove paddings in vocab (if any).
|
|
if logits is not None:
|
|
logits = logits[..., : self.org_vocab_size]
|
|
return logits
|
|
|
|
def extra_repr(self) -> str:
|
|
s = f"vocab_size={self.vocab_size}"
|
|
s += f", org_vocab_size={self.org_vocab_size}"
|
|
s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
|
|
return s
|