From b87c21fc89c772d231cae97346e0457ef3bb1bf9 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Mon, 3 Mar 2025 15:40:04 +0800 Subject: [PATCH] [Misc][Platform] Move use allgather to platform (#14010) Signed-off-by: Mengqing Cao --- vllm/model_executor/layers/logits_processor.py | 10 +++------- vllm/platforms/interface.py | 13 +++++++++++++ vllm/platforms/neuron.py | 4 ++++ vllm/platforms/tpu.py | 4 ++++ 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 2f39a0e87854..4a359725bad0 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -8,7 +8,6 @@ import torch import torch.nn as nn import vllm.envs as envs -from vllm.config import get_current_vllm_config from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -51,11 +50,7 @@ class LogitsProcessor(nn.Module): # Soft cap the logits. Used in Gemma 2. self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. - parallel_config = get_current_vllm_config().parallel_config - self.use_all_gather = current_platform.is_tpu() \ - or current_platform.is_neuron() \ - or envs.VLLM_USE_V1 \ - or parallel_config.distributed_executor_backend == "external_launcher" # noqa + self.use_all_gather = current_platform.use_all_gather() def forward( self, @@ -83,7 +78,8 @@ class LogitsProcessor(nn.Module): logits *= self.scale # Apply logits processors (if any). - if sampling_metadata is not None: + if sampling_metadata is not None and \ + sampling_metadata.seq_groups is not None: logits = _apply_logits_processors(logits, sampling_metadata) return logits diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d81a66e4bcb1..e7e55e11775c 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -330,6 +330,19 @@ class Platform: """ return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa + @classmethod + def use_all_gather(cls) -> bool: + """ + Whether to use allgather in LogitsProcessor to gather the logits. + """ + import vllm.envs as envs + from vllm.config import get_current_vllm_config + + parallel_config = get_current_vllm_config().parallel_config + return (envs.VLLM_USE_V1 + or parallel_config.distributed_executor_backend + == "external_launcher") + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 5a03f5f7acbc..b2eadb7932f3 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -55,3 +55,7 @@ class NeuronPlatform(Platform): def is_pin_memory_available(cls) -> bool: logger.warning("Pin memory is not supported on Neuron.") return False + + @classmethod + def use_all_gather(cls) -> bool: + return True diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index cdf835a52c0c..0b66b52713e9 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -119,3 +119,7 @@ class TpuPlatform(Platform): @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa + + @classmethod + def use_all_gather(cls) -> bool: + return True