mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 07:45:01 +08:00
[Misc] Begin deprecation of get_tensor_model_*_group (#22494)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
1712543df6
commit
43c4f3d77c
@ -10,8 +10,7 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
from vllm.distributed.communication_op import ( # noqa
|
from vllm.distributed.communication_op import ( # noqa
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
||||||
get_tp_group, graph_capture)
|
|
||||||
|
|
||||||
from ..utils import (ensure_model_parallel_initialized,
|
from ..utils import (ensure_model_parallel_initialized,
|
||||||
init_test_distributed_environment, multi_process_parallel)
|
init_test_distributed_environment, multi_process_parallel)
|
||||||
@ -37,7 +36,7 @@ def graph_allreduce(
|
|||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||||
distributed_init_port)
|
distributed_init_port)
|
||||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||||
group = get_tensor_model_parallel_group().device_group
|
group = get_tp_group().device_group
|
||||||
|
|
||||||
# A small all_reduce for warmup.
|
# A small all_reduce for warmup.
|
||||||
# this is needed because device communicators might be created lazily
|
# this is needed because device communicators might be created lazily
|
||||||
|
|||||||
@ -10,8 +10,7 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
from vllm.distributed.communication_op import ( # noqa
|
from vllm.distributed.communication_op import ( # noqa
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
||||||
get_tp_group, graph_capture)
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import (ensure_model_parallel_initialized,
|
from ..utils import (ensure_model_parallel_initialized,
|
||||||
@ -42,7 +41,7 @@ def graph_quickreduce(
|
|||||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||||
distributed_init_port)
|
distributed_init_port)
|
||||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||||
group = get_tensor_model_parallel_group().device_group
|
group = get_tp_group().device_group
|
||||||
|
|
||||||
# A small all_reduce for warmup.
|
# A small all_reduce for warmup.
|
||||||
# this is needed because device communicators might be created lazily
|
# this is needed because device communicators might be created lazily
|
||||||
|
|||||||
@ -36,6 +36,7 @@ from unittest.mock import patch
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from torch.distributed import Backend, ProcessGroup
|
from torch.distributed import Backend, ProcessGroup
|
||||||
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||||
@ -894,8 +895,12 @@ def get_tp_group() -> GroupCoordinator:
|
|||||||
return _TP
|
return _TP
|
||||||
|
|
||||||
|
|
||||||
# kept for backward compatibility
|
@deprecated("`get_tensor_model_parallel_group` has been replaced with "
|
||||||
get_tensor_model_parallel_group = get_tp_group
|
"`get_tp_group` and may be removed after v0.12. Please use "
|
||||||
|
"`get_tp_group` instead.")
|
||||||
|
def get_tensor_model_parallel_group():
|
||||||
|
return get_tp_group()
|
||||||
|
|
||||||
|
|
||||||
_PP: Optional[GroupCoordinator] = None
|
_PP: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
@ -921,8 +926,11 @@ def get_pp_group() -> GroupCoordinator:
|
|||||||
return _PP
|
return _PP
|
||||||
|
|
||||||
|
|
||||||
# kept for backward compatibility
|
@deprecated("`get_pipeline_model_parallel_group` has been replaced with "
|
||||||
get_pipeline_model_parallel_group = get_pp_group
|
"`get_pp_group` and may be removed in v0.12. Please use "
|
||||||
|
"`get_pp_group` instead.")
|
||||||
|
def get_pipeline_model_parallel_group():
|
||||||
|
return get_pp_group()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user