[MODEL] LoRA support for Jamba model (#11209)

Signed-off-by: Erez Schwartz <erezs@ai21.com>
This commit is contained in:
ErezSC42 2024-12-27 19:58:21 +02:00 committed by GitHub
parent 101418096f
commit 55509c2114
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 132 additions and 32 deletions

View File

@ -4,6 +4,7 @@ from typing import Dict, List, TypedDict
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
import safetensors
import torch import torch
import torch.nn as nn import torch.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@ -169,6 +170,29 @@ def mixtral_lora_files_all_target_modules():
return snapshot_download(repo_id="dyang415/mixtral-lora-v0") return snapshot_download(repo_id="dyang415/mixtral-lora-v0")
@pytest.fixture(scope="session")
def jamba_lora_files():
# some of the adapters have unnecessary weights for serving,
# hence we remove them
def remove_unnecessary_weights(path):
lora_path = f"{adapter_path}/adapter_model.safetensors"
tensors = safetensors.torch.load_file(lora_path)
nonlora_keys = []
for k in list(tensors.keys()):
if "lora" not in k:
nonlora_keys.append(k)
for k in nonlora_keys:
del tensors[k]
safetensors.torch.save_file(tensors, lora_path)
adapter_path = snapshot_download(
repo_id=
"hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora")
remove_unnecessary_weights(adapter_path)
return adapter_path
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def gemma_lora_files(): def gemma_lora_files():
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora") return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")

54
tests/lora/test_jamba.py Normal file
View File

@ -0,0 +1,54 @@
from typing import List
import pytest
import torch
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini"
MAX_TOKENS = 40
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
prompts: List[str]) -> List[str]:
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("tp_size", [4])
def test_jamba_lora(jamba_lora_files, tp_size):
"""Original test, the LoRA model has the common target modules, not all"""
if torch.cuda.device_count() < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
prompts = ["Write a story about a sheep and a goat."]
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size,
)
expected_jamba_output = [
"""Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501
]
assert do_sample(llm, jamba_lora_files, lora_id=1,
prompts=prompts) == expected_jamba_output

View File

@ -42,12 +42,14 @@ class MambaMixer(CustomOp):
use_rms_norm: bool, use_rms_norm: bool,
rms_norm_has_weight: bool = True, rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5, rms_norm_eps: float = 1e-5,
activation="silu"): activation="silu",
is_lora_enabled: bool = False):
super().__init__() super().__init__()
self.time_step_rank = time_step_rank self.time_step_rank = time_step_rank
self.ssm_state_size = ssm_state_size self.ssm_state_size = ssm_state_size
self.use_rms_norm = use_rms_norm self.use_rms_norm = use_rms_norm
self.activation = activation self.activation = activation
self.is_lora_enabled = is_lora_enabled
self.conv1d = ColumnParallelLinear( self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size, input_size=conv_kernel_size,
@ -63,6 +65,7 @@ class MambaMixer(CustomOp):
self.in_proj = MergedColumnParallelLinear(hidden_size, self.in_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2, [intermediate_size] * 2,
bias=use_bias) bias=use_bias)
# selective projection used to make dt, B and C input dependent # selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear( self.x_proj = RowParallelLinear(
intermediate_size, intermediate_size,
@ -170,6 +173,12 @@ class MambaMixer(CustomOp):
# 3. State Space Model sequence transformation # 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C # 3.a. input varying initialization of time_step, B and C
if self.is_lora_enabled:
# lora kernel requires contiguous tensor
ssm_parameters = self.x_proj(
hidden_states.transpose(-2, -1).contiguous())[0]
else:
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split( time_step, B, C = torch.split(
@ -222,6 +231,11 @@ class MambaMixer(CustomOp):
scan_outputs = scan_outputs.transpose(0, 1) scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection # 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(-2, if self.is_lora_enabled:
-1))[0] # lora kernel requires contiguous tensor
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1).contiguous())[0]
else:
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1))[0]
return contextualized_states return contextualized_states

View File

@ -107,9 +107,11 @@ class JambaMambaDecoderLayer(nn.Module):
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None: is_lora_enabled: Optional[bool] = False,
**kwargs) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.is_lora_enabled = is_lora_enabled
self.mamba = MambaMixer(hidden_size= config.hidden_size, self.mamba = MambaMixer(hidden_size= config.hidden_size,
ssm_state_size = config.mamba_d_state, ssm_state_size = config.mamba_d_state,
conv_kernel_size = config.mamba_d_conv, conv_kernel_size = config.mamba_d_conv,
@ -120,7 +122,9 @@ class JambaMambaDecoderLayer(nn.Module):
use_bias = config.mamba_proj_bias, use_bias = config.mamba_proj_bias,
use_rms_norm=True, use_rms_norm=True,
rms_norm_eps=config.rms_norm_eps, rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act) activation=config.hidden_act,
is_lora_enabled = self.is_lora_enabled
)
num_experts = config.layers_num_experts[layer_idx] num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
@ -156,14 +160,13 @@ class JambaMambaDecoderLayer(nn.Module):
class JambaAttentionDecoderLayer(nn.Module): class JambaAttentionDecoderLayer(nn.Module):
def __init__( def __init__(self,
self,
config: JambaConfig, config: JambaConfig,
layer_idx: int, layer_idx: int,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: **kwargs) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
@ -287,17 +290,18 @@ class JambaModel(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
extra_kwargs = {"is_lora_enabled": bool(vllm_config.lora_config)}
def get_layer(prefix: str): def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1]) layer_idx = int(prefix.rsplit(".", 1)[1])
layer_class = ALL_DECODER_LAYER_TYPES[ layer_class = ALL_DECODER_LAYER_TYPES[
config.layers_block_type[layer_idx]] config.layers_block_type[layer_idx]]
return layer_class( return layer_class(config,
config,
layer_idx, layer_idx,
cache_config, cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
) **extra_kwargs)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
@ -371,14 +375,13 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
"k_proj", "k_proj",
"v_proj", "v_proj",
], ],
"in_proj": ["in_proj"],
} }
# LoRA specific attributes # LoRA specific attributes
supported_lora_modules = [ supported_lora_modules = [
"qkv_proj", "qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj",
"o_proj", "down_proj", "gate_proj", "out_proj", "in_proj", "x_proj"
"embed_tokens",
"lm_head",
] ]
embedding_modules = { embedding_modules = {
"embed_tokens": "input_embeddings", "embed_tokens": "input_embeddings",
@ -446,7 +449,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
if self.mamba_cache is None: if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type( num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba) self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager( self.mamba_cache = MambaCacheManager(

View File

@ -38,10 +38,12 @@ class MambaDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: MambaConfig, config: MambaConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None,
is_lora_enabled: Optional[bool] = False) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.is_falcon_mamba = config.model_type == "falcon_mamba" self.is_falcon_mamba = config.model_type == "falcon_mamba"
self.is_lora_enabled = is_lora_enabled
mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None
self.mixer = MambaMixer(hidden_size=config.hidden_size, self.mixer = MambaMixer(hidden_size=config.hidden_size,
ssm_state_size=config.state_size, ssm_state_size=config.state_size,
@ -53,7 +55,8 @@ class MambaDecoderLayer(nn.Module):
use_rms_norm=self.is_falcon_mamba, use_rms_norm=self.is_falcon_mamba,
rms_norm_has_weight=not self.is_falcon_mamba, rms_norm_has_weight=not self.is_falcon_mamba,
rms_norm_eps=mixer_rms_eps, rms_norm_eps=mixer_rms_eps,
activation=config.hidden_act) activation=config.hidden_act,
is_lora_enabled=self.is_lora_enabled)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@ -85,6 +88,7 @@ class MambaModel(nn.Module):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
is_lora_enabled = bool(lora_config)
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
@ -101,8 +105,10 @@ class MambaModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: MambaDecoderLayer( lambda prefix: MambaDecoderLayer(config,
config, cache_config=cache_config, quant_config=quant_config), cache_config=cache_config,
quant_config=quant_config,
is_lora_enabled=is_lora_enabled),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm_f = RMSNorm(config.hidden_size, self.norm_f = RMSNorm(config.hidden_size,