mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 05:55:01 +08:00
[Chore] Remove unused sampler in models (#25324)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
86647d1cd0
commit
572ddf83ce
@ -17,7 +17,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA
|
||||
from vllm.platforms import current_platform
|
||||
@ -97,7 +96,6 @@ def dummy_model() -> nn.Module:
|
||||
# Special handling for lm_head & sampler
|
||||
("lm_head", ParallelLMHead(512, 10)),
|
||||
("logits_processor", LogitsProcessor(512)),
|
||||
("sampler", Sampler())
|
||||
]))
|
||||
model.config = MagicMock()
|
||||
model.embedding_modules = {"lm_head": "lm_head"}
|
||||
@ -125,7 +123,6 @@ def dummy_model_gate_up() -> nn.Module:
|
||||
# Special handling for lm_head & sampler
|
||||
("lm_head", ParallelLMHead(512, 10)),
|
||||
("logits_processor", LogitsProcessor(512)),
|
||||
("sampler", Sampler())
|
||||
]))
|
||||
model.config = MagicMock()
|
||||
model.packed_modules_mapping = {
|
||||
|
||||
@ -33,7 +33,6 @@ from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -160,7 +159,6 @@ class ErnieMTP(nn.Module, SupportsPP):
|
||||
self.lm_head = ParallelLMHead(self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
self.sampler = get_sampler()
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
@ -188,14 +186,6 @@ class ErnieMTP(nn.Module, SupportsPP):
|
||||
return self.model.compute_logits(hidden_states, self.lm_head,
|
||||
sampling_metadata, spec_step_idx)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
|
||||
@ -41,7 +41,6 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
@ -932,7 +931,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
self.config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -1030,14 +1028,6 @@ class Plamo2ForCausalLM(torch.nn.Module, HasInnerState, SupportsPP, IsHybrid):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
|
||||
@ -26,7 +26,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -391,7 +390,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
@ -413,14 +411,6 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
qkv_params_mapping = [
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from itertools import product
|
||||
from math import ceil, sqrt
|
||||
from typing import Any, Literal, Optional, TypedDict, Union
|
||||
@ -24,7 +23,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
@ -897,13 +895,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
if hasattr(self.language_model, "sampler"):
|
||||
return self.language_model.sampler
|
||||
|
||||
return get_sampler()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
@ -1069,13 +1060,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
skip_prefixes = []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user