diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index a911ddc56b023..869e80a1af88c 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -839,6 +839,75 @@ def test_multi_kv_connector_stats_aggregation(): assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6 +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_scheduler_kv_connector_stats_aggregation(): + """Test scheduler and worker KV connector stats aggregation.""" + from vllm.v1.core.sched.output import SchedulerOutput + + scheduler = create_scheduler(create_vllm_config()) + + # Worker stats with transfer metrics + worker_stats = NixlKVConnectorStats() + worker_stats.record_transfer(get_default_xfer_telemetry()) + worker_stats.data["remote_tokens"] = [] + + # Scheduler stats with custom metric (needs dummy transfer to avoid being skipped) + scheduler_stats = NixlKVConnectorStats() + scheduler_stats.data.update( + { # dummy transfer just for testing, to bypass is_empty() check + "transfer_duration": [0], + "post_duration": [0], + "bytes_transferred": [0], + "num_descriptors": [0], + "remote_tokens": [128], + } + ) + + # Mock the scheduler connector's stats method + scheduler.connector.get_kv_connector_stats = lambda: MultiKVConnectorStats( + data={"NixlConnector": scheduler_stats} + ) + + model_output = ModelRunnerOutput( + req_ids=["req_0"], + req_id_to_index={"req_0": 0}, + sampled_token_ids=[[123]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[None], + kv_connector_output=KVConnectorOutput( + kv_connector_stats=MultiKVConnectorStats( + data={"NixlConnector": worker_stats} + ) + ), + ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=None, + num_scheduled_tokens={"req_0": 1}, + total_num_scheduled_tokens=1, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0], + finished_req_ids=set(), + free_encoder_mm_hashes=set(), + structured_output_request_ids={}, + grammar_bitmask=None, + ) + + engine_core_outputs = scheduler.update_from_output(scheduler_output, model_output) + + final_stats = next( + iter(engine_core_outputs.values()) + ).scheduler_stats.kv_connector_stats + nixl_stats = final_stats["NixlConnector"] + assert nixl_stats.num_successful_transfers == 2 + assert nixl_stats.data["remote_tokens"] == [128] + + @pytest.mark.parametrize("distributed_executor_backend", ["ray", None]) @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index cbbdf48c6e0cd..55d7f17d5081e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -924,6 +924,10 @@ class Scheduler(SchedulerInterface): kv_connector_stats = ( kv_connector_output.kv_connector_stats if kv_connector_output else None ) + if kv_connector_stats and self.connector: + stats = self.connector.get_kv_connector_stats() + if stats: + kv_connector_stats = kv_connector_stats.aggregate(stats) failed_kv_load_req_ids = None if kv_connector_output and kv_connector_output.invalid_block_ids: