mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 23:02:22 +08:00
[Bugfix] Fix device ordinal for multi-node spec decode (#13269)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
parent
377d10bd14
commit
5ae9f26a5a
@ -10,6 +10,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
|
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
|
||||||
from vllm.distributed.communication_op import (broadcast_tensor_dict,
|
from vllm.distributed.communication_op import (broadcast_tensor_dict,
|
||||||
|
get_tp_group,
|
||||||
tensor_model_parallel_gather)
|
tensor_model_parallel_gather)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||||
@ -365,7 +366,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
target_lm_head_weight)
|
target_lm_head_weight)
|
||||||
|
|
||||||
self._metrics.init_tensors(self.rank, device_type=self.device)
|
self._metrics.init_tensors(self.rank, device_type=self.device)
|
||||||
self.spec_decode_sampler.init_tensors(self.rank,
|
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
|
||||||
device_type=self.device)
|
device_type=self.device)
|
||||||
|
|
||||||
scorer_cls: Type[SpeculativeScorer]
|
scorer_cls: Type[SpeculativeScorer]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user