Fix accuracy issue

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-12-17 16:20:54 +00:00
parent 777c90e954
commit 67ff54997d

View File

@ -942,7 +942,9 @@ class EplbState:
is_profile=is_profile,
)
def _update_layer_mapping_from_new(self, model_state: EplbModelState) -> None:
def _update_layer_mapping_from_new(
self, model_state: EplbModelState, layer: int
) -> None:
if (
model_state.new_physical_to_logical_map is None
or model_state.new_logical_to_physical_map is None
@ -951,33 +953,29 @@ class EplbState:
return
target_device = model_state.physical_to_logical_map.device
new_physical = model_state.new_physical_to_logical_map.to(
target_device, non_blocking=True
)
new_physical = model_state.new_physical_to_logical_map
if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]:
model_state.physical_to_logical_map = new_physical
model_state.physical_to_logical_map = new_physical.to(
target_device, non_blocking=True
)
else:
for layer_idx in range(model_state.physical_to_logical_map.shape[0]):
model_state.physical_to_logical_map[layer_idx].copy_(
new_physical[layer_idx]
)
model_state.physical_to_logical_map[layer].copy_(
new_physical[layer].to(target_device, non_blocking=True)
)
logical_device = model_state.logical_to_physical_map.device
for layer_idx in range(model_state.logical_to_physical_map.shape[0]):
new_logical = model_state.new_logical_to_physical_map[layer_idx].to(
logical_device
new_logical = model_state.new_logical_to_physical_map[layer].to(logical_device)
max_slots = model_state.logical_to_physical_map.shape[-1]
slot_delta = max_slots - new_logical.shape[-1]
if slot_delta > 0:
new_logical = torch.nn.functional.pad(
new_logical, (0, slot_delta), value=-1
)
max_slots = model_state.logical_to_physical_map.shape[-1]
slot_delta = max_slots - new_logical.shape[-1]
if slot_delta > 0:
new_logical = torch.nn.functional.pad(
new_logical, (0, slot_delta), value=-1
)
model_state.logical_to_physical_map[layer_idx].copy_(new_logical)
model_state.logical_to_physical_map[layer].copy_(new_logical)
replica_device = model_state.logical_replica_count.device
model_state.logical_replica_count.copy_(
model_state.new_logical_replica_count.to(replica_device)
model_state.logical_replica_count[layer].copy_(
model_state.new_logical_replica_count[layer].to(replica_device)
)
def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool:
@ -1035,7 +1033,7 @@ class EplbState:
ep_rank=ep_group.rank(),
)
transferred_layer = model_state.layer_to_transfer
self._update_layer_mapping_from_new(model_state, transferred_layer)
# After the main thread consumes, advance layer_to_transfer
model_state.layer_to_transfer += 1
model_state.ep_buffer_ready = 0
@ -1058,8 +1056,7 @@ class EplbState:
assert model_state.new_physical_to_logical_map is not None
assert model_state.new_logical_to_physical_map is not None
assert model_state.new_logical_replica_count is not None
if not is_profile:
self._update_layer_mapping_from_new(model_state)
model_state.new_physical_to_logical_map = None
model_state.new_logical_to_physical_map = None
model_state.new_logical_replica_count = None