Fix accuracy issue

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-12-17 16:20:54 +00:00
parent 777c90e954
commit 67ff54997d

View File

@ -942,7 +942,9 @@ class EplbState:
is_profile=is_profile, is_profile=is_profile,
) )
def _update_layer_mapping_from_new(self, model_state: EplbModelState) -> None: def _update_layer_mapping_from_new(
self, model_state: EplbModelState, layer: int
) -> None:
if ( if (
model_state.new_physical_to_logical_map is None model_state.new_physical_to_logical_map is None
or model_state.new_logical_to_physical_map is None or model_state.new_logical_to_physical_map is None
@ -951,33 +953,29 @@ class EplbState:
return return
target_device = model_state.physical_to_logical_map.device target_device = model_state.physical_to_logical_map.device
new_physical = model_state.new_physical_to_logical_map.to( new_physical = model_state.new_physical_to_logical_map
target_device, non_blocking=True
)
if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]: if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]:
model_state.physical_to_logical_map = new_physical model_state.physical_to_logical_map = new_physical.to(
target_device, non_blocking=True
)
else: else:
for layer_idx in range(model_state.physical_to_logical_map.shape[0]): model_state.physical_to_logical_map[layer].copy_(
model_state.physical_to_logical_map[layer_idx].copy_( new_physical[layer].to(target_device, non_blocking=True)
new_physical[layer_idx] )
)
logical_device = model_state.logical_to_physical_map.device logical_device = model_state.logical_to_physical_map.device
for layer_idx in range(model_state.logical_to_physical_map.shape[0]): new_logical = model_state.new_logical_to_physical_map[layer].to(logical_device)
new_logical = model_state.new_logical_to_physical_map[layer_idx].to( max_slots = model_state.logical_to_physical_map.shape[-1]
logical_device slot_delta = max_slots - new_logical.shape[-1]
if slot_delta > 0:
new_logical = torch.nn.functional.pad(
new_logical, (0, slot_delta), value=-1
) )
max_slots = model_state.logical_to_physical_map.shape[-1] model_state.logical_to_physical_map[layer].copy_(new_logical)
slot_delta = max_slots - new_logical.shape[-1]
if slot_delta > 0:
new_logical = torch.nn.functional.pad(
new_logical, (0, slot_delta), value=-1
)
model_state.logical_to_physical_map[layer_idx].copy_(new_logical)
replica_device = model_state.logical_replica_count.device replica_device = model_state.logical_replica_count.device
model_state.logical_replica_count.copy_( model_state.logical_replica_count[layer].copy_(
model_state.new_logical_replica_count.to(replica_device) model_state.new_logical_replica_count[layer].to(replica_device)
) )
def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool: def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool:
@ -1035,7 +1033,7 @@ class EplbState:
ep_rank=ep_group.rank(), ep_rank=ep_group.rank(),
) )
transferred_layer = model_state.layer_to_transfer transferred_layer = model_state.layer_to_transfer
self._update_layer_mapping_from_new(model_state, transferred_layer)
# After the main thread consumes, advance layer_to_transfer # After the main thread consumes, advance layer_to_transfer
model_state.layer_to_transfer += 1 model_state.layer_to_transfer += 1
model_state.ep_buffer_ready = 0 model_state.ep_buffer_ready = 0
@ -1058,8 +1056,7 @@ class EplbState:
assert model_state.new_physical_to_logical_map is not None assert model_state.new_physical_to_logical_map is not None
assert model_state.new_logical_to_physical_map is not None assert model_state.new_logical_to_physical_map is not None
assert model_state.new_logical_replica_count is not None assert model_state.new_logical_replica_count is not None
if not is_profile:
self._update_layer_mapping_from_new(model_state)
model_state.new_physical_to_logical_map = None model_state.new_physical_to_logical_map = None
model_state.new_logical_to_physical_map = None model_state.new_logical_to_physical_map = None
model_state.new_logical_replica_count = None model_state.new_logical_replica_count = None