diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index e93be9bfb1657..8a4ac214443eb 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -6,11 +6,11 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional import torch -import torch.nn as nn import vllm.envs as envs from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -22,7 +22,8 @@ if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None: envs.VLLM_LOGITS_PROCESSOR_THREADS) -class LogitsProcessor(nn.Module): +@CustomOp.register("logits_processor") +class LogitsProcessor(CustomOp): """Process logits and apply logits processors from sampling metadata. This layer does the following: diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index c92a7978195bc..15e628177b3f6 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -429,6 +429,7 @@ class VocabParallelEmbedding(CustomOp): return s +@CustomOp.register("parallel_lm_head") class ParallelLMHead(VocabParallelEmbedding): """Parallelized LM head.