mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:44:27 +08:00
[core][distributed] add ep group and all2all interface (#18077)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
754b699cbe
commit
6266c57bae
93
vllm/distributed/device_communicators/all2all.py
Normal file
93
vllm/distributed/device_communicators/all2all.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
|
|
||||||
|
class All2AllBase:
|
||||||
|
|
||||||
|
def __init__(self, cpu_group, model):
|
||||||
|
self.cpu_group = cpu_group
|
||||||
|
|
||||||
|
# compute some common properties
|
||||||
|
from vllm.distributed.parallel_state import (get_dp_group,
|
||||||
|
get_ep_group,
|
||||||
|
get_tp_group,
|
||||||
|
in_the_same_node_as)
|
||||||
|
|
||||||
|
# all2all lives in ep group, which is merged from dp and tp group
|
||||||
|
self.dp_group = get_dp_group()
|
||||||
|
self.tp_group = get_tp_group()
|
||||||
|
self.ep_group = get_ep_group()
|
||||||
|
self.dp_rank = self.dp_group.rank_in_group
|
||||||
|
self.dp_world_size = self.dp_group.world_size
|
||||||
|
|
||||||
|
# all2all communication often has separate implementations for
|
||||||
|
# intra-node and inter-node communication
|
||||||
|
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
|
||||||
|
self.internode = not self.intranode
|
||||||
|
|
||||||
|
def dispatch(self, hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def destroy(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NaiveAll2All(All2AllBase):
|
||||||
|
"""
|
||||||
|
A naive implementation of all2all communication.
|
||||||
|
It uses all-reduce under the hood, which is not
|
||||||
|
efficient at all. The main purpose is for testing and
|
||||||
|
debugging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cpu_group, model):
|
||||||
|
super().__init__(cpu_group, model)
|
||||||
|
|
||||||
|
def naive_multicast(self, x: torch.Tensor,
|
||||||
|
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||||
|
assert (len(x.shape) == 2)
|
||||||
|
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||||
|
device=x.device,
|
||||||
|
dtype=x.dtype)
|
||||||
|
|
||||||
|
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||||
|
self.dp_rank - 1]
|
||||||
|
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||||
|
buffer[start:end, :].copy_(x)
|
||||||
|
for idx in range(self.dp_world_size):
|
||||||
|
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||||
|
end = cu_tokens_across_dp_cpu[idx]
|
||||||
|
self.dp_group.broadcast(buffer[start:end, :], idx)
|
||||||
|
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
def dispatch(self, hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor):
|
||||||
|
cu_tokens_across_dp_cpu = get_forward_context(
|
||||||
|
).dp_metadata.cu_tokens_across_dp_cpu
|
||||||
|
|
||||||
|
hidden_states = self.naive_multicast(hidden_states,
|
||||||
|
cu_tokens_across_dp_cpu)
|
||||||
|
router_logits = self.naive_multicast(router_logits,
|
||||||
|
cu_tokens_across_dp_cpu)
|
||||||
|
return hidden_states, router_logits
|
||||||
|
|
||||||
|
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
cu_tokens_across_dp_cpu = get_forward_context(
|
||||||
|
).dp_metadata.cu_tokens_across_dp_cpu
|
||||||
|
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||||
|
self.dp_rank - 1]
|
||||||
|
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||||
|
|
||||||
|
all_hidden_states = self.dp_group.all_reduce(hidden_states)
|
||||||
|
hidden_states = all_hidden_states[start:end, :]
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def destroy(self):
|
||||||
|
pass
|
||||||
@ -1,5 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -149,3 +149,27 @@ class DeviceCommunicatorBase:
|
|||||||
|
|
||||||
def destroy(self):
|
def destroy(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def prepare_communication_buffer_for_model(self,
|
||||||
|
model: torch.nn.Module) -> None:
|
||||||
|
"""
|
||||||
|
Prepare the communication buffer for the model.
|
||||||
|
This is a no-op in the base class.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def dispatch(
|
||||||
|
self, hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Dispatch the hidden states and router logits to the appropriate device.
|
||||||
|
This is a no-op in the base class.
|
||||||
|
"""
|
||||||
|
return hidden_states, router_logits
|
||||||
|
|
||||||
|
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Combine the hidden states and router logits from the appropriate device.
|
||||||
|
This is a no-op in the base class.
|
||||||
|
"""
|
||||||
|
return hidden_states
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
|
||||||
|
from .all2all import All2AllBase
|
||||||
from .base_device_communicator import DeviceCommunicatorBase
|
from .base_device_communicator import DeviceCommunicatorBase
|
||||||
|
|
||||||
|
|
||||||
@ -23,9 +26,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
_ENABLE_CUSTOM_ALL_REDUCE)
|
_ENABLE_CUSTOM_ALL_REDUCE)
|
||||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||||
use_pynccl = True
|
|
||||||
|
# ep does not use pynccl
|
||||||
|
use_pynccl = "ep" not in unique_name
|
||||||
|
|
||||||
self.use_pynccl = use_pynccl
|
self.use_pynccl = use_pynccl
|
||||||
|
self.use_all2all = "ep" in unique_name
|
||||||
|
self.all2all_impl: Optional[All2AllBase] = None
|
||||||
self.use_custom_allreduce = use_custom_allreduce
|
self.use_custom_allreduce = use_custom_allreduce
|
||||||
|
|
||||||
# lazy import to avoid documentation build error
|
# lazy import to avoid documentation build error
|
||||||
@ -129,3 +136,31 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
self.pynccl_comm = None
|
self.pynccl_comm = None
|
||||||
if self.ca_comm is not None:
|
if self.ca_comm is not None:
|
||||||
self.ca_comm = None
|
self.ca_comm = None
|
||||||
|
if self.all2all_impl is not None:
|
||||||
|
self.all2all_impl.destroy()
|
||||||
|
self.all2all_impl = None
|
||||||
|
|
||||||
|
def prepare_communication_buffer_for_model(self,
|
||||||
|
model: torch.nn.Module) -> None:
|
||||||
|
"""
|
||||||
|
Prepare the communication buffer for the model.
|
||||||
|
"""
|
||||||
|
if not self.use_all2all:
|
||||||
|
return
|
||||||
|
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||||
|
if all2all_backend == "naive":
|
||||||
|
from .all2all import NaiveAll2All
|
||||||
|
self.all2all_impl = NaiveAll2All(self.cpu_group, model)
|
||||||
|
|
||||||
|
def dispatch(
|
||||||
|
self, hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert self.all2all_impl is not None
|
||||||
|
hidden_states, router_logits = self.all2all_impl.dispatch(
|
||||||
|
hidden_states, router_logits)
|
||||||
|
return hidden_states, router_logits
|
||||||
|
|
||||||
|
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert self.all2all_impl is not None
|
||||||
|
hidden_states = self.all2all_impl.combine(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|||||||
@ -757,6 +757,22 @@ class GroupCoordinator:
|
|||||||
if self.mq_broadcaster is not None:
|
if self.mq_broadcaster is not None:
|
||||||
self.mq_broadcaster = None
|
self.mq_broadcaster = None
|
||||||
|
|
||||||
|
def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
|
||||||
|
if self.device_communicator is not None:
|
||||||
|
self.device_communicator.prepare_communication_buffer_for_model(
|
||||||
|
model)
|
||||||
|
|
||||||
|
def dispatch(
|
||||||
|
self, hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if self.device_communicator is not None:
|
||||||
|
return self.device_communicator.dispatch(hidden_states,
|
||||||
|
router_logits)
|
||||||
|
|
||||||
|
def combine(self, hidden_states) -> torch.Tensor:
|
||||||
|
if self.device_communicator is not None:
|
||||||
|
return self.device_communicator.combine(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
_WORLD: Optional[GroupCoordinator] = None
|
_WORLD: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
@ -816,6 +832,14 @@ def get_dp_group() -> GroupCoordinator:
|
|||||||
return _DP
|
return _DP
|
||||||
|
|
||||||
|
|
||||||
|
_EP: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_ep_group() -> GroupCoordinator:
|
||||||
|
assert _EP is not None, ("expert parallel group is not initialized")
|
||||||
|
return _EP
|
||||||
|
|
||||||
|
|
||||||
def get_pp_group() -> GroupCoordinator:
|
def get_pp_group() -> GroupCoordinator:
|
||||||
assert _PP is not None, (
|
assert _PP is not None, (
|
||||||
"pipeline model parallel group is not initialized")
|
"pipeline model parallel group is not initialized")
|
||||||
@ -1001,10 +1025,21 @@ def initialize_model_parallel(
|
|||||||
backend,
|
backend,
|
||||||
group_name="dp")
|
group_name="dp")
|
||||||
|
|
||||||
|
global _EP
|
||||||
|
assert _EP is None, ("expert parallel group is already initialized")
|
||||||
|
group_ranks = all_ranks.transpose(1, 2).reshape(
|
||||||
|
-1, data_parallel_size * tensor_model_parallel_size).unbind(0)
|
||||||
|
group_ranks = [x.tolist() for x in group_ranks]
|
||||||
|
_EP = init_model_parallel_group(group_ranks,
|
||||||
|
get_world_group().local_rank,
|
||||||
|
backend,
|
||||||
|
group_name="ep")
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"rank %s in world size %s is assigned as "
|
"rank %s in world size %s is assigned as "
|
||||||
"DP rank %s, PP rank %s, TP rank %s", rank, world_size,
|
"DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
|
||||||
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)
|
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
|
||||||
|
_EP.rank_in_group)
|
||||||
|
|
||||||
|
|
||||||
def ensure_model_parallel_initialized(
|
def ensure_model_parallel_initialized(
|
||||||
@ -1035,6 +1070,23 @@ def ensure_model_parallel_initialized(
|
|||||||
f"{pipeline_model_parallel_size=}")
|
f"{pipeline_model_parallel_size=}")
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_communication_buffer_for_model(model: torch.nn.Module):
|
||||||
|
"""Prepare the communication buffer for the model.
|
||||||
|
Traditional communication libraries like NCCL are almost
|
||||||
|
model agnostic. However, emerging new communication libraries like
|
||||||
|
MoE all2all (DeepEP) usually allocate the communication buffer
|
||||||
|
based on the model shape for optimal performance.
|
||||||
|
"""
|
||||||
|
if _TP is not None:
|
||||||
|
_TP.prepare_communication_buffer_for_model(model)
|
||||||
|
if _PP is not None:
|
||||||
|
_PP.prepare_communication_buffer_for_model(model)
|
||||||
|
if _DP is not None:
|
||||||
|
_DP.prepare_communication_buffer_for_model(model)
|
||||||
|
if _EP is not None:
|
||||||
|
_EP.prepare_communication_buffer_for_model(model)
|
||||||
|
|
||||||
|
|
||||||
def model_parallel_is_initialized():
|
def model_parallel_is_initialized():
|
||||||
"""Check if tensor and pipeline parallel groups are initialized."""
|
"""Check if tensor and pipeline parallel groups are initialized."""
|
||||||
return (_TP is not None and _PP is not None)
|
return (_TP is not None and _PP is not None)
|
||||||
@ -1095,6 +1147,11 @@ def destroy_model_parallel():
|
|||||||
_DP.destroy()
|
_DP.destroy()
|
||||||
_DP = None
|
_DP = None
|
||||||
|
|
||||||
|
global _EP
|
||||||
|
if _EP:
|
||||||
|
_EP.destroy()
|
||||||
|
_EP = None
|
||||||
|
|
||||||
|
|
||||||
def destroy_distributed_environment():
|
def destroy_distributed_environment():
|
||||||
global _WORLD
|
global _WORLD
|
||||||
|
|||||||
@ -362,6 +362,11 @@ def stateless_destroy_torch_distributed_process_group(
|
|||||||
stateless_init_torch_distributed_process_group().
|
stateless_init_torch_distributed_process_group().
|
||||||
"""
|
"""
|
||||||
# Lazy import for non-CUDA backends.
|
# Lazy import for non-CUDA backends.
|
||||||
from torch.distributed.distributed_c10d import _shutdown_backend
|
try:
|
||||||
_shutdown_backend(pg)
|
# pytorch <= 2.6
|
||||||
|
from torch.distributed.distributed_c10d import _shutdown_backend
|
||||||
|
_shutdown_backend(pg)
|
||||||
|
except ImportError:
|
||||||
|
# pytorch >= 2.7
|
||||||
|
pg.shutdown()
|
||||||
_unregister_process_group(pg.group_name)
|
_unregister_process_group(pg.group_name)
|
||||||
|
|||||||
@ -115,6 +115,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
||||||
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
|
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
|
||||||
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
|
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
|
||||||
|
VLLM_ALL2ALL_BACKEND: str = "naive"
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -764,6 +765,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# Port used for NIXL handshake between remote agents.
|
# Port used for NIXL handshake between remote agents.
|
||||||
"VLLM_NIXL_SIDE_CHANNEL_PORT":
|
"VLLM_NIXL_SIDE_CHANNEL_PORT":
|
||||||
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
|
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
|
||||||
|
|
||||||
|
# all2all backend for vllm's expert parallel communication
|
||||||
|
"VLLM_ALL2ALL_BACKEND":
|
||||||
|
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@ -10,7 +10,8 @@ from torch.nn.parameter import UninitializedParameter
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
@ -832,24 +833,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
def naive_multicast(self, x: torch.Tensor,
|
|
||||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
|
||||||
assert (len(x.shape) == 2)
|
|
||||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
|
||||||
device=x.device,
|
|
||||||
dtype=x.dtype)
|
|
||||||
|
|
||||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
|
||||||
self.dp_rank - 1]
|
|
||||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
|
||||||
buffer[start:end, :].copy_(x)
|
|
||||||
for idx in range(get_dp_group().world_size):
|
|
||||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
|
||||||
end = cu_tokens_across_dp_cpu[idx]
|
|
||||||
get_dp_group().broadcast(buffer[start:end, :], idx)
|
|
||||||
|
|
||||||
return buffer
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor):
|
router_logits: torch.Tensor):
|
||||||
if self.use_direct_call:
|
if self.use_direct_call:
|
||||||
@ -863,14 +846,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
if self.dp_size > 1:
|
if self.dp_size > 1:
|
||||||
cu_tokens_across_dp_cpu = get_forward_context(
|
hidden_states, router_logits = get_ep_group().dispatch(
|
||||||
).dp_metadata.cu_tokens_across_dp_cpu
|
hidden_states, router_logits)
|
||||||
|
|
||||||
hidden_states = self.naive_multicast(hidden_states,
|
|
||||||
cu_tokens_across_dp_cpu)
|
|
||||||
router_logits = self.naive_multicast(router_logits,
|
|
||||||
cu_tokens_across_dp_cpu)
|
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
final_hidden_states = self.quant_method.apply(
|
final_hidden_states = self.quant_method.apply(
|
||||||
layer=self,
|
layer=self,
|
||||||
@ -891,12 +868,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.dp_size > 1:
|
if self.dp_size > 1:
|
||||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
final_hidden_states = get_ep_group().combine(final_hidden_states)
|
||||||
self.dp_rank - 1]
|
|
||||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
|
||||||
|
|
||||||
all_hidden_states = get_dp_group().all_reduce(final_hidden_states)
|
|
||||||
final_hidden_states = all_hidden_states[start:end, :]
|
|
||||||
|
|
||||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||||
# Default set to False. (May have to add shared expert outputs.)
|
# Default set to False. (May have to add shared expert outputs.)
|
||||||
|
|||||||
@ -19,7 +19,8 @@ from vllm.config import (CompilationLevel, VllmConfig,
|
|||||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
has_kv_transfer_group)
|
has_kv_transfer_group)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||||
from vllm.distributed.parallel_state import get_pp_group, graph_capture
|
from vllm.distributed.parallel_state import (
|
||||||
|
get_pp_group, graph_capture, prepare_communication_buffer_for_model)
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
@ -1457,6 +1458,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logger.info("Model loading took %.4f GiB and %.6f seconds",
|
logger.info("Model loading took %.4f GiB and %.6f seconds",
|
||||||
self.model_memory_usage / GiB_bytes,
|
self.model_memory_usage / GiB_bytes,
|
||||||
time_after_load - time_before_load)
|
time_after_load - time_before_load)
|
||||||
|
prepare_communication_buffer_for_model(self.model)
|
||||||
|
|
||||||
def _get_prompt_logprobs_dict(
|
def _get_prompt_logprobs_dict(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user