mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:55:44 +08:00
Remove hard-dependencies of Speculative decode to CUDA workers (#10587)
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
parent
2f0a0a17a4
commit
0a71900bc9
@ -595,8 +595,8 @@ def test_init_device(acceptance_sampler_method: str):
|
|||||||
|
|
||||||
target_worker.init_device.assert_called_once()
|
target_worker.init_device.assert_called_once()
|
||||||
|
|
||||||
metrics_collector.init_gpu_tensors.assert_called_once()
|
metrics_collector.init_tensors.assert_called_once()
|
||||||
spec_decode_sampler.init_gpu_tensors.assert_called_once()
|
spec_decode_sampler.init_tensors.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||||
|
|||||||
@ -990,6 +990,7 @@ class ParallelConfig:
|
|||||||
# the full name of the worker class to use. If "auto", the worker class
|
# the full name of the worker class to use. If "auto", the worker class
|
||||||
# will be determined based on the platform.
|
# will be determined based on the platform.
|
||||||
worker_cls: str = "auto"
|
worker_cls: str = "auto"
|
||||||
|
sd_worker_cls: str = "auto"
|
||||||
|
|
||||||
world_size: int = field(init=False)
|
world_size: int = field(init=False)
|
||||||
|
|
||||||
|
|||||||
@ -43,6 +43,21 @@ class SpecDecodeBaseSampler(nn.Module):
|
|||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
|
def init_tensors(self,
|
||||||
|
device: Union[int, str],
|
||||||
|
device_type: Union[torch.device, str] = 'cuda') -> None:
|
||||||
|
assert self.num_accepted_tokens is None
|
||||||
|
if isinstance(device_type, torch.device):
|
||||||
|
device_type = device_type.type
|
||||||
|
if isinstance(device, int):
|
||||||
|
device = f"{device_type}:{device}"
|
||||||
|
self.num_accepted_tokens = torch.tensor(0,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device)
|
||||||
|
self.num_emitted_tokens = torch.tensor(0,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def probs_dtype(self):
|
def probs_dtype(self):
|
||||||
return torch.float32
|
return torch.float32
|
||||||
@ -77,7 +92,7 @@ class SpecDecodeBaseSampler(nn.Module):
|
|||||||
tensor is [batch_size, k + num_bonus_tokens]
|
tensor is [batch_size, k + num_bonus_tokens]
|
||||||
"""
|
"""
|
||||||
batch_size, k = substitute_token_ids.shape
|
batch_size, k = substitute_token_ids.shape
|
||||||
bonus_token_ids = bonus_token_ids.squeeze()
|
bonus_token_ids = bonus_token_ids.squeeze(-1)
|
||||||
# Determine the index of the first False value for each row.
|
# Determine the index of the first False value for each row.
|
||||||
limits = (accepted == 0).max(1).indices
|
limits = (accepted == 0).max(1).indices
|
||||||
limits[~(accepted == 0).any(1)] = k
|
limits[~(accepted == 0).any(1)] = k
|
||||||
|
|||||||
@ -86,4 +86,10 @@ class CpuPlatform(Platform):
|
|||||||
parallel_config.distributed_executor_backend)
|
parallel_config.distributed_executor_backend)
|
||||||
parallel_config.distributed_executor_backend = "mp"
|
parallel_config.distributed_executor_backend = "mp"
|
||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
|
if vllm_config.speculative_config:
|
||||||
|
parallel_config.worker_cls = \
|
||||||
|
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||||
|
parallel_config.sd_worker_cls = \
|
||||||
|
"vllm.worker.cpu_worker.CPUWorker"
|
||||||
|
else:
|
||||||
|
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
|
||||||
|
|||||||
@ -106,6 +106,8 @@ class CudaPlatformBase(Platform):
|
|||||||
elif vllm_config.speculative_config:
|
elif vllm_config.speculative_config:
|
||||||
parallel_config.worker_cls = \
|
parallel_config.worker_cls = \
|
||||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||||
|
parallel_config.sd_worker_cls = \
|
||||||
|
"vllm.worker.worker.Worker"
|
||||||
else:
|
else:
|
||||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||||
|
|
||||||
@ -236,4 +238,4 @@ try:
|
|||||||
if not isinstance(pynvml, _MockModule):
|
if not isinstance(pynvml, _MockModule):
|
||||||
CudaPlatform.log_warnings()
|
CudaPlatform.log_warnings()
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
CudaPlatform.log_warnings()
|
CudaPlatform.log_warnings()
|
||||||
@ -20,8 +20,9 @@ except (ModuleNotFoundError, ImportError) as err:
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalKwargs
|
from vllm.multimodal import MultiModalKwargs
|
||||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||||
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
|
from vllm.worker.model_runner_base import (ModelRunnerBase,
|
||||||
ModelRunner)
|
ModelRunnerInputBase,
|
||||||
|
ModelRunnerWrapperBase)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -33,7 +34,7 @@ debug_advance_input = False
|
|||||||
allow_gpu_advance_step = True
|
allow_gpu_advance_step = True
|
||||||
|
|
||||||
|
|
||||||
class TP1DraftModelRunner(ModelRunner):
|
class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||||
"""Specialized model runner for speculative decoding draft model.
|
"""Specialized model runner for speculative decoding draft model.
|
||||||
Since the draft model always execute k forward passes consecutively to
|
Since the draft model always execute k forward passes consecutively to
|
||||||
generate k speculative tokens in a single speculative decoding step,
|
generate k speculative tokens in a single speculative decoding step,
|
||||||
@ -46,13 +47,14 @@ class TP1DraftModelRunner(ModelRunner):
|
|||||||
any broadcasting inside execute_model).
|
any broadcasting inside execute_model).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, model_runner: ModelRunnerBase):
|
||||||
if kwargs.get("return_hidden_states"):
|
if hasattr(
|
||||||
|
model_runner,
|
||||||
|
"return_hidden_states") and model_runner.return_hidden_states:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"return_hidden_states is not supported for TP1DraftModelRunner."
|
"return_hidden_states is not supported for TP1DraftModelRunner."
|
||||||
)
|
)
|
||||||
|
super().__init__(model_runner)
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
self.indices_of_seq_with_bonus_tokens = None
|
self.indices_of_seq_with_bonus_tokens = None
|
||||||
|
|
||||||
@ -73,10 +75,8 @@ class TP1DraftModelRunner(ModelRunner):
|
|||||||
assert seq_group.prompt_logprob_indices == [] # No prompt
|
assert seq_group.prompt_logprob_indices == [] # No prompt
|
||||||
assert seq_group.sample_indices == [i] # Simple
|
assert seq_group.sample_indices == [i] # Simple
|
||||||
|
|
||||||
def _gpu_advance_step(
|
def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
|
||||||
self, model_input: ModelInputForGPUWithSamplingMetadata,
|
last_output: SamplerOutput) -> ModelRunnerInputBase:
|
||||||
last_output: SamplerOutput
|
|
||||||
) -> ModelInputForGPUWithSamplingMetadata:
|
|
||||||
# Currently, we expect "decode mode" only
|
# Currently, we expect "decode mode" only
|
||||||
assert not model_input.is_prompt
|
assert not model_input.is_prompt
|
||||||
|
|
||||||
@ -168,7 +168,7 @@ class TP1DraftModelRunner(ModelRunner):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
model_input: ModelInputForGPUWithSamplingMetadata,
|
model_input: ModelRunnerInputBase,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Set
|
from typing import Optional, Set, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -75,9 +75,11 @@ class SpeculativeProposer(ABC):
|
|||||||
|
|
||||||
class SpeculativeScorer(ABC):
|
class SpeculativeScorer(ABC):
|
||||||
|
|
||||||
def __init__(self, scorer_worker: WorkerBase, device: str,
|
def __init__(self, scorer_worker: WorkerBase,
|
||||||
vocab_size: int):
|
device: Union[torch.device, str], vocab_size: int):
|
||||||
self._scorer_worker = scorer_worker
|
self._scorer_worker = scorer_worker
|
||||||
|
if isinstance(device, torch.device):
|
||||||
|
device = device.type
|
||||||
self._device = device
|
self._device = device
|
||||||
self._vocab_size = vocab_size
|
self._vocab_size = vocab_size
|
||||||
|
|
||||||
|
|||||||
@ -9,21 +9,22 @@ from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
|
|||||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
|
|
||||||
class MedusaWorker(NonLLMProposerWorkerBase, Worker):
|
class MedusaWorker(NonLLMProposerWorkerBase, WorkerWrapperBase):
|
||||||
"""Worker for Medusa.
|
"""Worker for Medusa.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(kwargs.get("vllm_config"))
|
||||||
|
self.init_worker(*args, **kwargs)
|
||||||
|
|
||||||
# Lazy initialization list.
|
# Lazy initialization list.
|
||||||
self._proposer: Top1Proposer
|
self._proposer: Top1Proposer
|
||||||
|
|
||||||
def init_device(self):
|
def init_device(self):
|
||||||
super().init_device()
|
self.worker.init_device()
|
||||||
|
|
||||||
self._proposer = Top1Proposer(
|
self._proposer = Top1Proposer(
|
||||||
weakref.proxy(self), # type: ignore[arg-type]
|
weakref.proxy(self), # type: ignore[arg-type]
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||||
SpecDecodeBaseSampler)
|
SpecDecodeBaseSampler)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
|
|
||||||
@ -81,8 +82,20 @@ class AsyncMetricsCollector:
|
|||||||
self._rank = rank
|
self._rank = rank
|
||||||
self._copy_stream = torch.cuda.Stream()
|
self._copy_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
def init_tensors(self,
|
||||||
|
rank: int,
|
||||||
|
device_type: Union[torch.device, str] = 'cuda') -> None:
|
||||||
|
self._rank = rank
|
||||||
|
if isinstance(device_type, torch.device):
|
||||||
|
device_type = device_type.type
|
||||||
|
if device_type == 'cuda':
|
||||||
|
self._copy_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
def maybe_collect_rejsample_metrics(
|
def maybe_collect_rejsample_metrics(
|
||||||
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
|
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
|
||||||
|
# currently using cuda.Event, skip for any non_cuda_alike platform
|
||||||
|
if not current_platform.is_cuda_alike():
|
||||||
|
return None
|
||||||
|
|
||||||
# If a copy was initiated in the previous call, collect and return.
|
# If a copy was initiated in the previous call, collect and return.
|
||||||
if self._in_flight_copy is not None:
|
if self._in_flight_copy is not None:
|
||||||
|
|||||||
@ -5,17 +5,21 @@ from typing import Dict, List, Set, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
|
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
|
||||||
SequenceGroupMetadata)
|
SequenceGroupMetadata)
|
||||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
|
||||||
|
if current_platform.is_cuda_alike():
|
||||||
|
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||||
|
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeProposer)
|
SpeculativeProposer)
|
||||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
|
|
||||||
class MultiStepWorker(Worker, ProposerWorkerBase):
|
class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
|
||||||
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
||||||
multiple forward passes in a single call, assuming the scheduler has
|
multiple forward passes in a single call, assuming the scheduler has
|
||||||
allocated enough space to store the additional KV. This reduces overhead
|
allocated enough space to store the additional KV. This reduces overhead
|
||||||
@ -28,13 +32,14 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(kwargs.get("vllm_config"))
|
||||||
|
self.init_worker(*args, **kwargs)
|
||||||
|
|
||||||
# Lazy initialization list.
|
# Lazy initialization list.
|
||||||
self._proposer: SpeculativeProposer
|
self._proposer: SpeculativeProposer
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
super().init_device()
|
self.worker.init_device()
|
||||||
|
|
||||||
self._proposer = Top1Proposer(
|
self._proposer = Top1Proposer(
|
||||||
weakref.proxy(self), # type: ignore[arg-type]
|
weakref.proxy(self), # type: ignore[arg-type]
|
||||||
@ -51,6 +56,18 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
|||||||
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
|
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
|
||||||
True)
|
True)
|
||||||
|
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
return self.worker.determine_num_available_blocks()
|
||||||
|
|
||||||
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
|
return self.worker.get_cache_block_size_bytes()
|
||||||
|
|
||||||
|
def initialize_cache(self, *args, **kwargs) -> None:
|
||||||
|
self.worker.initialize_cache(*args, **kwargs)
|
||||||
|
|
||||||
|
def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
|
||||||
|
return self.worker.execute_model(*args, **kwargs)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sampler_output(
|
def sampler_output(
|
||||||
self,
|
self,
|
||||||
@ -75,7 +92,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
|||||||
|
|
||||||
# Run model sample_len times.
|
# Run model sample_len times.
|
||||||
model_outputs: List[SamplerOutput] = []
|
model_outputs: List[SamplerOutput] = []
|
||||||
if isinstance(
|
if current_platform.is_cuda_alike() and isinstance(
|
||||||
self.model_runner, TP1DraftModelRunner
|
self.model_runner, TP1DraftModelRunner
|
||||||
) and self.model_runner.supports_gpu_multi_step(expanded_request):
|
) and self.model_runner.supports_gpu_multi_step(expanded_request):
|
||||||
# Here we run the draft_model_runner with multi-step prepare
|
# Here we run the draft_model_runner with multi-step prepare
|
||||||
@ -92,7 +109,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
|||||||
# and other restrictions that are part of DraftModelRunner's
|
# and other restrictions that are part of DraftModelRunner's
|
||||||
# supports_gpu_multi_step(..)
|
# supports_gpu_multi_step(..)
|
||||||
for _ in range(sample_len):
|
for _ in range(sample_len):
|
||||||
model_output: List[SamplerOutput] = super().execute_model(
|
model_output: List[SamplerOutput] = self.worker.execute_model(
|
||||||
execute_model_req=expanded_request)
|
execute_model_req=expanded_request)
|
||||||
assert (len(model_output) == 1
|
assert (len(model_output) == 1
|
||||||
), "composing multistep workers not supported"
|
), "composing multistep workers not supported"
|
||||||
|
|||||||
@ -22,6 +22,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
|
|||||||
# Get local_rank/vocab_size from kwargs attribute
|
# Get local_rank/vocab_size from kwargs attribute
|
||||||
self.local_rank = kwargs["local_rank"]
|
self.local_rank = kwargs["local_rank"]
|
||||||
self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size()
|
self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size()
|
||||||
|
self.device_type = kwargs.get("device_type", "cuda")
|
||||||
|
|
||||||
# Lazy initialization list.
|
# Lazy initialization list.
|
||||||
self._proposer: Top1Proposer
|
self._proposer: Top1Proposer
|
||||||
@ -34,7 +35,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
|
|||||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
|
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
|
||||||
|
|
||||||
def init_device(self):
|
def init_device(self):
|
||||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
self.device = torch.device(f"{self.device_type}:{self.local_rank}")
|
||||||
self.load_model = lambda *args, **kwargs: None
|
self.load_model = lambda *args, **kwargs: None
|
||||||
|
|
||||||
# Current NGramWorker only supports Top1Proposer
|
# Current NGramWorker only supports Top1Proposer
|
||||||
|
|||||||
@ -14,12 +14,16 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
|
|||||||
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
|
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
|
||||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||||
TypicalAcceptanceSampler)
|
TypicalAcceptanceSampler)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
|
||||||
CompletionSequenceGroupOutput, ExecuteModelRequest,
|
CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||||
HiddenStates, SequenceGroupMetadata,
|
HiddenStates, SequenceGroupMetadata,
|
||||||
get_all_seq_ids_and_request_ids)
|
get_all_seq_ids_and_request_ids)
|
||||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
|
||||||
|
if current_platform.is_cuda_alike():
|
||||||
|
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||||
|
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeScorer, SpeculativeScores)
|
SpeculativeScorer, SpeculativeScores)
|
||||||
from vllm.spec_decode.medusa_worker import MedusaWorker
|
from vllm.spec_decode.medusa_worker import MedusaWorker
|
||||||
@ -36,8 +40,8 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
|
|||||||
get_all_num_logprobs,
|
get_all_num_logprobs,
|
||||||
get_sampled_token_logprobs, nvtx_range,
|
get_sampled_token_logprobs, nvtx_range,
|
||||||
split_batch_by_proposal_len)
|
split_batch_by_proposal_len)
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker_base import (LoraNotSupportedWorkerBase, WorkerBase,
|
||||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
|
WorkerWrapperBase)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -53,7 +57,11 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
|||||||
draft_worker_kwargs = kwargs.copy()
|
draft_worker_kwargs = kwargs.copy()
|
||||||
|
|
||||||
kwargs["model_runner_cls"] = TargetModelRunner
|
kwargs["model_runner_cls"] = TargetModelRunner
|
||||||
target_worker = Worker(*args, **kwargs)
|
target_worker_config = copy.deepcopy(vllm_config)
|
||||||
|
target_worker_config.parallel_config.worker_cls =\
|
||||||
|
target_worker_config.parallel_config.sd_worker_cls
|
||||||
|
target_worker = WorkerWrapperBase(vllm_config=target_worker_config)
|
||||||
|
target_worker.init_worker(*args, **kwargs)
|
||||||
# Set the disable_logprobs variable in the TargetModelRunner instance
|
# Set the disable_logprobs variable in the TargetModelRunner instance
|
||||||
# as per its value specified in the SpeculativeConfig.
|
# as per its value specified in the SpeculativeConfig.
|
||||||
target_worker.model_runner.disable_logprobs =\
|
target_worker.model_runner.disable_logprobs =\
|
||||||
@ -65,6 +73,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
|||||||
draft_worker_config.model_config,
|
draft_worker_config.model_config,
|
||||||
vllm_config.load_config,
|
vllm_config.load_config,
|
||||||
)
|
)
|
||||||
|
speculative_config.draft_parallel_config.worker_cls =\
|
||||||
|
draft_worker_config.parallel_config.sd_worker_cls
|
||||||
draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa
|
draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa
|
||||||
# TODO allow draft-model specific load config.
|
# TODO allow draft-model specific load config.
|
||||||
|
|
||||||
@ -125,7 +135,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def create_worker(
|
def create_worker(
|
||||||
cls,
|
cls,
|
||||||
scorer_worker: Worker,
|
scorer_worker: WorkerBase,
|
||||||
draft_worker_kwargs: Dict[str, Any],
|
draft_worker_kwargs: Dict[str, Any],
|
||||||
disable_mqa_scorer: bool,
|
disable_mqa_scorer: bool,
|
||||||
disable_by_batch_size: Optional[int],
|
disable_by_batch_size: Optional[int],
|
||||||
@ -145,6 +155,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
|
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
|
||||||
'vllm_config'].parallel_config
|
'vllm_config'].parallel_config
|
||||||
if ngram_prompt_lookup_max > 0:
|
if ngram_prompt_lookup_max > 0:
|
||||||
|
draft_worker_kwargs[
|
||||||
|
"device_type"] = scorer_worker.device_config.device.type
|
||||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||||
ngram_prompt_lookup_max)
|
ngram_prompt_lookup_max)
|
||||||
@ -158,8 +170,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
||||||
else:
|
else:
|
||||||
if draft_tp == 1:
|
if draft_tp == 1:
|
||||||
draft_worker_kwargs[
|
if current_platform.is_cuda_alike():
|
||||||
"model_runner_cls"] = TP1DraftModelRunner
|
draft_worker_kwargs[
|
||||||
|
"model_runner_cls"] = TP1DraftModelRunner
|
||||||
else:
|
else:
|
||||||
if draft_model_config.hf_config.model_type == "eagle":
|
if draft_model_config.hf_config.model_type == "eagle":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -306,8 +319,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.scorer_worker.load_model()
|
self.scorer_worker.load_model()
|
||||||
self.proposer_worker.load_model()
|
self.proposer_worker.load_model()
|
||||||
|
|
||||||
self._metrics.init_gpu_tensors(self.rank)
|
self._metrics.init_tensors(self.rank, device_type=self.device)
|
||||||
self.spec_decode_sampler.init_gpu_tensors(self.rank)
|
self.spec_decode_sampler.init_tensors(self.rank,
|
||||||
|
device_type=self.device)
|
||||||
|
|
||||||
scorer_cls: Type[SpeculativeScorer]
|
scorer_cls: Type[SpeculativeScorer]
|
||||||
if self.disable_mqa_scorer:
|
if self.disable_mqa_scorer:
|
||||||
@ -1111,11 +1125,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def start_profile(self):
|
def start_profile(self):
|
||||||
if isinstance(self.scorer_worker, Worker):
|
if isinstance(self.scorer_worker, WorkerBase):
|
||||||
self.scorer_worker.start_profile()
|
self.scorer_worker.start_profile()
|
||||||
|
|
||||||
def stop_profile(self):
|
def stop_profile(self):
|
||||||
if isinstance(self.scorer_worker, Worker):
|
if isinstance(self.scorer_worker, WorkerBase):
|
||||||
self.scorer_worker.stop_profile()
|
self.scorer_worker.stop_profile()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.sequence import SequenceGroupMetadata
|
from vllm.sequence import SequenceGroupMetadata
|
||||||
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
|
from vllm.worker.model_runner_base import (ModelRunnerBase,
|
||||||
ModelRunner)
|
ModelRunnerInputBase,
|
||||||
|
ModelRunnerWrapperBase)
|
||||||
|
|
||||||
|
|
||||||
class TargetModelRunner(ModelRunner):
|
class TargetModelRunner(ModelRunnerWrapperBase):
|
||||||
"""Specialized model runner for speculative decoding target model.
|
"""Specialized model runner for speculative decoding target model.
|
||||||
In speculative decoding, the log probabilities selected finally may not
|
In speculative decoding, the log probabilities selected finally may not
|
||||||
be the same ones as selected by the target model sampling. This means
|
be the same ones as selected by the target model sampling. This means
|
||||||
@ -18,32 +18,21 @@ class TargetModelRunner(ModelRunner):
|
|||||||
requested or not.
|
requested or not.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, model_runner: ModelRunnerBase):
|
||||||
self,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
kv_cache_dtype: Optional[str] = "auto",
|
|
||||||
is_driver_worker: bool = False,
|
|
||||||
return_hidden_states: bool = False,
|
|
||||||
):
|
|
||||||
# An internal boolean member variable to indicate if token log
|
# An internal boolean member variable to indicate if token log
|
||||||
# probabilities are needed or not.
|
# probabilities are needed or not.
|
||||||
|
super().__init__(model_runner)
|
||||||
self.disable_logprobs = True
|
self.disable_logprobs = True
|
||||||
super().__init__(
|
|
||||||
vllm_config=vllm_config,
|
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
|
||||||
is_driver_worker=is_driver_worker,
|
|
||||||
return_hidden_states=return_hidden_states,
|
|
||||||
)
|
|
||||||
|
|
||||||
def prepare_model_input(
|
def prepare_model_input(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
virtual_engine: int = 0,
|
virtual_engine: int = 0,
|
||||||
finished_requests_ids: Optional[List[str]] = None
|
finished_requests_ids: Optional[List[str]] = None,
|
||||||
) -> ModelInputForGPUWithSamplingMetadata:
|
) -> ModelRunnerInputBase:
|
||||||
model_input: ModelInputForGPUWithSamplingMetadata = super(
|
model_input: ModelRunnerInputBase =\
|
||||||
).prepare_model_input(seq_group_metadata_list, virtual_engine,
|
self.model_runner.prepare_model_input(
|
||||||
finished_requests_ids)
|
seq_group_metadata_list, virtual_engine, finished_requests_ids)
|
||||||
# If token log probabilities is disabled then skip generating sampler
|
# If token log probabilities is disabled then skip generating sampler
|
||||||
# CPU output. We directly serialize the GPU sampled_token_id tensors
|
# CPU output. We directly serialize the GPU sampled_token_id tensors
|
||||||
# as needed. If log probabilities is enabled then synchronize all the
|
# as needed. If log probabilities is enabled then synchronize all the
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||||
PromptLogprobs, SequenceGroupMetadata,
|
PromptLogprobs, SequenceGroupMetadata,
|
||||||
SequenceOutput)
|
SequenceOutput)
|
||||||
@ -247,11 +248,14 @@ def nvtx_range(msg, *args, **kwargs):
|
|||||||
Arguments:
|
Arguments:
|
||||||
msg (string): message to associate with the range
|
msg (string): message to associate with the range
|
||||||
"""
|
"""
|
||||||
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
|
if current_platform.is_cuda_alike():
|
||||||
try:
|
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.cuda.nvtx.range_pop()
|
||||||
|
else:
|
||||||
yield
|
yield
|
||||||
finally:
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
|
|||||||
@ -80,6 +80,7 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
|
|||||||
Used by the ModelRunner.
|
Used by the ModelRunner.
|
||||||
"""
|
"""
|
||||||
sampling_metadata: Optional["SamplingMetadata"] = None
|
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||||
|
is_prompt: Optional[bool] = None
|
||||||
|
|
||||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
tensor_dict = {
|
tensor_dict = {
|
||||||
@ -395,6 +396,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
kv_cache_dtype: Optional[str] = "auto",
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False,
|
||||||
|
return_hidden_states: bool = False,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -403,19 +405,25 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
|||||||
cache_config = self.cache_config
|
cache_config = self.cache_config
|
||||||
|
|
||||||
self.is_driver_worker = is_driver_worker
|
self.is_driver_worker = is_driver_worker
|
||||||
|
self.return_hidden_states = return_hidden_states
|
||||||
|
|
||||||
self.device = self.device_config.device
|
self.device = self.device_config.device
|
||||||
|
self.pin_memory = False
|
||||||
|
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
self.sliding_window = model_config.get_sliding_window()
|
self.sliding_window = model_config.get_sliding_window()
|
||||||
self.block_size = cache_config.block_size
|
self.block_size = cache_config.block_size
|
||||||
|
num_attn_heads = self.model_config.get_num_attention_heads(
|
||||||
|
self.parallel_config)
|
||||||
|
needs_attn_backend = (num_attn_heads != 0
|
||||||
|
or self.model_config.is_attention_free)
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
self.model_config.get_head_size(),
|
self.model_config.get_head_size(),
|
||||||
self.model_config.dtype,
|
self.model_config.dtype,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
self.model_config.is_attention_free,
|
self.model_config.is_attention_free,
|
||||||
)
|
) if needs_attn_backend else None
|
||||||
|
|
||||||
# Multi-modal data support
|
# Multi-modal data support
|
||||||
self.mm_registry = MULTIMODAL_REGISTRY
|
self.mm_registry = MULTIMODAL_REGISTRY
|
||||||
@ -444,6 +452,15 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
|||||||
|
|
||||||
return builder.build() # type: ignore
|
return builder.build() # type: ignore
|
||||||
|
|
||||||
|
# sampler property will be used by spec_decode_worker
|
||||||
|
@property
|
||||||
|
def sampler(self):
|
||||||
|
return self.model.sampler
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return self.model_config.get_vocab_size()
|
||||||
|
|
||||||
|
|
||||||
class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||||
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
|
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
|
||||||
@ -480,9 +497,12 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
|||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
generators=generators)
|
generators=generators)
|
||||||
|
|
||||||
|
is_prompt = (seq_group_metadata_list[0].is_prompt
|
||||||
|
if seq_group_metadata_list else None)
|
||||||
return dataclasses.replace(model_input,
|
return dataclasses.replace(model_input,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
virtual_engine=virtual_engine)
|
virtual_engine=virtual_engine,
|
||||||
|
is_prompt=is_prompt)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
@ -491,16 +511,22 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
|||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
num_steps: int = 1,
|
num_steps: int = 1,
|
||||||
|
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||||
) -> Optional[List[SamplerOutput]]:
|
) -> Optional[List[SamplerOutput]]:
|
||||||
if num_steps > 1:
|
if num_steps > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"CPU worker does not support multi-step execution.")
|
"CPU worker does not support multi-step execution.")
|
||||||
|
|
||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
|
|
||||||
multimodal_kwargs = {}
|
multimodal_kwargs = {}
|
||||||
if model_input.multi_modal_kwargs is not None:
|
if model_input.multi_modal_kwargs is not None:
|
||||||
multimodal_kwargs = MultiModalKwargs.as_kwargs(
|
multimodal_kwargs = MultiModalKwargs.as_kwargs(
|
||||||
model_input.multi_modal_kwargs, device=self.device)
|
model_input.multi_modal_kwargs, device=self.device)
|
||||||
|
execute_model_kwargs = {}
|
||||||
|
if previous_hidden_states is not None:
|
||||||
|
execute_model_kwargs.update(
|
||||||
|
{"previous_hidden_states": previous_hidden_states})
|
||||||
|
|
||||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||||
hidden_states = model_executable(
|
hidden_states = model_executable(
|
||||||
@ -509,6 +535,7 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
|||||||
kv_caches=kv_caches,
|
kv_caches=kv_caches,
|
||||||
attn_metadata=model_input.attn_metadata,
|
attn_metadata=model_input.attn_metadata,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
**execute_model_kwargs,
|
||||||
**multimodal_kwargs,
|
**multimodal_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -525,4 +552,12 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
|||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=model_input.sampling_metadata,
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
)
|
)
|
||||||
|
if self.return_hidden_states:
|
||||||
|
# we only need to pass hidden states of most recent token
|
||||||
|
if model_input.is_prompt:
|
||||||
|
output.prefill_hidden_states = hidden_states
|
||||||
|
output.hidden_states = hidden_states
|
||||||
return [output]
|
return [output]
|
||||||
|
|
||||||
|
def generate_proposals(self, *args, **kwargs):
|
||||||
|
return self.model.generate_proposals(*args, **kwargs)
|
||||||
|
|||||||
@ -128,6 +128,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
kv_cache_dtype: Optional[str] = "auto",
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False,
|
||||||
|
model_runner_cls: Optional[Type[CPUModelRunner]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||||
|
|
||||||
@ -151,6 +152,16 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
else:
|
else:
|
||||||
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
|
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
|
||||||
|
|
||||||
|
# Return hidden states from target model if the draft model is an
|
||||||
|
# mlp_speculator
|
||||||
|
speculative_config = self.speculative_config
|
||||||
|
model_config = self.model_config
|
||||||
|
speculative_args = {} if speculative_config is None \
|
||||||
|
or (speculative_config.draft_model_config.model ==
|
||||||
|
model_config.model) \
|
||||||
|
or (speculative_config.draft_model_config.hf_config.model_type
|
||||||
|
not in ["medusa", "mlp_speculator", "eagle"]) \
|
||||||
|
else {"return_hidden_states": True}
|
||||||
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
|
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
|
||||||
if self.model_config.task == "embedding":
|
if self.model_config.task == "embedding":
|
||||||
ModelRunnerClass = CPUEmbeddingModelRunner
|
ModelRunnerClass = CPUEmbeddingModelRunner
|
||||||
@ -159,7 +170,11 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
|
self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
is_driver_worker=is_driver_worker)
|
is_driver_worker=is_driver_worker,
|
||||||
|
**speculative_args,
|
||||||
|
)
|
||||||
|
if model_runner_cls is not None:
|
||||||
|
self.model_runner = model_runner_cls(self.model_runner)
|
||||||
# Uninitialized cache engine. Will be initialized by
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# initialize_cache.
|
# initialize_cache.
|
||||||
self.cache_engine: List[CPUCacheEngine]
|
self.cache_engine: List[CPUCacheEngine]
|
||||||
@ -197,7 +212,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||||
if ret:
|
if ret:
|
||||||
logger.info(ret)
|
logger.info(ret)
|
||||||
|
self.device = torch.device("cpu")
|
||||||
self.init_distributed_environment()
|
self.init_distributed_environment()
|
||||||
# Set random seed.
|
# Set random seed.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
@ -297,6 +312,14 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||||
return self.cpu_cache
|
return self.cpu_cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return self.model_runner.vocab_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_model_len(self) -> int:
|
||||||
|
return self.model_config.max_model_len
|
||||||
|
|
||||||
def execute_worker(
|
def execute_worker(
|
||||||
self,
|
self,
|
||||||
worker_input: WorkerInput,
|
worker_input: WorkerInput,
|
||||||
|
|||||||
@ -289,3 +289,18 @@ class ModelRunnerBase(ABC, Generic[T]):
|
|||||||
self.generators.pop(request_id, None)
|
self.generators.pop(request_id, None)
|
||||||
|
|
||||||
return self.generators
|
return self.generators
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRunnerWrapperBase:
|
||||||
|
"""
|
||||||
|
The whole point of this class is to lazily initialize the model_runner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
moderl_runner: ModelRunnerBase,
|
||||||
|
) -> None:
|
||||||
|
self.model_runner: ModelRunnerBase = moderl_runner
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
return getattr(self.model_runner, attr)
|
||||||
|
|||||||
@ -74,9 +74,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
else {"return_hidden_states": True}
|
else {"return_hidden_states": True}
|
||||||
|
|
||||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||||
if model_runner_cls is not None:
|
if model_config.task == "embedding":
|
||||||
ModelRunnerClass = model_runner_cls
|
|
||||||
elif model_config.task == "embedding":
|
|
||||||
ModelRunnerClass = EmbeddingModelRunner
|
ModelRunnerClass = EmbeddingModelRunner
|
||||||
elif self.model_config.is_encoder_decoder:
|
elif self.model_config.is_encoder_decoder:
|
||||||
ModelRunnerClass = EncoderDecoderModelRunner
|
ModelRunnerClass = EncoderDecoderModelRunner
|
||||||
@ -86,6 +84,9 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
is_driver_worker=is_driver_worker,
|
is_driver_worker=is_driver_worker,
|
||||||
**speculative_args,
|
**speculative_args,
|
||||||
)
|
)
|
||||||
|
if model_runner_cls is not None:
|
||||||
|
self.model_runner = model_runner_cls(self.model_runner)
|
||||||
|
|
||||||
# Uninitialized cache engine. Will be initialized by
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# initialize_cache.
|
# initialize_cache.
|
||||||
self.cache_engine: List[CacheEngine]
|
self.cache_engine: List[CacheEngine]
|
||||||
|
|||||||
@ -466,6 +466,9 @@ class WorkerWrapperBase:
|
|||||||
logger.exception(msg)
|
logger.exception(msg)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
return getattr(self.worker, attr)
|
||||||
|
|
||||||
|
|
||||||
def extract_previous_hidden_states(
|
def extract_previous_hidden_states(
|
||||||
data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
|
data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user