mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
Add GPT-OSS model code and config [1/N] (#22327)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
796bae07c5
commit
de98252f49
@ -197,6 +197,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
{"6b": "EleutherAI/gpt-j-6b"}),
|
||||
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m",
|
||||
{"1b": "EleutherAI/pythia-1.4b"}),
|
||||
"GptOssForCausalLM": _HfExamplesInfo("openai/gpt-oss-20b"),
|
||||
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
|
||||
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
|
||||
"GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview"), # noqa: E501
|
||||
|
||||
@ -247,6 +247,34 @@ class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
|
||||
config.max_model_len)
|
||||
|
||||
|
||||
class GptOssConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
decoding_config = vllm_config.decoding_config
|
||||
if decoding_config.reasoning_backend == "":
|
||||
decoding_config.reasoning_backend = "openai"
|
||||
|
||||
# Increase the max capture size from 512 to 1024 for performance.
|
||||
# NOTE(woosuk): This will increase the number of CUDA graphs
|
||||
# from 67 to 83.
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if len(scheduler_config.cuda_graph_sizes) == 1:
|
||||
max_capture_size = scheduler_config.cuda_graph_sizes[0]
|
||||
# FIXME(woosuk): When using full cuda graph with FA3, the max
|
||||
# supported size is 992.
|
||||
if max_capture_size < 1024:
|
||||
cuda_graph_sizes = [1, 2, 4]
|
||||
# Step size 8 for small batch sizes
|
||||
cuda_graph_sizes += [i for i in range(8, 256, 8)]
|
||||
# Step size 16 for larger batch sizes
|
||||
cuda_graph_sizes += [i for i in range(256, 1025, 16)]
|
||||
scheduler_config.cuda_graph_sizes = cuda_graph_sizes
|
||||
logger.info(
|
||||
"Overriding max cuda graph capture size to "
|
||||
"%d for performance.", 1024)
|
||||
|
||||
|
||||
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@classmethod
|
||||
@ -345,4 +373,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
|
||||
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
|
||||
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
|
||||
"GptOssForCausalLM": GptOssConfig,
|
||||
}
|
||||
|
||||
472
vllm/model_executor/models/gpt_oss.py
Normal file
472
vllm/model_executor/models/gpt_oss.py
Normal file
@ -0,0 +1,472 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from transformers import GptOssConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv
|
||||
|
||||
from .utils import extract_layer_index, maybe_prefix
|
||||
|
||||
|
||||
class OAIAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GptOssConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.head_dim = config.head_dim
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=config.max_position_embeddings,
|
||||
base=config.rope_theta,
|
||||
dtype=torch.float32,
|
||||
rope_scaling={
|
||||
"rope_type":
|
||||
"yarn",
|
||||
"factor":
|
||||
config.rope_scaling["factor"],
|
||||
"original_max_position_embeddings":
|
||||
config.rope_scaling["original_max_position_embeddings"],
|
||||
"beta_fast":
|
||||
config.rope_ntk_beta,
|
||||
"beta_slow":
|
||||
config.rope_ntk_alpha,
|
||||
},
|
||||
is_neox_style=True,
|
||||
)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
attention_sink_dtype = (
|
||||
torch.float32 if envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
|
||||
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION else torch.bfloat16)
|
||||
self.sinks = torch.nn.Parameter(
|
||||
torch.empty(config.num_attention_heads // tp_size,
|
||||
dtype=attention_sink_dtype,
|
||||
requires_grad=False))
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
self.q_size = self.num_attention_heads * self.head_dim // tp_size
|
||||
self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = config.rope_theta
|
||||
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=self.hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_attention_heads,
|
||||
total_num_kv_heads=self.num_key_value_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.num_attention_heads * self.head_dim,
|
||||
output_size=self.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.num_local_attention_heads = config.num_attention_heads // tp_size
|
||||
self.num_local_key_value_heads = config.num_key_value_heads // tp_size
|
||||
|
||||
# Only apply sliding window to every other layer
|
||||
sliding_window = (config.sliding_window if self.layer_idx %
|
||||
2 == 0 else None)
|
||||
self.attn = Attention(
|
||||
self.num_local_attention_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_key_value_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
attn_type=AttentionType.DECODER,
|
||||
prefix=f"{prefix}.attn",
|
||||
sinks=self.sinks,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor) -> torch.Tensor:
|
||||
t = self.norm(hidden_states)
|
||||
|
||||
qkv, _ = self.qkv(t)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
v = v.contiguous()
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
|
||||
return output + hidden_states
|
||||
|
||||
|
||||
class MLPBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GptOssConfig,
|
||||
layer_idx: int,
|
||||
quant_config: QuantizationConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.num_experts = config.num_local_experts
|
||||
self.experts_per_token = config.num_experts_per_tok
|
||||
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.router = torch.nn.Linear(config.hidden_size,
|
||||
config.num_local_experts,
|
||||
dtype=torch.bfloat16)
|
||||
assert config.intermediate_size % self.world_size == 0
|
||||
self.experts = FusedMoE(num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_token,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
reduce_results=True,
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
apply_router_weight_on_input=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
t = self.norm(x)
|
||||
g = self.router(t)
|
||||
t = self.experts(hidden_states=t, router_logits=g)
|
||||
return x + t
|
||||
|
||||
|
||||
class TransformerBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GptOssConfig,
|
||||
quant_config: QuantizationConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.attn = OAIAttention(config, prefix=f"{prefix}.attn")
|
||||
self.mlp = MLPBlock(config,
|
||||
self.layer_idx,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor) -> torch.Tensor:
|
||||
attn_output = self.attn(hidden_states, positions)
|
||||
output = self.mlp(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class GptOssModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.config.hidden_size = self.config.hidden_size
|
||||
self.embedding = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
)
|
||||
self.layers = torch.nn.ModuleList([
|
||||
TransformerBlock(
|
||||
self.config,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, f"block.{layer_idx}"),
|
||||
) for layer_idx in range(self.config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor,
|
||||
positions: torch.Tensor) -> torch.Tensor:
|
||||
x = self.embedding(input_ids)
|
||||
for layer in self.layers:
|
||||
x = layer(x, positions)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class GptOssForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config.hf_config
|
||||
self.model = GptOssModel(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.model_config.vocab_size,
|
||||
self.model_config.hidden_size,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(self.model_config.vocab_size)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
assert intermediate_tensors is None
|
||||
assert inputs_embeds is None
|
||||
return self.model(input_ids, positions)
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
rename_mapping = {
|
||||
"self_attn": "attn",
|
||||
"input_layernorm.weight": "attn.norm.weight",
|
||||
"post_attention_layernorm.weight": "mlp.norm.weight",
|
||||
"embed_tokens": "embedding",
|
||||
}
|
||||
|
||||
def maybe_rename(name: str) -> str:
|
||||
for remap_name, new_name in rename_mapping.items():
|
||||
if remap_name in name:
|
||||
return name.replace(remap_name, new_name)
|
||||
return name
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
mxfp4_block = 32
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size = self.model_config.intermediate_size
|
||||
intermediate_size_block = intermediate_size // mxfp4_block
|
||||
per_rank_intermediate_size_block = cdiv(intermediate_size_block,
|
||||
tp_size)
|
||||
per_rank_intermediate_size = (per_rank_intermediate_size_block *
|
||||
mxfp4_block)
|
||||
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
||||
intermediate_size)
|
||||
|
||||
# Attention heads per rank
|
||||
heads_per_rank = self.model_config.num_attention_heads // tp_size
|
||||
head_start = tp_rank * heads_per_rank
|
||||
|
||||
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
|
||||
ep_size = get_ep_group().world_size
|
||||
ep_rank = get_ep_group().rank
|
||||
num_experts = self.model_config.num_local_experts
|
||||
experts_per_rank = num_experts // ep_size
|
||||
ep_rank_start = ep_rank * experts_per_rank
|
||||
ep_rank_end = (ep_rank + 1) * experts_per_rank
|
||||
|
||||
for name, weight in weights:
|
||||
# FIXME(woosuk): Remove this after testing.
|
||||
weight = weight.cuda()
|
||||
|
||||
if "gate_up_proj_blocks" in name:
|
||||
# Handle MLP gate and up projection weights
|
||||
new_name = name.replace("gate_up_proj_blocks", "w13_weight")
|
||||
|
||||
# flat weight from (E, 2 * N, block_size, entry_per_block)
|
||||
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
|
||||
weight = weight.view(num_experts, 2 * intermediate_size,
|
||||
-1).contiguous()
|
||||
|
||||
# Extract gate and up projection parts
|
||||
# since the weight is shuffled, we can slice directly
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end,
|
||||
...]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_blocks" in name:
|
||||
# Handle MLP down projection weights
|
||||
new_name = name.replace("down_proj_blocks", "w2_weight")
|
||||
# same flatten here, but since 2 mx4 value are packed in 1
|
||||
# uint8, divide by 2
|
||||
weight = weight.view(num_experts, -1,
|
||||
intermediate_size // 2).contiguous()
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[...,
|
||||
tp_rank_start // 2:tp_rank_end // 2]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "gate_up_proj_scales" in name:
|
||||
# Handle MLP gate and up projection weights scale
|
||||
new_name = name.replace("gate_up_proj_scales",
|
||||
"w13_weight_scale")
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end,
|
||||
...]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_scales" in name:
|
||||
# Handle MLP down projection weights
|
||||
new_name = name.replace("down_proj_scales", "w2_weight_scale")
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[..., tp_rank_start //
|
||||
mxfp4_block:tp_rank_end //
|
||||
mxfp4_block]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
elif "gate_up_proj_bias" in name:
|
||||
# Handle MLP gate and up projection biases
|
||||
new_name = name.replace("gate_up_proj_bias", "w13_bias")
|
||||
|
||||
# Extract gate and up projection bias parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:,
|
||||
2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_bias" in name:
|
||||
# Handle MLP down projection bias
|
||||
new_name = name.replace("down_proj_bias", "w2_bias")
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if use_ep:
|
||||
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
# (only load on rank 0 to avoid duplication)
|
||||
if tp_rank != 0:
|
||||
weight.zero_()
|
||||
weight_loader(param,
|
||||
weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None)
|
||||
loaded_params.add(new_name)
|
||||
elif "sinks" in name:
|
||||
# Handle attention sinks (distributed across ranks)
|
||||
name = name.replace("self_attn", "attn")
|
||||
param = params_dict[name]
|
||||
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
|
||||
shard_id = ("q" if "q_proj" in name else
|
||||
"k" if "k_proj" in name else "v")
|
||||
name = name.replace("self_attn", "attn")
|
||||
param_name = name.replace(f"{shard_id}_proj", "qkv")
|
||||
param = params_dict[param_name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, weight, loaded_shard_id=shard_id)
|
||||
loaded_params.add(param_name)
|
||||
else:
|
||||
# Handle all other weights with potential renaming
|
||||
renamed_name = maybe_rename(name)
|
||||
if renamed_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[renamed_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight)
|
||||
loaded_params.add(renamed_name)
|
||||
|
||||
return loaded_params
|
||||
@ -74,6 +74,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
||||
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
|
||||
"GptOssForCausalLM": ("gpt_oss", "GptOssForCausalLM"),
|
||||
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
||||
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
||||
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user