[Chore] Remove unused sampler in models (#25324)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-20 19:53:20 -07:00 committed by GitHub
parent 86647d1cd0
commit 572ddf83ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 0 additions and 49 deletions

View File

@ -17,7 +17,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -97,7 +96,6 @@ def dummy_model() -> nn.Module:
# Special handling for lm_head & sampler # Special handling for lm_head & sampler
("lm_head", ParallelLMHead(512, 10)), ("lm_head", ParallelLMHead(512, 10)),
("logits_processor", LogitsProcessor(512)), ("logits_processor", LogitsProcessor(512)),
("sampler", Sampler())
])) ]))
model.config = MagicMock() model.config = MagicMock()
model.embedding_modules = {"lm_head": "lm_head"} model.embedding_modules = {"lm_head": "lm_head"}
@ -125,7 +123,6 @@ def dummy_model_gate_up() -> nn.Module:
# Special handling for lm_head & sampler # Special handling for lm_head & sampler
("lm_head", ParallelLMHead(512, 10)), ("lm_head", ParallelLMHead(512, 10)),
("logits_processor", LogitsProcessor(512)), ("logits_processor", LogitsProcessor(512)),
("sampler", Sampler())
])) ]))
model.config = MagicMock() model.config = MagicMock()
model.packed_modules_mapping = { model.packed_modules_mapping = {

View File

@ -33,7 +33,6 @@ from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -160,7 +159,6 @@ class ErnieMTP(nn.Module, SupportsPP):
self.lm_head = ParallelLMHead(self.config.vocab_size, self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size, self.config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head")) prefix=maybe_prefix(prefix, "lm_head"))
self.sampler = get_sampler()
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
@ -188,14 +186,6 @@ class ErnieMTP(nn.Module, SupportsPP):
return self.model.compute_logits(hidden_states, self.lm_head, return self.model.compute_logits(hidden_states, self.lm_head,
sampling_metadata, spec_step_idx) sampling_metadata, spec_step_idx)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [

View File

@ -41,7 +41,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined) mamba_chunk_scan_combined)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
@ -932,7 +931,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.config.vocab_size) self.config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
@ -1030,14 +1028,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:

View File

@ -26,7 +26,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -391,7 +390,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
) )
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = get_sampler()
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
@ -413,14 +411,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
qkv_params_mapping = [ qkv_params_mapping = [

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from itertools import product from itertools import product
from math import ceil, sqrt from math import ceil, sqrt
from typing import Any, Literal, Optional, TypedDict, Union from typing import Any, Literal, Optional, TypedDict, Union
@ -24,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
@ -897,13 +895,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
@property @property
def device(self): def device(self):
return next(self.parameters()).device return next(self.parameters()).device
@ -1069,13 +1060,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
skip_prefixes = [] skip_prefixes = []