mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-21 17:54:27 +08:00
[bugfix] spec decode worker get tp group only when initialized (#13578)
This commit is contained in:
parent
ba81163997
commit
8c755c3b6d
@ -12,6 +12,7 @@ from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
|
||||
from vllm.distributed.communication_op import (broadcast_tensor_dict,
|
||||
get_tp_group,
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.distributed.parallel_state import model_parallel_is_initialized
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@ -366,8 +367,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
target_lm_head_weight)
|
||||
|
||||
self._metrics.init_tensors(self.rank, device_type=self.device)
|
||||
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
|
||||
device_type=self.device)
|
||||
if model_parallel_is_initialized():
|
||||
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
|
||||
device_type=self.device)
|
||||
else:
|
||||
self.spec_decode_sampler.init_tensors(self.rank,
|
||||
device_type=self.device)
|
||||
|
||||
scorer_cls: Type[SpeculativeScorer]
|
||||
if self.disable_mqa_scorer:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user