Remove hard-dependencies of Speculative decode to CUDA workers (#10587)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
Chendi.Xue 2024-11-26 19:57:11 -06:00 committed by GitHub
parent 2f0a0a17a4
commit 0a71900bc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 219 additions and 77 deletions

View File

@ -595,8 +595,8 @@ def test_init_device(acceptance_sampler_method: str):
target_worker.init_device.assert_called_once() target_worker.init_device.assert_called_once()
metrics_collector.init_gpu_tensors.assert_called_once() metrics_collector.init_tensors.assert_called_once()
spec_decode_sampler.init_gpu_tensors.assert_called_once() spec_decode_sampler.init_tensors.assert_called_once()
@pytest.mark.parametrize("acceptance_sampler_method", @pytest.mark.parametrize("acceptance_sampler_method",

View File

@ -990,6 +990,7 @@ class ParallelConfig:
# the full name of the worker class to use. If "auto", the worker class # the full name of the worker class to use. If "auto", the worker class
# will be determined based on the platform. # will be determined based on the platform.
worker_cls: str = "auto" worker_cls: str = "auto"
sd_worker_cls: str = "auto"
world_size: int = field(init=False) world_size: int = field(init=False)

View File

@ -43,6 +43,21 @@ class SpecDecodeBaseSampler(nn.Module):
dtype=torch.long, dtype=torch.long,
device=device) device=device)
def init_tensors(self,
device: Union[int, str],
device_type: Union[torch.device, str] = 'cuda') -> None:
assert self.num_accepted_tokens is None
if isinstance(device_type, torch.device):
device_type = device_type.type
if isinstance(device, int):
device = f"{device_type}:{device}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
@property @property
def probs_dtype(self): def probs_dtype(self):
return torch.float32 return torch.float32
@ -77,7 +92,7 @@ class SpecDecodeBaseSampler(nn.Module):
tensor is [batch_size, k + num_bonus_tokens] tensor is [batch_size, k + num_bonus_tokens]
""" """
batch_size, k = substitute_token_ids.shape batch_size, k = substitute_token_ids.shape
bonus_token_ids = bonus_token_ids.squeeze() bonus_token_ids = bonus_token_ids.squeeze(-1)
# Determine the index of the first False value for each row. # Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k limits[~(accepted == 0).any(1)] = k

View File

@ -86,4 +86,10 @@ class CpuPlatform(Platform):
parallel_config.distributed_executor_backend) parallel_config.distributed_executor_backend)
parallel_config.distributed_executor_backend = "mp" parallel_config.distributed_executor_backend = "mp"
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker" if vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.cpu_worker.CPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"

View File

@ -106,6 +106,8 @@ class CudaPlatformBase(Platform):
elif vllm_config.speculative_config: elif vllm_config.speculative_config:
parallel_config.worker_cls = \ parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker" "vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.worker.Worker"
else: else:
parallel_config.worker_cls = "vllm.worker.worker.Worker" parallel_config.worker_cls = "vllm.worker.worker.Worker"
@ -236,4 +238,4 @@ try:
if not isinstance(pynvml, _MockModule): if not isinstance(pynvml, _MockModule):
CudaPlatform.log_warnings() CudaPlatform.log_warnings()
except ModuleNotFoundError: except ModuleNotFoundError:
CudaPlatform.log_warnings() CudaPlatform.log_warnings()

View File

@ -20,8 +20,9 @@ except (ModuleNotFoundError, ImportError) as err:
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunner) ModelRunnerInputBase,
ModelRunnerWrapperBase)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -33,7 +34,7 @@ debug_advance_input = False
allow_gpu_advance_step = True allow_gpu_advance_step = True
class TP1DraftModelRunner(ModelRunner): class TP1DraftModelRunner(ModelRunnerWrapperBase):
"""Specialized model runner for speculative decoding draft model. """Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step, generate k speculative tokens in a single speculative decoding step,
@ -46,13 +47,14 @@ class TP1DraftModelRunner(ModelRunner):
any broadcasting inside execute_model). any broadcasting inside execute_model).
""" """
def __init__(self, *args, **kwargs): def __init__(self, model_runner: ModelRunnerBase):
if kwargs.get("return_hidden_states"): if hasattr(
model_runner,
"return_hidden_states") and model_runner.return_hidden_states:
raise ValueError( raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner." "return_hidden_states is not supported for TP1DraftModelRunner."
) )
super().__init__(model_runner)
super().__init__(*args, **kwargs)
self.indices_of_seq_with_bonus_tokens = None self.indices_of_seq_with_bonus_tokens = None
@ -73,10 +75,8 @@ class TP1DraftModelRunner(ModelRunner):
assert seq_group.prompt_logprob_indices == [] # No prompt assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple assert seq_group.sample_indices == [i] # Simple
def _gpu_advance_step( def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
self, model_input: ModelInputForGPUWithSamplingMetadata, last_output: SamplerOutput) -> ModelRunnerInputBase:
last_output: SamplerOutput
) -> ModelInputForGPUWithSamplingMetadata:
# Currently, we expect "decode mode" only # Currently, we expect "decode mode" only
assert not model_input.is_prompt assert not model_input.is_prompt
@ -168,7 +168,7 @@ class TP1DraftModelRunner(ModelRunner):
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
model_input: ModelInputForGPUWithSamplingMetadata, model_input: ModelRunnerInputBase,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
previous_hidden_states: Optional[torch.Tensor] = None, previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Set from typing import Optional, Set, Union
import torch import torch
@ -75,9 +75,11 @@ class SpeculativeProposer(ABC):
class SpeculativeScorer(ABC): class SpeculativeScorer(ABC):
def __init__(self, scorer_worker: WorkerBase, device: str, def __init__(self, scorer_worker: WorkerBase,
vocab_size: int): device: Union[torch.device, str], vocab_size: int):
self._scorer_worker = scorer_worker self._scorer_worker = scorer_worker
if isinstance(device, torch.device):
device = device.type
self._device = device self._device = device
self._vocab_size = vocab_size self._vocab_size = vocab_size

View File

@ -9,21 +9,22 @@ from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker from vllm.worker.worker_base import WorkerWrapperBase
class MedusaWorker(NonLLMProposerWorkerBase, Worker): class MedusaWorker(NonLLMProposerWorkerBase, WorkerWrapperBase):
"""Worker for Medusa. """Worker for Medusa.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(kwargs.get("vllm_config"))
self.init_worker(*args, **kwargs)
# Lazy initialization list. # Lazy initialization list.
self._proposer: Top1Proposer self._proposer: Top1Proposer
def init_device(self): def init_device(self):
super().init_device() self.worker.init_device()
self._proposer = Top1Proposer( self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type] weakref.proxy(self), # type: ignore[arg-type]

View File

@ -1,11 +1,12 @@
import time import time
from typing import Callable, Optional from typing import Callable, Optional, Union
import msgspec import msgspec
import torch import torch
from vllm.model_executor.layers.spec_decode_base_sampler import ( from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler) SpecDecodeBaseSampler)
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
@ -81,8 +82,20 @@ class AsyncMetricsCollector:
self._rank = rank self._rank = rank
self._copy_stream = torch.cuda.Stream() self._copy_stream = torch.cuda.Stream()
def init_tensors(self,
rank: int,
device_type: Union[torch.device, str] = 'cuda') -> None:
self._rank = rank
if isinstance(device_type, torch.device):
device_type = device_type.type
if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream()
def maybe_collect_rejsample_metrics( def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]: self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# currently using cuda.Event, skip for any non_cuda_alike platform
if not current_platform.is_cuda_alike():
return None
# If a copy was initiated in the previous call, collect and return. # If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None: if self._in_flight_copy is not None:

View File

@ -5,17 +5,21 @@ from typing import Dict, List, Set, Tuple
import torch import torch
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer) SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker from vllm.worker.worker_base import WorkerWrapperBase
class MultiStepWorker(Worker, ProposerWorkerBase): class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
"""The MultiStepWorker is equivalent to a Worker except that it allows """The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead allocated enough space to store the additional KV. This reduces overhead
@ -28,13 +32,14 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(kwargs.get("vllm_config"))
self.init_worker(*args, **kwargs)
# Lazy initialization list. # Lazy initialization list.
self._proposer: SpeculativeProposer self._proposer: SpeculativeProposer
def init_device(self) -> None: def init_device(self) -> None:
super().init_device() self.worker.init_device()
self._proposer = Top1Proposer( self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type] weakref.proxy(self), # type: ignore[arg-type]
@ -51,6 +56,18 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
True) True)
def determine_num_available_blocks(self) -> Tuple[int, int]:
return self.worker.determine_num_available_blocks()
def get_cache_block_size_bytes(self) -> int:
return self.worker.get_cache_block_size_bytes()
def initialize_cache(self, *args, **kwargs) -> None:
self.worker.initialize_cache(*args, **kwargs)
def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
return self.worker.execute_model(*args, **kwargs)
@torch.inference_mode() @torch.inference_mode()
def sampler_output( def sampler_output(
self, self,
@ -75,7 +92,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# Run model sample_len times. # Run model sample_len times.
model_outputs: List[SamplerOutput] = [] model_outputs: List[SamplerOutput] = []
if isinstance( if current_platform.is_cuda_alike() and isinstance(
self.model_runner, TP1DraftModelRunner self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request): ) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare # Here we run the draft_model_runner with multi-step prepare
@ -92,7 +109,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# and other restrictions that are part of DraftModelRunner's # and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..) # supports_gpu_multi_step(..)
for _ in range(sample_len): for _ in range(sample_len):
model_output: List[SamplerOutput] = super().execute_model( model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request) execute_model_req=expanded_request)
assert (len(model_output) == 1 assert (len(model_output) == 1
), "composing multistep workers not supported" ), "composing multistep workers not supported"

View File

@ -22,6 +22,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
# Get local_rank/vocab_size from kwargs attribute # Get local_rank/vocab_size from kwargs attribute
self.local_rank = kwargs["local_rank"] self.local_rank = kwargs["local_rank"]
self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size() self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size()
self.device_type = kwargs.get("device_type", "cuda")
# Lazy initialization list. # Lazy initialization list.
self._proposer: Top1Proposer self._proposer: Top1Proposer
@ -34,7 +35,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
def init_device(self): def init_device(self):
self.device = torch.device(f"cuda:{self.local_rank}") self.device = torch.device(f"{self.device_type}:{self.local_rank}")
self.load_model = lambda *args, **kwargs: None self.load_model = lambda *args, **kwargs: None
# Current NGramWorker only supports Top1Proposer # Current NGramWorker only supports Top1Proposer

View File

@ -14,12 +14,16 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import ( from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler) TypicalAcceptanceSampler)
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, ExecuteModelRequest, CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SequenceGroupMetadata, HiddenStates, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids) get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.medusa_worker import MedusaWorker from vllm.spec_decode.medusa_worker import MedusaWorker
@ -36,8 +40,8 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
get_all_num_logprobs, get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range, get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len) split_batch_by_proposal_len)
from vllm.worker.worker import Worker from vllm.worker.worker_base import (LoraNotSupportedWorkerBase, WorkerBase,
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase WorkerWrapperBase)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -53,7 +57,11 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_kwargs = kwargs.copy() draft_worker_kwargs = kwargs.copy()
kwargs["model_runner_cls"] = TargetModelRunner kwargs["model_runner_cls"] = TargetModelRunner
target_worker = Worker(*args, **kwargs) target_worker_config = copy.deepcopy(vllm_config)
target_worker_config.parallel_config.worker_cls =\
target_worker_config.parallel_config.sd_worker_cls
target_worker = WorkerWrapperBase(vllm_config=target_worker_config)
target_worker.init_worker(*args, **kwargs)
# Set the disable_logprobs variable in the TargetModelRunner instance # Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig. # as per its value specified in the SpeculativeConfig.
target_worker.model_runner.disable_logprobs =\ target_worker.model_runner.disable_logprobs =\
@ -65,6 +73,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_config.model_config, draft_worker_config.model_config,
vllm_config.load_config, vllm_config.load_config,
) )
speculative_config.draft_parallel_config.worker_cls =\
draft_worker_config.parallel_config.sd_worker_cls
draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa
# TODO allow draft-model specific load config. # TODO allow draft-model specific load config.
@ -125,7 +135,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@classmethod @classmethod
def create_worker( def create_worker(
cls, cls,
scorer_worker: Worker, scorer_worker: WorkerBase,
draft_worker_kwargs: Dict[str, Any], draft_worker_kwargs: Dict[str, Any],
disable_mqa_scorer: bool, disable_mqa_scorer: bool,
disable_by_batch_size: Optional[int], disable_by_batch_size: Optional[int],
@ -145,6 +155,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_parallel_config: ParallelConfig = draft_worker_kwargs[ draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'vllm_config'].parallel_config 'vllm_config'].parallel_config
if ngram_prompt_lookup_max > 0: if ngram_prompt_lookup_max > 0:
draft_worker_kwargs[
"device_type"] = scorer_worker.device_config.device.type
proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max) ngram_prompt_lookup_max)
@ -158,8 +170,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker = MedusaWorker(**draft_worker_kwargs) proposer_worker = MedusaWorker(**draft_worker_kwargs)
else: else:
if draft_tp == 1: if draft_tp == 1:
draft_worker_kwargs[ if current_platform.is_cuda_alike():
"model_runner_cls"] = TP1DraftModelRunner draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else: else:
if draft_model_config.hf_config.model_type == "eagle": if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError( raise NotImplementedError(
@ -306,8 +319,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.scorer_worker.load_model() self.scorer_worker.load_model()
self.proposer_worker.load_model() self.proposer_worker.load_model()
self._metrics.init_gpu_tensors(self.rank) self._metrics.init_tensors(self.rank, device_type=self.device)
self.spec_decode_sampler.init_gpu_tensors(self.rank) self.spec_decode_sampler.init_tensors(self.rank,
device_type=self.device)
scorer_cls: Type[SpeculativeScorer] scorer_cls: Type[SpeculativeScorer]
if self.disable_mqa_scorer: if self.disable_mqa_scorer:
@ -1111,11 +1125,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
raise NotImplementedError raise NotImplementedError
def start_profile(self): def start_profile(self):
if isinstance(self.scorer_worker, Worker): if isinstance(self.scorer_worker, WorkerBase):
self.scorer_worker.start_profile() self.scorer_worker.start_profile()
def stop_profile(self): def stop_profile(self):
if isinstance(self.scorer_worker, Worker): if isinstance(self.scorer_worker, WorkerBase):
self.scorer_worker.stop_profile() self.scorer_worker.stop_profile()

View File

@ -1,12 +1,12 @@
from typing import List, Optional from typing import List, Optional
from vllm.config import VllmConfig
from vllm.sequence import SequenceGroupMetadata from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunner) ModelRunnerInputBase,
ModelRunnerWrapperBase)
class TargetModelRunner(ModelRunner): class TargetModelRunner(ModelRunnerWrapperBase):
"""Specialized model runner for speculative decoding target model. """Specialized model runner for speculative decoding target model.
In speculative decoding, the log probabilities selected finally may not In speculative decoding, the log probabilities selected finally may not
be the same ones as selected by the target model sampling. This means be the same ones as selected by the target model sampling. This means
@ -18,32 +18,21 @@ class TargetModelRunner(ModelRunner):
requested or not. requested or not.
""" """
def __init__( def __init__(self, model_runner: ModelRunnerBase):
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
return_hidden_states: bool = False,
):
# An internal boolean member variable to indicate if token log # An internal boolean member variable to indicate if token log
# probabilities are needed or not. # probabilities are needed or not.
super().__init__(model_runner)
self.disable_logprobs = True self.disable_logprobs = True
super().__init__(
vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
return_hidden_states=return_hidden_states,
)
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None,
) -> ModelInputForGPUWithSamplingMetadata: ) -> ModelRunnerInputBase:
model_input: ModelInputForGPUWithSamplingMetadata = super( model_input: ModelRunnerInputBase =\
).prepare_model_input(seq_group_metadata_list, virtual_engine, self.model_runner.prepare_model_input(
finished_requests_ids) seq_group_metadata_list, virtual_engine, finished_requests_ids)
# If token log probabilities is disabled then skip generating sampler # If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors # CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the # as needed. If log probabilities is enabled then synchronize all the

View File

@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
import torch import torch
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SequenceGroupMetadata, PromptLogprobs, SequenceGroupMetadata,
SequenceOutput) SequenceOutput)
@ -247,11 +248,14 @@ def nvtx_range(msg, *args, **kwargs):
Arguments: Arguments:
msg (string): message to associate with the range msg (string): message to associate with the range
""" """
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs)) if current_platform.is_cuda_alike():
try: torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
try:
yield
finally:
torch.cuda.nvtx.range_pop()
else:
yield yield
finally:
torch.cuda.nvtx.range_pop()
class Timer: class Timer:

View File

@ -80,6 +80,7 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
Used by the ModelRunner. Used by the ModelRunner.
""" """
sampling_metadata: Optional["SamplingMetadata"] = None sampling_metadata: Optional["SamplingMetadata"] = None
is_prompt: Optional[bool] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = { tensor_dict = {
@ -395,6 +396,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
vllm_config: VllmConfig, vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
return_hidden_states: bool = False,
*args, *args,
**kwargs, **kwargs,
): ):
@ -403,19 +405,25 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
cache_config = self.cache_config cache_config = self.cache_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
self.return_hidden_states = return_hidden_states
self.device = self.device_config.device self.device = self.device_config.device
self.pin_memory = False
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window() self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
needs_attn_backend = (num_attn_heads != 0
or self.model_config.is_attention_free)
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.dtype, self.model_config.dtype,
self.kv_cache_dtype, self.kv_cache_dtype,
self.block_size, self.block_size,
self.model_config.is_attention_free, self.model_config.is_attention_free,
) ) if needs_attn_backend else None
# Multi-modal data support # Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
@ -444,6 +452,15 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
return builder.build() # type: ignore return builder.build() # type: ignore
# sampler property will be used by spec_decode_worker
@property
def sampler(self):
return self.model.sampler
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
@ -480,9 +497,12 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
pin_memory=False, pin_memory=False,
generators=generators) generators=generators)
is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None)
return dataclasses.replace(model_input, return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
virtual_engine=virtual_engine) virtual_engine=virtual_engine,
is_prompt=is_prompt)
@torch.no_grad() @torch.no_grad()
def execute_model( def execute_model(
@ -491,16 +511,22 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
previous_hidden_states: Optional[torch.Tensor] = None,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
if num_steps > 1: if num_steps > 1:
raise ValueError( raise ValueError(
"CPU worker does not support multi-step execution.") "CPU worker does not support multi-step execution.")
model_executable = self.model model_executable = self.model
multimodal_kwargs = {} multimodal_kwargs = {}
if model_input.multi_modal_kwargs is not None: if model_input.multi_modal_kwargs is not None:
multimodal_kwargs = MultiModalKwargs.as_kwargs( multimodal_kwargs = MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs, device=self.device) model_input.multi_modal_kwargs, device=self.device)
execute_model_kwargs = {}
if previous_hidden_states is not None:
execute_model_kwargs.update(
{"previous_hidden_states": previous_hidden_states})
with set_forward_context(model_input.attn_metadata, self.vllm_config): with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_states = model_executable( hidden_states = model_executable(
@ -509,6 +535,7 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata, attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**execute_model_kwargs,
**multimodal_kwargs, **multimodal_kwargs,
) )
@ -525,4 +552,12 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
logits=logits, logits=logits,
sampling_metadata=model_input.sampling_metadata, sampling_metadata=model_input.sampling_metadata,
) )
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
if model_input.is_prompt:
output.prefill_hidden_states = hidden_states
output.hidden_states = hidden_states
return [output] return [output]
def generate_proposals(self, *args, **kwargs):
return self.model.generate_proposals(*args, **kwargs)

View File

@ -128,6 +128,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
distributed_init_method: str, distributed_init_method: str,
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
model_runner_cls: Optional[Type[CPUModelRunner]] = None,
) -> None: ) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config) WorkerBase.__init__(self, vllm_config=vllm_config)
@ -151,6 +152,16 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
else: else:
self.local_omp_cpuid = omp_cpuids.split("|")[rank] self.local_omp_cpuid = omp_cpuids.split("|")[rank]
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"]) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
if self.model_config.task == "embedding": if self.model_config.task == "embedding":
ModelRunnerClass = CPUEmbeddingModelRunner ModelRunnerClass = CPUEmbeddingModelRunner
@ -159,7 +170,11 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.model_runner: CPUModelRunnerBase = ModelRunnerClass( self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
vllm_config=vllm_config, vllm_config=vllm_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker,
**speculative_args,
)
if model_runner_cls is not None:
self.model_runner = model_runner_cls(self.model_runner)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
self.cache_engine: List[CPUCacheEngine] self.cache_engine: List[CPUCacheEngine]
@ -197,7 +212,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
if ret: if ret:
logger.info(ret) logger.info(ret)
self.device = torch.device("cpu")
self.init_distributed_environment() self.init_distributed_environment()
# Set random seed. # Set random seed.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
@ -297,6 +312,14 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.cpu_cache return self.cpu_cache
@property
def vocab_size(self) -> int:
return self.model_runner.vocab_size
@property
def max_model_len(self) -> int:
return self.model_config.max_model_len
def execute_worker( def execute_worker(
self, self,
worker_input: WorkerInput, worker_input: WorkerInput,

View File

@ -289,3 +289,18 @@ class ModelRunnerBase(ABC, Generic[T]):
self.generators.pop(request_id, None) self.generators.pop(request_id, None)
return self.generators return self.generators
class ModelRunnerWrapperBase:
"""
The whole point of this class is to lazily initialize the model_runner.
"""
def __init__(
self,
moderl_runner: ModelRunnerBase,
) -> None:
self.model_runner: ModelRunnerBase = moderl_runner
def __getattr__(self, attr):
return getattr(self.model_runner, attr)

View File

@ -74,9 +74,7 @@ class Worker(LocalOrDistributedWorkerBase):
else {"return_hidden_states": True} else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_runner_cls is not None: if model_config.task == "embedding":
ModelRunnerClass = model_runner_cls
elif model_config.task == "embedding":
ModelRunnerClass = EmbeddingModelRunner ModelRunnerClass = EmbeddingModelRunner
elif self.model_config.is_encoder_decoder: elif self.model_config.is_encoder_decoder:
ModelRunnerClass = EncoderDecoderModelRunner ModelRunnerClass = EncoderDecoderModelRunner
@ -86,6 +84,9 @@ class Worker(LocalOrDistributedWorkerBase):
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
**speculative_args, **speculative_args,
) )
if model_runner_cls is not None:
self.model_runner = model_runner_cls(self.model_runner)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
self.cache_engine: List[CacheEngine] self.cache_engine: List[CacheEngine]

View File

@ -466,6 +466,9 @@ class WorkerWrapperBase:
logger.exception(msg) logger.exception(msg)
raise e raise e
def __getattr__(self, attr):
return getattr(self.worker, attr)
def extract_previous_hidden_states( def extract_previous_hidden_states(
data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \ data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \