diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index fce06a81ff04a..3f381d5199d7c 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -10,6 +10,7 @@ import torch.nn as nn from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig from vllm.distributed.communication_op import (broadcast_tensor_dict, + get_tp_group, tensor_model_parallel_gather) from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler @@ -365,7 +366,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): target_lm_head_weight) self._metrics.init_tensors(self.rank, device_type=self.device) - self.spec_decode_sampler.init_tensors(self.rank, + self.spec_decode_sampler.init_tensors(get_tp_group().local_rank, device_type=self.device) scorer_cls: Type[SpeculativeScorer]