diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index a7e148d01cad7..7d93a44c50595 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -115,14 +115,15 @@ def test_prepare_inputs(): ("eagle3", lambda k: _create_proposer("eagle3", k), eagle3_dir, ('model', 'embed_tokens')), ]) +@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') @mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') @mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry') @mock.patch('vllm.v1.spec_decode.eagle.get_model_loader') @mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype') @mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config') def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader, - mock_registry, mock_get_layers, method, proposer_helper, - draft_model_dir, target_attribute_path): + mock_registry, mock_get_layers, mock_get_pp_group, method, + proposer_helper, draft_model_dir, target_attribute_path): # Setup mock for model class mock_model_cls = mock.MagicMock() @@ -158,6 +159,11 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader, # Make mock_get_layers return different values for each call mock_get_layers.side_effect = [target_attn_layers, all_attn_layers] + # Setup mock for pp group to return the appropriate value for world size + mock_pp_group = mock.MagicMock() + mock_pp_group.world_size = 2 if method == "eagle" else 1 + mock_get_pp_group.return_value = mock_pp_group + # Setup model loader mock mock_loader = mock.MagicMock() mock_get_loader.return_value = mock_loader