mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-20 18:27:03 +08:00
Correct eplb state logs
Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
parent
6b2a1de500
commit
fc54d760a6
@ -27,7 +27,6 @@ physical experts.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@ -699,11 +698,14 @@ class EplbState:
|
|||||||
ep_group = get_ep_group().device_group
|
ep_group = get_ep_group().device_group
|
||||||
ep_rank = ep_group.rank()
|
ep_rank = ep_group.rank()
|
||||||
|
|
||||||
time_start = None
|
start_event = None
|
||||||
|
end_event = None
|
||||||
is_main_rank = ep_rank == 0
|
is_main_rank = ep_rank == 0
|
||||||
if is_main_rank:
|
if is_main_rank:
|
||||||
torch.cuda.synchronize()
|
if not self.is_async or is_profile:
|
||||||
time_start = time.perf_counter()
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
start_event.record()
|
||||||
logger.info(
|
logger.info(
|
||||||
"Rearranging experts %s %s...",
|
"Rearranging experts %s %s...",
|
||||||
"(async mode)" if self.is_async else "sync mode",
|
"(async mode)" if self.is_async else "sync mode",
|
||||||
@ -864,13 +866,15 @@ class EplbState:
|
|||||||
new_logical_replica_count
|
new_logical_replica_count
|
||||||
)
|
)
|
||||||
if is_main_rank:
|
if is_main_rank:
|
||||||
assert time_start is not None
|
assert start_event is not None
|
||||||
torch.cuda.synchronize()
|
assert end_event is not None
|
||||||
time_end = time.perf_counter()
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
gpu_elapsed = start_event.elapsed_time(end_event) / 1000.0
|
||||||
logger.info(
|
logger.info(
|
||||||
"Rearranged experts%sin %.2f seconds.",
|
"Rearranged experts %s in %.2f s.",
|
||||||
" (profile) " if is_profile else " ",
|
" (profile) " if is_profile else " ",
|
||||||
time_end - time_start,
|
gpu_elapsed,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
max_slots = eplb_model_state.logical_to_physical_map.shape[-1]
|
max_slots = eplb_model_state.logical_to_physical_map.shape[-1]
|
||||||
@ -1010,7 +1014,7 @@ class EplbState:
|
|||||||
# After the main thread consumes, advance layer_to_transfer
|
# After the main thread consumes, advance layer_to_transfer
|
||||||
model_state.layer_to_transfer += 1
|
model_state.layer_to_transfer += 1
|
||||||
model_state.ep_buffer_ready = 0
|
model_state.ep_buffer_ready = 0
|
||||||
logger.info(
|
logger.debug(
|
||||||
"model %s successfully move_to_workspace layer %d",
|
"model %s successfully move_to_workspace layer %d",
|
||||||
model_state.model_name,
|
model_state.model_name,
|
||||||
transferred_layer,
|
transferred_layer,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user