[Bugfix] Fix device ordinal for multi-node spec decode (#13269)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc 2025-02-19 22:13:15 +08:00 committed by GitHub
parent 377d10bd14
commit 5ae9f26a5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,6 +10,7 @@ import torch.nn as nn
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import (broadcast_tensor_dict,
get_tp_group,
tensor_model_parallel_gather)
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
@ -365,7 +366,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
target_lm_head_weight)
self._metrics.init_tensors(self.rank, device_type=self.device)
self.spec_decode_sampler.init_tensors(self.rank,
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
device_type=self.device)
scorer_cls: Type[SpeculativeScorer]