[Speculative Decoding] Support draft model on different tensor-parallel size than target model (#5414)

This commit is contained in:
Woo-Yeon Lee 2024-06-25 18:56:06 +09:00 committed by GitHub
parent f23871e9ee
commit 2ce5d6688b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 388 additions and 59 deletions

View File

@ -54,7 +54,7 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
@ -71,6 +71,7 @@ steps:
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
- label: Engine Test
mirror_hardwares: [amd]

View File

@ -25,6 +25,8 @@ def main(args: argparse.Namespace):
model=args.model,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
speculative_draft_tensor_parallel_size=\
args.speculative_draft_tensor_parallel_size,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
@ -127,6 +129,10 @@ if __name__ == '__main__':
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--speculative-model', type=str, default=None)
parser.add_argument('--num-speculative-tokens', type=int, default=None)
parser.add_argument('--speculative-draft-tensor-parallel-size',
'-spec-draft-tp',
type=int,
default=None)
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',

View File

@ -0,0 +1,111 @@
"""Tests which cover integration of the speculative decoding framework with
tensor parallelism.
"""
import pytest
import torch
from vllm.utils import is_hip
from .conftest import run_greedy_equality_correctness_test
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"tensor_parallel_size": 2,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when tensor parallelism is used.
"""
if is_hip():
pytest.skip("hip is not well-supported yet")
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"tensor_parallel_size": 2,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_draft_tensor_parallel_size": 1,
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model_tp2(test_llm_generator,
baseline_llm_generator,
batch_size: int):
"""Verify spec decode works well with smaller tp for draft models.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=32,
force_output_len=True)

View File

@ -5,16 +5,16 @@ tensor parallelism.
import pytest
import torch
from vllm.utils import is_hip
from .conftest import run_greedy_equality_correctness_test
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
@ -22,7 +22,7 @@ from .conftest import run_greedy_equality_correctness_test
# Required for spec decode.
"use_v2_block_manager": True,
"tensor_parallel_size": 2,
"tensor_parallel_size": 4,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
@ -31,35 +31,30 @@ from .conftest import run_greedy_equality_correctness_test
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"output_len",
"test_llm_kwargs",
[
# Use smaller output len for fast test.
32,
#TODO(wooyeon): add spec_draft_dp=2 case
{
"speculative_draft_tensor_parallel_size": 1,
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when tensor parallelism is used.
def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
baseline_llm_generator,
batch_size: int):
"""Verify spec decode works well with smaller tp for draft models.
"""
if is_hip():
pytest.skip("hip is not well-supported yet")
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
max_output_len=32,
force_output_len=True)

View File

@ -797,6 +797,7 @@ class SpeculativeConfig:
target_parallel_config: ParallelConfig,
target_dtype: str,
speculative_model: Optional[str],
speculative_draft_tensor_parallel_size: Optional[int],
num_speculative_tokens: Optional[int],
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
@ -819,6 +820,8 @@ class SpeculativeConfig:
target_dtype (str): The data type used for the target model.
speculative_model (Optional[str]): The name of the speculative
model, if provided.
speculative_draft_tensor_parallel_size (Optional[int]): The degree
of the tensor parallelism for the draft model.
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided. Will default to the number in the draft
model config if present, otherwise is required.
@ -939,7 +942,8 @@ class SpeculativeConfig:
draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config))
target_parallel_config,
speculative_draft_tensor_parallel_size))
if num_speculative_tokens is None:
raise ValueError(
@ -993,16 +997,26 @@ class SpeculativeConfig:
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig) -> ParallelConfig:
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int]
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config. In the future the
draft worker can have a different parallel strategy, e.g. TP=1.
This is mostly a copy of the target parallel config, except the tp_size.
"""
if speculative_draft_tensor_parallel_size is None:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
elif speculative_draft_tensor_parallel_size != 1:
# TODO(wooyeon): allow tp values larger than 1
raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be"
f"other value than 1")
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size,
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
tensor_parallel_size=speculative_draft_tensor_parallel_size,
distributed_executor_backend=target_parallel_config.
distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.

View File

@ -676,6 +676,28 @@ def get_world_group() -> GroupCoordinator:
return _WORLD
def init_world_group(ranks: List[int], local_rank: int,
backend: str) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=[ranks],
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=False,
use_custom_allreduce=False,
)
def init_model_parallel_group(group_ranks: List[List[int]], local_rank: int,
backend: str) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
)
_TP: Optional[GroupCoordinator] = None
@ -764,13 +786,7 @@ def init_distributed_environment(
global _WORLD
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = GroupCoordinator(
group_ranks=[ranks],
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=False,
use_custom_allreduce=False,
)
_WORLD = init_world_group(ranks, local_rank, backend)
else:
assert _WORLD.world_size == torch.distributed.get_world_size(), (
"world group already initialized with a different world size")
@ -827,13 +843,8 @@ def initialize_model_parallel(
range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size))
group_ranks.append(ranks)
_TP = GroupCoordinator(
group_ranks=group_ranks,
local_rank=get_world_group().local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
)
_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank, backend)
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (world_size //
@ -845,13 +856,8 @@ def initialize_model_parallel(
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
_PP = GroupCoordinator(
group_ranks=group_ranks,
local_rank=get_world_group().local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
)
_PP = init_model_parallel_group(group_ranks,
get_world_group().local_rank, backend)
def ensure_model_parallel_initialized(
@ -887,6 +893,34 @@ def model_parallel_is_initialized():
return (_TP is not None and _PP is not None)
_TP_STATE_PATCHED = False
@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
"""Patch the tp group temporarily until this function ends.
This method is for draft workers of speculative decoding to run draft model
with different tp degree from that of target model workers.
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _TP_STATE_PATCHED
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
_TP_STATE_PATCHED = True
old_tp_group = get_tp_group()
global _TP
_TP = tp_group
try:
yield
finally:
# restore the original state
_TP_STATE_PATCHED = False
_TP = old_tp_group
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return get_tp_group().world_size

View File

@ -94,6 +94,7 @@ class EngineArgs:
guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
@ -537,6 +538,13 @@ class EngineArgs:
default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-draft-tensor-parallel-size',
'-spec-draft-tp',
type=int,
default=EngineArgs.speculative_draft_tensor_parallel_size,
help='Number of tensor parallel replicas for '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-max-model-len',
@ -686,6 +694,8 @@ class EngineArgs:
target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
speculative_draft_tensor_parallel_size = \
self.speculative_draft_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,

View File

@ -6,7 +6,8 @@ import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
@ -28,9 +29,9 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
super().__init__(*args, **kwargs)
# Lazy initialization list.
self._proposer: Top1Proposer
self._proposer: SpeculativeProposer
def init_device(self):
def init_device(self) -> None:
super().init_device()
self._proposer = Top1Proposer(
@ -40,7 +41,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
max_proposal_len=self.max_model_len,
)
def set_include_gpu_probs_tensor(self):
def set_include_gpu_probs_tensor(self) -> None:
# Need include_gpu_probs_tensor for multi_step_worker
self.model_runner.model.sampler.include_gpu_probs_tensor = True
@ -73,7 +74,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
for _ in range(sample_len):
model_output = super().execute_model(
model_output: List[SamplerOutput] = super().execute_model(
execute_model_req=copied_execute_model_req)
assert (len(model_output) == 1
), "composing multistep workers not supported"

View File

@ -3,10 +3,10 @@ from typing import List, Optional, Tuple
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposer
from vllm.worker.worker_base import WorkerBase
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class ProposerWorkerBase(WorkerBase, SpeculativeProposer):
class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
"""Interface for proposer workers"""
@abstractmethod

View File

@ -0,0 +1,149 @@
from typing import List, Optional, Tuple
import torch
from vllm.distributed.parallel_state import (get_tp_group,
init_model_parallel_group,
patch_tensor_parallel_group)
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
logger = init_logger(__name__)
class SmallerTpProposerWorker(ProposerWorkerBase):
"""Class which allows a speculative draft model to run with smaller tensor
parallel degree than target model.
This reduces the communication overhead of small draft models.
To implement this feature, this class differs behavior based on is_dummy
flag, where dummy means worker that does not participate draft generation.
Participating workers use a smaller tp group by patching vLLM's tensor
parallel group temporarily during forward passes of draft models.
"""
@classmethod
def maybe_wrap_worker(cls, worker, draft_tensor_parallel_size: int,
target_tensor_parallel_size: int):
"""Wrap the worker in a SmallerTpProposerWorker if necessary.
"""
if draft_tensor_parallel_size == target_tensor_parallel_size:
return worker
# gpu ranks that will generate draft tokens together
draft_ranks = list(range(draft_tensor_parallel_size))
logger.info("Wrapping {%s} in {%s}", type(worker), cls)
return cls(worker, draft_ranks)
def __init__(self, worker: MultiStepWorker, draft_ranks: List[int]):
"""Create a SmallerTpProposerWorker.
Args:
worker (MultiStepWorker): an actual worker wrapped with this class
draft_ranks (List[int]): if this value is given, only the GPU ranks
written in this value participate in draft generation
"""
self._worker = worker
self._draft_ranks = draft_ranks
# init during init_device
self._is_dummy = False
self._tp_group = None
def _patch_tensor_parallel_group(self):
"""Temporarily patch the global tp group state with its own tp group
state.
"""
return patch_tensor_parallel_group(self._tp_group)
def init_device(self) -> None:
self._is_dummy = get_tp_group().rank not in self._draft_ranks
# dummy workers do nothing
if self._is_dummy:
return
# creates tp process group containing only a subset of gpu ranks
local_rank = get_tp_group().local_rank
tp_backend = torch.distributed.get_backend(get_tp_group().device_group)
self._tp_group = init_model_parallel_group([self._draft_ranks],
local_rank, tp_backend)
with self._patch_tensor_parallel_group():
self._worker.init_device()
def set_include_gpu_probs_tensor(self) -> None:
if self._is_dummy:
return
# Need include_gpu_probs_tensor for multi_step_worker
self._worker.set_include_gpu_probs_tensor()
def load_model(self) -> None:
if self._is_dummy:
return
with self._patch_tensor_parallel_group():
self._worker.load_model()
def determine_num_available_blocks(self) -> Tuple[int, int]:
if self._is_dummy:
# this case is not used now
return -1, -1
with self._patch_tensor_parallel_group():
return self._worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
if self._is_dummy:
return
with self._patch_tensor_parallel_group():
self._worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[List[SamplerOutput], bool]:
# Do not check _is_dummy, as it's always called by get_spec_proposals
return self._worker.sampler_output(execute_model_req, sample_len)
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
if self._is_dummy:
return SpeculativeProposals(None, None, None)
with self._patch_tensor_parallel_group():
return self._worker.get_spec_proposals(execute_model_req)
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
if self._is_dummy:
return []
with self._patch_tensor_parallel_group():
return self._worker.execute_model(execute_model_req)
def get_cache_block_size_bytes(self) -> int:
if self._is_dummy:
# by returning zero, target worker can use the entire kv cache space
return 0
return self._worker.get_cache_block_size_bytes()
@property
def vocab_size(self) -> int:
return self._worker.vocab_size

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from vllm.config import SpeculativeConfig
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
@ -18,6 +18,7 @@ from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
@ -90,7 +91,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@classmethod
def create_worker(
cls,
scorer_worker: WorkerBase,
scorer_worker: Worker,
draft_worker_kwargs: Dict[str, Any],
disable_by_batch_size: Optional[int],
) -> "SpecDecodeWorker":
@ -111,7 +112,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
disable_bonus_tokens = False
else:
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'parallel_config']
draft_tp = draft_parallel_config.tensor_parallel_size
target_tp = scorer_worker.parallel_config.tensor_parallel_size
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
proposer_worker, draft_tp, target_tp)
logger.info("Configuring SpecDecodeWorker with proposer=%s",
type(proposer_worker))