diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 43c0bbdd72e44..4e3fdad8bb9cc 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -942,7 +942,9 @@ class EplbState: is_profile=is_profile, ) - def _update_layer_mapping_from_new(self, model_state: EplbModelState) -> None: + def _update_layer_mapping_from_new( + self, model_state: EplbModelState, layer: int + ) -> None: if ( model_state.new_physical_to_logical_map is None or model_state.new_logical_to_physical_map is None @@ -951,33 +953,29 @@ class EplbState: return target_device = model_state.physical_to_logical_map.device - new_physical = model_state.new_physical_to_logical_map.to( - target_device, non_blocking=True - ) + new_physical = model_state.new_physical_to_logical_map if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]: - model_state.physical_to_logical_map = new_physical + model_state.physical_to_logical_map = new_physical.to( + target_device, non_blocking=True + ) else: - for layer_idx in range(model_state.physical_to_logical_map.shape[0]): - model_state.physical_to_logical_map[layer_idx].copy_( - new_physical[layer_idx] - ) + model_state.physical_to_logical_map[layer].copy_( + new_physical[layer].to(target_device, non_blocking=True) + ) logical_device = model_state.logical_to_physical_map.device - for layer_idx in range(model_state.logical_to_physical_map.shape[0]): - new_logical = model_state.new_logical_to_physical_map[layer_idx].to( - logical_device + new_logical = model_state.new_logical_to_physical_map[layer].to(logical_device) + max_slots = model_state.logical_to_physical_map.shape[-1] + slot_delta = max_slots - new_logical.shape[-1] + if slot_delta > 0: + new_logical = torch.nn.functional.pad( + new_logical, (0, slot_delta), value=-1 ) - max_slots = model_state.logical_to_physical_map.shape[-1] - slot_delta = max_slots - new_logical.shape[-1] - if slot_delta > 0: - new_logical = torch.nn.functional.pad( - new_logical, (0, slot_delta), value=-1 - ) - model_state.logical_to_physical_map[layer_idx].copy_(new_logical) + model_state.logical_to_physical_map[layer].copy_(new_logical) replica_device = model_state.logical_replica_count.device - model_state.logical_replica_count.copy_( - model_state.new_logical_replica_count.to(replica_device) + model_state.logical_replica_count[layer].copy_( + model_state.new_logical_replica_count[layer].to(replica_device) ) def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool: @@ -1035,7 +1033,7 @@ class EplbState: ep_rank=ep_group.rank(), ) transferred_layer = model_state.layer_to_transfer - + self._update_layer_mapping_from_new(model_state, transferred_layer) # After the main thread consumes, advance layer_to_transfer model_state.layer_to_transfer += 1 model_state.ep_buffer_ready = 0 @@ -1058,8 +1056,7 @@ class EplbState: assert model_state.new_physical_to_logical_map is not None assert model_state.new_logical_to_physical_map is not None assert model_state.new_logical_replica_count is not None - if not is_profile: - self._update_layer_mapping_from_new(model_state) + model_state.new_physical_to_logical_map = None model_state.new_logical_to_physical_map = None model_state.new_logical_replica_count = None