Order sequence ids + config update to support specifying custom quantization layers (#18279)

Signed-off-by: Elaine Zhao <elaineyz@amazon.com>
Co-authored-by: Tailin Pan <tailinpa@amazon.com>
Co-authored-by: Rishabh Rajesh <rishyraj@amazon.com>
Co-authored-by: Yishan McNabb <yishanm@amazon.com>
Co-authored-by: Patrick Lange <patlange@amazon.com>
Co-authored-by: Maxwell Goldberg <mgld@amazon.com>
Co-authored-by: Aakash Shetty <sheaak@amazon.com>
This commit is contained in:
aws-elaineyz 2025-05-22 02:20:36 -07:00 committed by GitHub
parent ebed81fbf5
commit fa72f9a812
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 73 additions and 10 deletions

View File

@ -7,7 +7,7 @@ def test_mistral():
llm = LLM(model="mistralai/Mistral-7B-v0.1", llm = LLM(model="mistralai/Mistral-7B-v0.1",
tensor_parallel_size=2, tensor_parallel_size=2,
max_num_seqs=4, max_num_seqs=4,
max_model_len=512, max_model_len=128,
use_v2_block_manager=True, use_v2_block_manager=True,
override_neuron_config={ override_neuron_config={
"sequence_parallel_enabled": False, "sequence_parallel_enabled": False,
@ -15,16 +15,46 @@ def test_mistral():
}, },
device="neuron") device="neuron")
# Send more prompts than the compiled batch size (4) and request
# varying generation lengths to test accuracy related to Neuron
# specific sequence id sorting.
prompts = [ prompts = [
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"What is Annapurna labs?",
"I believe the meaning of life is",
"Tell me a story about a brave knight",
"Hello, my name is Llama",
] ]
outputs = llm.generate(prompts, SamplingParams(top_k=1))
sampling_params = [
SamplingParams(top_k=1, max_tokens=10),
SamplingParams(top_k=1, max_tokens=20),
SamplingParams(top_k=1, max_tokens=30),
SamplingParams(top_k=1, max_tokens=40),
SamplingParams(top_k=1, max_tokens=50),
SamplingParams(top_k=1, max_tokens=60)
]
outputs = llm.generate(prompts, sampling_params)
expected_outputs = [ expected_outputs = [
" the most powerful person in the world. He is the head of state " " the most powerful person in the world. He is",
"and head", " a city of many faces. It is a city of history, culture, art, "
" a city of many faces. It is a city of history, culture, art" "fashion, and",
"\n\nAnnapurna Labs is a semiconductor company that was founded "
"in 2013 by Amazon. The company is",
" to be happy.\n\nI believe that happiness is a choice.\n\nI "
"believe that happiness is a state of mind.\n\nI believe that "
"happiness is a journey.\n\nI believe",
" who rescued a princess from a dragon.\n\nTell me a story about"
" a princess who rescued herself from a dragon.\n\nTell me a "
"story about a princess who rescued herself from a dragon and "
"then rescued a knight from",
" and I am a 10 year old male. I am a very friendly and "
"affectionate boy who loves to be around people. I am a very "
"active boy who loves to play and run around. I am a very smart "
"boy who loves to learn new things. I am a very loyal boy"
] ]
for expected_output, output in zip(expected_outputs, outputs): for expected_output, output in zip(expected_outputs, outputs):

View File

@ -87,16 +87,29 @@ class NeuronCausalLM(nn.Module):
input_block_ids: torch.Tensor, input_block_ids: torch.Tensor,
sampling_params: torch.Tensor, sampling_params: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
input_ids = torch.index_select(input_ids, 0, sorted_indices)
positions = torch.index_select(positions, 0, sorted_indices)
sampling_params = torch.index_select(sampling_params, 0,
sorted_indices)
output = self.model(input_ids, output = self.model(input_ids,
attention_mask=None, attention_mask=None,
position_ids=positions, position_ids=positions,
seq_ids=input_block_ids, seq_ids=sorted_input_block_ids,
sampling_params=sampling_params) sampling_params=sampling_params)
# on-device sampling # on-device sampling
if self.config.neuron_config.on_device_sampling_config: if self.config.neuron_config.on_device_sampling_config:
return output.hidden_states output = output.hidden_states
else: else:
return output.logits[:, -1, :] output = output.logits[:, -1, :]
restored_indices = torch.argsort(sorted_indices)
if input_block_ids.shape[0] != 1:
output = torch.index_select(output, 0, restored_indices)
return output
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
@ -340,14 +353,26 @@ class NeuronSpeculationCausalLM(nn.Module):
input_block_ids: torch.Tensor, input_block_ids: torch.Tensor,
sampling_params: torch.Tensor, sampling_params: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
input_ids = torch.index_select(input_ids, 0, sorted_indices)
positions = torch.index_select(positions, 0, sorted_indices)
sampling_params = torch.index_select(sampling_params, 0,
sorted_indices)
output = self.model(input_ids, output = self.model(input_ids,
attention_mask=None, attention_mask=None,
position_ids=positions, position_ids=positions,
seq_ids=input_block_ids, seq_ids=sorted_input_block_ids,
sampling_params=sampling_params) sampling_params=sampling_params)
restored_indices = torch.argsort(sorted_indices)
# CTX encoding # CTX encoding
if (positions[:, 0]).sum().item() == 0: if (positions[:, 0]).sum().item() == 0:
return output.fused_outputs[0][:, 0:1] output = output.fused_outputs[0][:, 0:1]
if input_block_ids.shape[0] != 1:
output = torch.index_select(output, 0, restored_indices)
return output
# Fused Spec (Generation) # Fused Spec (Generation)
accepted_tokens_with_padding = output.fused_outputs[0] accepted_tokens_with_padding = output.fused_outputs[0]
@ -362,6 +387,10 @@ class NeuronSpeculationCausalLM(nn.Module):
-1) >= generated_token_counts -1) >= generated_token_counts
accepted_tokens_with_padding[mask] = -1 accepted_tokens_with_padding[mask] = -1
if input_block_ids.shape[0] != 1:
accepted_tokens_with_padding = torch.index_select(
accepted_tokens_with_padding, 0, restored_indices)
return accepted_tokens_with_padding return accepted_tokens_with_padding
def sample( def sample(
@ -416,6 +445,10 @@ class NeuronSpeculationCausalLM(nn.Module):
draft_neuron_config.speculation_length = 0 draft_neuron_config.speculation_length = 0
draft_neuron_config.trace_tokengen_model = True draft_neuron_config.trace_tokengen_model = True
draft_neuron_config.enable_fused_speculation = False draft_neuron_config.enable_fused_speculation = False
if getattr(config.neuron_config, "draft_model_modules_to_not_convert",
None):
draft_neuron_config.modules_to_not_convert = (
draft_neuron_config.draft_model_modules_to_not_convert)
if config.neuron_config.enable_eagle_speculation: if config.neuron_config.enable_eagle_speculation:
draft_neuron_config.is_eagle_draft = True draft_neuron_config.is_eagle_draft = True
draft_neuron_config.sequence_parallel_enabled = False draft_neuron_config.sequence_parallel_enabled = False