mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 06:55:01 +08:00
[Bugfix] Fix test_eagle test (#18223)
Signed-off-by: Lucia Fang <fanglu@fb.com>
This commit is contained in:
parent
0b34593017
commit
8795eb9975
@ -115,14 +115,15 @@ def test_prepare_inputs():
|
|||||||
("eagle3", lambda k: _create_proposer("eagle3", k), eagle3_dir,
|
("eagle3", lambda k: _create_proposer("eagle3", k), eagle3_dir,
|
||||||
('model', 'embed_tokens')),
|
('model', 'embed_tokens')),
|
||||||
])
|
])
|
||||||
|
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
||||||
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
||||||
@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry')
|
@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry')
|
||||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader')
|
@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader')
|
||||||
@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype')
|
@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype')
|
||||||
@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config')
|
@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config')
|
||||||
def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
|
def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
|
||||||
mock_registry, mock_get_layers, method, proposer_helper,
|
mock_registry, mock_get_layers, mock_get_pp_group, method,
|
||||||
draft_model_dir, target_attribute_path):
|
proposer_helper, draft_model_dir, target_attribute_path):
|
||||||
|
|
||||||
# Setup mock for model class
|
# Setup mock for model class
|
||||||
mock_model_cls = mock.MagicMock()
|
mock_model_cls = mock.MagicMock()
|
||||||
@ -158,6 +159,11 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
|
|||||||
# Make mock_get_layers return different values for each call
|
# Make mock_get_layers return different values for each call
|
||||||
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
|
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
|
||||||
|
|
||||||
|
# Setup mock for pp group to return the appropriate value for world size
|
||||||
|
mock_pp_group = mock.MagicMock()
|
||||||
|
mock_pp_group.world_size = 2 if method == "eagle" else 1
|
||||||
|
mock_get_pp_group.return_value = mock_pp_group
|
||||||
|
|
||||||
# Setup model loader mock
|
# Setup model loader mock
|
||||||
mock_loader = mock.MagicMock()
|
mock_loader = mock.MagicMock()
|
||||||
mock_get_loader.return_value = mock_loader
|
mock_get_loader.return_value = mock_loader
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user