[bugfix] spec decode worker get tp group only when initialized (#13578)

This commit is contained in:
Simon Mo 2025-02-19 20:46:28 -08:00 committed by GitHub
parent ba81163997
commit 8c755c3b6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,6 +12,7 @@ from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import (broadcast_tensor_dict,
get_tp_group,
tensor_model_parallel_gather)
from vllm.distributed.parallel_state import model_parallel_is_initialized
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import SamplerOutput
@ -366,8 +367,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
target_lm_head_weight)
self._metrics.init_tensors(self.rank, device_type=self.device)
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
device_type=self.device)
if model_parallel_is_initialized():
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
device_type=self.device)
else:
self.spec_decode_sampler.init_tensors(self.rank,
device_type=self.device)
scorer_cls: Type[SpeculativeScorer]
if self.disable_mqa_scorer: