From fa72f9a8126051abb9d00144f11aeb7615f36d21 Mon Sep 17 00:00:00 2001 From: aws-elaineyz Date: Thu, 22 May 2025 02:20:36 -0700 Subject: [PATCH] Order sequence ids + config update to support specifying custom quantization layers (#18279) Signed-off-by: Elaine Zhao Co-authored-by: Tailin Pan Co-authored-by: Rishabh Rajesh Co-authored-by: Yishan McNabb Co-authored-by: Patrick Lange Co-authored-by: Maxwell Goldberg Co-authored-by: Aakash Shetty --- tests/neuron/2_core/test_mistral.py | 40 ++++++++++++++--- .../model_loader/neuronx_distributed.py | 43 ++++++++++++++++--- 2 files changed, 73 insertions(+), 10 deletions(-) diff --git a/tests/neuron/2_core/test_mistral.py b/tests/neuron/2_core/test_mistral.py index 8acd082f2ded..cc3b53a9d7c9 100644 --- a/tests/neuron/2_core/test_mistral.py +++ b/tests/neuron/2_core/test_mistral.py @@ -7,7 +7,7 @@ def test_mistral(): llm = LLM(model="mistralai/Mistral-7B-v0.1", tensor_parallel_size=2, max_num_seqs=4, - max_model_len=512, + max_model_len=128, use_v2_block_manager=True, override_neuron_config={ "sequence_parallel_enabled": False, @@ -15,16 +15,46 @@ def test_mistral(): }, device="neuron") + # Send more prompts than the compiled batch size (4) and request + # varying generation lengths to test accuracy related to Neuron + # specific sequence id sorting. prompts = [ "The president of the United States is", "The capital of France is", + "What is Annapurna labs?", + "I believe the meaning of life is", + "Tell me a story about a brave knight", + "Hello, my name is Llama", ] - outputs = llm.generate(prompts, SamplingParams(top_k=1)) + + sampling_params = [ + SamplingParams(top_k=1, max_tokens=10), + SamplingParams(top_k=1, max_tokens=20), + SamplingParams(top_k=1, max_tokens=30), + SamplingParams(top_k=1, max_tokens=40), + SamplingParams(top_k=1, max_tokens=50), + SamplingParams(top_k=1, max_tokens=60) + ] + + outputs = llm.generate(prompts, sampling_params) expected_outputs = [ - " the most powerful person in the world. He is the head of state " - "and head", - " a city of many faces. It is a city of history, culture, art" + " the most powerful person in the world. He is", + " a city of many faces. It is a city of history, culture, art, " + "fashion, and", + "\n\nAnnapurna Labs is a semiconductor company that was founded " + "in 2013 by Amazon. The company is", + " to be happy.\n\nI believe that happiness is a choice.\n\nI " + "believe that happiness is a state of mind.\n\nI believe that " + "happiness is a journey.\n\nI believe", + " who rescued a princess from a dragon.\n\nTell me a story about" + " a princess who rescued herself from a dragon.\n\nTell me a " + "story about a princess who rescued herself from a dragon and " + "then rescued a knight from", + " and I am a 10 year old male. I am a very friendly and " + "affectionate boy who loves to be around people. I am a very " + "active boy who loves to play and run around. I am a very smart " + "boy who loves to learn new things. I am a very loyal boy" ] for expected_output, output in zip(expected_outputs, outputs): diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index 3a4d93c8c13f..557feea46a90 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -87,16 +87,29 @@ class NeuronCausalLM(nn.Module): input_block_ids: torch.Tensor, sampling_params: torch.Tensor, ) -> torch.Tensor: + # sort block ids sequentially for perf/neuron support reasons + sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids) + input_ids = torch.index_select(input_ids, 0, sorted_indices) + positions = torch.index_select(positions, 0, sorted_indices) + sampling_params = torch.index_select(sampling_params, 0, + sorted_indices) + output = self.model(input_ids, attention_mask=None, position_ids=positions, - seq_ids=input_block_ids, + seq_ids=sorted_input_block_ids, sampling_params=sampling_params) # on-device sampling if self.config.neuron_config.on_device_sampling_config: - return output.hidden_states + output = output.hidden_states else: - return output.logits[:, -1, :] + output = output.logits[:, -1, :] + + restored_indices = torch.argsort(sorted_indices) + if input_block_ids.shape[0] != 1: + output = torch.index_select(output, 0, restored_indices) + + return output def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: @@ -340,14 +353,26 @@ class NeuronSpeculationCausalLM(nn.Module): input_block_ids: torch.Tensor, sampling_params: torch.Tensor, ) -> torch.Tensor: + # sort block ids sequentially for perf/neuron support reasons + sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids) + input_ids = torch.index_select(input_ids, 0, sorted_indices) + positions = torch.index_select(positions, 0, sorted_indices) + sampling_params = torch.index_select(sampling_params, 0, + sorted_indices) + output = self.model(input_ids, attention_mask=None, position_ids=positions, - seq_ids=input_block_ids, + seq_ids=sorted_input_block_ids, sampling_params=sampling_params) + restored_indices = torch.argsort(sorted_indices) + # CTX encoding if (positions[:, 0]).sum().item() == 0: - return output.fused_outputs[0][:, 0:1] + output = output.fused_outputs[0][:, 0:1] + if input_block_ids.shape[0] != 1: + output = torch.index_select(output, 0, restored_indices) + return output # Fused Spec (Generation) accepted_tokens_with_padding = output.fused_outputs[0] @@ -362,6 +387,10 @@ class NeuronSpeculationCausalLM(nn.Module): -1) >= generated_token_counts accepted_tokens_with_padding[mask] = -1 + if input_block_ids.shape[0] != 1: + accepted_tokens_with_padding = torch.index_select( + accepted_tokens_with_padding, 0, restored_indices) + return accepted_tokens_with_padding def sample( @@ -416,6 +445,10 @@ class NeuronSpeculationCausalLM(nn.Module): draft_neuron_config.speculation_length = 0 draft_neuron_config.trace_tokengen_model = True draft_neuron_config.enable_fused_speculation = False + if getattr(config.neuron_config, "draft_model_modules_to_not_convert", + None): + draft_neuron_config.modules_to_not_convert = ( + draft_neuron_config.draft_model_modules_to_not_convert) if config.neuron_config.enable_eagle_speculation: draft_neuron_config.is_eagle_draft = True draft_neuron_config.sequence_parallel_enabled = False