Neuron up mistral (#18222)

Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
This commit is contained in:
Satyajith Chilappagari 2025-05-19 09:54:47 -07:00 committed by GitHub
parent 8171221834
commit dc1440cf9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 36 additions and 2 deletions

View File

@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
from vllm import LLM, SamplingParams
def test_mistral():
llm = LLM(model="mistralai/Mistral-7B-v0.1",
tensor_parallel_size=2,
max_num_seqs=4,
max_model_len=512,
use_v2_block_manager=True,
override_neuron_config={
"sequence_parallel_enabled": False,
"skip_warmup": True
},
device="neuron")
prompts = [
"The president of the United States is",
"The capital of France is",
]
outputs = llm.generate(prompts, SamplingParams(top_k=1))
expected_outputs = [
" the most powerful person in the world. He is the head of state "
"and head",
" a city of many faces. It is a city of history, culture, art"
]
for expected_output, output in zip(expected_outputs, outputs):
generated_text = output.outputs[0].text
assert (expected_output == generated_text)

View File

@ -48,6 +48,9 @@ TORCH_DTYPE_TO_NEURON_AMP = {
# Models supported by Neuronx distributed for inference.
_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str]] = {
"LlamaForCausalLM":
("neuronx_distributed_inference.models.llama.modeling_llama",
"NeuronLlamaForCausalLM"),
"MistralForCausalLM":
("neuronx_distributed_inference.models.llama.modeling_llama",
"NeuronLlamaForCausalLM"),
"DbrxForCausalLM":

View File

@ -51,8 +51,7 @@ class NeuronPlatform(Platform):
assert (vllm_config.lora_config
is None), "LoRA is not supported for Neuron backend."
cache_config = vllm_config.cache_config
if cache_config:
if vllm_config.cache_config and vllm_config.model_config:
# neuron needs block_size = max_model_len
vllm_config.cache_config.block_size = \
vllm_config.model_config.max_model_len # type: ignore