[Benchmark] Add retry support to fix workload bias in multi-turn benchmark (#28493)

This commit is contained in:
ai-jz 2025-11-11 21:00:45 -08:00 committed by GitHub
parent b9ce9a3013
commit f31419ed8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -55,6 +55,7 @@ class ClientArgs(NamedTuple):
verify_output: bool verify_output: bool
conversation_sampling: ConversationSampling conversation_sampling: ConversationSampling
request_rate: float request_rate: float
max_retries: int
class RequestArgs(NamedTuple): class RequestArgs(NamedTuple):
@ -527,6 +528,25 @@ async def poisson_sleep(request_rate: float, verbose: bool = False) -> None:
await asyncio.sleep(interval) await asyncio.sleep(interval)
async def exponential_backoff_sleep(
attempt_cnt: int,
base_rate: float = 1.0,
backoff_factor: float = 2.0,
jitter_fraction: float = 0.10,
verbose: bool = False,
) -> None:
# Sleep with exponential backoff and jitter after a failed request.
backoff_delay = base_rate * (backoff_factor**attempt_cnt)
jittered_delay = backoff_delay * (
1 + np.random.uniform(-jitter_fraction, jitter_fraction)
)
if verbose:
logger.info(f"Backoff for {jittered_delay:.3f} seconds...")
await asyncio.sleep(jittered_delay)
async def client_main( async def client_main(
args: ClientArgs, args: ClientArgs,
req_args: RequestArgs, req_args: RequestArgs,
@ -655,59 +675,62 @@ async def client_main(
) )
time_of_last_turn[conv_id] = curr_time_sec time_of_last_turn[conv_id] = curr_time_sec
success = True success = False
try: for attempt_cnt in range(args.max_retries + 1):
result = await send_turn( try:
session, exception = False
client_id, result = await send_turn(
conv_id, session,
messages, client_id,
current_turn, conv_id,
tokenizer, messages,
req_args, current_turn,
args.print_content, tokenizer,
args.verify_output, req_args,
) args.print_content,
if result is not None: args.verify_output,
result_queue.put(result) )
else: if result is not None:
# None means that the request failed, result_queue.put(result)
# and should not be added to the statistics. success = True
success = False break
num_failures += 1 else:
logger.warning(
logger.warning( f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 )
except asyncio.exceptions.TimeoutError:
exception = True
logger.error(
"%sClient %d - Timeout during conversation ID %s (turn: %d). "
"Base timeout is %ss (set with --request-timeout-sec), but the "
"effective timeout may be longer based on max_tokens. If this "
"is unexpected, consider increasing the timeout or checking "
"model performance.%s",
Color.RED,
client_id,
conv_id,
current_turn,
req_args.timeout_sec,
Color.RESET,
)
except Exception:
exception = True
logger.exception(
f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
) )
# Remove the conversation (should not be used again) # Sleep before retry if not last attempt
active_convs.pop(conv_id) if not success and attempt_cnt < args.max_retries:
await exponential_backoff_sleep(attempt_cnt, verbose=args.verbose)
except asyncio.exceptions.TimeoutError: if not success:
num_failures += 1 num_failures += 1
logger.error( # Remove the conversation (should not be used again)
"%sClient %d - Timeout during conversation ID %s (turn: %d). " active_convs.pop(conv_id)
"Base timeout is %ss (set with --request-timeout-sec), but the " if exception:
"effective timeout may be longer based on max_tokens. If this " break # Exit gracefully instead of raising an error
"is unexpected, consider increasing the timeout or checking "
"model performance.%s",
Color.RED,
client_id,
conv_id,
current_turn,
req_args.timeout_sec,
Color.RESET,
)
break # Exit gracefully instead of raising an error
except Exception: else:
num_failures += 1
logger.exception(
f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
)
break # Exit gracefully instead of raising an error
if success:
num_successes += 1 num_successes += 1
# Update the turns counter to include the LLM response # Update the turns counter to include the LLM response
@ -822,6 +845,7 @@ def get_client_config(
verify_output=args.verify_output, verify_output=args.verify_output,
conversation_sampling=args.conversation_sampling, conversation_sampling=args.conversation_sampling,
request_rate=args.request_rate, request_rate=args.request_rate,
max_retries=args.max_retries,
) )
if args.limit_min_tokens > 0 or args.limit_max_tokens > 0: if args.limit_min_tokens > 0 or args.limit_max_tokens > 0:
@ -1357,6 +1381,16 @@ async def main() -> None:
help="Expected request rate (Poisson process) per client in requests/sec." help="Expected request rate (Poisson process) per client in requests/sec."
"Set to 0 for no delay between requests.", "Set to 0 for no delay between requests.",
) )
parser.add_argument(
"--max-retries",
type=int,
default=int(os.environ.get("MULTITURN_BENCH_MAX_RETRIES", "0")),
help="Maximum number of retry attempts for timed-out requests. "
"Default is 0 (no retries). "
"Set to higher values to retry failed requests and maintain "
"fair workload distribution. "
"Can also be set via MULTITURN_BENCH_MAX_RETRIES environment variable.",
)
parser.add_argument( parser.add_argument(
"--conversation-sampling", "--conversation-sampling",
type=ConversationSampling, type=ConversationSampling,