[Frontend] Don't log duplicate error stacktrace for every request in the batch (#9023)

Signed-off-by: Wallas Santos <wallashss@ibm.com>
This commit is contained in:
Wallas Henrique 2024-10-21 18:49:41 -03:00 committed by GitHub
parent 15713e3b75
commit 711f3a7806
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 10 deletions

View File

@ -59,15 +59,7 @@ async def test_evil_forward(tmp_socket):
await asyncio.sleep(2.0)
await client.check_health()
# Throws an error in first forward pass.
with pytest.raises(RAISED_ERROR):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
assert client.errored
# Engine is errored, should get ENGINE_DEAD_ERROR.
# Throws an error that should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
@ -149,7 +141,7 @@ async def test_failed_abort(tmp_socket):
client = await engine.make_client()
assert client.is_running
# Firsh check health should work.
# First check health should work.
await client.check_health()
# Trigger an abort on the client side.
@ -174,6 +166,45 @@ async def test_failed_abort(tmp_socket):
client.close()
@pytest.mark.asyncio
async def test_batch_error(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_abort) as engine:
client = await engine.make_client()
assert client.is_running
# First check health should work.
await client.check_health()
# Batch of requests
async def do_generate(client):
# min_tokens=2048 to keep busy the engine busy
# to get enough time to get process a request
# that will crash the engine
params = SamplingParams(min_tokens=2048, max_tokens=2048)
async for _ in client.generate(prompt="Hello my name is",
sampling_params=params,
request_id=uuid.uuid4()):
pass
tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)]
# This request will force a processing batch to raise
# an exception and next the engine get errored
await client.abort(request_id="foo")
# The batch of those request failed, then they
# should get the same exception as a MQEngineDeadError.
errors = await asyncio.gather(*tasks, return_exceptions=True)
for e in errors:
assert isinstance(e, MQEngineDeadError)
assert "KeyError" in repr(e)
client.close()
@pytest.mark.asyncio
async def test_bad_request(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,

View File

@ -204,8 +204,20 @@ class MQLLMEngineClient(EngineClient):
# (and record only the first one)
if is_engine_errored and not self._errored_with:
self._errored_with = exception
# If engine is errored, no matter the type of exception
# it will no longer be able to receive new requests,
# therefore we have to inform that the current
# processed requests failed as well. Send back a dead
# engine error give this feedback and also give a
# 'hint' to the server to shutdown next.
exception = self.dead_error
if request_id is None:
# If request_id is None, then the engine raised an
# exception for a batch, and we may not know the
# request that caused it, neither if it was actually
# caused by any of them (e.g. CUDA OOM). Therefore we
# broadcast the same exception for all requests.
for queue_i in tuple(self.output_queues.values()):
queue_i.put_nowait(exception)
else: