[Feat] adapt step3 text model

This commit is contained in:
i-yuanyukun 2025-12-18 14:30:55 +08:00
parent 36f9c3d6b5
commit d306d01dd7
6 changed files with 580 additions and 44 deletions

View File

@ -7,6 +7,316 @@ from datetime import timedelta
import torch
from torch.distributed.distributed_c10d import _get_default_group, _update_default_pg
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import (
GroupCoordinator,
TensorMetadata,
init_afd_process_group,
init_model_parallel_group,
)
from vllm.logger import init_logger
from .base import AFDConnectorBase
from .metadata import AFDConnectorMetadata
logger = init_logger(__name__)
class DefaultProcessGroupSwitcher:
def __init__(self, default_group, new_default_group):
self.default_group = default_group
self.new_default_group = new_default_group
def __enter__(self):
_update_default_pg(self.new_default_group)
def __exit__(self, exc_type, exc_value, traceback):
_update_default_pg(self.default_group)
class P2PAFDConnector(AFDConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: "VllmConfig",
) -> None:
self.rank = rank
self.local_rank = local_rank
self.config = config
self._initialized: bool = False
self._need_recv_metadata: bool = True
self._tensor_metadata_list: dict[int, TensorMetadata] = {}
self._current_afd_connector_metadata: AFDConnectorMetadata | None = None
if getattr(self.config.model_config.hf_config, "text_config", None) is not None:
self.num_hidden_layers: int = (
self.config.model_config.hf_config.text_config.num_hidden_layers
)
else:
self.num_hidden_layers: int = (
self.config.model_config.hf_config.num_hidden_layers
)
self.recv_attn_output_counter: int = 0
self.recv_ffn_output_counter: int = 0
def close(self) -> None:
"""Close the connector and release resources."""
# TODO: Implement proper resource clean up if needed.
pass
def init_afd_connector(self) -> None:
"""Initialize the AFD connector."""
afd_size = self.config.afd_config.afd_extra_config.get("afd_size")
role = self.config.afd_config.afd_role
attn_size, ffn_size = map(int, re.match(r"(\d+)\D+(\d+)", afd_size).groups())
world_rank = self.rank if role == "attention" else self.rank + attn_size
afd_pg = init_afd_process_group(
backend="nccl",
init_method=(
f"tcp://{self.config.afd_config.afd_host}"
f":{self.config.afd_config.afd_port}"
),
world_size=ffn_size + attn_size,
rank=world_rank,
group_name="afd",
timeout=timedelta(minutes=2),
)
# Construct rank lists for sub groups.
# Each group contains one attention and one ffn rank.
ffn_ranks = [i for i in range(ffn_size, ffn_size + attn_size)]
attn_ranks = [i for i in range(attn_size)]
assert len(ffn_ranks) == len(attn_ranks), (
"ffn_ranks and attn_ranks must have the same length"
)
default_pg_switcher = DefaultProcessGroupSwitcher(_get_default_group(), afd_pg)
with default_pg_switcher:
sub_group_ranks = []
for i in range(len(ffn_ranks)):
ranks = [attn_ranks[i], ffn_ranks[i]]
sub_group_ranks.append(ranks)
# Create two independent groups:
# a2e_group: for attention -> expert/ffn communication (send_attn, recv_attn)
# e2a_group: for expert/ffn -> attention communication (send_ffn, recv_ffn)
# The communication domain (rank range) is the same, but different group_name
# creates independent groups.
self.a2e_group = init_model_parallel_group(
sub_group_ranks,
self.local_rank,
backend="nccl",
group_name="a2e",
)
self.e2a_group = init_model_parallel_group(
sub_group_ranks,
self.local_rank,
backend="nccl",
group_name="e2a",
)
self._initialized = True
def is_initialized(self) -> bool:
"""Check if the connector is initialized and ready to use.
Returns:
bool: True if the connector is initialized, False otherwise.
"""
return self._initialized
def _build_tensor_metadata_list(
self,
tensor_metadata: TensorMetadata,
connector_metadata: AFDConnectorMetadata,
) -> dict[int, TensorMetadata]:
tensor_metadata_list = {}
num_of_stages = connector_metadata.num_of_stages
for idx in range(num_of_stages):
if idx == 0:
tensor_metadata_list[0] = tensor_metadata
else:
new_size = list(tensor_metadata.size)
new_size[0] = connector_metadata.afd_tokens_lens[idx]
tensor_metadata_list[idx] = TensorMetadata(
tensor_metadata.device,
tensor_metadata.dtype,
torch.Size(new_size),
)
return tensor_metadata_list
def _send_metadata(
self,
metadata: AFDConnectorMetadata,
hidden_states: torch.Tensor,
dst: int,
process_group: GroupCoordinator,
) -> None:
if not torch.distributed.is_initialized() or process_group.world_size == 1:
return []
assert dst < process_group.world_size, f"Invalid dst rank ({dst})"
tensor_metadata = TensorMetadata(
hidden_states.device.type, hidden_states.dtype, hidden_states.size()
)
metadata_tuple = (metadata, tensor_metadata)
process_group.send_object(metadata_tuple, dst=dst)
self._tensor_metadata_list = self._build_tensor_metadata_list(
tensor_metadata, metadata
)
def _recv_metadata(
self,
src: int,
process_group: GroupCoordinator,
) -> None:
(self._current_afd_connector_metadata, tensor_metadata) = (
process_group.recv_object(src=src)
)
self._tensor_metadata_list = self._build_tensor_metadata_list(
tensor_metadata, self._current_afd_connector_metadata
)
def _send_hidden_states(
self,
hidden_states: torch.Tensor,
dst: int,
process_group: GroupCoordinator,
) -> None:
if not torch.distributed.is_initialized() or process_group.world_size == 1:
return []
assert dst < process_group.world_size, f"Invalid dst rank ({dst})"
assert not hidden_states.is_cpu, "Hidden states must be on GPU"
torch.distributed.send(
hidden_states,
dst=process_group.ranks[dst],
group=process_group.device_group,
)
def _recv_hidden_states(
self,
src: int,
process_group: GroupCoordinator,
tensor_metadata: TensorMetadata,
) -> tuple[torch.Tensor, list]:
if not torch.distributed.is_initialized() or process_group.world_size == 1:
return {}, []
assert src < process_group.world_size, f"Invalid src rank ({src})"
hidden_states = torch.empty(
tensor_metadata.size,
dtype=tensor_metadata.dtype,
device=tensor_metadata.device,
)
torch.distributed.recv(
hidden_states,
src=process_group.ranks[src],
group=process_group.device_group,
)
return hidden_states, []
# -------------------------------------------------------------------------
# attn -> ffn
# -------------------------------------------------------------------------
def send_attn_output(
self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata
) -> None:
"""
Called by ATTN side to send intermediate tensors
generated by ATTN instances to FFN.
"""
try:
dst = (self.a2e_group.rank_in_group + 1) % self.a2e_group.world_size
if metadata.layer_idx == 0 and metadata.stage_idx == 0:
self._send_metadata(metadata, hidden_states, dst, self.a2e_group)
self._current_afd_connector_metadata = metadata
self._send_hidden_states(hidden_states, dst, self.a2e_group)
except Exception as e:
raise RuntimeError(f"Communication error: {e}")
def recv_ffn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]:
"""
Called by the ATTN side to receive MOE output intermediate tensors,
possibly dispatching from the receiver to other GPUs.
"""
src = (self.e2a_group.rank_in_group - 1) % self.e2a_group.world_size
stage_idx = (
self.recv_ffn_output_counter
% self._current_afd_connector_metadata.num_of_stages
)
hidden_states, work_list = self._recv_hidden_states(
src,
self.e2a_group,
self._tensor_metadata_list[stage_idx],
)
self._current_afd_connector_metadata.recv_handle_list = work_list
self.recv_ffn_output_counter = (
self.recv_ffn_output_counter + 1
) % self._current_afd_connector_metadata.num_of_stages
return hidden_states, self._current_afd_connector_metadata
# -------------------------------------------------------------------------
# ffn -> attn
# -------------------------------------------------------------------------
def send_ffn_output(
self,
hidden_states: torch.Tensor,
metadata: AFDConnectorMetadata,
) -> None:
"""
Called by FFN side to send intermediate tensors generated by FFN
instances back to the sender (should be the same GPU as source).
"""
dst = (self.e2a_group.rank_in_group + 1) % self.e2a_group.world_size
self._send_hidden_states(hidden_states, dst, self.e2a_group)
self.recv_attn_output_counter += 1
if (
self.recv_attn_output_counter
% (
self._current_afd_connector_metadata.num_of_stages
* self.num_hidden_layers
)
== 0
):
self._need_recv_metadata = True
self.recv_attn_output_counter = 0
def recv_attn_output(self) -> tuple[torch.Tensor, AFDConnectorMetadata]:
"""
Called by the FFN side to receive intermediate tensors from ATTN.
Handles receiving and possibly dispatching tensors.
"""
src = (self.a2e_group.rank_in_group - 1) % self.a2e_group.world_size
if self._need_recv_metadata:
self._recv_metadata(src, self.a2e_group)
self._need_recv_metadata = False
stage_idx = (
self.recv_attn_output_counter
% self._current_afd_connector_metadata.num_of_stages
)
layer_idx = (
self.recv_attn_output_counter
// self._current_afd_connector_metadata.num_of_stages
)
hidden_states, work_list = self._recv_hidden_states(
src,
self.a2e_group,
self._tensor_metadata_list[stage_idx],
)
self._current_afd_connector_metadata.recv_handle_list = work_list
self._current_afd_connector_metadata.layer_idx = layer_idx
return hidden_states, self._current_afd_connector_metadata
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import re
from datetime import timedelta
import torch
from torch.distributed.distributed_c10d import _get_default_group, _update_default_pg
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import (
GroupCoordinator,

View File

@ -1670,7 +1670,7 @@ class DeepseekV2ForCausalLM(
return hidden_states
def compute_ffn_output(
self, current_layer_idx, hidden_states
self, hidden_states, current_layer_idx
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx)
return hidden_states

View File

@ -11,12 +11,16 @@ from torch import nn
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.config import AFDConfig, CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.distributed.afd_transfer.afd_connector.metadata import (
AFDConnectorMetadata,
)
from vllm.forward_context import AFDMetadata, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
@ -37,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.step3_vl import Step3TextConfig
from vllm.v1.worker.ubatching import dbo_current_ubatch_id, dbo_enabled, dbo_yield
from .interfaces import SupportsPP
from .utils import (
@ -228,54 +233,59 @@ class Step3TextDecoderLayer(nn.Module):
config: Step3TextConfig,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
afd_config: AFDConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.afd_role = afd_config.afd_role if afd_config is not None else None
self.self_attn = Step3TextAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
norm_eps=config.rms_norm_eps,
max_position_embedding=config.max_position_embedding,
head_dim=config.head_dim,
share_q_dim=config.share_q_dim,
rope_parameters=config.rope_parameters,
prefix=f"{prefix}.self_attn",
)
layer_idx = int(prefix.split("layers.")[1].split(".")[0])
moe_layers_enum = getattr(config, "moe_layers_enum", None)
if moe_layers_enum is not None:
moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
else:
# Default to 1dense.
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
if layer_idx in moe_layers_idx:
self.moe = FusedMoEBlock(
config=config, quant_config=quant_config, prefix=f"{prefix}.moe"
)
self.share_expert = Step3TextMLP(
if self.afd_role is None or self.afd_role == "attention":
self.self_attn = Step3TextAttention(
hidden_size=self.hidden_size,
intermediate_size=config.share_expert_dim,
hidden_act="silu",
num_heads=config.num_attention_heads,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.share_expert",
norm_eps=config.rms_norm_eps,
max_position_embedding=config.max_position_embedding,
head_dim=config.head_dim,
share_q_dim=config.share_q_dim,
rope_parameters=config.rope_parameters,
prefix=f"{prefix}.self_attn",
)
self.use_moe = True
else:
self.mlp = Step3TextMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act="silu",
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.use_moe = False
self.layer_idx = int(prefix.split("layers.")[1].split(".")[0])
if self.afd_role is None or self.afd_role == "ffn":
moe_layers_enum = getattr(config, "moe_layers_enum", None)
if moe_layers_enum is not None:
moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
else:
# Default to 1dense.
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
if self.layer_idx in moe_layers_idx:
self.moe = FusedMoEBlock(
config=config, quant_config=quant_config, prefix=f"{prefix}.moe"
)
self.share_expert = Step3TextMLP(
hidden_size=self.hidden_size,
intermediate_size=config.share_expert_dim,
hidden_act="silu",
quant_config=quant_config,
prefix=f"{prefix}.share_expert",
)
self.use_moe = True
else:
self.mlp = Step3TextMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act="silu",
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.use_moe = False
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
@ -300,6 +310,9 @@ class Step3TextDecoderLayer(nn.Module):
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
if self.afd_role == "attention":
return hidden_states, residual
if self.use_moe:
share_output = self.share_expert(hidden_states)
moe_output = self.moe(hidden_states)
@ -309,6 +322,25 @@ class Step3TextDecoderLayer(nn.Module):
return hidden_states, residual
def compute_attn_output(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
):
pass
def compute_ffn_output(self, hidden_states):
assert self.afd_role == "ffn"
if self.use_moe:
share_output = self.share_expert(hidden_states)
moe_output = self.moe(hidden_states)
hidden_states = share_output + moe_output
else:
hidden_states = self.mlp(hidden_states)
logger.info(f"{type(hidden_states)=}")
return hidden_states
@support_torch_compile
class Step3TextModel(nn.Module):
@ -317,6 +349,8 @@ class Step3TextModel(nn.Module):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
logger.info(f"{quant_config=}")
afd_config = vllm_config.afd_config
self.vocab_size = config.vocab_size
self.config = config
@ -336,6 +370,7 @@ class Step3TextModel(nn.Module):
config=config,
cache_config=cache_config,
quant_config=quant_config,
afd_config=afd_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
@ -352,6 +387,51 @@ class Step3TextModel(nn.Module):
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward_with_afd(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
afd_metadata: AFDMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
recv_handle = None
logger.info(f"{__file__}: forward with afd called, may blocked here")
for layer in islice(self.layers, self.start_layer, self.end_layer):
afd_connector = afd_metadata.afd_connector
afd_metadata.afd_stage_idx = dbo_current_ubatch_id()
if layer.layer_idx > 0:
hidden_states, recv_metadata = afd_connector.recv_ffn_output()
if recv_metadata.recv_handle_list is not None:
recv_handle = recv_metadata.recv_handle_list
if recv_handle is not None:
for work in recv_handle:
work.wait()
current_hidden, residual = layer(positions, hidden_states, residual)
metadata = AFDConnectorMetadata.create_attention_metadata(
layer_idx=layer.layer_idx,
stage_idx=afd_metadata.afd_stage_idx,
seq_len=current_hidden.shape[0],
dtype=current_hidden.dtype,
device=current_hidden.device,
num_of_stages=afd_metadata.num_of_stages,
afd_tokens_lens=afd_metadata.afd_tokens_lens,
)
afd_connector.send_attn_output(current_hidden, metadata)
if dbo_enabled():
dbo_yield()
hidden_states, recv_metadata = afd_connector.recv_ffn_output()
if recv_metadata.recv_handle_list is not None:
recv_handle = recv_metadata.recv_handle_list
if recv_handle is not None:
for work in recv_handle:
work.wait()
return hidden_states, residual
def forward(
self,
input_ids: torch.Tensor,
@ -370,8 +450,19 @@ class Step3TextModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
forward_ctx = get_forward_context()
afd_metadata = forward_ctx.afd_metadata if forward_ctx is not None else None
if afd_metadata is not None:
hidden_states, residual = self.forward_with_afd(
hidden_states,
residual,
positions,
afd_metadata,
)
else:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@ -384,6 +475,15 @@ class Step3TextModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def compute_ffn_output(
self,
hidden_states,
layer_idx,
) -> torch.Tensor | IntermediateTensors:
logger.info(f"{type(self.layers)=}, {type(layer_idx)=}")
hidden_states = self.layers[layer_idx].compute_ffn_output(hidden_states)
return hidden_states
class Step3TextForCausalLM(nn.Module, SupportsPP):
def __init__(
@ -398,6 +498,11 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
self.config = config
self.vllm_config = vllm_config
self.afd_config = vllm_config.afd_config
self.afd_role = (
self.afd_config.afd_role if self.afd_config is not None else None
)
self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix)
if get_pp_group().is_last_rank:
@ -429,11 +534,20 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
)
return hidden_states
def compute_ffn_output(
self,
hidden_states,
current_layer_idx,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model.compute_ffn_output(hidden_states, current_layer_idx)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
logger.info(f"{__file__}: load_weights!")
qkv_params_mapping = [
# (param_name, shard_name, relative_start_idx, relative_end_idx)
(
@ -466,6 +580,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
logger.info(f"{params_dict.keys()=}")
loaded_params: set[str] = set()
expert_params_mapping = [
@ -477,9 +592,17 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
disable_moe_stacked_params = [data[1] for data in expert_params_mapping]
for name, loaded_weight in weights:
logger.info(
f"{self.afd_role=}, {name=}, is_moe: {self.is_moe_weight(name)}, "
f"is_common: {self.is_common_weight(name)}"
)
if self.afd_role == "attention" and self.is_moe_weight(name):
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if any(
disable_moe_stacked_param in name
for disable_moe_stacked_param in disable_moe_stacked_params
@ -498,6 +621,10 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
param_name, weight_name, shard_id = mapping
if weight_name not in name:
continue
if self.afd_role is not None and self.afd_role == "attention":
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
@ -521,12 +648,19 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
loaded_params.add(name)
break
else:
if (
self.afd_role == "ffn"
and not self.is_moe_weight(name)
and not self.is_common_weight(name)
):
continue
for (
param_name,
weight_name,
start_idx,
end_idx,
) in qkv_params_mapping:
logger.info(f"{weight_name=}, {name=}")
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
@ -552,3 +686,25 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def is_moe_weight(self, name):
if (
"shared_expert" in name
or "experts" in name
or "gate" in name
or "up" in name
or "down" in name
):
return True
return False
def is_common_weight(self, name):
if (
"lm_head" in name
or "model.norm.weight" in name
or "embed" in name
or "input_layernorm" in name
or "post_attention_layernorm" in name
):
return True
return False

View File

@ -1126,6 +1126,16 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return hidden_states
def compute_ffn_output(
self,
hidden_states,
current_layer_idx,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.language_model.compute_ffn_output(
hidden_states, current_layer_idx
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,

View File

@ -642,6 +642,25 @@ class GPUModelRunner(
with_stack=False,
)
profile_dir = (
"./profiler_logs/attn"
if self.afd_config is not None and self.afd_config.afd_role == "attention"
else "./profiler_logs/normal"
)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=6000 + 4000, warmup=1, active=30, repeat=1
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir),
record_shapes=True,
profile_memory=False,
with_stack=False,
)
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
@ -2969,6 +2988,38 @@ class GPUModelRunner(
)
return afd_metadata
def _build_afd_metadata(
self, ubatch_slices: UBatchSlices | None, num_tokens_unpadded: int
):
afd_metadata = None
if self.afd_config:
# For prefill, compute tokens per stage based on actual token
# counts
afd_tokens_start_loc = [0]
afd_tokens_lens = []
if ubatch_slices and len(ubatch_slices) > 1:
afd_tokens_start_loc = [ub.token_slice.start for ub in ubatch_slices]
afd_reqs_start_loc = [ub.request_slice.start for ub in ubatch_slices]
logger.info(
f"afd_tokens_start_loc: {afd_tokens_start_loc} "
f"afd_reqs_start_loc: {afd_reqs_start_loc} "
f"ubatch_slices: {ubatch_slices}"
)
afd_tokens_lens = [ub.num_tokens for ub in ubatch_slices]
else:
afd_tokens_start_loc = [0]
afd_reqs_start_loc = [0]
afd_tokens_lens = [num_tokens_unpadded]
afd_metadata = AFDMetadata(
afd_tokens_start_loc=afd_tokens_start_loc,
afd_reqs_start_loc=afd_reqs_start_loc,
afd_stage_idx=0,
afd_connector=self.afd_connector,
afd_tokens_lens=afd_tokens_lens,
num_of_stages=len(ubatch_slices) if ubatch_slices else 1,
)
return afd_metadata
@torch.inference_mode()
def execute_model(
self,
@ -5517,6 +5568,11 @@ class GPUModelRunner(
if hasattr(self, "afd_connector") and self.afd_connector:
self.afd_connector.init_afd_connector()
def initialize_afd_connector(self) -> None:
"""Initialize AFD connector if available."""
if hasattr(self, "afd_connector") and self.afd_connector:
self.afd_connector.init_afd_connector()
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
"""
Add encoder-only layers to the KV cache config.

View File

@ -127,6 +127,10 @@ class UBatchWrapper:
comm_sms: int = envs.VLLM_DBO_COMM_SMS
set_comm_sms = lambda sms: None
if (
vllm_config.parallel_config.enable_expert_parallel
and not vllm_config.afd_config
):
if (
vllm_config.parallel_config.enable_expert_parallel
and not vllm_config.afd_config