[Platform] Custom ops support for LMhead and LogitsProcessor (#23564)

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
This commit is contained in:
zzhxxx 2025-09-10 21:26:31 +08:00 committed by GitHub
parent 2eb9986a2d
commit 736569da8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 2 deletions

View File

@ -6,11 +6,11 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed import (tensor_model_parallel_all_gather, from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather) tensor_model_parallel_gather)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -22,7 +22,8 @@ if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
envs.VLLM_LOGITS_PROCESSOR_THREADS) envs.VLLM_LOGITS_PROCESSOR_THREADS)
class LogitsProcessor(nn.Module): @CustomOp.register("logits_processor")
class LogitsProcessor(CustomOp):
"""Process logits and apply logits processors from sampling metadata. """Process logits and apply logits processors from sampling metadata.
This layer does the following: This layer does the following:

View File

@ -429,6 +429,7 @@ class VocabParallelEmbedding(CustomOp):
return s return s
@CustomOp.register("parallel_lm_head")
class ParallelLMHead(VocabParallelEmbedding): class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head. """Parallelized LM head.