mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 16:45:43 +08:00
[Speculative Decoding] Support draft model on different tensor-parallel size than target model (#5414)
This commit is contained in:
parent
f23871e9ee
commit
2ce5d6688b
@ -54,7 +54,7 @@ steps:
|
|||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
- pytest -v -s spec_decode/e2e/test_integration_dist.py
|
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
||||||
|
|
||||||
@ -71,6 +71,7 @@ steps:
|
|||||||
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
|
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
|
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||||
|
|
||||||
- label: Engine Test
|
- label: Engine Test
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
|
|||||||
@ -25,6 +25,8 @@ def main(args: argparse.Namespace):
|
|||||||
model=args.model,
|
model=args.model,
|
||||||
speculative_model=args.speculative_model,
|
speculative_model=args.speculative_model,
|
||||||
num_speculative_tokens=args.num_speculative_tokens,
|
num_speculative_tokens=args.num_speculative_tokens,
|
||||||
|
speculative_draft_tensor_parallel_size=\
|
||||||
|
args.speculative_draft_tensor_parallel_size,
|
||||||
tokenizer=args.tokenizer,
|
tokenizer=args.tokenizer,
|
||||||
quantization=args.quantization,
|
quantization=args.quantization,
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
@ -127,6 +129,10 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||||
parser.add_argument('--speculative-model', type=str, default=None)
|
parser.add_argument('--speculative-model', type=str, default=None)
|
||||||
parser.add_argument('--num-speculative-tokens', type=int, default=None)
|
parser.add_argument('--num-speculative-tokens', type=int, default=None)
|
||||||
|
parser.add_argument('--speculative-draft-tensor-parallel-size',
|
||||||
|
'-spec-draft-tp',
|
||||||
|
type=int,
|
||||||
|
default=None)
|
||||||
parser.add_argument('--tokenizer', type=str, default=None)
|
parser.add_argument('--tokenizer', type=str, default=None)
|
||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
|
|||||||
111
tests/spec_decode/e2e/test_integration_dist_tp2.py
Normal file
111
tests/spec_decode/e2e/test_integration_dist_tp2.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
"""Tests which cover integration of the speculative decoding framework with
|
||||||
|
tensor parallelism.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
|
from .conftest import run_greedy_equality_correctness_test
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
"tensor_parallel_size": 2,
|
||||||
|
|
||||||
|
# Use AsyncLLM engine, so that the engine runs in its own process.
|
||||||
|
# Otherwise, since vLLM does not follow true SPMD, the test runner
|
||||||
|
# process will have both the engine and the rank0 worker. NCCL is not
|
||||||
|
# cleaned up properly, and its server host thread leaks, causing the
|
||||||
|
# second run of the test to fail with internal NCCL error.
|
||||||
|
"use_async": True,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speculative_model": "[ngram]",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"ngram_prompt_lookup_max": 3,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
32,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
|
||||||
|
batch_size: int, output_len: int):
|
||||||
|
"""Verify greedy equality when tensor parallelism is used.
|
||||||
|
"""
|
||||||
|
if is_hip():
|
||||||
|
pytest.skip("hip is not well-supported yet")
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Use a small model for a fast test.
|
||||||
|
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
"tensor_parallel_size": 2,
|
||||||
|
|
||||||
|
# Use AsyncLLM engine, so that the engine runs in its own process.
|
||||||
|
# Otherwise, since vLLM does not follow true SPMD, the test runner
|
||||||
|
# process will have both the engine and the rank0 worker. NCCL is not
|
||||||
|
# cleaned up properly, and its server host thread leaks, causing the
|
||||||
|
# second run of the test to fail with internal NCCL error.
|
||||||
|
"use_async": True,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"speculative_draft_tensor_parallel_size": 1,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_draft_model_tp_lt_target_model_tp2(test_llm_generator,
|
||||||
|
baseline_llm_generator,
|
||||||
|
batch_size: int):
|
||||||
|
"""Verify spec decode works well with smaller tp for draft models.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=32,
|
||||||
|
force_output_len=True)
|
||||||
@ -5,16 +5,16 @@ tensor parallelism.
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.utils import is_hip
|
|
||||||
|
|
||||||
from .conftest import run_greedy_equality_correctness_test
|
from .conftest import run_greedy_equality_correctness_test
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
reason="Need at least 4 GPUs to run the test.")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
|
# Use a small model for a fast test.
|
||||||
|
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||||
"model": "JackFram/llama-68m",
|
"model": "JackFram/llama-68m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
@ -22,7 +22,7 @@ from .conftest import run_greedy_equality_correctness_test
|
|||||||
|
|
||||||
# Required for spec decode.
|
# Required for spec decode.
|
||||||
"use_v2_block_manager": True,
|
"use_v2_block_manager": True,
|
||||||
"tensor_parallel_size": 2,
|
"tensor_parallel_size": 4,
|
||||||
|
|
||||||
# Use AsyncLLM engine, so that the engine runs in its own process.
|
# Use AsyncLLM engine, so that the engine runs in its own process.
|
||||||
# Otherwise, since vLLM does not follow true SPMD, the test runner
|
# Otherwise, since vLLM does not follow true SPMD, the test runner
|
||||||
@ -31,35 +31,30 @@ from .conftest import run_greedy_equality_correctness_test
|
|||||||
# second run of the test to fail with internal NCCL error.
|
# second run of the test to fail with internal NCCL error.
|
||||||
"use_async": True,
|
"use_async": True,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
|
||||||
{
|
{
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"speculative_model": "JackFram/llama-68m",
|
||||||
"num_speculative_tokens": 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"speculative_model": "[ngram]",
|
|
||||||
"num_speculative_tokens": 5,
|
"num_speculative_tokens": 5,
|
||||||
"ngram_prompt_lookup_max": 3,
|
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [2])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"output_len",
|
"test_llm_kwargs",
|
||||||
[
|
[
|
||||||
# Use smaller output len for fast test.
|
#TODO(wooyeon): add spec_draft_dp=2 case
|
||||||
32,
|
{
|
||||||
|
"speculative_draft_tensor_parallel_size": 1,
|
||||||
|
},
|
||||||
])
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
|
def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
|
||||||
batch_size: int, output_len: int):
|
baseline_llm_generator,
|
||||||
"""Verify greedy equality when tensor parallelism is used.
|
batch_size: int):
|
||||||
|
"""Verify spec decode works well with smaller tp for draft models.
|
||||||
"""
|
"""
|
||||||
if is_hip():
|
|
||||||
pytest.skip("hip is not well-supported yet")
|
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
test_llm_generator,
|
test_llm_generator,
|
||||||
batch_size,
|
batch_size,
|
||||||
max_output_len=output_len,
|
max_output_len=32,
|
||||||
force_output_len=True)
|
force_output_len=True)
|
||||||
@ -797,6 +797,7 @@ class SpeculativeConfig:
|
|||||||
target_parallel_config: ParallelConfig,
|
target_parallel_config: ParallelConfig,
|
||||||
target_dtype: str,
|
target_dtype: str,
|
||||||
speculative_model: Optional[str],
|
speculative_model: Optional[str],
|
||||||
|
speculative_draft_tensor_parallel_size: Optional[int],
|
||||||
num_speculative_tokens: Optional[int],
|
num_speculative_tokens: Optional[int],
|
||||||
speculative_max_model_len: Optional[int],
|
speculative_max_model_len: Optional[int],
|
||||||
enable_chunked_prefill: bool,
|
enable_chunked_prefill: bool,
|
||||||
@ -819,6 +820,8 @@ class SpeculativeConfig:
|
|||||||
target_dtype (str): The data type used for the target model.
|
target_dtype (str): The data type used for the target model.
|
||||||
speculative_model (Optional[str]): The name of the speculative
|
speculative_model (Optional[str]): The name of the speculative
|
||||||
model, if provided.
|
model, if provided.
|
||||||
|
speculative_draft_tensor_parallel_size (Optional[int]): The degree
|
||||||
|
of the tensor parallelism for the draft model.
|
||||||
num_speculative_tokens (Optional[int]): The number of speculative
|
num_speculative_tokens (Optional[int]): The number of speculative
|
||||||
tokens, if provided. Will default to the number in the draft
|
tokens, if provided. Will default to the number in the draft
|
||||||
model config if present, otherwise is required.
|
model config if present, otherwise is required.
|
||||||
@ -939,7 +942,8 @@ class SpeculativeConfig:
|
|||||||
|
|
||||||
draft_parallel_config = (
|
draft_parallel_config = (
|
||||||
SpeculativeConfig.create_draft_parallel_config(
|
SpeculativeConfig.create_draft_parallel_config(
|
||||||
target_parallel_config))
|
target_parallel_config,
|
||||||
|
speculative_draft_tensor_parallel_size))
|
||||||
|
|
||||||
if num_speculative_tokens is None:
|
if num_speculative_tokens is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -993,16 +997,26 @@ class SpeculativeConfig:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_draft_parallel_config(
|
def create_draft_parallel_config(
|
||||||
target_parallel_config: ParallelConfig) -> ParallelConfig:
|
target_parallel_config: ParallelConfig,
|
||||||
|
speculative_draft_tensor_parallel_size: Optional[int]
|
||||||
|
) -> ParallelConfig:
|
||||||
"""Create a parallel config for use by the draft worker.
|
"""Create a parallel config for use by the draft worker.
|
||||||
|
|
||||||
This is mostly a copy of the target parallel config. In the future the
|
This is mostly a copy of the target parallel config, except the tp_size.
|
||||||
draft worker can have a different parallel strategy, e.g. TP=1.
|
|
||||||
"""
|
"""
|
||||||
|
if speculative_draft_tensor_parallel_size is None:
|
||||||
|
speculative_draft_tensor_parallel_size = \
|
||||||
|
target_parallel_config.tensor_parallel_size
|
||||||
|
elif speculative_draft_tensor_parallel_size != 1:
|
||||||
|
# TODO(wooyeon): allow tp values larger than 1
|
||||||
|
raise ValueError(
|
||||||
|
f"{speculative_draft_tensor_parallel_size=} cannot be"
|
||||||
|
f"other value than 1")
|
||||||
|
|
||||||
draft_parallel_config = ParallelConfig(
|
draft_parallel_config = ParallelConfig(
|
||||||
pipeline_parallel_size=target_parallel_config.
|
pipeline_parallel_size=target_parallel_config.
|
||||||
pipeline_parallel_size,
|
pipeline_parallel_size,
|
||||||
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
|
tensor_parallel_size=speculative_draft_tensor_parallel_size,
|
||||||
distributed_executor_backend=target_parallel_config.
|
distributed_executor_backend=target_parallel_config.
|
||||||
distributed_executor_backend,
|
distributed_executor_backend,
|
||||||
max_parallel_loading_workers=target_parallel_config.
|
max_parallel_loading_workers=target_parallel_config.
|
||||||
|
|||||||
@ -676,6 +676,28 @@ def get_world_group() -> GroupCoordinator:
|
|||||||
return _WORLD
|
return _WORLD
|
||||||
|
|
||||||
|
|
||||||
|
def init_world_group(ranks: List[int], local_rank: int,
|
||||||
|
backend: str) -> GroupCoordinator:
|
||||||
|
return GroupCoordinator(
|
||||||
|
group_ranks=[ranks],
|
||||||
|
local_rank=local_rank,
|
||||||
|
torch_distributed_backend=backend,
|
||||||
|
use_pynccl=False,
|
||||||
|
use_custom_allreduce=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_model_parallel_group(group_ranks: List[List[int]], local_rank: int,
|
||||||
|
backend: str) -> GroupCoordinator:
|
||||||
|
return GroupCoordinator(
|
||||||
|
group_ranks=group_ranks,
|
||||||
|
local_rank=local_rank,
|
||||||
|
torch_distributed_backend=backend,
|
||||||
|
use_pynccl=True,
|
||||||
|
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_TP: Optional[GroupCoordinator] = None
|
_TP: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
|
|
||||||
@ -764,13 +786,7 @@ def init_distributed_environment(
|
|||||||
global _WORLD
|
global _WORLD
|
||||||
if _WORLD is None:
|
if _WORLD is None:
|
||||||
ranks = list(range(torch.distributed.get_world_size()))
|
ranks = list(range(torch.distributed.get_world_size()))
|
||||||
_WORLD = GroupCoordinator(
|
_WORLD = init_world_group(ranks, local_rank, backend)
|
||||||
group_ranks=[ranks],
|
|
||||||
local_rank=local_rank,
|
|
||||||
torch_distributed_backend=backend,
|
|
||||||
use_pynccl=False,
|
|
||||||
use_custom_allreduce=False,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert _WORLD.world_size == torch.distributed.get_world_size(), (
|
assert _WORLD.world_size == torch.distributed.get_world_size(), (
|
||||||
"world group already initialized with a different world size")
|
"world group already initialized with a different world size")
|
||||||
@ -827,13 +843,8 @@ def initialize_model_parallel(
|
|||||||
range(i * tensor_model_parallel_size,
|
range(i * tensor_model_parallel_size,
|
||||||
(i + 1) * tensor_model_parallel_size))
|
(i + 1) * tensor_model_parallel_size))
|
||||||
group_ranks.append(ranks)
|
group_ranks.append(ranks)
|
||||||
_TP = GroupCoordinator(
|
_TP = init_model_parallel_group(group_ranks,
|
||||||
group_ranks=group_ranks,
|
get_world_group().local_rank, backend)
|
||||||
local_rank=get_world_group().local_rank,
|
|
||||||
torch_distributed_backend=backend,
|
|
||||||
use_pynccl=True,
|
|
||||||
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build the pipeline model-parallel groups.
|
# Build the pipeline model-parallel groups.
|
||||||
num_pipeline_model_parallel_groups: int = (world_size //
|
num_pipeline_model_parallel_groups: int = (world_size //
|
||||||
@ -845,13 +856,8 @@ def initialize_model_parallel(
|
|||||||
for i in range(num_pipeline_model_parallel_groups):
|
for i in range(num_pipeline_model_parallel_groups):
|
||||||
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
||||||
group_ranks.append(ranks)
|
group_ranks.append(ranks)
|
||||||
_PP = GroupCoordinator(
|
_PP = init_model_parallel_group(group_ranks,
|
||||||
group_ranks=group_ranks,
|
get_world_group().local_rank, backend)
|
||||||
local_rank=get_world_group().local_rank,
|
|
||||||
torch_distributed_backend=backend,
|
|
||||||
use_pynccl=True,
|
|
||||||
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_model_parallel_initialized(
|
def ensure_model_parallel_initialized(
|
||||||
@ -887,6 +893,34 @@ def model_parallel_is_initialized():
|
|||||||
return (_TP is not None and _PP is not None)
|
return (_TP is not None and _PP is not None)
|
||||||
|
|
||||||
|
|
||||||
|
_TP_STATE_PATCHED = False
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
|
||||||
|
"""Patch the tp group temporarily until this function ends.
|
||||||
|
|
||||||
|
This method is for draft workers of speculative decoding to run draft model
|
||||||
|
with different tp degree from that of target model workers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tp_group (GroupCoordinator): the tp group coordinator
|
||||||
|
"""
|
||||||
|
global _TP_STATE_PATCHED
|
||||||
|
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
|
||||||
|
|
||||||
|
_TP_STATE_PATCHED = True
|
||||||
|
old_tp_group = get_tp_group()
|
||||||
|
global _TP
|
||||||
|
_TP = tp_group
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# restore the original state
|
||||||
|
_TP_STATE_PATCHED = False
|
||||||
|
_TP = old_tp_group
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_model_parallel_world_size():
|
def get_tensor_model_parallel_world_size():
|
||||||
"""Return world size for the tensor model parallel group."""
|
"""Return world size for the tensor model parallel group."""
|
||||||
return get_tp_group().world_size
|
return get_tp_group().world_size
|
||||||
|
|||||||
@ -94,6 +94,7 @@ class EngineArgs:
|
|||||||
guided_decoding_backend: str = 'outlines'
|
guided_decoding_backend: str = 'outlines'
|
||||||
# Speculative decoding configuration.
|
# Speculative decoding configuration.
|
||||||
speculative_model: Optional[str] = None
|
speculative_model: Optional[str] = None
|
||||||
|
speculative_draft_tensor_parallel_size: Optional[int] = None
|
||||||
num_speculative_tokens: Optional[int] = None
|
num_speculative_tokens: Optional[int] = None
|
||||||
speculative_max_model_len: Optional[int] = None
|
speculative_max_model_len: Optional[int] = None
|
||||||
speculative_disable_by_batch_size: Optional[int] = None
|
speculative_disable_by_batch_size: Optional[int] = None
|
||||||
@ -537,6 +538,13 @@ class EngineArgs:
|
|||||||
default=EngineArgs.num_speculative_tokens,
|
default=EngineArgs.num_speculative_tokens,
|
||||||
help='The number of speculative tokens to sample from '
|
help='The number of speculative tokens to sample from '
|
||||||
'the draft model in speculative decoding.')
|
'the draft model in speculative decoding.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--speculative-draft-tensor-parallel-size',
|
||||||
|
'-spec-draft-tp',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.speculative_draft_tensor_parallel_size,
|
||||||
|
help='Number of tensor parallel replicas for '
|
||||||
|
'the draft model in speculative decoding.')
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--speculative-max-model-len',
|
'--speculative-max-model-len',
|
||||||
@ -686,6 +694,8 @@ class EngineArgs:
|
|||||||
target_parallel_config=parallel_config,
|
target_parallel_config=parallel_config,
|
||||||
target_dtype=self.dtype,
|
target_dtype=self.dtype,
|
||||||
speculative_model=self.speculative_model,
|
speculative_model=self.speculative_model,
|
||||||
|
speculative_draft_tensor_parallel_size = \
|
||||||
|
self.speculative_draft_tensor_parallel_size,
|
||||||
num_speculative_tokens=self.num_speculative_tokens,
|
num_speculative_tokens=self.num_speculative_tokens,
|
||||||
speculative_disable_by_batch_size=self.
|
speculative_disable_by_batch_size=self.
|
||||||
speculative_disable_by_batch_size,
|
speculative_disable_by_batch_size,
|
||||||
|
|||||||
@ -6,7 +6,8 @@ import torch
|
|||||||
|
|
||||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||||
SequenceGroupMetadata)
|
SequenceGroupMetadata)
|
||||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
|
SpeculativeProposer)
|
||||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker import Worker
|
||||||
@ -28,9 +29,9 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
# Lazy initialization list.
|
# Lazy initialization list.
|
||||||
self._proposer: Top1Proposer
|
self._proposer: SpeculativeProposer
|
||||||
|
|
||||||
def init_device(self):
|
def init_device(self) -> None:
|
||||||
super().init_device()
|
super().init_device()
|
||||||
|
|
||||||
self._proposer = Top1Proposer(
|
self._proposer = Top1Proposer(
|
||||||
@ -40,7 +41,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
|||||||
max_proposal_len=self.max_model_len,
|
max_proposal_len=self.max_model_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_include_gpu_probs_tensor(self):
|
def set_include_gpu_probs_tensor(self) -> None:
|
||||||
# Need include_gpu_probs_tensor for multi_step_worker
|
# Need include_gpu_probs_tensor for multi_step_worker
|
||||||
self.model_runner.model.sampler.include_gpu_probs_tensor = True
|
self.model_runner.model.sampler.include_gpu_probs_tensor = True
|
||||||
|
|
||||||
@ -73,7 +74,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
|||||||
# Run model sample_len times.
|
# Run model sample_len times.
|
||||||
model_outputs: List[SamplerOutput] = []
|
model_outputs: List[SamplerOutput] = []
|
||||||
for _ in range(sample_len):
|
for _ in range(sample_len):
|
||||||
model_output = super().execute_model(
|
model_output: List[SamplerOutput] = super().execute_model(
|
||||||
execute_model_req=copied_execute_model_req)
|
execute_model_req=copied_execute_model_req)
|
||||||
assert (len(model_output) == 1
|
assert (len(model_output) == 1
|
||||||
), "composing multistep workers not supported"
|
), "composing multistep workers not supported"
|
||||||
|
|||||||
@ -3,10 +3,10 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.spec_decode.interfaces import SpeculativeProposer
|
from vllm.spec_decode.interfaces import SpeculativeProposer
|
||||||
from vllm.worker.worker_base import WorkerBase
|
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||||
|
|
||||||
|
|
||||||
class ProposerWorkerBase(WorkerBase, SpeculativeProposer):
|
class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
|
||||||
"""Interface for proposer workers"""
|
"""Interface for proposer workers"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
149
vllm/spec_decode/smaller_tp_proposer_worker.py
Normal file
149
vllm/spec_decode/smaller_tp_proposer_worker.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed.parallel_state import (get_tp_group,
|
||||||
|
init_model_parallel_group,
|
||||||
|
patch_tensor_parallel_group)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
|
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||||
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SmallerTpProposerWorker(ProposerWorkerBase):
|
||||||
|
"""Class which allows a speculative draft model to run with smaller tensor
|
||||||
|
parallel degree than target model.
|
||||||
|
This reduces the communication overhead of small draft models.
|
||||||
|
|
||||||
|
To implement this feature, this class differs behavior based on is_dummy
|
||||||
|
flag, where dummy means worker that does not participate draft generation.
|
||||||
|
Participating workers use a smaller tp group by patching vLLM's tensor
|
||||||
|
parallel group temporarily during forward passes of draft models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def maybe_wrap_worker(cls, worker, draft_tensor_parallel_size: int,
|
||||||
|
target_tensor_parallel_size: int):
|
||||||
|
"""Wrap the worker in a SmallerTpProposerWorker if necessary.
|
||||||
|
"""
|
||||||
|
if draft_tensor_parallel_size == target_tensor_parallel_size:
|
||||||
|
return worker
|
||||||
|
|
||||||
|
# gpu ranks that will generate draft tokens together
|
||||||
|
draft_ranks = list(range(draft_tensor_parallel_size))
|
||||||
|
|
||||||
|
logger.info("Wrapping {%s} in {%s}", type(worker), cls)
|
||||||
|
return cls(worker, draft_ranks)
|
||||||
|
|
||||||
|
def __init__(self, worker: MultiStepWorker, draft_ranks: List[int]):
|
||||||
|
"""Create a SmallerTpProposerWorker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker (MultiStepWorker): an actual worker wrapped with this class
|
||||||
|
draft_ranks (List[int]): if this value is given, only the GPU ranks
|
||||||
|
written in this value participate in draft generation
|
||||||
|
"""
|
||||||
|
self._worker = worker
|
||||||
|
self._draft_ranks = draft_ranks
|
||||||
|
|
||||||
|
# init during init_device
|
||||||
|
self._is_dummy = False
|
||||||
|
self._tp_group = None
|
||||||
|
|
||||||
|
def _patch_tensor_parallel_group(self):
|
||||||
|
"""Temporarily patch the global tp group state with its own tp group
|
||||||
|
state.
|
||||||
|
"""
|
||||||
|
return patch_tensor_parallel_group(self._tp_group)
|
||||||
|
|
||||||
|
def init_device(self) -> None:
|
||||||
|
self._is_dummy = get_tp_group().rank not in self._draft_ranks
|
||||||
|
|
||||||
|
# dummy workers do nothing
|
||||||
|
if self._is_dummy:
|
||||||
|
return
|
||||||
|
|
||||||
|
# creates tp process group containing only a subset of gpu ranks
|
||||||
|
local_rank = get_tp_group().local_rank
|
||||||
|
tp_backend = torch.distributed.get_backend(get_tp_group().device_group)
|
||||||
|
self._tp_group = init_model_parallel_group([self._draft_ranks],
|
||||||
|
local_rank, tp_backend)
|
||||||
|
|
||||||
|
with self._patch_tensor_parallel_group():
|
||||||
|
self._worker.init_device()
|
||||||
|
|
||||||
|
def set_include_gpu_probs_tensor(self) -> None:
|
||||||
|
if self._is_dummy:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Need include_gpu_probs_tensor for multi_step_worker
|
||||||
|
self._worker.set_include_gpu_probs_tensor()
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
if self._is_dummy:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._patch_tensor_parallel_group():
|
||||||
|
self._worker.load_model()
|
||||||
|
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
if self._is_dummy:
|
||||||
|
# this case is not used now
|
||||||
|
return -1, -1
|
||||||
|
|
||||||
|
with self._patch_tensor_parallel_group():
|
||||||
|
return self._worker.determine_num_available_blocks()
|
||||||
|
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int) -> None:
|
||||||
|
if self._is_dummy:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._patch_tensor_parallel_group():
|
||||||
|
self._worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||||
|
|
||||||
|
def sampler_output(
|
||||||
|
self,
|
||||||
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
sample_len: int,
|
||||||
|
) -> Tuple[List[SamplerOutput], bool]:
|
||||||
|
# Do not check _is_dummy, as it's always called by get_spec_proposals
|
||||||
|
return self._worker.sampler_output(execute_model_req, sample_len)
|
||||||
|
|
||||||
|
def get_spec_proposals(
|
||||||
|
self,
|
||||||
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
) -> SpeculativeProposals:
|
||||||
|
"""Produce speculations given an input batch of sequences. The number of
|
||||||
|
speculative tokens per sequence is determined by max_proposal_len.
|
||||||
|
"""
|
||||||
|
if self._is_dummy:
|
||||||
|
return SpeculativeProposals(None, None, None)
|
||||||
|
|
||||||
|
with self._patch_tensor_parallel_group():
|
||||||
|
return self._worker.get_spec_proposals(execute_model_req)
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
if self._is_dummy:
|
||||||
|
return []
|
||||||
|
|
||||||
|
with self._patch_tensor_parallel_group():
|
||||||
|
return self._worker.execute_model(execute_model_req)
|
||||||
|
|
||||||
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
|
if self._is_dummy:
|
||||||
|
# by returning zero, target worker can use the entire kv cache space
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return self._worker.get_cache_block_size_bytes()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return self._worker.vocab_size
|
||||||
@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import SpeculativeConfig
|
from vllm.config import ParallelConfig, SpeculativeConfig
|
||||||
from vllm.distributed.communication_op import broadcast_tensor_dict
|
from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||||
@ -18,6 +18,7 @@ from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
|
|||||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||||
|
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
||||||
from vllm.spec_decode.util import (create_sequence_group_output,
|
from vllm.spec_decode.util import (create_sequence_group_output,
|
||||||
get_all_num_logprobs,
|
get_all_num_logprobs,
|
||||||
get_sampled_token_logprobs, nvtx_range,
|
get_sampled_token_logprobs, nvtx_range,
|
||||||
@ -90,7 +91,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def create_worker(
|
def create_worker(
|
||||||
cls,
|
cls,
|
||||||
scorer_worker: WorkerBase,
|
scorer_worker: Worker,
|
||||||
draft_worker_kwargs: Dict[str, Any],
|
draft_worker_kwargs: Dict[str, Any],
|
||||||
disable_by_batch_size: Optional[int],
|
disable_by_batch_size: Optional[int],
|
||||||
) -> "SpecDecodeWorker":
|
) -> "SpecDecodeWorker":
|
||||||
@ -111,7 +112,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
||||||
disable_bonus_tokens = False
|
disable_bonus_tokens = False
|
||||||
else:
|
else:
|
||||||
|
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
|
||||||
|
'parallel_config']
|
||||||
|
draft_tp = draft_parallel_config.tensor_parallel_size
|
||||||
|
target_tp = scorer_worker.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||||
|
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
|
||||||
|
proposer_worker, draft_tp, target_tp)
|
||||||
|
|
||||||
logger.info("Configuring SpecDecodeWorker with proposer=%s",
|
logger.info("Configuring SpecDecodeWorker with proposer=%s",
|
||||||
type(proposer_worker))
|
type(proposer_worker))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user