diff --git a/tests/v1/metrics/test_perf_metrics.py b/tests/v1/metrics/test_perf_metrics.py new file mode 100644 index 0000000000000..b6cda7bef3d41 --- /dev/null +++ b/tests/v1/metrics/test_perf_metrics.py @@ -0,0 +1,897 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for the analytic estimators in metrics/flops.py. +""" + +import types +from types import SimpleNamespace + +from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config +from transformers.models.llama4.configuration_llama4 import ( + Llama4Config, + Llama4TextConfig, +) +from transformers.models.qwen3.configuration_qwen3 import Qwen3Config +from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig + +from vllm.config.model import ModelConfig, get_hf_text_config +from vllm.v1.metrics.perf import ( + AttentionMetrics, + BaseConfigParser, + ExecutionContext, + FfnMetrics, + ModelMetrics, + ParsedArgs, + UnembedMetrics, +) + + +class MockModelConfig: + """Mock ModelConfig that implements the getter methods used by parsers.""" + + def __init__(self, hf_config, dtype): + self.hf_config = hf_config + self.hf_text_config = get_hf_text_config(hf_config) + self.dtype = dtype + self.is_attention_free = False + + def __getattr__(self, name): + # 1. Check if ModelConfig actually has this attribute + if not hasattr(ModelConfig, name): + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}' " + f"and neither does 'ModelConfig'." + ) + + # 2. Fetch the attribute from the ModelConfig CLASS + attr = getattr(ModelConfig, name) + + # 3. Case A: It is a @property + if isinstance(attr, property): + # Manually invoke the property's getter, passing 'self' (this mock instance) + return attr.__get__(self, self.__class__) + + # 4. Case B: It is a standard method (function) + if isinstance(attr, types.FunctionType): + # Bind the function to 'self' so it acts like a method of + # this instance. This creates a bound method where 'self' is + # automatically passed as the first arg. + return types.MethodType(attr, self) + + # 5. Case C: It is a class attribute / static variable + return attr + + +def create_mock_vllm_config( + hf_config, + model_dtype="bfloat16", + cache_dtype="auto", + quant_config=None, + data_parallel_size=1, + tensor_parallel_size=1, + pipeline_parallel_size=1, + enable_expert_parallel=False, +) -> SimpleNamespace: + vllm_config = SimpleNamespace() + vllm_config.model_config = MockModelConfig(hf_config, model_dtype) + + vllm_config.cache_config = SimpleNamespace() + vllm_config.cache_config.cache_dtype = cache_dtype + + vllm_config.quant_config = quant_config + + vllm_config.parallel_config = SimpleNamespace() + vllm_config.parallel_config.data_parallel_size = data_parallel_size + vllm_config.parallel_config.tensor_parallel_size = tensor_parallel_size + vllm_config.parallel_config.pipeline_parallel_size = pipeline_parallel_size + vllm_config.parallel_config.enable_expert_parallel = enable_expert_parallel + + return vllm_config + + +#### Parser Tests #### + + +def test_base_config_parser(): + """Test BaseConfigParser extracts base model attributes correctly.""" + hf_config = Qwen3Config( + vocab_size=50000, + hidden_size=2048, + num_attention_heads=16, + num_hidden_layers=24, + ) + vllm_config = create_mock_vllm_config(hf_config, model_dtype="float16") + + parser = BaseConfigParser() + args = ParsedArgs() + result = parser.parse(args, vllm_config) + + assert result.vocab_size == 50000 + assert result.hidden_size == 2048 + assert result.num_attention_heads == 16 + assert result.num_hidden_layers == 24 + assert result.weight_byte_size == 2 # float16 is 2 bytes + assert result.activation_byte_size == 2 # default activation size + + +def test_base_attention_config_parser_with_gqa(): + """Test BaseAttentionConfigParser with grouped query attention.""" + hf_config = Qwen3Config( + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, # GQA with 4:1 ratio + head_dim=128, + ) + vllm_config = create_mock_vllm_config(hf_config) + + parser_chain = AttentionMetrics.get_parser() + result = parser_chain.parse(vllm_config) + + assert result.num_key_value_heads == 8 + assert result.head_dim == 128 + + +def test_base_attention_config_parser_without_gqa(): + """ + Test BaseAttentionConfigParser defaults to MHA when num_key_value_heads not + specified. + """ + hf_config = Qwen3Config( + hidden_size=4096, + num_attention_heads=32, + # No num_key_value_heads specified + ) + vllm_config = create_mock_vllm_config(hf_config) + + parser_chain = AttentionMetrics.get_parser() + result = parser_chain.parse(vllm_config) + + # Should default to MHA (num_key_value_heads = num_attention_heads) + assert result.num_key_value_heads == 32 + + +def test_base_ffn_config_parser_dense(): + """Test BaseFfnConfigParser for dense FFN.""" + hf_config = Qwen3Config( + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + ) + vllm_config = create_mock_vllm_config(hf_config) + + parser_chain = FfnMetrics.get_parser() + result = parser_chain.parse(vllm_config) + + assert result.intermediate_size == 11008 + assert result.num_experts == 0 + assert result.num_experts_per_tok == 0 + assert result.num_moe_layers == 0 # No MoE + + +def test_base_ffn_config_parser_moe(): + """Test BaseFfnConfigParser for MoE FFN.""" + hf_config = Qwen3MoeConfig( + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_experts=64, + num_experts_per_tok=8, + moe_intermediate_size=14336, + n_shared_experts=2, + ) + vllm_config = create_mock_vllm_config(hf_config) + + parser_chain = FfnMetrics.get_parser() + result = parser_chain.parse(vllm_config) + + assert result.num_experts == 64 + assert result.num_experts_per_tok == 8 + assert result.moe_intermediate_size == 14336 + assert result.num_shared_experts == 2 + assert result.num_moe_layers == 32 # All layers are MoE by default + + +def test_interleave_moe_layer_step_parser(): + """Test InterleaveMoeLayerStepParser correctly computes MoE layer count.""" + hf_config = Llama4Config( + text_config=Llama4TextConfig( + num_hidden_layers=32, + num_local_experts=64, + interleave_moe_layer_step=4, # Every 4th layer is MoE + ), + ) + + vllm_config = create_mock_vllm_config(hf_config) + + parser_chain = FfnMetrics.get_parser() + result = parser_chain.parse(vllm_config) + + assert result.num_moe_layers == 8 + + +def test_moe_layer_freq_parser(): + """Test MoeLayerFreqParser correctly computes MoE layer count.""" + hf_config = DeepseekV3Config( + num_hidden_layers=30, + n_routed_experts=64, + moe_layer_freq=3, # Every 3rd layer after first_k_dense_replace + first_k_dense_replace=6, # First 6 layers are dense + ) + vllm_config = create_mock_vllm_config(hf_config) + + parser_chain = FfnMetrics.get_parser() + result = parser_chain.parse(vllm_config) + + # Layers >= 6 and divisible by 3: 6, 9, 12, 15, 18, 21, 24, 27 + expected_moe_layers = len( + [layer for layer in range(30) if layer >= 6 and layer % 3 == 0] + ) + assert expected_moe_layers == 8 + assert result.num_moe_layers == expected_moe_layers + + +#### ComponentMetrics Tests #### + + +def test_attention_metrics_scaling(): + """Test that attention metrics scale proportionally with model dimensions.""" + base_hf_config = Qwen3Config( + hidden_size=2048, + num_attention_heads=16, + num_key_value_heads=16, + num_hidden_layers=12, + head_dim=128, + ) + + base_vllm_config = create_mock_vllm_config(base_hf_config) + base_metrics = AttentionMetrics.from_vllm_config(base_vllm_config) + + # Test scaling with number of layers + double_layers_hf_config = Qwen3Config( + hidden_size=2048, + num_attention_heads=16, + num_key_value_heads=16, + num_hidden_layers=24, # Double the layers + head_dim=128, + ) + double_layers_vllm_config = create_mock_vllm_config(double_layers_hf_config) + double_layers_metrics = AttentionMetrics.from_vllm_config(double_layers_vllm_config) + + ctx = ExecutionContext.from_single_request( + num_tokens=100, context_len=512, is_prefill=True + ) + + # FLOPS should double when layers double + base_flops = base_metrics.get_num_flops(ctx) + double_flops = double_layers_metrics.get_num_flops(ctx) + assert double_flops == 2 * base_flops + + # Read/write bytes should also scale proportionally + base_read = base_metrics.get_read_bytes(ctx) + double_read = double_layers_metrics.get_read_bytes(ctx) + assert double_read == 2 * base_read + + base_write = base_metrics.get_write_bytes(ctx) + double_write = double_layers_metrics.get_write_bytes(ctx) + assert double_write == 2 * base_write + + +def test_attention_metrics_grouped_query(): + """Test attention metrics handle grouped query attention correctly.""" + mha_hf_config = Qwen3Config( + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=32, # MHA + num_hidden_layers=1, + ) + mha_config = create_mock_vllm_config(mha_hf_config) + + gqa_hf_config = Qwen3Config( + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, # GQA with 4:1 ratio + num_hidden_layers=1, + ) + gqa_config = create_mock_vllm_config(gqa_hf_config) + + mha_metrics = AttentionMetrics.from_vllm_config(mha_config) + gqa_metrics = AttentionMetrics.from_vllm_config(gqa_config) + + ctx = ExecutionContext.from_single_request( + num_tokens=1, context_len=1024, is_prefill=False + ) + + # GQA should have less KV cache reads since fewer KV heads + mha_read = mha_metrics.get_read_bytes(ctx) + gqa_read = gqa_metrics.get_read_bytes(ctx) + assert gqa_read < mha_read + + +def test_ffn_metrics_scaling(): + """Test FFN metrics scale proportionally with model dimensions.""" + base_hf_config = Qwen3Config( + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=12, + ) + base_vllm_config = create_mock_vllm_config(base_hf_config) + base_metrics = FfnMetrics.from_vllm_config(base_vllm_config) + + # Test scaling with intermediate size + larger_ffn_hf_config = Qwen3Config( + hidden_size=2048, + intermediate_size=16384, # Double intermediate size + num_hidden_layers=12, + ) + larger_ffn_vllm_config = create_mock_vllm_config(larger_ffn_hf_config) + larger_ffn_metrics = FfnMetrics.from_vllm_config(larger_ffn_vllm_config) + + ctx = ExecutionContext.from_single_request( + num_tokens=100, context_len=512, is_prefill=True + ) + + # FLOPS should double when intermediate size doubles + base_flops = base_metrics.get_num_flops(ctx) + larger_flops = larger_ffn_metrics.get_num_flops(ctx) + assert larger_flops == base_flops * 2 + + +def test_moe_metrics_vs_dense(): + """Test MoE metrics versus dense metrics.""" + dense_hf_config = Qwen3Config( + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=12, + ) + dense_config = create_mock_vllm_config(dense_hf_config) + + moe_hf_config = Qwen3MoeConfig( + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=12, + num_experts=64, + num_experts_per_tok=2, # 2 routed expert + moe_intermediate_size=8192, + n_shared_experts=0, + ) + moe_config = create_mock_vllm_config(moe_hf_config) + + dense_metrics = FfnMetrics.from_vllm_config(dense_config) + moe_metrics = FfnMetrics.from_vllm_config(moe_config) + + ctx = ExecutionContext.from_single_request( + num_tokens=100, context_len=512, is_prefill=True + ) + + # MoE should have different compute/memory characteristics + dense_flops = dense_metrics.get_num_flops(ctx) + moe_flops = moe_metrics.get_num_flops(ctx) + + # 2 routed experts vs 1 dense. + assert moe_flops == dense_flops * 2 + + +def test_unembed_metrics_scaling(): + """Test unembedding metrics scale with vocab size.""" + small_vocab_hf_config = Qwen3Config( + hidden_size=2048, + vocab_size=32000, + ) + small_vocab_config = create_mock_vllm_config(small_vocab_hf_config) + + large_vocab_hf_config = Qwen3Config( + hidden_size=2048, + vocab_size=64000, # Double vocab size + ) + large_vocab_config = create_mock_vllm_config(large_vocab_hf_config) + + small_vocab_metrics = UnembedMetrics.from_vllm_config(small_vocab_config) + large_vocab_metrics = UnembedMetrics.from_vllm_config(large_vocab_config) + + ctx = ExecutionContext.from_single_request( + num_tokens=100, context_len=512, is_prefill=True + ) + + # FLOPS should double when vocab size doubles + small_flops = small_vocab_metrics.get_num_flops(ctx) + large_flops = large_vocab_metrics.get_num_flops(ctx) + assert large_flops == 2 * small_flops + + +def test_prefill_vs_decode_differences(): + """Test that prefill and decode have different memory access patterns.""" + hf_config = Qwen3Config( + hidden_size=2048, + num_attention_heads=16, + num_key_value_heads=16, + num_hidden_layers=1, + ) + config = create_mock_vllm_config(hf_config) + + metrics = AttentionMetrics.from_vllm_config(config) + + prefill_ctx = ExecutionContext.from_single_request( + num_tokens=512, context_len=512, is_prefill=True + ) + decode_ctx = ExecutionContext.from_single_request( + num_tokens=1, context_len=512, is_prefill=False + ) + + prefill_read = metrics.get_read_bytes(prefill_ctx) + decode_read = metrics.get_read_bytes(decode_ctx) + + assert prefill_read != decode_read + + +def test_model_metrics_aggregation(): + """Test ModelMetrics correctly aggregates across components.""" + hf_config = Qwen3Config( + hidden_size=2048, + num_attention_heads=16, + num_hidden_layers=12, + vocab_size=32000, + intermediate_size=8192, + ) + config = create_mock_vllm_config(hf_config) + + model_metrics = ModelMetrics(config) + ctx = ExecutionContext.from_single_request( + num_tokens=100, context_len=512, is_prefill=True + ) + + # Should have metrics for attention, ffn, and unembed + total_flops = model_metrics.get_num_flops(ctx) + breakdown = model_metrics.get_num_flops_breakdown(ctx) + + # Breakdown should sum to total + assert total_flops == sum(breakdown.values()) + + +def test_moe_expert_activation_proportional_scaling(): + """Test that routed expert metrics scale proportionally with num_experts_per_tok.""" + base_moe_config = Qwen3MoeConfig( + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=12, + num_experts=64, + num_experts_per_tok=1, # 1 expert per token + moe_intermediate_size=8192, + n_shared_experts=2, + ) + + double_experts_config = Qwen3MoeConfig( + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=12, + num_experts=64, + num_experts_per_tok=2, # 2 experts per token (double) + moe_intermediate_size=8192, + n_shared_experts=2, # Same shared experts + ) + + triple_experts_config = Qwen3MoeConfig( + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=12, + num_experts=64, + num_experts_per_tok=3, # 3 experts per token (triple) + moe_intermediate_size=8192, + n_shared_experts=2, # Same shared experts + ) + + base_vllm_config = create_mock_vllm_config(base_moe_config) + double_vllm_config = create_mock_vllm_config(double_experts_config) + triple_vllm_config = create_mock_vllm_config(triple_experts_config) + + base_metrics = FfnMetrics.from_vllm_config(base_vllm_config) + double_metrics = FfnMetrics.from_vllm_config(double_vllm_config) + triple_metrics = FfnMetrics.from_vllm_config(triple_vllm_config) + + ctx = ExecutionContext.from_single_request( + num_tokens=100, context_len=512, is_prefill=True + ) + + # Get total metrics - the key insight is that differences should be proportional + base_flops = base_metrics.get_num_flops(ctx) + double_flops = double_metrics.get_num_flops(ctx) + triple_flops = triple_metrics.get_num_flops(ctx) + + # The difference between double and base should equal one additional expert + one_expert_diff = double_flops - base_flops + + # The difference between triple and base should equal two additional experts + two_expert_diff = triple_flops - base_flops + + # Proportional scaling: 2 * (1 expert diff) should equal (2 expert diff) + assert two_expert_diff == 2 * one_expert_diff + + # Same logic applies to memory operations + base_read = base_metrics.get_read_bytes(ctx) + double_read = double_metrics.get_read_bytes(ctx) + triple_read = triple_metrics.get_read_bytes(ctx) + + one_expert_read_diff = double_read - base_read + two_expert_read_diff = triple_read - base_read + + assert two_expert_read_diff == 2 * one_expert_read_diff + + # Same for write bytes + base_write = base_metrics.get_write_bytes(ctx) + double_write = double_metrics.get_write_bytes(ctx) + triple_write = triple_metrics.get_write_bytes(ctx) + + one_expert_write_diff = double_write - base_write + two_expert_write_diff = triple_write - base_write + + assert two_expert_write_diff == 2 * one_expert_write_diff + + +def test_quantization_config_parser_fp8(): + """Test quantization parsers with fp8.""" + + class MockQuantConfig: + def get_name(self): + return "fp8" + + hf_config = Qwen3Config( + hidden_size=2048, num_attention_heads=16, num_hidden_layers=1 + ) + vllm_config = create_mock_vllm_config(hf_config, quant_config=MockQuantConfig()) + + attn_result = AttentionMetrics.get_parser().parse(vllm_config) + assert attn_result.weight_byte_size == 1 # fp8 + + ffn_result = FfnMetrics.get_parser().parse(vllm_config) + assert ffn_result.weight_byte_size == 1 # fp8 + + +def test_quantization_config_parser_mxfp4(): + """Test quantization parsers with mxfp4.""" + + class MockQuantConfig: + def get_name(self): + return "mxfp4" + + hf_config = Qwen3Config( + hidden_size=2048, intermediate_size=8192, num_hidden_layers=1 + ) + vllm_config = create_mock_vllm_config(hf_config, quant_config=MockQuantConfig()) + + ffn_result = FfnMetrics.get_parser().parse(vllm_config) + assert ffn_result.weight_byte_size == 0.5 # mxfp4 + + +#### Per-GPU Tests #### + + +def test_attention_per_gpu_with_tensor_parallelism(): + """Test attention metrics with tensor parallelism - per_gpu vs global.""" + hf_config = Qwen3Config( + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=8, + num_hidden_layers=24, + ) + + # Test with TP=4 + vllm_config = create_mock_vllm_config(hf_config, tensor_parallel_size=4) + metrics = AttentionMetrics.from_vllm_config(vllm_config) + + ctx = ExecutionContext.from_single_request( + num_tokens=128, context_len=1024, is_prefill=True + ) + + # Get global and per-gpu metrics + global_flops = metrics.get_num_flops(ctx, per_gpu=False) + per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True) + + # With TP=4, global flops should be 4x per-gpu flops (heads divided by 4) + assert global_flops == 4 * per_gpu_flops + + # Same for read/write bytes + global_read = metrics.get_read_bytes(ctx, per_gpu=False) + per_gpu_read = metrics.get_read_bytes(ctx, per_gpu=True) + # Reads should scale similarly (weight reads are divided by TP) + assert global_read > per_gpu_read + + global_write = metrics.get_write_bytes(ctx, per_gpu=False) + per_gpu_write = metrics.get_write_bytes(ctx, per_gpu=True) + assert global_write > per_gpu_write + + +def test_attention_per_gpu_with_pipeline_parallelism(): + """Test attention metrics with pipeline parallelism - per_gpu vs global.""" + hf_config = Qwen3Config( + hidden_size=2048, + num_attention_heads=16, + num_hidden_layers=32, + ) + + # Test with PP=4 + vllm_config = create_mock_vllm_config(hf_config, pipeline_parallel_size=4) + metrics = AttentionMetrics.from_vllm_config(vllm_config) + + ctx = ExecutionContext.from_single_request( + num_tokens=100, context_len=512, is_prefill=False + ) + + # Get global and per-gpu metrics + global_flops = metrics.get_num_flops(ctx, per_gpu=False) + per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True) + + # With PP=4, global flops should be 4x per-gpu flops (layers divided by 4) + assert global_flops == 4 * per_gpu_flops + + global_read = metrics.get_read_bytes(ctx, per_gpu=False) + per_gpu_read = metrics.get_read_bytes(ctx, per_gpu=True) + assert global_read == 4 * per_gpu_read + + +def test_ffn_per_gpu_with_tensor_parallelism(): + """Test FFN metrics with tensor parallelism - per_gpu vs global.""" + hf_config = Qwen3Config( + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + ) + + # Test with DP=2, TP=4 (ffn_tp_size will be 8) + vllm_config = create_mock_vllm_config( + hf_config, + data_parallel_size=2, + tensor_parallel_size=4, + ) + metrics = FfnMetrics.from_vllm_config(vllm_config) + + # ffn_tp_size should be dp_size * tp_size = 8 (when EP not enabled) + assert metrics.ffn_tp_size == 8 + + ctx = ExecutionContext.from_single_request( + num_tokens=128, context_len=2048, is_prefill=True + ) + + # Get global and per-gpu metrics + global_flops = metrics.get_num_flops(ctx, per_gpu=False) + per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True) + + # With ffn_tp_size=8, global should be 8x per-gpu + assert global_flops == 8 * per_gpu_flops + + +def test_ffn_per_gpu_with_pipeline_parallelism(): + """Test FFN metrics with pipeline parallelism - per_gpu vs global.""" + hf_config = Qwen3Config( + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=24, + ) + + # Test with PP=6 + vllm_config = create_mock_vllm_config(hf_config, pipeline_parallel_size=6) + metrics = FfnMetrics.from_vllm_config(vllm_config) + + ctx = ExecutionContext.from_single_request( + num_tokens=100, context_len=512, is_prefill=True + ) + + # Get global and per-gpu metrics + global_flops = metrics.get_num_flops(ctx, per_gpu=False) + per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True) + + # With PP=6, global should be 6x per-gpu (layers divided by 6) + assert global_flops == 6 * per_gpu_flops + + +def test_moe_per_gpu_with_expert_parallelism(): + """ + Test MoE metrics with expert parallelism - verifies num_activated_experts bug fix. + """ + hf_config = Qwen3MoeConfig( + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=24, + num_experts=64, + num_experts_per_tok=8, + moe_intermediate_size=14336, + n_shared_experts=2, + ) + + # Test with DP=2, TP=4, EP enabled (ffn_ep_size will be 8) + vllm_config = create_mock_vllm_config( + hf_config, + data_parallel_size=2, + tensor_parallel_size=4, + enable_expert_parallel=True, + ) + metrics = FfnMetrics.from_vllm_config(vllm_config) + + # When EP enabled, ffn_ep_size = dp_size * tp_size = 8 + assert metrics.ffn_ep_size == 8 + assert metrics.ffn_tp_size == 1 + + ctx = ExecutionContext.from_single_request( + num_tokens=100, context_len=512, is_prefill=True + ) + + # Get per-gpu metrics + per_gpu_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=True) + global_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=False) + + # Verify that routed expert weight reads are reasonable + # With per_gpu=True, each GPU has 64/8 = 8 experts + # T=100, E_per_gpu=8/8=1, so T*E=100 expert activations + # num_activated_experts should be min(100, 8) = 8 + + # Check that weight reads scale appropriately + # Global has all 64 experts, per-gpu has 8 experts + # So weight reads should reflect this difference + if "routed_up_gate_weights" in per_gpu_read_breakdown: + per_gpu_weight_reads = per_gpu_read_breakdown["routed_up_gate_weights"] + global_weight_reads = global_read_breakdown["routed_up_gate_weights"] + + # The ratio should reflect the expert count difference + # This verifies the bug fix works correctly + assert per_gpu_weight_reads < global_weight_reads + + # Global should read more experts than per-gpu + # Exact ratio depends on num_activated_experts calculation + ratio = global_weight_reads / per_gpu_weight_reads + # Should be > 1 since global has more experts to read + assert ratio > 1 + + +def test_moe_per_gpu_expert_activation_accounting(): + """ + Test that MoE correctly accounts for expert activations with small batch sizes. + """ + hf_config = Qwen3MoeConfig( + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=12, + num_experts=64, + num_experts_per_tok=8, + moe_intermediate_size=14336, + n_shared_experts=0, # No shared experts for this test + ) + + # Test with EP=8 + vllm_config = create_mock_vllm_config( + hf_config, + data_parallel_size=8, + enable_expert_parallel=True, + ) + metrics = FfnMetrics.from_vllm_config(vllm_config) + + # Small batch: T=10, E_per_gpu=8/8=1 + # Each GPU: T*E = 10*1 = 10 activations + # Experts per GPU: 64/8 = 8 + # So num_activated_experts should be min(10, 8) = 8 + small_ctx = ExecutionContext.from_single_request( + num_tokens=10, context_len=512, is_prefill=True + ) + small_read = metrics.get_read_bytes_breakdown(small_ctx, per_gpu=True) + + # Large batch: T=1000, E_per_gpu=1 + # Each GPU: T*E = 1000*1 = 1000 activations + # Experts per GPU: 8 + # So num_activated_experts should be min(1000, 8) = 8 (all experts activated) + large_ctx = ExecutionContext.from_single_request( + num_tokens=1000, context_len=512, is_prefill=True + ) + large_read = metrics.get_read_bytes_breakdown(large_ctx, per_gpu=True) + + # Weight reads should be similar (both activate all 8 experts per GPU) + # But activation reads should differ (proportional to T*E) + if "routed_up_gate_weights" in small_read: + small_weight = small_read["routed_up_gate_weights"] + large_weight = large_read["routed_up_gate_weights"] + + # Weight reads should be the same (both read all 8 experts) + assert small_weight == large_weight + + # But input activation reads should scale with T*E + small_input = small_read["routed_up_gate_input"] + large_input = large_read["routed_up_gate_input"] + assert large_input == 100 * small_input # 1000/10 = 100x + + +def test_unembed_per_gpu_with_tensor_parallelism(): + """Test unembed metrics with tensor parallelism - per_gpu vs global.""" + hf_config = Qwen3Config( + hidden_size=4096, + vocab_size=128000, + ) + + # Test with TP=8 + vllm_config = create_mock_vllm_config(hf_config, tensor_parallel_size=8) + metrics = UnembedMetrics.from_vllm_config(vllm_config) + + ctx = ExecutionContext.from_single_request( + num_tokens=100, context_len=512, is_prefill=True + ) + + # Get global and per-gpu metrics + global_flops = metrics.get_num_flops(ctx, per_gpu=False) + per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True) + + # With TP=8, vocab is divided by 8, so global should be 8x per-gpu + assert global_flops == 8 * per_gpu_flops + + # For read bytes, weight reads scale with TP but input reads don't (replicated) + global_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=False) + per_gpu_read_breakdown = metrics.get_read_bytes_breakdown(ctx, per_gpu=True) + + # Input reads should be the same (replicated across TP ranks) + assert global_read_breakdown["input"] == per_gpu_read_breakdown["input"] + + # Weight reads should scale 8x (divided by TP) + assert global_read_breakdown["weight"] == 8 * per_gpu_read_breakdown["weight"] + + +def test_model_metrics_per_gpu_aggregation(): + """Test ModelMetrics correctly aggregates per_gpu metrics across components.""" + hf_config = Qwen3Config( + hidden_size=2048, + num_attention_heads=16, + num_hidden_layers=12, + vocab_size=32000, + intermediate_size=8192, + ) + + # Test with mixed parallelism: TP=2, PP=2 + vllm_config = create_mock_vllm_config( + hf_config, + tensor_parallel_size=2, + pipeline_parallel_size=2, + ) + + model_metrics = ModelMetrics(vllm_config) + ctx = ExecutionContext.from_single_request( + num_tokens=100, context_len=512, is_prefill=True + ) + + # Get breakdowns for both modes + per_gpu_breakdown = model_metrics.get_num_flops_breakdown(ctx, per_gpu=True) + global_breakdown = model_metrics.get_num_flops_breakdown(ctx, per_gpu=False) + + # Verify breakdown sums match totals + per_gpu_total = model_metrics.get_num_flops(ctx, per_gpu=True) + global_total = model_metrics.get_num_flops(ctx, per_gpu=False) + + assert per_gpu_total == sum(per_gpu_breakdown.values()) + assert global_total == sum(global_breakdown.values()) + + # Global should be larger than per-gpu due to parallelism + assert global_total > per_gpu_total + + # With TP=2 and PP=2, the ratio depends on which parallelism applies to + # which component but we can verify that global is reasonably larger + ratio = global_total / per_gpu_total + assert ratio > 1 # Should be between PP and TP*PP depending on component mix + + +def test_attention_per_gpu_heads_not_evenly_divisible(): + """Test attention with heads not evenly divisible by TP.""" + hf_config = Qwen3Config( + hidden_size=2048, + num_attention_heads=17, # Not divisible by 4 + num_key_value_heads=5, # Not divisible by 4 + num_hidden_layers=8, + ) + + vllm_config = create_mock_vllm_config(hf_config, tensor_parallel_size=4) + metrics = AttentionMetrics.from_vllm_config(vllm_config) + + ctx = ExecutionContext.from_single_request( + num_tokens=64, context_len=256, is_prefill=True + ) + + # Should not crash and should handle max(1, ...) correctly + per_gpu_flops = metrics.get_num_flops(ctx, per_gpu=True) + global_flops = metrics.get_num_flops(ctx, per_gpu=False) + + # Both should be positive + assert per_gpu_flops > 0 + assert global_flops > 0 + assert global_flops > per_gpu_flops diff --git a/vllm/config/observability.py b/vllm/config/observability.py index e40bf18a00ce2..4aca6b15684ac 100644 --- a/vllm/config/observability.py +++ b/vllm/config/observability.py @@ -64,6 +64,9 @@ class ObservabilityConfig: module in the model and attach informations such as input/output shapes to nvtx range markers. Noted that this doesn't work with CUDA graphs enabled.""" + enable_mfu_metrics: bool = False + """Enable Model FLOPs Utilization (MFU) metrics.""" + @cached_property def collect_model_forward_time(self) -> bool: """Whether to collect model forward time for the request.""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 03720bd2516d4..64510bdcaf8a8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -523,6 +523,7 @@ class EngineArgs: enable_layerwise_nvtx_tracing: bool = ( ObservabilityConfig.enable_layerwise_nvtx_tracing ) + enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls @@ -1042,6 +1043,10 @@ class EngineArgs: "--enable-layerwise-nvtx-tracing", **observability_kwargs["enable_layerwise_nvtx_tracing"], ) + observability_group.add_argument( + "--enable-mfu-metrics", + **observability_kwargs["enable_mfu_metrics"], + ) # Scheduler arguments scheduler_kwargs = get_kwargs(SchedulerConfig) @@ -1689,6 +1694,7 @@ class EngineArgs: kv_cache_metrics_sample=self.kv_cache_metrics_sample, cudagraph_metrics=self.cudagraph_metrics, enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing, + enable_mfu_metrics=self.enable_mfu_metrics, ) # Compilation config overrides diff --git a/vllm/envs.py b/vllm/envs.py index 2f8158d88d6c5..b59991aa6523a 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -244,6 +244,7 @@ if TYPE_CHECKING: VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool = False + VLLM_DEBUG_MFU_METRICS: bool = False def get_default_cache_root(): @@ -1565,6 +1566,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_V2_MODEL_RUNNER": lambda: bool( int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) ), + # Debug logging for --enable-mfu-metrics + "VLLM_DEBUG_MFU_METRICS": lambda: bool( + int(os.getenv("VLLM_DEBUG_MFU_METRICS", "0")) + ), } # --8<-- [end:env-vars-definition] diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8e835ad096405..da8339558b143 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -43,6 +43,7 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.metrics.perf import ModelMetrics, PerfStats from vllm.v1.metrics.stats import ( PrefixCacheStats, SchedulerStats, @@ -219,6 +220,10 @@ class Scheduler(SchedulerInterface): self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER + self.perf_metrics: ModelMetrics | None = None + if self.log_stats and vllm_config.observability_config.enable_mfu_metrics: + self.perf_metrics = ModelMetrics(vllm_config) + def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -1066,6 +1071,10 @@ class Scheduler(SchedulerInterface): kv_connector_output = model_runner_output.kv_connector_output cudagraph_stats = model_runner_output.cudagraph_stats + perf_stats: PerfStats | None = None + if self.perf_metrics and self.perf_metrics.is_enabled(): + perf_stats = self.perf_metrics.get_step_perf_stats_per_gpu(scheduler_output) + outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: SpecDecodingStats | None = None kv_connector_stats: KVConnectorStats | None = ( @@ -1262,7 +1271,7 @@ class Scheduler(SchedulerInterface): if ( stats := self.make_stats( - spec_decoding_stats, kv_connector_stats, cudagraph_stats + spec_decoding_stats, kv_connector_stats, cudagraph_stats, perf_stats ) ) is not None: # Return stats to only one of the front-ends. @@ -1485,6 +1494,7 @@ class Scheduler(SchedulerInterface): spec_decoding_stats: SpecDecodingStats | None = None, kv_connector_stats: KVConnectorStats | None = None, cudagraph_stats: CUDAGraphStat | None = None, + perf_stats: PerfStats | None = None, ) -> SchedulerStats | None: if not self.log_stats: return None @@ -1510,6 +1520,7 @@ class Scheduler(SchedulerInterface): spec_decoding_stats=spec_stats, kv_connector_stats=connector_stats_payload, cudagraph_stats=cudagraph_stats, + perf_stats=perf_stats, ) def make_spec_decoding_stats( diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 9eaee1bb97bb9..2213b952c7a89 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( from vllm.logger import init_logger from vllm.plugins import STAT_LOGGER_PLUGINS_GROUP, load_plugins_by_group from vllm.v1.engine import FinishReason +from vllm.v1.metrics.perf import PerfMetricsLogging from vllm.v1.metrics.prometheus import unregister_vllm_metrics from vllm.v1.metrics.stats import ( CachingMetrics, @@ -118,6 +119,9 @@ class LoggingStatLogger(StatLoggerBase): self.engine_is_idle = False self.aggregated = False + if self._enable_perf_stats(): + self.perf_metrics_logging = PerfMetricsLogging(vllm_config) + def _reset(self, now): self.last_log_time = now @@ -127,6 +131,9 @@ class LoggingStatLogger(StatLoggerBase): self.num_corrupted_reqs: int = 0 self.num_preemptions: int = 0 + def _enable_perf_stats(self) -> bool: + return self.vllm_config.observability_config.enable_mfu_metrics + def _track_iteration_stats(self, iteration_stats: IterationStats): # Save tracked stats for token counters. self.num_prompt_tokens += iteration_stats.num_prompt_tokens @@ -175,6 +182,8 @@ class LoggingStatLogger(StatLoggerBase): self.cudagraph_logging.observe(scheduler_stats.cudagraph_stats) if not self.aggregated: self.last_scheduler_stats = scheduler_stats + if (perf_stats := scheduler_stats.perf_stats) and self._enable_perf_stats(): + self.perf_metrics_logging.observe(perf_stats) if mm_cache_stats: self.mm_caching_metrics.observe(mm_cache_stats) @@ -211,7 +220,7 @@ class LoggingStatLogger(StatLoggerBase): "Running: %d reqs", "Waiting: %d reqs", ] - log_args = [ + log_args: list[int | float | str] = [ self.last_prompt_throughput, self.last_generation_throughput, self.last_scheduler_stats.num_running_reqs, @@ -254,6 +263,8 @@ class LoggingStatLogger(StatLoggerBase): self.kv_connector_logging.log(log_fn=log_fn) if self.cudagraph_logging is not None: self.cudagraph_logging.log(log_fn=log_fn) + if self._enable_perf_stats(): + self.perf_metrics_logging.log(log_fn=log_fn, log_prefix=self.log_prefix) def log_engine_initialized(self): if self.vllm_config.cache_config.num_gpu_blocks: @@ -282,6 +293,10 @@ class AggregatedLoggingStatLogger(LoggingStatLogger, AggregateStatLoggerBase): def log_prefix(self): return "{} Engines Aggregated: ".format(len(self.engine_indexes)) + def _enable_perf_stats(self) -> bool: + # Adding per_gpu perf stats across engines can lead to misleading numbers. + return False + def record( self, scheduler_stats: SchedulerStats | None, diff --git a/vllm/v1/metrics/perf.py b/vllm/v1/metrics/perf.py new file mode 100644 index 0000000000000..446a81fc4855d --- /dev/null +++ b/vllm/v1/metrics/perf.py @@ -0,0 +1,1244 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Analytic flops/memory estimation module for transformer components, +to help derive MFU (Model Flops Utilization) stats for a running model. +""" + +import json +import time +from abc import ABC, abstractmethod +from collections.abc import Iterable +from dataclasses import asdict, dataclass +from typing import Any, Protocol + +import torch +from pydantic import BaseModel, Field, ValidationError, model_validator +from typing_extensions import Self + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils.torch_utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + get_dtype_size, + get_kv_cache_torch_dtype, +) +from vllm.v1.core.sched.output import SchedulerOutput + +logger = init_logger(__name__) + + +class InvalidComponent(Exception): + """ + Custom exception to indicate that a certain ComponentMetric is not + applicable to the given VllmConfig. + """ + + pass + + +#### Basic Data Types #### + + +@dataclass +class DebugPerfStats: + ## Stats for debugging the metrics calculation + calc_duration: float = 0.0 # time spent calculating these stats + num_prefill_requests: int = 0 + num_decode_requests: int = 0 + context_breakdown: dict[str, int] | None = None + num_flops_per_gpu_breakdown: dict[str, int] | None = None + num_read_bytes_per_gpu_breakdown: dict[str, int] | None = None + num_write_bytes_per_gpu_breakdown: dict[str, int] | None = None + + +@dataclass +class PerfStats: + num_flops_per_gpu: int = 0 + num_read_bytes_per_gpu: int = 0 + num_write_bytes_per_gpu: int = 0 + debug_stats: DebugPerfStats | None = None + + +@dataclass +class ExecutionContext: + """ + Represents an execution context for a batch of requests. + + This class aggregates statistics across multiple requests in a batch, + separately tracking prefill and decode phases. + + Example) + - Batch with one full prefill (2048 tokens) and one decode (1 token, 8192 context): + ctx = ExecutionContext() + ctx.add(2048, 2048, is_prefill=True) + ctx.add(1, 8192, is_prefill=False) + """ + + # Prefill phase statistics + num_prefill_requests: int = 0 + prefill_num_tokens: int = 0 # sum of num_tokens for prefill requests + prefill_context_len: int = 0 # sum of context_len for prefill requests + prefill_token_context_product: int = 0 # sum of (num_tokens * context_len) + + # Decode phase statistics + num_decode_requests: int = 0 + decode_num_tokens: int = 0 # sum of num_tokens for decode requests + decode_context_len: int = 0 # sum of context_len for decode requests + decode_token_context_product: int = 0 # sum of (num_tokens * context_len) + + def add(self, num_tokens: int, context_len: int, is_prefill: bool) -> None: + """Add a single request's statistics to this batch context.""" + if is_prefill: + self.num_prefill_requests += 1 + self.prefill_num_tokens += num_tokens + self.prefill_context_len += context_len + self.prefill_token_context_product += num_tokens * context_len + else: + self.num_decode_requests += 1 + self.decode_num_tokens += num_tokens + self.decode_context_len += context_len + self.decode_token_context_product += num_tokens * context_len + + def total_num_tokens(self) -> int: + """Total number of tokens across all requests in the batch.""" + return self.prefill_num_tokens + self.decode_num_tokens + + def total_token_context_product(self) -> int: + """Total sum of (num_tokens * context_len) across all requests.""" + return self.prefill_token_context_product + self.decode_token_context_product + + @classmethod + def from_single_request( + cls, num_tokens: int, context_len: int, is_prefill: bool + ) -> "ExecutionContext": + """Create an ExecutionContext from a single request. + + This is a convenience method primarily for testing. + """ + ctx = cls() + ctx.add(num_tokens, context_len, is_prefill) + return ctx + + +class ParsedArgs: + """ + Syntactic sugar so that Parsers can use dot notations + to access/update the parsed arguments. + + e.g.) + args = ParsedArgs() + args.x = 3 + args.y = args.x + 1 + """ + + def __getattr__(self, name: str) -> Any: + raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") + + def __setattr__(self, name: str, value: Any) -> None: + object.__setattr__(self, name, value) + + def model_dump(self) -> dict[str, Any]: + return vars(self).copy() + + +#### Abstract #### + + +class Parser(Protocol): + def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: + """ + Parse the vllm config and update the current ParsedArgs and pass it on. + If the parser isn't applicable to the vllm_config, it will do nothing. + """ + ... + + +class ParserChain: + """ + Applies chain of parser in a sequential order. + Later parsers might overwrite results from previous parsers, + so parsers should be chained in the appropriate order if they + are not mutually exclusive. + """ + + def __init__(self, *parsers: Parser) -> None: + self.parsers = list(parsers) + + def add_parser(self, parser: Parser) -> None: + self.parsers.append(parser) + + def parse(self, vllm_config: VllmConfig) -> ParsedArgs: + args = ParsedArgs() + for parser in self.parsers: + args = parser.parse(args, vllm_config) + return args + + +_COMPONENT_METRICS_REGISTRY: dict[str, type["ComponentMetrics"]] = {} + + +class ComponentMetrics(BaseModel, ABC): + """ + Each concrete ComponentMetrics class is associated with: + - fields that are required for metric derivation + (fields are specified/validated through pydantic model) + - parser to parse VllmConfig into fields + - metric methods that derive flops/bytes for a given execution context + """ + + @classmethod + @abstractmethod + def component_type(cls) -> str: ... + + @classmethod + @abstractmethod + def get_parser(cls) -> ParserChain: + """ + Return a ParserChain that provides values for all required fields. + The returned parser chain must populate ParsedArgs with values for every + field defined on this ComponentMetrics class. Missing fields will cause + a ValidationError when from_vllm_config() is called. + See individual Parser docstrings for which args they provide, and field + comments on ComponentMetrics subclasses for which parser provides each field. + """ + ... + + def __init_subclass__(cls): + _COMPONENT_METRICS_REGISTRY[cls.component_type()] = cls + + @classmethod + def from_vllm_config(cls, vllm_config: VllmConfig) -> Self: + """ + Instantiate this class from VllmConfig. + Raises ValidationError if parsing fails. + """ + + parser = cls.get_parser() + parsed_args = parser.parse(vllm_config) + try: + return cls.model_validate(parsed_args.model_dump()) + except ValidationError as e: + raise InvalidComponent(f"Invalid {cls.component_type()} config: {e}") from e + + @classmethod + def registered_metrics(cls) -> Iterable[type["ComponentMetrics"]]: + return iter(_COMPONENT_METRICS_REGISTRY.values()) + + @abstractmethod + def get_num_flops_breakdown( + self, ctx: ExecutionContext, per_gpu: bool = True + ) -> dict[str, int]: ... + + @abstractmethod + def get_read_bytes_breakdown( + self, ctx: ExecutionContext, per_gpu: bool = True + ) -> dict[str, int]: ... + + @abstractmethod + def get_write_bytes_breakdown( + self, ctx: ExecutionContext, per_gpu: bool = True + ) -> dict[str, int]: ... + + def get_num_flops(self, ctx: ExecutionContext, per_gpu: bool = True) -> int: + return sum(self.get_num_flops_breakdown(ctx, per_gpu).values()) + + def get_read_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int: + return sum(self.get_read_bytes_breakdown(ctx, per_gpu).values()) + + def get_write_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int: + return sum(self.get_write_bytes_breakdown(ctx, per_gpu).values()) + + +#### parsers #### + + +class BaseConfigParser(Parser): + """ + Parses base model configuration. + Provides: vocab_size, hidden_size, num_attention_heads, num_hidden_layers, + weight_byte_size, activation_byte_size, dp_size, tp_size, pp_size, enable_ep + """ + + def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: + model_config = vllm_config.model_config + + args.vocab_size = model_config.get_vocab_size() + args.hidden_size = model_config.get_hidden_size() + # NOTE: model_config.get_attention_heads() divide by TP + # so we access field manually here to get total num_heads + args.num_attention_heads = get_required( + model_config.hf_text_config, "num_attention_heads" + ) + args.num_hidden_layers = get_required( + model_config.hf_text_config, "num_hidden_layers" + ) + + model_dtype = vllm_config.model_config.dtype + + if isinstance(model_dtype, torch.dtype): + torch_dtype = model_dtype + elif isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + else: + # FIXME: handle this better + logger.warning( + "Unknown model_dtype %s, defaulting to bfloat16", + model_dtype, + ) + torch_dtype = torch.bfloat16 + + args.weight_byte_size = get_dtype_size(torch_dtype) + + # FIXME: handle this better by parsing whether activations use + # bf16, fp32, etc... + args.activation_byte_size = 2 + + args.dp_size = vllm_config.parallel_config.data_parallel_size + args.tp_size = vllm_config.parallel_config.tensor_parallel_size + args.pp_size = vllm_config.parallel_config.pipeline_parallel_size + args.enable_ep = vllm_config.parallel_config.enable_expert_parallel + + return args + + +#### Attention #### + + +class BaseAttentionConfigParser(Parser): + """ + Parses attention-specific configuration. + Provides: num_key_value_heads, head_dim, cache_byte_size + """ + + def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: + model_config = vllm_config.model_config + + args.num_key_value_heads = model_config.get_total_num_kv_heads() + args.head_dim = model_config.get_head_size() + + model_dtype = vllm_config.model_config.dtype + cache_dtype = vllm_config.cache_config.cache_dtype + + kv_cache_torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + args.cache_byte_size = get_dtype_size(kv_cache_torch_dtype) + + return args + + +class AttentionQuantizationConfigParser(Parser): + """ + Parses quantization configuration for attention layers. + Overrides: weight_byte_size + """ + + def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: + cfg = vllm_config.quant_config + + if cfg is None: + return args + + quant_method = cfg.get_name() + if quant_method in ["fp8", "fbgemm_fp8"]: + # FIXME: This is a hacky coarse-grained fp8 quantization detection. + # FIXME: These configs also have concept of "ignored layers" and we + # need to solve the same problem as above. + args.weight_byte_size = 1 + elif quant_method == "mxfp4": + # FIXME: Also has "ignored layers" issue above + args.weight_byte_size = 0.5 + else: + # FIXME: Add more parsing logic for different quant methods. + raise InvalidComponent + + return args + + +class AttentionMetrics(ComponentMetrics): + # From BaseConfigParser + num_hidden_layers: int = Field(..., gt=0) + hidden_size: int = Field(..., gt=0) + num_attention_heads: int = Field(..., gt=0) + activation_byte_size: int = Field(..., gt=0) + tp_size: int = Field(..., gt=0) + pp_size: int = Field(..., gt=0) + + # From BaseAttentionConfigParser + num_key_value_heads: int = Field(..., gt=0) + head_dim: int = Field(..., gt=0) + cache_byte_size: int = Field(..., gt=0) + + # From BaseConfig Parser, overridden by AttentionQuantizationConfigParser + weight_byte_size: int | float = Field(..., gt=0) + + # TODO: discern cases where we have mixture of different attention layer types + # such as SWA, MLA, etc. + + @classmethod + def component_type(cls) -> str: + return "attn" + + @classmethod + def get_parser(cls) -> ParserChain: + return ParserChain( + BaseConfigParser(), + BaseAttentionConfigParser(), + AttentionQuantizationConfigParser(), + ) + + def get_num_flops_breakdown( + self, ctx: ExecutionContext, per_gpu: bool = True + ) -> dict[str, int]: + L, D, q, kv, d = ( + self.num_hidden_layers, + self.hidden_size, + self.num_attention_heads, + self.num_key_value_heads, + self.head_dim, + ) + T = ctx.total_num_tokens() + TC = ctx.total_token_context_product() + + if per_gpu: + L //= self.pp_size + # tensor parallel along heads + q = max(1, q // self.tp_size) + kv = max(1, kv // self.tp_size) + + return { + "qkv_proj": 2 * T * D * (q + 2 * kv) * d * L, + "attn_qk": 2 * q * TC * d * L, + "attn_av": 2 * q * TC * d * L, + "out_proj": 2 * T * D * q * d * L, + } + + def get_read_bytes_breakdown( + self, ctx: ExecutionContext, per_gpu: bool = True + ) -> dict[str, int]: + L, D, q, kv, d = ( + self.num_hidden_layers, + self.hidden_size, + self.num_attention_heads, + self.num_key_value_heads, + self.head_dim, + ) + T = ctx.total_num_tokens() + + if per_gpu: + L //= self.pp_size + # tensor parallel along heads + q = max(1, q // self.tp_size) + kv = max(1, kv // self.tp_size) + + read_bytes = {} + + read_bytes["qkv_input"] = T * D * self.activation_byte_size * L + read_bytes["qkv_weight"] = int(D * (q + 2 * kv) * d * self.weight_byte_size * L) + + # Attention input reads differ between prefill and decode + # Prefill: read Q, K, V activations (all in activation_byte_size) + if ctx.prefill_num_tokens > 0: + read_bytes["attn_input"] = ( + (ctx.prefill_num_tokens * q + 2 * ctx.prefill_context_len * kv) + * d + * self.activation_byte_size + * L + ) + + # Decode: read Q activations + read K, V from cache (in cache_byte_size) + if ctx.decode_num_tokens > 0: + read_bytes["attn_input"] = read_bytes.get("attn_input", 0) + ( + ctx.decode_num_tokens * q * d * self.activation_byte_size * L + + 2 * ctx.decode_context_len * kv * d * self.cache_byte_size * L + ) + + read_bytes["out_input"] = T * q * d * self.activation_byte_size * L + read_bytes["out_weight"] = int(q * d * D * self.weight_byte_size * L) + + return read_bytes + + def get_write_bytes_breakdown( + self, ctx: ExecutionContext, per_gpu: bool = True + ) -> dict[str, int]: + """Calculate write memory traffic for attention layers.""" + L, D, q, kv, d = ( + self.num_hidden_layers, + self.hidden_size, + self.num_attention_heads, + self.num_key_value_heads, + self.head_dim, + ) + T = ctx.total_num_tokens() + + if per_gpu: + L //= self.pp_size + # tensor parallel along heads + q = max(1, q // self.tp_size) + kv = max(1, kv // self.tp_size) + + return { + "qkv_output": T * (q + 2 * kv) * d * self.activation_byte_size * L, + "kv_cache": 2 * T * kv * d * self.cache_byte_size * L, + "out_output": T * D * self.activation_byte_size * L, + } + + +#### Ffn #### + + +class BaseFfnConfigParser(Parser): + """ + Parses FFN and MoE configuration. + Provides: intermediate_size, num_experts, num_experts_per_tok, + moe_intermediate_size, num_shared_experts, num_moe_layers + """ + + def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: + cfg = vllm_config.model_config.hf_config + if hasattr(cfg, "text_config") and cfg.text_config is not None: + cfg = cfg.text_config + + args.intermediate_size = getattr(cfg, "intermediate_size", args.hidden_size * 4) + + # Try different naming conventions. + args.num_experts = vllm_config.model_config.get_num_experts() + args.num_experts_per_tok = getattr_from_list( + cfg, ["num_experts_per_tok", "moe_topk"], 0 + ) + args.moe_intermediate_size = getattr_from_list( + cfg, ["moe_intermediate_size", "intermediate_size"], 0 + ) + args.num_shared_experts = getattr_from_list( + cfg, ["n_shared_experts", "num_shared_experts"], 0 + ) + + is_moe = args.num_experts != 0 + # Assume all MoE layers by default + args.num_moe_layers = args.num_hidden_layers if is_moe else 0 + + return args + + +class FfnParallelParser(Parser): + """ + Parses FFN parallelism configuration. + + Provides: ffn_tp_size, ffn_ep_size + """ + + def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: + # NOTE: ffn tp_size does not equal the tp_size parameter directly. + # e.g.) If we use DP2TP4, ffn will use TP8 (or EP8 if EP is enabled.) + if args.enable_ep: + ffn_tp_size, ffn_ep_size = 1, args.dp_size * args.tp_size + else: + ffn_tp_size, ffn_ep_size = args.dp_size * args.tp_size, 1 + + args.ffn_tp_size = ffn_tp_size + args.ffn_ep_size = ffn_ep_size + + return args + + +class InterleaveMoeLayerStepParser(Parser): + """ + Parses interleave_moe_layer_step field for models like Llama4. + + Overrides: num_moe_layers + """ + + def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: + cfg = vllm_config.model_config.hf_config + if hasattr(cfg, "text_config") and cfg.text_config is not None: + cfg = cfg.text_config + + if ( + hasattr(cfg, "interleave_moe_layer_step") + and cfg.interleave_moe_layer_step > 0 + ): + args.num_moe_layers = len( + [ + layer + for layer in range(args.num_hidden_layers) + if (layer + 1) % cfg.interleave_moe_layer_step == 0 + ] + ) + + return args + + +class MoeLayerFreqParser(Parser): + """ + Parses moe_layer_freq and first_k_dense_replace fields for models like Deepseek. + + Overrides: num_moe_layers + """ + + def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: + cfg = vllm_config.model_config.hf_config + if hasattr(cfg, "text_config") and cfg.text_config is not None: + cfg = cfg.text_config + + if hasattr(cfg, "moe_layer_freq") and hasattr(cfg, "first_k_dense_replace"): + args.num_moe_layers = len( + [ + layer + for layer in range(args.num_hidden_layers) + if layer >= cfg.first_k_dense_replace + and layer % cfg.moe_layer_freq == 0 + ] + ) + + return args + + +class FfnQuantizationConfigParser(Parser): + """ + Parses quantization configuration for FFN layers. + + Overrides: weight_byte_size + """ + + def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: + cfg = vllm_config.quant_config + + if cfg is None: + return args + + quant_method = cfg.get_name() + if quant_method in ["fp8", "fbgemm_fp8"]: + # FIXME: This is a hacky coarse-grained fp8 quantization detection. + # (there might be more quantization methods for fp8). + # FIXME: These configs also have concept of "ignored layers" and we + # need to solve the same problem as above. + args.weight_byte_size = 1 + pass + elif quant_method == "mxfp4": + # FIXME: Also has "ignored layers" issue above + args.weight_byte_size = 0.5 + else: + # FIXME: Add more parsing logic for different quant methods. + raise InvalidComponent + + return args + + +class FfnMetrics(ComponentMetrics): + # From BaseConfigParser + num_hidden_layers: int = Field(..., gt=0) + hidden_size: int = Field(..., gt=0) + activation_byte_size: int = Field(..., gt=0) + pp_size: int = Field(..., gt=0) + + # From FfnParallelParser + ffn_tp_size: int = Field(..., gt=0) + ffn_ep_size: int = Field(..., gt=0) + + # From BaseFfnConfigParser + intermediate_size: int = Field(..., gt=0) + num_experts: int = Field(0) + num_experts_per_tok: int = Field(1) + moe_intermediate_size: int = Field(0) + num_shared_experts: int = Field(0) + + # From BaseConfigParser, can be overridden InterleaveMoeLayerStep or MoeLayerFreq + num_moe_layers: int = Field(..., ge=0) + + # FIXME: might have to make this more granular + # (i.e. dense_weight_byte_size, moe_routed_weight_byte_size, + # moe_shared_weight_byte_size) + # since it can differ from byte size of other components (e.g. attn) + # and can differ even from each other. + + # From BaseConfigParser, can be overridden by FfnQuantizationConfigParser + weight_byte_size: int | float = Field(..., gt=0) + + @model_validator(mode="after") + def validate_moe_fields(self) -> Self: + """Validate that MoE-related fields are properly set when num_moe_layers > 0.""" + if self.num_moe_layers > 0: + assert self.num_experts, f"{self.num_experts=}" + assert self.num_experts_per_tok, f"{self.num_experts_per_tok=}" + assert self.moe_intermediate_size, f"{self.moe_intermediate_size=}" + return self + + @classmethod + def component_type(cls) -> str: + return "ffn" + + @classmethod + def get_parser(cls) -> ParserChain: + return ParserChain( + BaseConfigParser(), + FfnParallelParser(), + BaseFfnConfigParser(), + InterleaveMoeLayerStepParser(), + MoeLayerFreqParser(), + FfnQuantizationConfigParser(), + ) + + def get_num_flops_breakdown( + self, ctx: ExecutionContext, per_gpu: bool = True + ) -> dict[str, int]: + """Calculate flops breakdown for FFN layers.""" + L, D, DI = self.num_hidden_layers, self.hidden_size, self.intermediate_size + Lm, E, MI, S = ( + self.num_moe_layers, + self.num_experts_per_tok, + self.moe_intermediate_size, + self.num_shared_experts, + ) + T = ctx.total_num_tokens() + + Ld = L - Lm + + num_activated_tokens = T * E if E else 0 + + if per_gpu: + Ld //= self.pp_size + Lm //= self.pp_size + + DI //= self.ffn_tp_size + if MI is not None: + MI //= self.ffn_tp_size + if E: + num_activated_tokens //= self.ffn_ep_size + + flops = {} + + # Dense FFN layers (SwiGLU: 3 linear layers: up, gate, down) + if Ld: + flops["dense_ffn"] = 2 * D * 3 * DI * T * Ld + + # MoE routed experts (each token activates E experts) + if Lm and E: + flops["routed_ffn"] = 2 * D * 3 * MI * num_activated_tokens * Lm + + # MoE shared experts (all S shared experts run for every token) + if Lm and S: + flops["shared_ffn"] = 2 * D * 3 * MI * S * T * Lm + + return flops + + def get_read_bytes_breakdown( + self, ctx: ExecutionContext, per_gpu: bool = True + ) -> dict[str, int]: + """Calculate read memory traffic for FFN layers.""" + L, D, DI = self.num_hidden_layers, self.hidden_size, self.intermediate_size + Lm, E, MI, S = ( + self.num_moe_layers, + self.num_experts_per_tok, + self.moe_intermediate_size, + self.num_shared_experts, + ) + T = ctx.total_num_tokens() + num_experts = self.num_experts + + Ld = L - Lm + + num_activated_tokens = T * E if E else 0 + + if per_gpu: + Ld //= self.pp_size + Lm //= self.pp_size + + DI //= self.ffn_tp_size + if MI is not None: + MI //= self.ffn_tp_size + if E: + num_activated_tokens //= self.ffn_ep_size + if num_experts is not None: + num_experts //= self.ffn_ep_size + + read_bytes = {} + + # Dense FFN layers (3 GEMMs: up, gate, down projections + SiLU activation) + if Ld: + read_bytes["dense_up_gate_input"] = int( + T * D * self.activation_byte_size * Ld + ) + read_bytes["dense_up_gate_weights"] = int( + 2 * D * DI * self.weight_byte_size * Ld + ) + read_bytes["dense_silu_input"] = int( + 2 * T * DI * self.activation_byte_size * Ld + ) + read_bytes["dense_down_input"] = int( + T * DI * self.activation_byte_size * Ld + ) + read_bytes["dense_down_weights"] = int(D * DI * self.weight_byte_size * Ld) + + if Lm: + # MoE routed expert reads + if E: + # FIXME: Assume perfect load balancing for now. + num_activated_experts = min(num_activated_tokens, num_experts) + + read_bytes["routed_up_gate_input"] = int( + num_activated_tokens * D * self.activation_byte_size * Lm + ) + read_bytes["routed_up_gate_weights"] = int( + 2 * D * MI * num_activated_experts * self.weight_byte_size * Lm + ) + read_bytes["routed_silu_input"] = int( + 2 * num_activated_tokens * MI * self.activation_byte_size * Lm + ) + read_bytes["routed_down_input"] = int( + num_activated_tokens * MI * self.activation_byte_size * Lm + ) + read_bytes["routed_down_weights"] = int( + D * MI * num_activated_experts * self.weight_byte_size * Lm + ) + + # MoE shared expert reads + if S: + read_bytes["shared_up_gate_input"] = int( + T * D * self.activation_byte_size * Lm + ) + read_bytes["shared_up_gate_weights"] = int( + 2 * D * MI * S * self.weight_byte_size * Lm + ) + read_bytes["shared_silu_input"] = int( + 2 * T * MI * S * self.activation_byte_size * Lm + ) + read_bytes["shared_down_input"] = int( + T * MI * self.activation_byte_size * Lm + ) + read_bytes["shared_down_weights"] = int( + D * MI * S * self.weight_byte_size * Lm + ) + + return read_bytes + + def get_write_bytes_breakdown( + self, ctx: ExecutionContext, per_gpu: bool = True + ) -> dict[str, int]: + """Calculate write memory traffic for FFN layers.""" + L, D, DI = self.num_hidden_layers, self.hidden_size, self.intermediate_size + Lm, E, MI, S = ( + self.num_moe_layers, + self.num_experts_per_tok, + self.moe_intermediate_size, + self.num_shared_experts, + ) + T = ctx.total_num_tokens() + + Ld = L - Lm + + num_activated_tokens = T * E if E else 0 + + if per_gpu: + Ld //= self.pp_size + Lm //= self.pp_size + + DI //= self.ffn_tp_size + if MI is not None: + MI //= self.ffn_tp_size + if E: + num_activated_tokens //= self.ffn_ep_size + + write_bytes = {} + + # Dense FFN layers + if Ld: + write_bytes["dense_up_gate_output"] = int( + 2 * T * DI * self.activation_byte_size * Ld + ) + write_bytes["dense_silu_output"] = int( + T * DI * self.activation_byte_size * Ld + ) + write_bytes["dense_down_output"] = int( + T * D * self.activation_byte_size * Ld + ) + + # MoE outputs + if Lm: + if E: + write_bytes["routed_up_gate_output"] = int( + 2 * num_activated_tokens * MI * self.activation_byte_size * Lm + ) + write_bytes["routed_silu_output"] = int( + num_activated_tokens * MI * self.activation_byte_size * Lm + ) + write_bytes["routed_down_output"] = int( + num_activated_tokens * D * self.activation_byte_size * Lm + ) + if S: + write_bytes["shared_up_gate_output"] = int( + 2 * T * S * MI * self.activation_byte_size * Lm + ) + write_bytes["shared_silu_output"] = int( + T * S * MI * self.activation_byte_size * Lm + ) + write_bytes["shared_down_output"] = int( + T * S * D * self.activation_byte_size * Lm + ) + + return write_bytes + + +#### Unembed #### + + +class UnembedMetrics(ComponentMetrics): + # From BaseConfigParser + hidden_size: int = Field(..., gt=0) + vocab_size: int = Field(..., gt=0) + weight_byte_size: int = Field(..., gt=0) + activation_byte_size: int = Field(..., gt=0) + + tp_size: int + + @classmethod + def component_type(cls) -> str: + return "unembed" + + @classmethod + def get_parser(cls) -> ParserChain: + return ParserChain( + BaseConfigParser(), + ) + + def get_num_flops_breakdown( + self, ctx: ExecutionContext, per_gpu: bool = True + ) -> dict[str, int]: + """Calculate flops breakdown for unembedding layer.""" + D, V = self.hidden_size, self.vocab_size + T = ctx.total_num_tokens() + + if per_gpu: + V //= self.tp_size + + return { + "unembed": 2 * T * D * V, + } + + def get_read_bytes_breakdown( + self, ctx: ExecutionContext, per_gpu: bool = True + ) -> dict[str, int]: + """Calculate read memory traffic for unembedding layer.""" + D, V = self.hidden_size, self.vocab_size + T = ctx.total_num_tokens() + + if per_gpu: + V //= self.tp_size + + return { + "input": T * D * self.activation_byte_size, + "weight": D * V * self.weight_byte_size, + } + + def get_write_bytes_breakdown( + self, ctx: ExecutionContext, per_gpu: bool = True + ) -> dict[str, int]: + """Calculate write memory traffic for unembedding layer.""" + V = self.vocab_size + T = ctx.total_num_tokens() + + if per_gpu: + V //= self.tp_size + + return { + "output": T * V * self.activation_byte_size, + } + + +#### ModelMetrics #### + + +class ModelMetrics: + def __init__(self, vllm_config: VllmConfig) -> None: + """ + Parse vllm_config to instantiate metrics for each component. + is_enabled() will return False if no component metrics could be instantiated. + """ + + self.vllm_config = vllm_config + + self.metrics: list[ComponentMetrics] = [] + for metric_cls in ComponentMetrics.registered_metrics(): + try: + metric = metric_cls.from_vllm_config(vllm_config) + self.metrics.append(metric) + logger.info( + "Instantiated ComponentMetrics [%s] with (%s)", + metric.component_type(), + str(metric), + ) + except InvalidComponent as e: + logger.debug( + "Failed to instantiate %s from %s", + metric_cls.component_type(), + str(e), + ) + + def is_enabled(self) -> bool: + return len(self.metrics) > 0 + + def get_num_flops(self, ctx: ExecutionContext, per_gpu: bool = True) -> int: + return sum(metric.get_num_flops(ctx, per_gpu) for metric in self.metrics) + + def get_read_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int: + return sum(metric.get_read_bytes(ctx, per_gpu) for metric in self.metrics) + + def get_write_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int: + return sum(metric.get_write_bytes(ctx, per_gpu) for metric in self.metrics) + + def get_num_flops_breakdown( + self, ctx: ExecutionContext, per_gpu: bool = True + ) -> dict[str, int]: + total = {} + for metric in self.metrics: + breakdown = metric.get_num_flops_breakdown(ctx, per_gpu) + component = metric.component_type() + prefixed = {f"{component}.{key}": val for key, val in breakdown.items()} + total.update(prefixed) + return total + + def get_read_bytes_breakdown( + self, ctx: ExecutionContext, per_gpu: bool = True + ) -> dict[str, int]: + total = {} + for metric in self.metrics: + breakdown = metric.get_read_bytes_breakdown(ctx, per_gpu) + component = metric.component_type() + prefixed = {f"{component}.{key}": val for key, val in breakdown.items()} + total.update(prefixed) + return total + + def get_write_bytes_breakdown( + self, ctx: ExecutionContext, per_gpu: bool = True + ) -> dict[str, int]: + total = {} + for metric in self.metrics: + breakdown = metric.get_write_bytes_breakdown(ctx, per_gpu) + component = metric.component_type() + prefixed = {f"{component}.{key}": val for key, val in breakdown.items()} + total.update(prefixed) + return total + + def get_step_perf_stats_per_gpu( + self, scheduler_output: SchedulerOutput + ) -> PerfStats: + """ + Calculate perf stats for the current step based on scheduled tokens. + """ + + t0 = time.monotonic() + + # Build a single batch context + ctx = ExecutionContext() + + # Process new requests (these are in prefill phase) + for new_req in scheduler_output.scheduled_new_reqs: + req_id = new_req.req_id + num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0) + if num_tokens == 0: + continue + + # For new requests, context_len = num_computed_tokens + num_tokens + # num_computed_tokens represents previously computed tokens in the sequence + context_len = new_req.num_computed_tokens + num_tokens + ctx.add(num_tokens, context_len, is_prefill=True) + + # Process cached requests (continuing requests) + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0) + if num_tokens == 0: + continue + + # For cached requests, we have the current num_computed_tokens + num_computed_tokens = cached_reqs.num_computed_tokens[i] + context_len = num_computed_tokens + num_tokens + + # Cached requests are typically in decode phase (num_tokens == 1) + # unless they're doing chunked prefill (num_tokens > 1) + is_prefill = num_tokens > 1 + ctx.add(num_tokens, context_len, is_prefill) + + num_flops_breakdown = self.get_num_flops_breakdown(ctx, True) + read_bytes_breakdown = self.get_read_bytes_breakdown(ctx, True) + write_bytes_breakdown = self.get_write_bytes_breakdown(ctx, True) + perf_stats = PerfStats( + sum(num_flops_breakdown.values()), + sum(read_bytes_breakdown.values()), + sum(write_bytes_breakdown.values()), + ) + + if envs.VLLM_DEBUG_MFU_METRICS: + perf_stats.debug_stats = DebugPerfStats( + time.monotonic() - t0, + ctx.num_prefill_requests, + ctx.num_decode_requests, + asdict(ctx), + num_flops_breakdown, + read_bytes_breakdown, + write_bytes_breakdown, + ) + + return perf_stats + + +#### Logging #### + + +class PerfMetricsDebugLogging: + def __init__(self): + self.reset() + + def reset(self): + self.total_calc_duration: float = 0.0 + self.total_num_prefill_requests: int = 0 + self.total_num_decode_requests: int = 0 + self.total_num_batches: int = 0 + self.total_context_breakdown: dict[str, int] = {} + self.total_num_flops_per_gpu_breakdown: dict[str, int] = {} + self.total_read_bytes_per_gpu_breakdown: dict[str, int] = {} + self.total_write_bytes_per_gpu_breakdown: dict[str, int] = {} + + def observe(self, debug_stats: DebugPerfStats) -> None: + self.total_calc_duration += debug_stats.calc_duration + self.total_num_prefill_requests += debug_stats.num_prefill_requests + self.total_num_decode_requests += debug_stats.num_decode_requests + self.total_num_batches += 1 + + for dst, src in zip( + [ + self.total_context_breakdown, + self.total_num_flops_per_gpu_breakdown, + self.total_read_bytes_per_gpu_breakdown, + self.total_write_bytes_per_gpu_breakdown, + ], + [ + debug_stats.context_breakdown, + debug_stats.num_flops_per_gpu_breakdown, + debug_stats.num_read_bytes_per_gpu_breakdown, + debug_stats.num_write_bytes_per_gpu_breakdown, + ], + ): + assert isinstance(src, dict) + for key, val in src.items(): + dst[key] = dst.get(key, 0) + val + + def log(self, log_fn, log_prefix: str, delta_time: float): + # pretty print breakdowns + total_num_flops_per_gpu_breakdown = { + k: f"{v / 1e12:.1f}TF" + for k, v in self.total_num_flops_per_gpu_breakdown.items() + } + total_read_bytes_per_gpu_breakdown = { + k: f"{v / 1e9:.1f}GB" + for k, v in self.total_read_bytes_per_gpu_breakdown.items() + } + total_write_bytes_per_gpu_breakdown = { + k: f"{v / 1e9:.1f}GB" + for k, v in self.total_write_bytes_per_gpu_breakdown.items() + } + + logger.debug( + "%sMFU details: %s", + log_prefix, + json.dumps( + { + "prefill_reqs": self.total_num_prefill_requests, + "decode_reqs": self.total_num_decode_requests, + "num_batches": self.total_num_batches, + "context_breakdown": self.total_context_breakdown, + "flops_breakdown": total_num_flops_per_gpu_breakdown, + "num_read_bytes_breakdown": total_read_bytes_per_gpu_breakdown, + "num_write_bytes_breakdown": (total_write_bytes_per_gpu_breakdown), + "duration": f"{delta_time:.1f}s", + "mfu_calc_overhead": ( + f"{self.total_calc_duration / delta_time:.1%}" + ), + }, + indent=2, + ), + ) + + +class PerfMetricsLogging: + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + self.pp_size = vllm_config.parallel_config.pipeline_parallel_size + + self.debug_logging: PerfMetricsDebugLogging | None = None + if envs.VLLM_DEBUG_MFU_METRICS: + self.debug_logging = PerfMetricsDebugLogging() + + self.reset() + + def reset(self): + self.last_log_time = time.monotonic() + + self.total_num_flops_per_gpu: int = 0 + self.total_read_bytes_per_gpu: int = 0 + self.total_write_bytes_per_gpu: int = 0 + + if self.debug_logging: + self.debug_logging.reset() + + def observe(self, perf_stats: PerfStats) -> None: + self.total_num_flops_per_gpu += perf_stats.num_flops_per_gpu + self.total_read_bytes_per_gpu += perf_stats.num_read_bytes_per_gpu + self.total_write_bytes_per_gpu += perf_stats.num_write_bytes_per_gpu + + if self.debug_logging: + assert perf_stats.debug_stats is not None + self.debug_logging.observe(perf_stats.debug_stats) + + def log(self, log_fn=logger.info, log_prefix: str = "") -> None: + if not ( + self.total_num_flops_per_gpu + or self.total_read_bytes_per_gpu + or self.total_write_bytes_per_gpu + ): + return + + now = time.monotonic() + delta_time = now - self.last_log_time + + if delta_time <= 0.0: + avg_tflops_per_gpu = 0.0 + avg_gbps_per_gpu = 0.0 + else: + avg_tflops_per_gpu = self.total_num_flops_per_gpu / delta_time / 1e12 + avg_gbps_per_gpu = ( + (self.total_read_bytes_per_gpu + self.total_write_bytes_per_gpu) + / delta_time + / 1e9 + ) + + log_fn( + "%sMFU: %.1f TF/s/GPU %.1f GB/s/GPU", + log_prefix, + avg_tflops_per_gpu, + avg_gbps_per_gpu, + ) + + if self.debug_logging: + self.debug_logging.log(log_fn, log_prefix, delta_time) + + self.reset() + + +## util functions + + +def get_required(obj: object, attr: str): + """Get an attr from an object, or throw a InvalidComponentError if it's not set.""" + if not hasattr(obj, attr): + raise InvalidComponent(f"Missing required attr {attr} in config") + return getattr(obj, attr) + + +def getattr_from_list(obj: object, attrs: list[str], default: object = None): + """Try to get the first attr that exists in the object + from a list of attrs. Otherwise return None.""" + for attr in attrs: + if hasattr(obj, attr): + return getattr(obj, attr) + return default diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index a0cc58d0a64e8..cb1a860e38fbc 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any import vllm.envs as envs from vllm.compilation.cuda_graph import CUDAGraphStat +from vllm.v1.metrics.perf import PerfStats from vllm.v1.spec_decode.metrics import SpecDecodingStats if TYPE_CHECKING: @@ -186,6 +187,8 @@ class SchedulerStats: cudagraph_stats: CUDAGraphStat | None = None + perf_stats: PerfStats | None = None + @dataclass class RequestStateStats: