mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:15:00 +08:00
[MODEL] LoRA support for Jamba model (#11209)
Signed-off-by: Erez Schwartz <erezs@ai21.com>
This commit is contained in:
parent
101418096f
commit
55509c2114
@ -4,6 +4,7 @@ from typing import Dict, List, TypedDict
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
@ -169,6 +170,29 @@ def mixtral_lora_files_all_target_modules():
|
|||||||
return snapshot_download(repo_id="dyang415/mixtral-lora-v0")
|
return snapshot_download(repo_id="dyang415/mixtral-lora-v0")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def jamba_lora_files():
|
||||||
|
# some of the adapters have unnecessary weights for serving,
|
||||||
|
# hence we remove them
|
||||||
|
def remove_unnecessary_weights(path):
|
||||||
|
lora_path = f"{adapter_path}/adapter_model.safetensors"
|
||||||
|
tensors = safetensors.torch.load_file(lora_path)
|
||||||
|
nonlora_keys = []
|
||||||
|
for k in list(tensors.keys()):
|
||||||
|
if "lora" not in k:
|
||||||
|
nonlora_keys.append(k)
|
||||||
|
for k in nonlora_keys:
|
||||||
|
del tensors[k]
|
||||||
|
safetensors.torch.save_file(tensors, lora_path)
|
||||||
|
|
||||||
|
adapter_path = snapshot_download(
|
||||||
|
repo_id=
|
||||||
|
"hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora")
|
||||||
|
|
||||||
|
remove_unnecessary_weights(adapter_path)
|
||||||
|
return adapter_path
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def gemma_lora_files():
|
def gemma_lora_files():
|
||||||
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
|
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
|
||||||
|
|||||||
54
tests/lora/test_jamba.py
Normal file
54
tests/lora/test_jamba.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini"
|
||||||
|
|
||||||
|
MAX_TOKENS = 40
|
||||||
|
|
||||||
|
|
||||||
|
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
|
||||||
|
prompts: List[str]) -> List[str]:
|
||||||
|
|
||||||
|
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts,
|
||||||
|
sampling_params,
|
||||||
|
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||||
|
if lora_id else None)
|
||||||
|
# Print the outputs.
|
||||||
|
generated_texts: List[str] = []
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text.strip()
|
||||||
|
generated_texts.append(generated_text)
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
return generated_texts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tp_size", [4])
|
||||||
|
def test_jamba_lora(jamba_lora_files, tp_size):
|
||||||
|
"""Original test, the LoRA model has the common target modules, not all"""
|
||||||
|
if torch.cuda.device_count() < tp_size:
|
||||||
|
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||||
|
|
||||||
|
prompts = ["Write a story about a sheep and a goat."]
|
||||||
|
|
||||||
|
llm = vllm.LLM(
|
||||||
|
MODEL_PATH,
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=16,
|
||||||
|
max_loras=4,
|
||||||
|
distributed_executor_backend="ray",
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_jamba_output = [
|
||||||
|
"""Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501
|
||||||
|
]
|
||||||
|
assert do_sample(llm, jamba_lora_files, lora_id=1,
|
||||||
|
prompts=prompts) == expected_jamba_output
|
||||||
@ -42,12 +42,14 @@ class MambaMixer(CustomOp):
|
|||||||
use_rms_norm: bool,
|
use_rms_norm: bool,
|
||||||
rms_norm_has_weight: bool = True,
|
rms_norm_has_weight: bool = True,
|
||||||
rms_norm_eps: float = 1e-5,
|
rms_norm_eps: float = 1e-5,
|
||||||
activation="silu"):
|
activation="silu",
|
||||||
|
is_lora_enabled: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_step_rank = time_step_rank
|
self.time_step_rank = time_step_rank
|
||||||
self.ssm_state_size = ssm_state_size
|
self.ssm_state_size = ssm_state_size
|
||||||
self.use_rms_norm = use_rms_norm
|
self.use_rms_norm = use_rms_norm
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
self.is_lora_enabled = is_lora_enabled
|
||||||
|
|
||||||
self.conv1d = ColumnParallelLinear(
|
self.conv1d = ColumnParallelLinear(
|
||||||
input_size=conv_kernel_size,
|
input_size=conv_kernel_size,
|
||||||
@ -63,6 +65,7 @@ class MambaMixer(CustomOp):
|
|||||||
self.in_proj = MergedColumnParallelLinear(hidden_size,
|
self.in_proj = MergedColumnParallelLinear(hidden_size,
|
||||||
[intermediate_size] * 2,
|
[intermediate_size] * 2,
|
||||||
bias=use_bias)
|
bias=use_bias)
|
||||||
|
|
||||||
# selective projection used to make dt, B and C input dependent
|
# selective projection used to make dt, B and C input dependent
|
||||||
self.x_proj = RowParallelLinear(
|
self.x_proj = RowParallelLinear(
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
@ -170,6 +173,12 @@ class MambaMixer(CustomOp):
|
|||||||
|
|
||||||
# 3. State Space Model sequence transformation
|
# 3. State Space Model sequence transformation
|
||||||
# 3.a. input varying initialization of time_step, B and C
|
# 3.a. input varying initialization of time_step, B and C
|
||||||
|
|
||||||
|
if self.is_lora_enabled:
|
||||||
|
# lora kernel requires contiguous tensor
|
||||||
|
ssm_parameters = self.x_proj(
|
||||||
|
hidden_states.transpose(-2, -1).contiguous())[0]
|
||||||
|
else:
|
||||||
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
|
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
|
||||||
|
|
||||||
time_step, B, C = torch.split(
|
time_step, B, C = torch.split(
|
||||||
@ -222,6 +231,11 @@ class MambaMixer(CustomOp):
|
|||||||
scan_outputs = scan_outputs.transpose(0, 1)
|
scan_outputs = scan_outputs.transpose(0, 1)
|
||||||
|
|
||||||
# 4. Final linear projection
|
# 4. Final linear projection
|
||||||
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
|
if self.is_lora_enabled:
|
||||||
-1))[0]
|
# lora kernel requires contiguous tensor
|
||||||
|
contextualized_states = self.out_proj(
|
||||||
|
scan_outputs.transpose(-2, -1).contiguous())[0]
|
||||||
|
else:
|
||||||
|
contextualized_states = self.out_proj(
|
||||||
|
scan_outputs.transpose(-2, -1))[0]
|
||||||
return contextualized_states
|
return contextualized_states
|
||||||
|
|||||||
@ -107,9 +107,11 @@ class JambaMambaDecoderLayer(nn.Module):
|
|||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "") -> None:
|
is_lora_enabled: Optional[bool] = False,
|
||||||
|
**kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.is_lora_enabled = is_lora_enabled
|
||||||
self.mamba = MambaMixer(hidden_size= config.hidden_size,
|
self.mamba = MambaMixer(hidden_size= config.hidden_size,
|
||||||
ssm_state_size = config.mamba_d_state,
|
ssm_state_size = config.mamba_d_state,
|
||||||
conv_kernel_size = config.mamba_d_conv,
|
conv_kernel_size = config.mamba_d_conv,
|
||||||
@ -120,7 +122,9 @@ class JambaMambaDecoderLayer(nn.Module):
|
|||||||
use_bias = config.mamba_proj_bias,
|
use_bias = config.mamba_proj_bias,
|
||||||
use_rms_norm=True,
|
use_rms_norm=True,
|
||||||
rms_norm_eps=config.rms_norm_eps,
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
activation=config.hidden_act)
|
activation=config.hidden_act,
|
||||||
|
is_lora_enabled = self.is_lora_enabled
|
||||||
|
)
|
||||||
|
|
||||||
num_experts = config.layers_num_experts[layer_idx]
|
num_experts = config.layers_num_experts[layer_idx]
|
||||||
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
||||||
@ -156,14 +160,13 @@ class JambaMambaDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
class JambaAttentionDecoderLayer(nn.Module):
|
class JambaAttentionDecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
|
||||||
config: JambaConfig,
|
config: JambaConfig,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
**kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@ -287,17 +290,18 @@ class JambaModel(nn.Module):
|
|||||||
org_num_embeddings=config.vocab_size,
|
org_num_embeddings=config.vocab_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
|
||||||
|
|
||||||
def get_layer(prefix: str):
|
def get_layer(prefix: str):
|
||||||
layer_idx = int(prefix.rsplit(".", 1)[1])
|
layer_idx = int(prefix.rsplit(".", 1)[1])
|
||||||
layer_class = ALL_DECODER_LAYER_TYPES[
|
layer_class = ALL_DECODER_LAYER_TYPES[
|
||||||
config.layers_block_type[layer_idx]]
|
config.layers_block_type[layer_idx]]
|
||||||
return layer_class(
|
return layer_class(config,
|
||||||
config,
|
|
||||||
layer_idx,
|
layer_idx,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
)
|
**extra_kwargs)
|
||||||
|
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
|
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
|
||||||
@ -371,14 +375,13 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
"k_proj",
|
"k_proj",
|
||||||
"v_proj",
|
"v_proj",
|
||||||
],
|
],
|
||||||
|
"in_proj": ["in_proj"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# LoRA specific attributes
|
# LoRA specific attributes
|
||||||
supported_lora_modules = [
|
supported_lora_modules = [
|
||||||
"qkv_proj",
|
"qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj",
|
||||||
"o_proj",
|
"down_proj", "gate_proj", "out_proj", "in_proj", "x_proj"
|
||||||
"embed_tokens",
|
|
||||||
"lm_head",
|
|
||||||
]
|
]
|
||||||
embedding_modules = {
|
embedding_modules = {
|
||||||
"embed_tokens": "input_embeddings",
|
"embed_tokens": "input_embeddings",
|
||||||
@ -446,7 +449,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
if self.mamba_cache is None:
|
if self.mamba_cache is None:
|
||||||
|
|
||||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||||
self.mamba_cache = MambaCacheManager(
|
self.mamba_cache = MambaCacheManager(
|
||||||
|
|||||||
@ -38,10 +38,12 @@ class MambaDecoderLayer(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: MambaConfig,
|
config: MambaConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
is_lora_enabled: Optional[bool] = False) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.is_falcon_mamba = config.model_type == "falcon_mamba"
|
self.is_falcon_mamba = config.model_type == "falcon_mamba"
|
||||||
|
self.is_lora_enabled = is_lora_enabled
|
||||||
mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
|
mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
|
||||||
self.mixer = MambaMixer(hidden_size=config.hidden_size,
|
self.mixer = MambaMixer(hidden_size=config.hidden_size,
|
||||||
ssm_state_size=config.state_size,
|
ssm_state_size=config.state_size,
|
||||||
@ -53,7 +55,8 @@ class MambaDecoderLayer(nn.Module):
|
|||||||
use_rms_norm=self.is_falcon_mamba,
|
use_rms_norm=self.is_falcon_mamba,
|
||||||
rms_norm_has_weight=not self.is_falcon_mamba,
|
rms_norm_has_weight=not self.is_falcon_mamba,
|
||||||
rms_norm_eps=mixer_rms_eps,
|
rms_norm_eps=mixer_rms_eps,
|
||||||
activation=config.hidden_act)
|
activation=config.hidden_act,
|
||||||
|
is_lora_enabled=self.is_lora_enabled)
|
||||||
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
@ -85,6 +88,7 @@ class MambaModel(nn.Module):
|
|||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
|
is_lora_enabled = bool(lora_config)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
@ -101,8 +105,10 @@ class MambaModel(nn.Module):
|
|||||||
|
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda prefix: MambaDecoderLayer(
|
lambda prefix: MambaDecoderLayer(config,
|
||||||
config, cache_config=cache_config, quant_config=quant_config),
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
is_lora_enabled=is_lora_enabled),
|
||||||
prefix=f"{prefix}.layers")
|
prefix=f"{prefix}.layers")
|
||||||
|
|
||||||
self.norm_f = RMSNorm(config.hidden_size,
|
self.norm_f = RMSNorm(config.hidden_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user