[Platform] Custom ops support for LMhead and LogitsProcessor (#23564)

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
This commit is contained in:
zzhxxx 2025-09-10 21:26:31 +08:00 committed by GitHub
parent 2eb9986a2d
commit 736569da8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 2 deletions

View File

@ -6,11 +6,11 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -22,7 +22,8 @@ if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
envs.VLLM_LOGITS_PROCESSOR_THREADS)
class LogitsProcessor(nn.Module):
@CustomOp.register("logits_processor")
class LogitsProcessor(CustomOp):
"""Process logits and apply logits processors from sampling metadata.
This layer does the following:

View File

@ -429,6 +429,7 @@ class VocabParallelEmbedding(CustomOp):
return s
@CustomOp.register("parallel_lm_head")
class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.