mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-02 12:07:27 +08:00
[Bugfix] Fix device ordinal for multi-node spec decode (#13269)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
parent
377d10bd14
commit
5ae9f26a5a
@ -10,6 +10,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
|
||||
from vllm.distributed.communication_op import (broadcast_tensor_dict,
|
||||
get_tp_group,
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
@ -365,7 +366,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
target_lm_head_weight)
|
||||
|
||||
self._metrics.init_tensors(self.rank, device_type=self.device)
|
||||
self.spec_decode_sampler.init_tensors(self.rank,
|
||||
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
|
||||
device_type=self.device)
|
||||
|
||||
scorer_cls: Type[SpeculativeScorer]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user