[Bugfix] Fix test_eagle test (#18223)

Signed-off-by: Lucia Fang <fanglu@fb.com>
This commit is contained in:
Lucia Fang 2025-05-15 15:59:42 -07:00 committed by GitHub
parent 0b34593017
commit 8795eb9975
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -115,14 +115,15 @@ def test_prepare_inputs():
("eagle3", lambda k: _create_proposer("eagle3", k), eagle3_dir,
('model', 'embed_tokens')),
])
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry')
@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader')
@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype')
@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config')
def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
mock_registry, mock_get_layers, method, proposer_helper,
draft_model_dir, target_attribute_path):
mock_registry, mock_get_layers, mock_get_pp_group, method,
proposer_helper, draft_model_dir, target_attribute_path):
# Setup mock for model class
mock_model_cls = mock.MagicMock()
@ -158,6 +159,11 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
# Make mock_get_layers return different values for each call
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
# Setup mock for pp group to return the appropriate value for world size
mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 2 if method == "eagle" else 1
mock_get_pp_group.return_value = mock_pp_group
# Setup model loader mock
mock_loader = mock.MagicMock()
mock_get_loader.return_value = mock_loader