[Benchmark] Add retry support to fix workload bias in multi-turn benchmark (#28493)

This commit is contained in:
ai-jz 2025-11-11 21:00:45 -08:00 committed by GitHub
parent b9ce9a3013
commit f31419ed8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -55,6 +55,7 @@ class ClientArgs(NamedTuple):
verify_output: bool verify_output: bool
conversation_sampling: ConversationSampling conversation_sampling: ConversationSampling
request_rate: float request_rate: float
max_retries: int
class RequestArgs(NamedTuple): class RequestArgs(NamedTuple):
@ -527,6 +528,25 @@ async def poisson_sleep(request_rate: float, verbose: bool = False) -> None:
await asyncio.sleep(interval) await asyncio.sleep(interval)
async def exponential_backoff_sleep(
attempt_cnt: int,
base_rate: float = 1.0,
backoff_factor: float = 2.0,
jitter_fraction: float = 0.10,
verbose: bool = False,
) -> None:
# Sleep with exponential backoff and jitter after a failed request.
backoff_delay = base_rate * (backoff_factor**attempt_cnt)
jittered_delay = backoff_delay * (
1 + np.random.uniform(-jitter_fraction, jitter_fraction)
)
if verbose:
logger.info(f"Backoff for {jittered_delay:.3f} seconds...")
await asyncio.sleep(jittered_delay)
async def client_main( async def client_main(
args: ClientArgs, args: ClientArgs,
req_args: RequestArgs, req_args: RequestArgs,
@ -655,8 +675,10 @@ async def client_main(
) )
time_of_last_turn[conv_id] = curr_time_sec time_of_last_turn[conv_id] = curr_time_sec
success = True success = False
for attempt_cnt in range(args.max_retries + 1):
try: try:
exception = False
result = await send_turn( result = await send_turn(
session, session,
client_id, client_id,
@ -670,21 +692,14 @@ async def client_main(
) )
if result is not None: if result is not None:
result_queue.put(result) result_queue.put(result)
success = True
break
else: else:
# None means that the request failed,
# and should not be added to the statistics.
success = False
num_failures += 1
logger.warning( logger.warning(
f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
) )
# Remove the conversation (should not be used again)
active_convs.pop(conv_id)
except asyncio.exceptions.TimeoutError: except asyncio.exceptions.TimeoutError:
num_failures += 1 exception = True
logger.error( logger.error(
"%sClient %d - Timeout during conversation ID %s (turn: %d). " "%sClient %d - Timeout during conversation ID %s (turn: %d). "
"Base timeout is %ss (set with --request-timeout-sec), but the " "Base timeout is %ss (set with --request-timeout-sec), but the "
@ -698,16 +713,24 @@ async def client_main(
req_args.timeout_sec, req_args.timeout_sec,
Color.RESET, Color.RESET,
) )
break # Exit gracefully instead of raising an error
except Exception: except Exception:
num_failures += 1 exception = True
logger.exception( logger.exception(
f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
) )
# Sleep before retry if not last attempt
if not success and attempt_cnt < args.max_retries:
await exponential_backoff_sleep(attempt_cnt, verbose=args.verbose)
if not success:
num_failures += 1
# Remove the conversation (should not be used again)
active_convs.pop(conv_id)
if exception:
break # Exit gracefully instead of raising an error break # Exit gracefully instead of raising an error
if success: else:
num_successes += 1 num_successes += 1
# Update the turns counter to include the LLM response # Update the turns counter to include the LLM response
@ -822,6 +845,7 @@ def get_client_config(
verify_output=args.verify_output, verify_output=args.verify_output,
conversation_sampling=args.conversation_sampling, conversation_sampling=args.conversation_sampling,
request_rate=args.request_rate, request_rate=args.request_rate,
max_retries=args.max_retries,
) )
if args.limit_min_tokens > 0 or args.limit_max_tokens > 0: if args.limit_min_tokens > 0 or args.limit_max_tokens > 0:
@ -1357,6 +1381,16 @@ async def main() -> None:
help="Expected request rate (Poisson process) per client in requests/sec." help="Expected request rate (Poisson process) per client in requests/sec."
"Set to 0 for no delay between requests.", "Set to 0 for no delay between requests.",
) )
parser.add_argument(
"--max-retries",
type=int,
default=int(os.environ.get("MULTITURN_BENCH_MAX_RETRIES", "0")),
help="Maximum number of retry attempts for timed-out requests. "
"Default is 0 (no retries). "
"Set to higher values to retry failed requests and maintain "
"fair workload distribution. "
"Can also be set via MULTITURN_BENCH_MAX_RETRIES environment variable.",
)
parser.add_argument( parser.add_argument(
"--conversation-sampling", "--conversation-sampling",
type=ConversationSampling, type=ConversationSampling,