From 736569da8d856d5a66d7d7523a7e05f97a4d21f1 Mon Sep 17 00:00:00 2001 From: zzhxxx <96690582+zzhx1@users.noreply.github.com> Date: Wed, 10 Sep 2025 21:26:31 +0800 Subject: [PATCH] [Platform] Custom ops support for LMhead and LogitsProcessor (#23564) Signed-off-by: zzhx1 --- vllm/model_executor/layers/logits_processor.py | 5 +++-- vllm/model_executor/layers/vocab_parallel_embedding.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index e93be9bfb1657..8a4ac214443eb 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -6,11 +6,11 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional import torch -import torch.nn as nn import vllm.envs as envs from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -22,7 +22,8 @@ if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None: envs.VLLM_LOGITS_PROCESSOR_THREADS) -class LogitsProcessor(nn.Module): +@CustomOp.register("logits_processor") +class LogitsProcessor(CustomOp): """Process logits and apply logits processors from sampling metadata. This layer does the following: diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index c92a7978195bc..15e628177b3f6 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -429,6 +429,7 @@ class VocabParallelEmbedding(CustomOp): return s +@CustomOp.register("parallel_lm_head") class ParallelLMHead(VocabParallelEmbedding): """Parallelized LM head.