mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 20:57:09 +08:00
[Core] Enable CUDA graphs for DP + All2All kernels (#18724)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
6dbe5b5c93
commit
7951d78738
@ -10,7 +10,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import ParallelConfig, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -30,6 +30,44 @@ class DPMetadata:
|
|||||||
max_tokens_across_dp_cpu: torch.Tensor
|
max_tokens_across_dp_cpu: torch.Tensor
|
||||||
cu_tokens_across_dp_cpu: torch.Tensor
|
cu_tokens_across_dp_cpu: torch.Tensor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def num_tokens_across_dp(num_tokens: int, dp_size: int,
|
||||||
|
dp_rank: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Gather the num_tokens across all DP ranks and return results in a
|
||||||
|
CPU tensor of size dp_size.
|
||||||
|
"""
|
||||||
|
num_tokens_across_dp = [0] * dp_size
|
||||||
|
num_tokens_across_dp[dp_rank] = num_tokens
|
||||||
|
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int32)
|
||||||
|
from vllm.distributed.parallel_state import get_dp_group
|
||||||
|
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
|
||||||
|
return num_tokens_tensor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make(parallel_config: ParallelConfig, attn_metadata: Any,
|
||||||
|
num_tokens: int) -> "DPMetadata":
|
||||||
|
|
||||||
|
assert parallel_config.data_parallel_size > 1
|
||||||
|
dp_size = parallel_config.data_parallel_size
|
||||||
|
dp_rank = parallel_config.data_parallel_rank
|
||||||
|
if attn_metadata is not None and hasattr(attn_metadata,
|
||||||
|
"num_prefill_tokens"):
|
||||||
|
# for v0 attention backends
|
||||||
|
batchsize = attn_metadata.num_prefill_tokens + \
|
||||||
|
attn_metadata.num_decode_tokens
|
||||||
|
else:
|
||||||
|
# for v1 attention backends or no attn_metadata
|
||||||
|
batchsize = num_tokens
|
||||||
|
|
||||||
|
num_tokens_tensor = DPMetadata.num_tokens_across_dp(
|
||||||
|
batchsize, dp_size, dp_rank)
|
||||||
|
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
|
||||||
|
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
|
||||||
|
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ForwardContext:
|
class ForwardContext:
|
||||||
@ -74,27 +112,8 @@ def set_forward_context(attn_metadata: Any,
|
|||||||
forward_start_time = time.perf_counter()
|
forward_start_time = time.perf_counter()
|
||||||
dp_metadata: Optional[DPMetadata] = None
|
dp_metadata: Optional[DPMetadata] = None
|
||||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
|
||||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
attn_metadata, num_tokens)
|
||||||
if attn_metadata is not None and hasattr(attn_metadata,
|
|
||||||
"num_prefill_tokens"):
|
|
||||||
# for v0 attention backends
|
|
||||||
batchsize = attn_metadata.num_prefill_tokens + \
|
|
||||||
attn_metadata.num_decode_tokens
|
|
||||||
else:
|
|
||||||
# for v1 attention backends or no attn_metadata
|
|
||||||
batchsize = num_tokens
|
|
||||||
num_tokens_across_dp = [0] * dp_size
|
|
||||||
num_tokens_across_dp[dp_rank] = batchsize
|
|
||||||
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
|
|
||||||
device="cpu",
|
|
||||||
dtype=torch.int32)
|
|
||||||
from vllm.distributed.parallel_state import get_dp_group
|
|
||||||
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
|
|
||||||
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
|
|
||||||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
|
|
||||||
dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
|
|
||||||
cu_tokens_across_dp_cpu)
|
|
||||||
|
|
||||||
global _forward_context
|
global _forward_context
|
||||||
prev_context = _forward_context
|
prev_context = _forward_context
|
||||||
|
|||||||
@ -828,6 +828,21 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||||
|
|
||||||
|
# Chunked all2all staging tensor
|
||||||
|
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||||
|
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||||
|
if self.moe_parallel_config.use_pplx_kernels:
|
||||||
|
act_dtype = vllm_config.model_config.dtype
|
||||||
|
self.batched_hidden_states = torch.zeros(
|
||||||
|
(MOE_DP_CHUNK_SIZE, self.hidden_size),
|
||||||
|
dtype=act_dtype,
|
||||||
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
|
self.batched_router_logits = torch.zeros(
|
||||||
|
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
|
||||||
|
dtype=act_dtype,
|
||||||
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tp_size(self):
|
def tp_size(self):
|
||||||
return self.moe_parallel_config.tp_size
|
return self.moe_parallel_config.tp_size
|
||||||
@ -1217,18 +1232,39 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
||||||
full_router_logits: torch.Tensor):
|
full_router_logits: torch.Tensor):
|
||||||
|
assert self.batched_hidden_states is not None
|
||||||
|
assert self.batched_router_logits is not None
|
||||||
|
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
|
||||||
|
assert self.batched_router_logits.dtype == full_router_logits.dtype
|
||||||
|
# Check size compatibility.
|
||||||
|
assert (
|
||||||
|
self.batched_hidden_states.size(-1) == full_hidden_states.size(-1))
|
||||||
|
assert (
|
||||||
|
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
|
||||||
|
|
||||||
full_final_hidden_states = torch.empty_like(full_hidden_states)
|
full_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||||
|
|
||||||
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
||||||
|
chunk_size = chunk_end - chunk_start
|
||||||
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
||||||
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
||||||
|
|
||||||
|
assert (self.batched_hidden_states.size(0) # type: ignore
|
||||||
|
>= chunk_size)
|
||||||
|
assert (self.batched_router_logits.size(0) # type: ignore
|
||||||
|
>= chunk_size)
|
||||||
|
staged_hidden_states = self.batched_hidden_states[:
|
||||||
|
chunk_size, :] # type: ignore
|
||||||
|
staged_router_logits = self.batched_router_logits[:
|
||||||
|
chunk_size, :] # type: ignore
|
||||||
|
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
||||||
|
staged_router_logits.copy_(router_logits, non_blocking=True)
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
final_hidden_states = self.quant_method.apply(
|
final_hidden_states = self.quant_method.apply(
|
||||||
layer=self,
|
layer=self,
|
||||||
x=hidden_states,
|
x=staged_hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=staged_router_logits,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
renormalize=self.renormalize,
|
renormalize=self.renormalize,
|
||||||
use_grouped_topk=self.use_grouped_topk,
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
@ -1244,7 +1280,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
if not skip_result_store:
|
if not skip_result_store:
|
||||||
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||||
final_hidden_states)
|
final_hidden_states, non_blocking=True)
|
||||||
|
|
||||||
ctx = get_forward_context()
|
ctx = get_forward_context()
|
||||||
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||||
|
|||||||
@ -106,7 +106,6 @@ class CudaPlatformBase(Platform):
|
|||||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
compilation_config = vllm_config.compilation_config
|
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
|
|
||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
@ -154,16 +153,6 @@ class CudaPlatformBase(Platform):
|
|||||||
logger.info(
|
logger.info(
|
||||||
"Forcing kv cache block size to 64 for FlashMLA backend.")
|
"Forcing kv cache block size to 64 for FlashMLA backend.")
|
||||||
|
|
||||||
if (parallel_config.data_parallel_size > 1
|
|
||||||
and compilation_config.use_cudagraph):
|
|
||||||
logger.info(
|
|
||||||
"Data Parallel: Forcing enforce eager to be True since DP is "
|
|
||||||
"currently not supported with CUDA Graphs.")
|
|
||||||
vllm_config.model_config.enforce_eager = True
|
|
||||||
compilation_config.use_cudagraph = False
|
|
||||||
# FIXME: inductor breaks cudagraph (from @bnell)
|
|
||||||
compilation_config.use_inductor = False
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_current_memory_usage(cls,
|
def get_current_memory_usage(cls,
|
||||||
device: Optional[torch.types.Device] = None
|
device: Optional[torch.types.Device] = None
|
||||||
|
|||||||
@ -24,7 +24,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
|||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_pp_group, get_tp_group, graph_capture,
|
get_pp_group, get_tp_group, graph_capture,
|
||||||
prepare_communication_buffer_for_model)
|
prepare_communication_buffer_for_model)
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context
|
from vllm.forward_context import (DPMetadata, get_forward_context,
|
||||||
|
set_forward_context)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model
|
from vllm.model_executor.model_loader import TensorizerLoader, get_model
|
||||||
@ -1104,6 +1105,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
for k, v in self.intermediate_tensors.items()
|
for k, v in self.intermediate_tensors.items()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def get_dp_padding(self, num_tokens: int):
|
||||||
|
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||||
|
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
||||||
|
if dp_size == 1:
|
||||||
|
# Early exit.
|
||||||
|
return 0
|
||||||
|
|
||||||
|
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
||||||
|
num_tokens, dp_size, dp_rank)
|
||||||
|
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
|
||||||
|
return max_tokens_across_dp_cpu - num_tokens
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
@ -1141,6 +1154,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
num_input_tokens = num_scheduled_tokens
|
num_input_tokens = num_scheduled_tokens
|
||||||
|
|
||||||
|
# Padding for DP
|
||||||
|
num_input_tokens += self.get_dp_padding(num_input_tokens)
|
||||||
|
|
||||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||||
# modal outputs after that to ensure the correct order
|
# modal outputs after that to ensure the correct order
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
@ -1658,6 +1674,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
skip_attn: bool = True,
|
skip_attn: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# Padding for DP
|
||||||
|
num_tokens += self.get_dp_padding(num_tokens)
|
||||||
|
|
||||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||||
# for dummy run with LoRA so that the num_reqs collectively
|
# for dummy run with LoRA so that the num_reqs collectively
|
||||||
# has num_tokens in total.
|
# has num_tokens in total.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user