From 8c755c3b6d37954895a9952eb6bf7691c4d25a50 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Wed, 19 Feb 2025 20:46:28 -0800 Subject: [PATCH] [bugfix] spec decode worker get tp group only when initialized (#13578) --- vllm/spec_decode/spec_decode_worker.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 3f381d5199d7c..8af71842224ba 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -12,6 +12,7 @@ from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig from vllm.distributed.communication_op import (broadcast_tensor_dict, get_tp_group, tensor_model_parallel_gather) +from vllm.distributed.parallel_state import model_parallel_is_initialized from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.sampler import SamplerOutput @@ -366,8 +367,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): target_lm_head_weight) self._metrics.init_tensors(self.rank, device_type=self.device) - self.spec_decode_sampler.init_tensors(get_tp_group().local_rank, - device_type=self.device) + if model_parallel_is_initialized(): + self.spec_decode_sampler.init_tensors(get_tp_group().local_rank, + device_type=self.device) + else: + self.spec_decode_sampler.init_tensors(self.rank, + device_type=self.device) scorer_cls: Type[SpeculativeScorer] if self.disable_mqa_scorer: