mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 16:35:43 +08:00
[Bugfix] Fix test_eagle test (#18223)
Signed-off-by: Lucia Fang <fanglu@fb.com>
This commit is contained in:
parent
0b34593017
commit
8795eb9975
@ -115,14 +115,15 @@ def test_prepare_inputs():
|
||||
("eagle3", lambda k: _create_proposer("eagle3", k), eagle3_dir,
|
||||
('model', 'embed_tokens')),
|
||||
])
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config')
|
||||
def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
|
||||
mock_registry, mock_get_layers, method, proposer_helper,
|
||||
draft_model_dir, target_attribute_path):
|
||||
mock_registry, mock_get_layers, mock_get_pp_group, method,
|
||||
proposer_helper, draft_model_dir, target_attribute_path):
|
||||
|
||||
# Setup mock for model class
|
||||
mock_model_cls = mock.MagicMock()
|
||||
@ -158,6 +159,11 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
|
||||
# Make mock_get_layers return different values for each call
|
||||
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
|
||||
|
||||
# Setup mock for pp group to return the appropriate value for world size
|
||||
mock_pp_group = mock.MagicMock()
|
||||
mock_pp_group.world_size = 2 if method == "eagle" else 1
|
||||
mock_get_pp_group.return_value = mock_pp_group
|
||||
|
||||
# Setup model loader mock
|
||||
mock_loader = mock.MagicMock()
|
||||
mock_get_loader.return_value = mock_loader
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user