From 5ae9f26a5ae2aa370cbee48924aa7ab885a45925 Mon Sep 17 00:00:00 2001 From: shangmingc Date: Wed, 19 Feb 2025 22:13:15 +0800 Subject: [PATCH] [Bugfix] Fix device ordinal for multi-node spec decode (#13269) Signed-off-by: Shangming Cai --- vllm/spec_decode/spec_decode_worker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index fce06a81ff04a..3f381d5199d7c 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -10,6 +10,7 @@ import torch.nn as nn from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig from vllm.distributed.communication_op import (broadcast_tensor_dict, + get_tp_group, tensor_model_parallel_gather) from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler @@ -365,7 +366,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): target_lm_head_weight) self._metrics.init_tensors(self.rank, device_type=self.device) - self.spec_decode_sampler.init_tensors(self.rank, + self.spec_decode_sampler.init_tensors(get_tp_group().local_rank, device_type=self.device) scorer_cls: Type[SpeculativeScorer]