diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 3f381d5199d7c..8af71842224ba 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -12,6 +12,7 @@ from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig from vllm.distributed.communication_op import (broadcast_tensor_dict, get_tp_group, tensor_model_parallel_gather) +from vllm.distributed.parallel_state import model_parallel_is_initialized from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.sampler import SamplerOutput @@ -366,8 +367,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): target_lm_head_weight) self._metrics.init_tensors(self.rank, device_type=self.device) - self.spec_decode_sampler.init_tensors(get_tp_group().local_rank, - device_type=self.device) + if model_parallel_is_initialized(): + self.spec_decode_sampler.init_tensors(get_tp_group().local_rank, + device_type=self.device) + else: + self.spec_decode_sampler.init_tensors(self.rank, + device_type=self.device) scorer_cls: Type[SpeculativeScorer] if self.disable_mqa_scorer: