mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-05 10:56:33 +08:00
[Core] Enable CUDA graphs for DP + All2All kernels (#18724)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
6dbe5b5c93
commit
7951d78738
@ -10,7 +10,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -30,6 +30,44 @@ class DPMetadata:
|
||||
max_tokens_across_dp_cpu: torch.Tensor
|
||||
cu_tokens_across_dp_cpu: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def num_tokens_across_dp(num_tokens: int, dp_size: int,
|
||||
dp_rank: int) -> torch.Tensor:
|
||||
"""
|
||||
Gather the num_tokens across all DP ranks and return results in a
|
||||
CPU tensor of size dp_size.
|
||||
"""
|
||||
num_tokens_across_dp = [0] * dp_size
|
||||
num_tokens_across_dp[dp_rank] = num_tokens
|
||||
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
|
||||
return num_tokens_tensor
|
||||
|
||||
@staticmethod
|
||||
def make(parallel_config: ParallelConfig, attn_metadata: Any,
|
||||
num_tokens: int) -> "DPMetadata":
|
||||
|
||||
assert parallel_config.data_parallel_size > 1
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
if attn_metadata is not None and hasattr(attn_metadata,
|
||||
"num_prefill_tokens"):
|
||||
# for v0 attention backends
|
||||
batchsize = attn_metadata.num_prefill_tokens + \
|
||||
attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# for v1 attention backends or no attn_metadata
|
||||
batchsize = num_tokens
|
||||
|
||||
num_tokens_tensor = DPMetadata.num_tokens_across_dp(
|
||||
batchsize, dp_size, dp_rank)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
|
||||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
|
||||
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardContext:
|
||||
@ -74,27 +112,8 @@ def set_forward_context(attn_metadata: Any,
|
||||
forward_start_time = time.perf_counter()
|
||||
dp_metadata: Optional[DPMetadata] = None
|
||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
if attn_metadata is not None and hasattr(attn_metadata,
|
||||
"num_prefill_tokens"):
|
||||
# for v0 attention backends
|
||||
batchsize = attn_metadata.num_prefill_tokens + \
|
||||
attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# for v1 attention backends or no attn_metadata
|
||||
batchsize = num_tokens
|
||||
num_tokens_across_dp = [0] * dp_size
|
||||
num_tokens_across_dp[dp_rank] = batchsize
|
||||
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
|
||||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
|
||||
dp_metadata = DPMetadata(max_tokens_across_dp_cpu,
|
||||
cu_tokens_across_dp_cpu)
|
||||
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
|
||||
attn_metadata, num_tokens)
|
||||
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
|
||||
@ -828,6 +828,21 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
# Chunked all2all staging tensor
|
||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||
if self.moe_parallel_config.use_pplx_kernels:
|
||||
act_dtype = vllm_config.model_config.dtype
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
(MOE_DP_CHUNK_SIZE, self.hidden_size),
|
||||
dtype=act_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
self.batched_router_logits = torch.zeros(
|
||||
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
|
||||
dtype=act_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
@ -1217,18 +1232,39 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor):
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
|
||||
assert self.batched_router_logits.dtype == full_router_logits.dtype
|
||||
# Check size compatibility.
|
||||
assert (
|
||||
self.batched_hidden_states.size(-1) == full_hidden_states.size(-1))
|
||||
assert (
|
||||
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
|
||||
|
||||
full_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
|
||||
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
||||
chunk_size = chunk_end - chunk_start
|
||||
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
||||
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
||||
|
||||
assert (self.batched_hidden_states.size(0) # type: ignore
|
||||
>= chunk_size)
|
||||
assert (self.batched_router_logits.size(0) # type: ignore
|
||||
>= chunk_size)
|
||||
staged_hidden_states = self.batched_hidden_states[:
|
||||
chunk_size, :] # type: ignore
|
||||
staged_router_logits = self.batched_router_logits[:
|
||||
chunk_size, :] # type: ignore
|
||||
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
||||
staged_router_logits.copy_(router_logits, non_blocking=True)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
x=staged_hidden_states,
|
||||
router_logits=staged_router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
@ -1244,7 +1280,7 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
if not skip_result_store:
|
||||
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
final_hidden_states)
|
||||
final_hidden_states, non_blocking=True)
|
||||
|
||||
ctx = get_forward_context()
|
||||
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
|
||||
@ -106,7 +106,6 @@ class CudaPlatformBase(Platform):
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
compilation_config = vllm_config.compilation_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if parallel_config.worker_cls == "auto":
|
||||
@ -154,16 +153,6 @@ class CudaPlatformBase(Platform):
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLA backend.")
|
||||
|
||||
if (parallel_config.data_parallel_size > 1
|
||||
and compilation_config.use_cudagraph):
|
||||
logger.info(
|
||||
"Data Parallel: Forcing enforce eager to be True since DP is "
|
||||
"currently not supported with CUDA Graphs.")
|
||||
vllm_config.model_config.enforce_eager = True
|
||||
compilation_config.use_cudagraph = False
|
||||
# FIXME: inductor breaks cudagraph (from @bnell)
|
||||
compilation_config.use_inductor = False
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(cls,
|
||||
device: Optional[torch.types.Device] = None
|
||||
|
||||
@ -24,7 +24,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tp_group, graph_capture,
|
||||
prepare_communication_buffer_for_model)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.forward_context import (DPMetadata, get_forward_context,
|
||||
set_forward_context)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model
|
||||
@ -1104,6 +1105,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
def get_dp_padding(self, num_tokens: int):
|
||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
||||
if dp_size == 1:
|
||||
# Early exit.
|
||||
return 0
|
||||
|
||||
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
||||
num_tokens, dp_size, dp_rank)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
|
||||
return max_tokens_across_dp_cpu - num_tokens
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@ -1141,6 +1154,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
num_input_tokens = num_scheduled_tokens
|
||||
|
||||
# Padding for DP
|
||||
num_input_tokens += self.get_dp_padding(num_input_tokens)
|
||||
|
||||
# _prepare_inputs may reorder the batch, so we must gather multi
|
||||
# modal outputs after that to ensure the correct order
|
||||
if self.is_multimodal_model:
|
||||
@ -1658,6 +1674,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
skip_attn: bool = True,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Padding for DP
|
||||
num_tokens += self.get_dp_padding(num_tokens)
|
||||
|
||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||
# for dummy run with LoRA so that the num_reqs collectively
|
||||
# has num_tokens in total.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user