Remove unnecessary copies

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-12-16 14:50:28 +00:00
parent c761ce527a
commit 720fe99051
2 changed files with 25 additions and 20 deletions

View File

@ -64,6 +64,8 @@ def run_rebalance_experts(
assert model_state.eplb_stats is not None
eplb_stats = model_state.eplb_stats
# Move the global expert load window to CPU for computation.
# It has to be done in the main stream to avoid race condition
# with the main thread.
global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
# Compute new expert mappings for the model
(

View File

@ -948,9 +948,7 @@ class EplbState:
is_profile=is_profile,
)
def _update_layer_mapping_from_new(
self, model_state: EplbModelState, layer: int
) -> None:
def _update_layer_mapping_from_new(self, model_state: EplbModelState) -> None:
if (
model_state.new_physical_to_logical_map is None
or model_state.new_logical_to_physical_map is None
@ -959,27 +957,33 @@ class EplbState:
return
target_device = model_state.physical_to_logical_map.device
new_physical = model_state.new_physical_to_logical_map
new_physical = model_state.new_physical_to_logical_map.to(
target_device, non_blocking=True
)
if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]:
model_state.physical_to_logical_map = new_physical.to(target_device)
model_state.physical_to_logical_map = new_physical
else:
model_state.physical_to_logical_map[layer].copy_(
new_physical[layer].to(target_device)
)
for layer_idx in range(model_state.physical_to_logical_map.shape[0]):
model_state.physical_to_logical_map[layer_idx].copy_(
new_physical[layer_idx]
)
logical_device = model_state.logical_to_physical_map.device
new_logical = model_state.new_logical_to_physical_map[layer].to(logical_device)
max_slots = model_state.logical_to_physical_map.shape[-1]
slot_delta = max_slots - new_logical.shape[-1]
if slot_delta > 0:
new_logical = torch.nn.functional.pad(
new_logical, (0, slot_delta), value=-1
for layer_idx in range(model_state.logical_to_physical_map.shape[0]):
new_logical = model_state.new_logical_to_physical_map[layer_idx].to(
logical_device
)
model_state.logical_to_physical_map[layer].copy_(new_logical)
max_slots = model_state.logical_to_physical_map.shape[-1]
slot_delta = max_slots - new_logical.shape[-1]
if slot_delta > 0:
new_logical = torch.nn.functional.pad(
new_logical, (0, slot_delta), value=-1
)
model_state.logical_to_physical_map[layer_idx].copy_(new_logical)
replica_device = model_state.logical_replica_count.device
model_state.logical_replica_count[layer].copy_(
model_state.new_logical_replica_count[layer].to(replica_device)
model_state.logical_replica_count.copy_(
model_state.new_logical_replica_count.to(replica_device)
)
def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool:
@ -1037,7 +1041,7 @@ class EplbState:
ep_rank=ep_group.rank(),
)
transferred_layer = model_state.layer_to_transfer
self._update_layer_mapping_from_new(model_state, transferred_layer)
# After the main thread consumes, advance layer_to_transfer
model_state.layer_to_transfer += 1
model_state.ep_buffer_ready = 0
@ -1061,8 +1065,7 @@ class EplbState:
assert model_state.new_logical_to_physical_map is not None
assert model_state.new_logical_replica_count is not None
if not is_profile:
for layer_idx in range(model_state.physical_to_logical_map.shape[0]):
self._update_layer_mapping_from_new(model_state, layer_idx)
self._update_layer_mapping_from_new(model_state)
model_state.new_physical_to_logical_map = None
model_state.new_logical_to_physical_map = None
model_state.new_logical_replica_count = None