diff --git a/benchmarks/kernels/bench_fp8_gemm.py b/benchmarks/kernels/bench_fp8_gemm.py index 640a33419005..b964ed242edf 100644 --- a/benchmarks/kernels/bench_fp8_gemm.py +++ b/benchmarks/kernels/bench_fp8_gemm.py @@ -5,11 +5,11 @@ import copy import itertools import torch -import triton from weight_shapes import WEIGHT_SHAPES from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant +from vllm.triton_utils import triton @triton.testing.perf_report( diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index eff8eff43ea9..c93b7f57c041 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -9,6 +9,7 @@ import torch from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) +from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.v1.spec_decode.eagle import EagleProposer model_dir = "meta-llama/Llama-3.1-8B-Instruct" @@ -113,21 +114,26 @@ def test_prepare_inputs(): assert torch.equal(token_indices, expected_token_indices) -@pytest.mark.parametrize( - "method,proposer_helper,draft_model_dir,target_attribute_path", [ - ("eagle", lambda k: _create_proposer("eagle", k), eagle_dir, - ('lm_head', )), - ("eagle3", lambda k: _create_proposer("eagle3", k), eagle3_dir, - ('model', 'embed_tokens')), - ]) +@pytest.mark.parametrize("method,proposer_helper", [ + ("eagle", lambda k: _create_proposer("eagle", k)), + ("eagle3", lambda k: _create_proposer("eagle3", k)), +]) +@pytest.mark.parametrize("pp_size", [1, 2]) +@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False]) @mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') @mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') @mock.patch('vllm.v1.spec_decode.eagle.get_model') def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, - proposer_helper, draft_model_dir, target_attribute_path): - - # Setup model mock + proposer_helper, pp_size, use_distinct_embed_tokens): + # Setup draft model mock mock_model = mock.MagicMock() + if use_distinct_embed_tokens: + # Some models can have a different hidden size than the target model, + # so we test that their embed_tokens doesn't get overwritten + mock_model.model.embed_tokens.weight.shape = (131072, 2048) + else: + mock_model.model.embed_tokens.weight.shape = (131072, 4096) + mock_get_model.return_value = mock_model # Setup mocks for attention layers @@ -145,22 +151,24 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, # Setup mock for pp group to return the appropriate value for world size mock_pp_group = mock.MagicMock() - mock_pp_group.world_size = 2 if method == "eagle" else 1 + mock_pp_group.world_size = pp_size mock_get_pp_group.return_value = mock_pp_group - # Setup target model with the appropriate attributes - target_model = mock.MagicMock() + # Setup the target model mock with a custom class so that + # isinstance() checks match the expected type. + class _TargetModelStub(LlamaForCausalLM): + model: mock.MagicMock + lm_head: mock.MagicMock - # Create the necessary attributes on the target model - current_obj = target_model - for i, attr in enumerate(target_attribute_path): - if i == len(target_attribute_path) - 1: - # Set the last attribute in the path to a MagicMock - setattr(current_obj, attr, mock.MagicMock()) - else: - # Create intermediate objects if needed - setattr(current_obj, attr, mock.MagicMock()) - current_obj = getattr(current_obj, attr) + target_model = mock.create_autospec(_TargetModelStub, instance=True) + target_model.model = mock.MagicMock() + target_model.model.embed_tokens.weight.shape = (131072, 4096) + + from vllm.model_executor.models import SupportsMultiModal + assert not isinstance(target_model, SupportsMultiModal) + + if method == "eagle": + target_model.lm_head = mock.MagicMock() # Create proposer using the helper function proposer = proposer_helper(k=8) @@ -171,10 +179,18 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, # Verify common interactions mock_get_model.assert_called_once() - # Verify the specific attribute sharing based on the method + # Verify that EAGLE models gain the lm head from the target model if method == "eagle": assert proposer.model.lm_head == target_model.lm_head + + # Verify that the embed tokens are set correctly + # If pp_size is > 1, the embed tokens should be distinct + if pp_size > 1 or use_distinct_embed_tokens: + assert proposer.model.model.embed_tokens != \ + target_model.model.embed_tokens else: + # When pp_size is 1 and the draft and target models have + # embed_tokens of the same shape, they should be shared. assert proposer.model.model.embed_tokens == \ target_model.model.embed_tokens diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index f73b863fef23..c7690604c1d0 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -55,13 +55,11 @@ class LlamaModel(nn.Module): speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size - # if PP disabled then draft will share embed with target - if get_pp_group().world_size > 1: - self.embed_tokens = VocabParallelEmbedding( - self.config.vocab_size, - self.config.hidden_size, - prefix=maybe_prefix(prefix, "embed_tokens"), - ) + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) self.layers = nn.ModuleList([ LlamaDecoderLayer( @@ -164,4 +162,4 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): if "lm_head" not in name: name = "model." + name model_weights[name] = loaded_weight - return loader.load_weights(model_weights.items()) + loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index d31a321b876a..7fc9fe2ebb6f 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -10,7 +10,6 @@ from transformers import LlamaConfig from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig -from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear @@ -95,13 +94,11 @@ class LlamaModel(nn.Module): speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size - # if PP disabled then draft will share embed with target - if get_pp_group().world_size > 1: - self.embed_tokens = VocabParallelEmbedding( - self.config.vocab_size, - self.config.hidden_size, - prefix=maybe_prefix(prefix, "embed_tokens"), - ) + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) self.layers = nn.ModuleList([ LlamaDecoderLayer( @@ -240,6 +237,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): model_weights = {} includes_draft_id_mapping = False + includes_embed_tokens = False for name, loaded_weight in weights: if "t2d" in name: continue @@ -248,12 +246,18 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): includes_draft_id_mapping = True elif "lm_head" not in name: name = "model." + name + if "embed_tokens" in name: + includes_embed_tokens = True model_weights[name] = loaded_weight + skip_substrs = [] + if not includes_draft_id_mapping: + skip_substrs.append("draft_id_to_target_id") + if not includes_embed_tokens: + skip_substrs.append("embed_tokens") loader = AutoWeightsLoader( self, skip_prefixes=None, - skip_substrs=["draft_id_to_target_id"] \ - if not includes_draft_id_mapping else None, + skip_substrs=skip_substrs, ) loader.load_weights(model_weights.items()) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8bbda94e5097..8ad66776c4e9 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -173,6 +173,7 @@ class CudaPlatformBase(Platform): def get_current_memory_usage(cls, device: Optional[torch.types.Device] = None ) -> float: + torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(device) return torch.cuda.max_memory_allocated(device) diff --git a/vllm/utils.py b/vllm/utils.py index 4f905e505dbe..c19c0221cf83 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -899,6 +899,7 @@ class DeviceMemoryProfiler: def current_memory_usage(self) -> float: # Return the memory usage in bytes. from vllm.platforms import current_platform + gc.collect() return current_platform.get_current_memory_usage(self.device) def __enter__(self): diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 416bc8af18ab..4b5c9b7ec640 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -330,16 +330,19 @@ class EagleProposer: self.attn_layer_names = list(draft_attn_layer_names) # share embed_tokens with the target model if needed - if get_pp_group().world_size == 1: + if get_pp_group().world_size == 1 \ + and self.model.model.embed_tokens.weight.shape \ + == target_model.model.embed_tokens.weight.shape: logger.info( - "The EAGLE head shares the same vocab embedding" \ + "Assuming the EAGLE head shares the same vocab embedding" \ " with the target model." ) + del self.model.model.embed_tokens self.model.model.embed_tokens = target_model.model.embed_tokens else: logger.info( - "Since PP > 1, the EAGLE head loaded its own vocab embedding" \ - " weights instead of sharing them with the target model." + "The EAGLE head's vocab embedding will be loaded separately" \ + " from the target model." ) # share lm_head with the target model if needed