mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-27 10:10:18 +08:00
afd use ubatch without thread
Signed-off-by: jiangkuaixue123 <jiangxiaozhou111@163.com>
This commit is contained in:
parent
bd8fe276f5
commit
28cba040c7
@ -4,7 +4,7 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import torch
|
||||
@ -14,6 +14,7 @@ from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.distributed.afd_transfer import AFDConnectorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlices
|
||||
|
||||
@ -232,6 +233,13 @@ class AFDMetadata:
|
||||
afd_tokens_lens: list[int] # padded lengths for tensor slicing
|
||||
num_of_stages: int
|
||||
|
||||
input_ids_list: list[torch.Tensor] = field(default_factory=list)
|
||||
positions_list: list[torch.Tensor] = field(default_factory=list)
|
||||
inputs_embeds_list: list[torch.Tensor] = field(default_factory=list)
|
||||
intermediate_tensors_list: list[IntermediateTensors] = field(default_factory=list)
|
||||
attn_metadata_list: list[AttentionMetadata] = field(default_factory=list)
|
||||
dp_metadata_list: list[DPMetadata] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardContext:
|
||||
|
||||
@ -1361,7 +1361,9 @@ class DeepseekV2Model(nn.Module):
|
||||
if recv_handle is not None:
|
||||
for work in recv_handle:
|
||||
work.wait()
|
||||
current_hidden, residual = layer(positions, hidden_states, residual, llama_4_scaling)
|
||||
current_hidden, residual = layer(
|
||||
positions, hidden_states, residual, llama_4_scaling
|
||||
)
|
||||
metadata = AFDConnectorMetadata.create_attention_metadata(
|
||||
layer_idx=layer.layer_idx,
|
||||
stage_idx=afd_metadata.afd_stage_idx,
|
||||
@ -1385,6 +1387,96 @@ class DeepseekV2Model(nn.Module):
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def forward_with_afd_v2(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
afd_metadata: AFDMetadata,
|
||||
llama_4_scaling: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
forward_conext = get_forward_context()
|
||||
recv_handle = None
|
||||
|
||||
ubatch_hidden_states = []
|
||||
ubatch_residual = []
|
||||
|
||||
start_idx = 0
|
||||
for pos in afd_metadata.positions_list:
|
||||
# DeepSeekV2 uses MROPE with shape (3, num_tokens), so use shape[1] if ndim==2
|
||||
# Otherwise use shape[0] as requested
|
||||
num_tokens = pos.shape[1] if pos.ndim == 2 else pos.shape[0]
|
||||
end_idx = start_idx + num_tokens
|
||||
ubatch_hidden_states.append(hidden_states[start_idx:end_idx])
|
||||
ubatch_residual.append(
|
||||
residual[start_idx:end_idx] if residual is not None else None
|
||||
)
|
||||
start_idx = end_idx
|
||||
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
for stage_i in range(forward_conext.afd_metadata.num_of_stages):
|
||||
logger.info(
|
||||
f"jcz deepseekv2 forward_with_afd_v2 layer_idx: {layer.layer_idx}, stage_i: {stage_i}"
|
||||
)
|
||||
afd_connector = afd_metadata.afd_connector
|
||||
forward_conext.attn_metadata = afd_metadata.attn_metadata_list[stage_i]
|
||||
forward_conext.dp_metadata = afd_metadata.dp_metadata_list[stage_i]
|
||||
|
||||
residual = ubatch_residual[stage_i]
|
||||
|
||||
if layer.layer_idx > 0:
|
||||
hidden_states, recv_metadata = afd_connector.recv_ffn_output()
|
||||
if recv_metadata.recv_handle_list is not None:
|
||||
recv_handle = recv_metadata.recv_handle_list
|
||||
else:
|
||||
hidden_states = ubatch_hidden_states[stage_i]
|
||||
|
||||
if recv_handle is not None:
|
||||
for work in recv_handle:
|
||||
work.wait()
|
||||
|
||||
current_positions = afd_metadata.positions_list[stage_i]
|
||||
logger.info(
|
||||
f"jcz deepseekv2 forward_with_afd_v2 hidden_states: {hidden_states.shape}"
|
||||
f" positions:{positions.shape}"
|
||||
)
|
||||
hidden_states, residual = layer(
|
||||
current_positions, hidden_states, residual, llama_4_scaling
|
||||
)
|
||||
|
||||
ubatch_hidden_states[stage_i] = hidden_states
|
||||
ubatch_residual[stage_i] = residual
|
||||
|
||||
metadata = AFDConnectorMetadata.create_attention_metadata(
|
||||
layer_idx=layer.layer_idx,
|
||||
stage_idx=stage_i,
|
||||
seq_len=hidden_states.shape[0],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
num_of_stages=afd_metadata.num_of_stages,
|
||||
afd_tokens_lens=afd_metadata.afd_tokens_lens,
|
||||
)
|
||||
afd_connector.send_attn_output(hidden_states, metadata)
|
||||
|
||||
# Recv last layer and last stage FFN output.
|
||||
ubatch_hidden_states[afd_metadata.num_of_stages - 1], recv_metadata = (
|
||||
afd_connector.recv_ffn_output()
|
||||
)
|
||||
if recv_metadata.recv_handle_list is not None:
|
||||
recv_handle = recv_metadata.recv_handle_list
|
||||
if recv_handle is not None:
|
||||
for work in recv_handle:
|
||||
work.wait()
|
||||
|
||||
# Re-assemble the batch
|
||||
hidden_states = torch.cat(ubatch_hidden_states, dim=0)
|
||||
if any(r is not None for r in ubatch_residual):
|
||||
residual = torch.cat(ubatch_residual, dim=0)
|
||||
else:
|
||||
residual = None
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -1421,12 +1513,14 @@ class DeepseekV2Model(nn.Module):
|
||||
afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None
|
||||
|
||||
if afd_metadata != None:
|
||||
hidden_states, residual = self.forward_with_afd(
|
||||
hidden_states, residual = self.forward_with_afd_v2(
|
||||
hidden_states, residual, positions, afd_metadata, llama_4_scaling
|
||||
)
|
||||
else:
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(positions, hidden_states, residual, llama_4_scaling)
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, residual, llama_4_scaling
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
|
||||
@ -14,6 +14,7 @@ from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
|
||||
from vllm.forward_context import (
|
||||
AFDMetadata,
|
||||
DPMetadata,
|
||||
create_forward_context,
|
||||
get_forward_context,
|
||||
@ -358,6 +359,59 @@ class UBatchWrapper:
|
||||
|
||||
return ubatch_metadata
|
||||
|
||||
def _make_afd_ubatch_metadata(
|
||||
self,
|
||||
ubatch_slices,
|
||||
attn_metadata,
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
intermediate_tensors,
|
||||
dp_metadata,
|
||||
afd_metadata,
|
||||
) -> AFDMetadata:
|
||||
if ubatch_slices is None:
|
||||
afd_metadata.input_ids_list.append(input_ids)
|
||||
afd_metadata.positions_list.append(positions)
|
||||
afd_metadata.inputs_embeds_list.append(inputs_embeds)
|
||||
afd_metadata.intermediate_tensors_list.append(intermediate_tensors)
|
||||
afd_metadata.attn_metadata_list.append(attn_metadata)
|
||||
afd_metadata.dp_metadata_list.append(dp_metadata)
|
||||
else:
|
||||
for i, ubatch_slice in enumerate(ubatch_slices):
|
||||
(
|
||||
sliced_input_ids,
|
||||
sliced_positions,
|
||||
sliced_inputs_embeds,
|
||||
sliced_intermediate_tensors,
|
||||
) = self._slice_model_inputs(
|
||||
ubatch_slice.token_slice,
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
intermediate_tensors,
|
||||
)
|
||||
|
||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
ubatch_num_tokens_across_dp = torch.tensor(
|
||||
[ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32
|
||||
)
|
||||
ubatch_dp_metadata = DPMetadata.make(
|
||||
self.vllm_config.parallel_config,
|
||||
ubatch_slice.num_tokens,
|
||||
ubatch_num_tokens_across_dp,
|
||||
)
|
||||
|
||||
afd_metadata.input_ids_list.append(sliced_input_ids)
|
||||
afd_metadata.positions_list.append(sliced_positions)
|
||||
afd_metadata.inputs_embeds_list.append(sliced_inputs_embeds)
|
||||
afd_metadata.intermediate_tensors_list.append(sliced_intermediate_tensors)
|
||||
afd_metadata.attn_metadata_list.append(
|
||||
attn_metadata[i] if attn_metadata is not None else None)
|
||||
afd_metadata.dp_metadata_list.append(ubatch_dp_metadata)
|
||||
|
||||
return afd_metadata
|
||||
|
||||
def _slice_model_inputs(
|
||||
self,
|
||||
tokens_slice: slice,
|
||||
@ -392,6 +446,33 @@ class UBatchWrapper:
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
afd_metadata = forward_context.afd_metadata
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
input_ids = kwargs["input_ids"]
|
||||
positions = kwargs["positions"]
|
||||
intermediate_tensors = kwargs["intermediate_tensors"]
|
||||
inputs_embeds = kwargs["inputs_embeds"]
|
||||
compute_stream = torch.cuda.current_stream()
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
|
||||
if self.vllm_config.afd_config:
|
||||
afd_metadata = self._make_afd_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
dp_metadata=dp_metadata,
|
||||
afd_metadata=afd_metadata,
|
||||
)
|
||||
forward_context.afd_metadata = afd_metadata
|
||||
if cudagraph_runtime_mode is CUDAGraphMode.NONE:
|
||||
return self.runnable(*args, **kwargs)
|
||||
else:
|
||||
assert self.cudagraph_wrapper is not None
|
||||
return self.cudagraph_wrapper(*args, **kwargs)
|
||||
|
||||
# If there's no ubatching, just run the runnable object
|
||||
if ubatch_slices is None:
|
||||
# This is to account for the case where ubatching was aborted.
|
||||
@ -400,6 +481,7 @@ class UBatchWrapper:
|
||||
# num_tokens, we don't have a non-ubatched one. Without this
|
||||
# check, the cudagraph wrapper will try to capture a cudagraph
|
||||
# for this shape during a normal run.
|
||||
|
||||
if cudagraph_runtime_mode is CUDAGraphMode.FULL:
|
||||
assert batch_descriptor is not None
|
||||
if batch_descriptor.num_tokens in self.cudagraphs:
|
||||
@ -411,18 +493,9 @@ class UBatchWrapper:
|
||||
assert self.cudagraph_wrapper is not None
|
||||
return self.cudagraph_wrapper(*args, **kwargs)
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
num_tokens = (
|
||||
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
|
||||
) * 2
|
||||
input_ids = kwargs["input_ids"]
|
||||
positions = kwargs["positions"]
|
||||
intermediate_tensors = kwargs["intermediate_tensors"]
|
||||
inputs_embeds = kwargs["inputs_embeds"]
|
||||
compute_stream = torch.cuda.current_stream()
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
|
||||
# We shouldn't be here unless we are running with multiple DP ranks
|
||||
assert dp_metadata is not None
|
||||
ubatch_dp_metadata = []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user