vllm/vllm/lora/layers/fused_moe.py
Jee Jee Li c069086b9c
[Bugfix] Fix getting device for MoE LoRA (#29475)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
2025-11-26 23:16:07 -08:00

719 lines
27 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm import envs
from vllm.config.lora import LoRAConfig
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.distributed.utils import divide
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
modular_marlin_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe,
try_get_optimal_moe_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod,
)
from .utils import _get_lora_device
class FusedMoEWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: FusedMoE) -> None:
super().__init__()
self.base_layer = base_layer
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.device = _get_lora_device(base_layer)
self._w13_slices = 2
self._inject_lora_into_fused_moe()
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
normalized_config = {}
for key, value in config.items():
if key.islower():
if key.startswith("block_"):
normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper()
else:
normalized_key = key.upper()
else:
normalized_key = key
normalized_config[normalized_key] = value
return normalized_config
def _get_lora_moe_configs(
self,
op_prefix: str,
num_loras: int,
rank: int,
num_slices: int,
M: int,
layer: FusedMoE,
top_k: int,
config_dtype: str,
):
if envs.VLLM_TUNED_CONFIG_FOLDER:
hidden_size = layer.hidden_size
intermediate_size = layer.intermediate_size_per_partition
shrink_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_shrink",
max_loras=num_loras,
batch=M,
hidden_size=hidden_size,
rank=rank,
num_slices=num_slices,
moe_intermediate_size=intermediate_size,
)
expand_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_expand",
max_loras=num_loras,
batch=M,
hidden_size=hidden_size, # lora_a_stacked.shape[-1],
rank=rank,
num_slices=num_slices,
moe_intermediate_size=intermediate_size, # lora_b_stacked.shape[-2],
)
else: # fall back to the default config
get_config_func = functools.partial(
try_get_optimal_moe_config,
layer.w13_weight.size(),
layer.w2_weight.size(),
top_k,
config_dtype,
block_shape=layer.quant_method.moe_quant_config.block_shape,
)
shrink_config = get_config_func(M)
expand_config = get_config_func(M)
shrink_config = self._normalize_keys(shrink_config)
expand_config = self._normalize_keys(expand_config)
return shrink_config, expand_config
def _inject_lora_into_fused_moe(self):
moe_state_dict = {}
top_k = self.base_layer.top_k
self.base_layer.ensure_moe_quant_config_init()
quant_config = self.base_layer.quant_method.moe_quant_config
m_fused_moe_fn = (
modular_triton_fused_moe(
quant_config, shared_experts=self.base_layer.shared_experts
)
if not quant_config.use_mxfp4_w4a16
else modular_marlin_fused_moe(
quant_config, shared_experts=self.base_layer.shared_experts
)
)
def fwd_decorator(layer, func):
def wrapper(*args, **kwargs):
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
moe_state_dict["expert_map"] = kwargs["expert_map"]
moe_state_dict["apply_router_weight_on_input"] = kwargs[
"apply_router_weight_on_input"
]
result = func(*args, **kwargs)
return result
return wrapper
def act_decorator(layer, func):
def wrapper(*args, **kwargs):
_, output, input = args
hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"]
curr_topk_ids = moe_state_dict["topk_ids"]
expert_map = moe_state_dict["expert_map"]
config_dtype = _get_config_dtype_str(
dtype=hidden_states.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)
max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w13",
num_loras=self.max_loras,
rank=max_lora_rank,
num_slices=self._w13_slices,
M=M,
layer=layer,
top_k=top_k,
config_dtype=config_dtype,
)
# get the block size of m from customized config or default config
(
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
) = self.punica_wrapper.moe_lora_align_block_size(
curr_topk_ids,
num_tokens,
shrink_config["BLOCK_SIZE_M"],
self.base_layer.local_num_experts,
self.max_loras,
self.adapter_enabled,
expert_map,
)
moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
moe_state_dict["expert_ids_lora"] = expert_ids_lora
moe_state_dict["num_tokens_post_padded_lora"] = (
num_tokens_post_padded_lora
)
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
#
self.punica_wrapper.add_lora_fused_moe(
input.view(-1, top_k, input.shape[-1]),
hidden_states,
self.w13_lora_a_stacked,
self.w13_lora_b_stacked,
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
shrink_config, ## pass the shrink config
expand_config, ## pass the expand config
self.adapter_enabled,
fully_sharded=self.fully_sharded,
)
result = func(*args, **kwargs)
moe_state_dict["intermediate_cache2"] = output
return result
return wrapper
def moe_sum_decorator(layer, func):
def wrapper(*args, **kwargs):
hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"]
config_dtype = _get_config_dtype_str(
dtype=hidden_states.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)
max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w2",
num_loras=self.max_loras,
rank=max_lora_rank,
num_slices=1,
M=M,
layer=layer,
top_k=top_k,
config_dtype=config_dtype,
)
sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
expert_ids_lora = moe_state_dict["expert_ids_lora"]
num_tokens_post_padded_lora = moe_state_dict[
"num_tokens_post_padded_lora"
]
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
intermediate_cache3 = args[0]
shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
self.punica_wrapper.add_lora_fused_moe(
intermediate_cache3,
intermediate_cache2,
self.w2_lora_a_stacked,
self.w2_lora_b_stacked,
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
shrink_config, ## pass the shrink config
expand_config, ## pass the expand config
self.adapter_enabled,
True,
fully_sharded=self.fully_sharded,
offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
)
result = func(*args, **kwargs)
return result
return wrapper
fused_experts = m_fused_moe_fn.fused_experts
m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
fused_experts.activation = act_decorator(
self.base_layer, fused_experts.activation
)
fused_experts.moe_sum = moe_sum_decorator(
self.base_layer, fused_experts.moe_sum
)
self.base_layer.quant_method = FusedMoEModularMethod(
self.base_layer.quant_method, m_fused_moe_fn
)
def _create_lora_a_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
):
self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank
if not self.fully_sharded
else divide(lora_config.max_lora_rank, self.tp_size),
self.base_layer.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self._w13_slices)
)
self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
self.base_layer.intermediate_size_per_partition,
),
dtype=lora_config.lora_dtype,
device=self.device,
),
)
def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self._w13_slices)
)
self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.hidden_size
if not self.fully_sharded
else divide(self.base_layer.hidden_size, self.tp_size),
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
),
)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)
self._create_lora_a_weights(max_loras, lora_config)
self._create_lora_b_weights(max_loras, lora_config)
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
# to create a dummy LoRA weights.
# TODO Optimize this section
self.lora_a_stacked = []
self.lora_b_stacked = []
for lora_id in range(max_loras):
for experts_id in range(self.base_layer.local_num_experts):
# gate_proj,down_proj,up_proj
self.lora_a_stacked.append(
self.w13_lora_a_stacked[0][lora_id][experts_id]
)
self.lora_a_stacked.append(
self.w2_lora_a_stacked[0][lora_id][experts_id]
)
self.lora_b_stacked.append(
self.w13_lora_b_stacked[0][lora_id][experts_id]
)
self.lora_b_stacked.append(
self.w2_lora_b_stacked[0][lora_id][experts_id]
)
self.lora_a_stacked.append(
self.w13_lora_a_stacked[1][lora_id][experts_id]
)
self.lora_b_stacked.append(
self.w13_lora_b_stacked[1][lora_id][experts_id]
)
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
for pos in range(self._w13_slices):
self.w13_lora_a_stacked[pos][index] = 0
self.w13_lora_b_stacked[pos][index] = 0
self.w2_lora_a_stacked[0][index] = 0
self.w2_lora_b_stacked[0][index] = 0
self.adapter_enabled[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
assert isinstance(lora_a, list)
assert isinstance(lora_b, list)
self.reset_lora(index)
self.adapter_enabled[index] = 1
for eid in range(len(lora_a) // 3):
w1_lora_a = lora_a[eid * 3]
w2_lora_a = lora_a[eid * 3 + 1]
w3_lora_a = lora_a[eid * 3 + 2]
w1_lora_b = lora_b[eid * 3]
w2_lora_b = lora_b[eid * 3 + 1]
w3_lora_b = lora_b[eid * 3 + 2]
# Handle the case of adding LoRA to only a subset of experts
if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None:
continue
if self.tp_size > 1:
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
w1_lora_b = w1_lora_b[start_idx:end_idx, :]
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
w2_lora_a = w2_lora_a[:, start_idx:end_idx]
if self.fully_sharded:
# Based on S-LoRA, we slice W1 and W3 A along the rank dim,
# and W2 B along the hidden_size dim.
w13_shard_size = self.w13_lora_a_stacked[0][index, eid].shape[0]
w13_start_idx = self.tp_rank * w13_shard_size
w13_end_idx = (self.tp_rank + 1) * w13_shard_size
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]
w2_shard_size = self.w2_lora_b_stacked[0][index, eid].shape[0]
w2_start_idx = self.tp_rank * w2_shard_size
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
# w1 lora_a
self.w13_lora_a_stacked[0][
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
].copy_(w1_lora_a, non_blocking=True)
# w3 lora_a
self.w13_lora_a_stacked[1][
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
].copy_(w3_lora_a, non_blocking=True)
# w1 lora_b
self.w13_lora_b_stacked[0][
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
].copy_(w1_lora_b, non_blocking=True)
# w3 lora_b
self.w13_lora_b_stacked[1][
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
].copy_(w3_lora_b, non_blocking=True)
self.w2_lora_a_stacked[0][
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
].copy_(w2_lora_a, non_blocking=True)
self.w2_lora_b_stacked[0][
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
].copy_(w2_lora_b, non_blocking=True)
def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs)
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
@property
def _shared_experts(self):
return self.base_layer._shared_experts
@property
def quant_method(self):
return self.base_layer.quant_method
@property
def is_internal_router(self) -> bool:
return self.base_layer.is_internal_router
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
# return type(source_layer) is FusedMoE
return type(source_layer) is FusedMoE and len(packed_modules_list) == 2
class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
def __init__(self, base_layer):
super().__init__(base_layer)
self._w13_slices = 1
def _create_lora_b_weights(self, max_loras, lora_config):
self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition * 2,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self._w13_slices)
)
self.w2_lora_b_stacked: tuple[torch.Tensor] = (
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.hidden_size
if not self.fully_sharded
else divide(self.base_layer.hidden_size, self.tp_size),
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
),
)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)
self._create_lora_a_weights(max_loras, lora_config)
self._create_lora_b_weights(max_loras, lora_config)
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1 or not self.fully_sharded:
return w13_lora_a
# w13_lora_a shape (num_experts,rank,input_size)
current_lora_rank = w13_lora_a.shape[1]
assert current_lora_rank % self.tp_size == 0
sliced_rank = current_lora_rank // self.tp_size
start_idx = self.tp_rank * sliced_rank
end_idx = (self.tp_rank + 1) * sliced_rank
return w13_lora_a[:, start_idx:end_idx, :]
def _slice_w13_b(self, w13_lora_b: torch.Tensor, is_interleave: bool = True):
if self.tp_size == 1:
return w13_lora_b
# w13_lora_b shape (num_experts,output_size,rank)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
if is_interleave:
# For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
# in the interleaved order, and corresponding LoRA need to be processed.
w1_lora_b = w13_lora_b[:, ::2, :]
w3_lora_b = w13_lora_b[:, 1::2, :]
sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
1, 2
)
else:
slice_size = w13_lora_b.shape[1] // 2
w1_lora_b = w13_lora_b[:, :slice_size, :]
w3_lora_b = w13_lora_b[:, slice_size:, :]
sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1:
return w2_lora_a
# w2_lora_a shape (num_experts,rank,input_size)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w2_lora_a[:, :, start_idx:end_idx]
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
if self.tp_size == 1 or not self.fully_sharded:
return w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
current_lora_size = w2_lora_b.shape[1]
sliced_size = current_lora_size // self.tp_size
start_idx = self.tp_rank * sliced_size
end_idx = (self.tp_rank + 1) * sliced_size
return w2_lora_b[:, start_idx:end_idx, :]
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
# Make mypy happy
assert isinstance(lora_a, list)
assert isinstance(lora_b, list)
assert len(lora_a) == len(lora_b) == 2
self.reset_lora(index)
self.adapter_enabled[index] = 1
num_experts = self.w13_lora_a_stacked[0].shape[1]
w13_lora_a, w2_lora_a = lora_a
w13_lora_b, w2_lora_b = lora_b
# (num_experts,rank,input_size)
w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1])
w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1])
# (output_size,num_experts,rank)
w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], num_experts, -1)
w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], num_experts, -1)
# (num_experts,output_size,rank)
w13_lora_b = w13_lora_b.permute(1, 0, 2)
w2_lora_b = w2_lora_b.permute(1, 0, 2)
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b, is_interleave=True)
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
self.w13_lora_a_stacked[0][
index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
].copy_(sliced_w13_lora_a, non_blocking=True)
self.w2_lora_a_stacked[0][
index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
].copy_(sliced_w2_lora_a, non_blocking=True)
self.w13_lora_b_stacked[0][
index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
].copy_(sliced_w13_lora_b, non_blocking=True)
self.w2_lora_b_stacked[0][
index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
].copy_(sliced_w2_lora_b, non_blocking=True)
@property
def w13_input_size(self):
"""
Full size
"""
return self.w13_lora_a_stacked[0].shape[-1]
@property
def w13_output_size(self):
"""
Full size
"""
return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size
@property
def w2_input_size(self):
"""
Full size
"""
return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size
@property
def w2_output_size(self):
"""
Full size
"""
return self.w2_lora_a_stacked[0].shape[-2]
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
return type(source_layer) is FusedMoE and len(packed_modules_list) == 1