mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-21 11:31:20 +08:00
Optimize model execution with CUDA graph (#1926)
Co-authored-by: Chen Shen <scv119@gmail.com> Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
eed74a558f
commit
37ca558103
@ -23,6 +23,7 @@ def main(args: argparse.Namespace):
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
dtype=args.dtype,
|
||||
enforce_eager=args.enforce_eager,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
@ -111,6 +112,9 @@ if __name__ == '__main__':
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument('--enforce-eager',
|
||||
action='store_true',
|
||||
help='enforce eager mode and disable CUDA graph')
|
||||
parser.add_argument(
|
||||
'--profile',
|
||||
action='store_true',
|
||||
|
||||
@ -69,7 +69,8 @@ def run_vllm(
|
||||
use_beam_search: bool,
|
||||
trust_remote_code: bool,
|
||||
dtype: str,
|
||||
max_model_len: Optional[int] = None,
|
||||
max_model_len: Optional[int],
|
||||
enforce_eager: bool,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(
|
||||
@ -81,6 +82,7 @@ def run_vllm(
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
@ -204,7 +206,7 @@ def main(args: argparse.Namespace):
|
||||
args.quantization, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype,
|
||||
args.max_model_len)
|
||||
args.max_model_len, args.enforce_eager)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@ -279,6 +281,9 @@ if __name__ == "__main__":
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument("--enforce-eager",
|
||||
action="store_true",
|
||||
help="enforce eager execution")
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
||||
@ -12,3 +12,4 @@ fastapi
|
||||
uvicorn[standard]
|
||||
pydantic == 1.10.13 # Required for OpenAI server.
|
||||
aioprometheus[starlette]
|
||||
cupy-cuda12x # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. # FIXME: Fix this in setup.py.
|
||||
|
||||
@ -49,6 +49,12 @@ class ModelConfig:
|
||||
output). If None, will be derived from the model.
|
||||
quantization: Quantization method that was used to quantize the model
|
||||
weights. If None, we assume the model weights are not quantized.
|
||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -65,6 +71,8 @@ class ModelConfig:
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
@ -76,6 +84,8 @@ class ModelConfig:
|
||||
self.revision = revision
|
||||
self.tokenizer_revision = tokenizer_revision
|
||||
self.quantization = quantization
|
||||
self.enforce_eager = enforce_eager
|
||||
self.max_context_len_to_capture = max_context_len_to_capture
|
||||
|
||||
if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
|
||||
# download model from ModelScope hub,
|
||||
@ -95,6 +105,7 @@ class ModelConfig:
|
||||
self._verify_load_format()
|
||||
self._verify_tokenizer_mode()
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
load_format = self.load_format.lower()
|
||||
@ -169,6 +180,12 @@ class ModelConfig:
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.")
|
||||
|
||||
def _verify_cuda_graph(self) -> None:
|
||||
if self.max_context_len_to_capture is None:
|
||||
self.max_context_len_to_capture = self.max_model_len
|
||||
self.max_context_len_to_capture = min(self.max_context_len_to_capture,
|
||||
self.max_model_len)
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
|
||||
@ -33,6 +33,8 @@ class EngineArgs:
|
||||
revision: Optional[str] = None
|
||||
tokenizer_revision: Optional[str] = None
|
||||
quantization: Optional[str] = None
|
||||
enforce_eager: bool = False
|
||||
max_context_len_to_capture: int = 8192
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
@ -182,6 +184,17 @@ class EngineArgs:
|
||||
choices=['awq', 'gptq', 'squeezellm', None],
|
||||
default=None,
|
||||
help='Method used to quantize the weights')
|
||||
parser.add_argument('--enforce-eager',
|
||||
action='store_true',
|
||||
help='Always use eager-mode PyTorch. If False, '
|
||||
'will use eager mode and CUDA graph in hybrid '
|
||||
'for maximal performance and flexibility.')
|
||||
parser.add_argument('--max-context-len-to-capture',
|
||||
type=int,
|
||||
default=EngineArgs.max_context_len_to_capture,
|
||||
help='maximum context length covered by CUDA '
|
||||
'graphs. When a sequence has context length '
|
||||
'larger than this, we fall back to eager mode.')
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -200,7 +213,8 @@ class EngineArgs:
|
||||
self.download_dir, self.load_format,
|
||||
self.dtype, self.seed, self.revision,
|
||||
self.tokenizer_revision, self.max_model_len,
|
||||
self.quantization)
|
||||
self.quantization, self.enforce_eager,
|
||||
self.max_context_len_to_capture)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space,
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
from vllm.utils import Counter
|
||||
from vllm.utils import Counter, get_open_port
|
||||
|
||||
if ray:
|
||||
from ray.air.util.torch_dist import init_torch_dist_process_group
|
||||
@ -84,6 +84,7 @@ class LLMEngine:
|
||||
f"load_format={model_config.load_format}, "
|
||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||
f"quantization={model_config.quantization}, "
|
||||
f"enforce_eager={model_config.enforce_eager}, "
|
||||
f"seed={model_config.seed})")
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
@ -189,6 +190,7 @@ class LLMEngine:
|
||||
))
|
||||
self._run_workers(
|
||||
"init_model",
|
||||
cupy_port=get_open_port(),
|
||||
get_all_outputs=True,
|
||||
)
|
||||
self._run_workers(
|
||||
@ -232,6 +234,9 @@ class LLMEngine:
|
||||
|
||||
# Initialize the cache.
|
||||
self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
||||
# Warm up the model. This includes capturing the model into CUDA graph
|
||||
# if enforce_eager is False.
|
||||
self._run_workers("warm_up_model")
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import socket
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
from vllm.utils import get_open_port, is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -43,12 +42,6 @@ if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
|
||||
def get_open_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def initialize_cluster(
|
||||
parallel_config: ParallelConfig,
|
||||
engine_use_ray: bool = False,
|
||||
|
||||
@ -56,6 +56,12 @@ class LLM:
|
||||
when their `best_of` sampling parameters are larger than 1. If all
|
||||
requests will have `best_of=1`, you can safely set this to 0.
|
||||
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -72,6 +78,8 @@ class LLM:
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: int = 8192,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
@ -89,6 +97,8 @@ class LLM:
|
||||
seed=seed,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
enforce_eager=enforce_eager,
|
||||
max_context_len_to_capture=max_context_len_to_capture,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_engine = LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
@ -21,12 +21,14 @@ class InputMetadata:
|
||||
max_context_len: Optional[int],
|
||||
context_lens: Optional[torch.Tensor],
|
||||
block_tables: Optional[torch.Tensor],
|
||||
use_cuda_graph: bool,
|
||||
) -> None:
|
||||
self.prompt_lens = prompt_lens
|
||||
self.max_context_len = max_context_len
|
||||
self.slot_mapping = slot_mapping
|
||||
self.context_lens = context_lens
|
||||
self.block_tables = block_tables
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
|
||||
self.is_prompt = len(prompt_lens) > 0
|
||||
# Set during the execution of the first attention op.
|
||||
@ -39,4 +41,5 @@ class InputMetadata:
|
||||
f"max_context_len={self.max_context_len}, "
|
||||
f"slot_mapping={self.slot_mapping}, "
|
||||
f"context_lens={self.context_lens}, "
|
||||
f"block_tables={self.block_tables})")
|
||||
f"block_tables={self.block_tables}, "
|
||||
f"use_cuda_graph={self.use_cuda_graph})")
|
||||
|
||||
@ -24,13 +24,10 @@ class PagedAttention(nn.Module):
|
||||
can either contain prompt tokens or generation tokens.
|
||||
The class does the following:
|
||||
|
||||
1. Wait for the cache operations (e.g., swap, copy) to finish. The cache
|
||||
operations are issued by the cache engine before executing the forward
|
||||
pass of the model, and they are executed asynchronously.
|
||||
2. Reshape and store the input key and value tensors in the KV cache.
|
||||
3. Perform (multi-head/multi-query/grouped-query) attention using either
|
||||
1. Reshape and store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention using either
|
||||
xformers or the PagedAttention custom op.
|
||||
4. Return the output tensor.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -67,7 +64,6 @@ class PagedAttention(nn.Module):
|
||||
key_cache: Optional[torch.Tensor],
|
||||
value_cache: Optional[torch.Tensor],
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
"""PagedAttention forward pass.
|
||||
|
||||
@ -80,7 +76,6 @@ class PagedAttention(nn.Module):
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for the inputs.
|
||||
cache_event: event to wait for the cache operations to finish.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
@ -89,10 +84,6 @@ class PagedAttention(nn.Module):
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
slot_mapping = input_metadata.slot_mapping.flatten()
|
||||
|
||||
if cache_event is not None:
|
||||
cache_event.wait()
|
||||
|
||||
# Reshape the keys and values and store them in the cache.
|
||||
# If key_cache and value_cache are not provided, the new key and value
|
||||
@ -104,7 +95,7 @@ class PagedAttention(nn.Module):
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
input_metadata.slot_mapping.flatten(),
|
||||
)
|
||||
|
||||
if input_metadata.is_prompt:
|
||||
@ -165,15 +156,20 @@ class PagedAttention(nn.Module):
|
||||
output = out.view_as(query)
|
||||
else:
|
||||
# Decoding run.
|
||||
output = _paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
if key_cache is not None and value_cache is not None:
|
||||
output = _paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
# This happens during the initial memory profiling run for
|
||||
# CUDA graphs.
|
||||
output = torch.zeros_like(query)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, seq_len, hidden_size)
|
||||
|
||||
@ -158,14 +158,12 @@ class AquilaAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -209,7 +207,6 @@ class AquilaDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
@ -219,7 +216,6 @@ class AquilaDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -258,18 +254,15 @@ class AquilaModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
@ -296,10 +289,9 @@ class AquilaForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -172,15 +172,13 @@ class BaiChuanAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.W_pack(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
if self.postion_embedding != "ALIBI":
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -221,7 +219,6 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -236,7 +233,6 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -273,19 +269,16 @@ class BaiChuanModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -311,10 +304,9 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -118,14 +118,12 @@ class BloomAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
del position_ids # Unused.
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
@ -184,7 +182,6 @@ class BloomBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
@ -201,7 +198,6 @@ class BloomBlock(nn.Module):
|
||||
hidden_states=layernorm_output,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
attention_output = attention_output + residual
|
||||
layernorm_output = self.post_attention_layernorm(attention_output)
|
||||
@ -250,19 +246,16 @@ class BloomModel(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
hidden_states = self.word_embeddings_layernorm(hidden_states)
|
||||
for i in range(len(self.h)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -288,10 +281,9 @@ class BloomForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -100,7 +100,6 @@ class GLMAttention(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
@ -113,7 +112,6 @@ class GLMAttention(nn.Module):
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
attn_output, _ = self.dense(context_layer)
|
||||
return attn_output
|
||||
@ -203,7 +201,6 @@ class GLMBlock(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# hidden_states: [num_tokens, h]
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
@ -214,7 +211,6 @@ class GLMBlock(nn.Module):
|
||||
position_ids=position_ids,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
|
||||
# Residual connection.
|
||||
@ -269,17 +265,14 @@ class GLMTransformer(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
for i in range(self.num_layers):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
kv_cache=kv_caches[i],
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
# Final layer norm.
|
||||
if self.post_layer_norm:
|
||||
@ -314,8 +307,7 @@ class ChatGLMModel(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
|
||||
# Run encoder.
|
||||
@ -324,9 +316,7 @@ class ChatGLMModel(nn.Module):
|
||||
position_ids=position_ids,
|
||||
kv_caches=kv_caches,
|
||||
input_metadata=input_metadata,
|
||||
cache_events=cache_events,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -350,10 +340,9 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -178,7 +178,6 @@ class FalconAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, bias = self.query_key_value(hidden_states)
|
||||
if bias is not None:
|
||||
@ -187,8 +186,7 @@ class FalconAttention(nn.Module):
|
||||
if self.use_rotary:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output, bias = self.dense(attn_output)
|
||||
return attn_output, bias
|
||||
|
||||
@ -266,8 +264,7 @@ class FalconDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
@ -282,7 +279,6 @@ class FalconDecoderLayer(nn.Module):
|
||||
hidden_states=attention_layernorm_out,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
if self.reduce_row_parallel_results and attention_bias is not None:
|
||||
attention_output += attention_bias
|
||||
@ -311,7 +307,6 @@ class FalconDecoderLayer(nn.Module):
|
||||
mlp_output += mlp_bias
|
||||
|
||||
output = mlp_output + residual
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@ -349,18 +344,15 @@ class FalconModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
for i in range(len(self.h)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -389,14 +381,12 @@ class FalconForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
input_metadata,
|
||||
cache_events,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -82,13 +82,12 @@ class GPT2Attention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata, cache_event)
|
||||
input_metadata)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -148,7 +147,6 @@ class GPT2Block(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
@ -156,7 +154,6 @@ class GPT2Block(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
@ -196,17 +193,14 @@ class GPT2Model(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
for i in range(len(self.h)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
|
||||
cache_event)
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -232,10 +226,9 @@ class GPT2LMHeadModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -95,7 +95,6 @@ class GPTBigCodeAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.split(
|
||||
@ -107,7 +106,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata, cache_event)
|
||||
input_metadata)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -167,7 +166,6 @@ class GPTBigCodeBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
@ -175,7 +173,6 @@ class GPTBigCodeBlock(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
@ -215,17 +212,14 @@ class GPTBigCodeModel(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
for i in range(len(self.h)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
|
||||
cache_event)
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -251,10 +245,9 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -94,14 +94,12 @@ class GPTJAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
attn_output, _ = self.out_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -156,7 +154,6 @@ class GPTJBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
@ -165,7 +162,6 @@ class GPTJBlock(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
hidden_states = attn_output + mlp_output + residual
|
||||
@ -196,18 +192,15 @@ class GPTJModel(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids)
|
||||
for i in range(len(self.h)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -238,10 +231,9 @@ class GPTJForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -92,14 +92,12 @@ class GPTNeoXAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
@ -155,7 +153,6 @@ class GPTNeoXLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.input_layernorm(hidden_states)
|
||||
attn_output = self.attention(
|
||||
@ -163,7 +160,6 @@ class GPTNeoXLayer(nn.Module):
|
||||
hidden_states=attn_input,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
|
||||
if self.use_parallel_residual:
|
||||
@ -210,18 +206,15 @@ class GPTNeoXModel(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_in(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
@ -250,10 +243,9 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -110,14 +110,12 @@ class InternLMAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -160,7 +158,6 @@ class InternLMDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -175,7 +172,6 @@ class InternLMDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -214,19 +210,16 @@ class InternLMModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -253,10 +246,9 @@ class InternLMForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -147,14 +147,12 @@ class LlamaAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -198,7 +196,6 @@ class LlamaDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -213,7 +210,6 @@ class LlamaDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -250,19 +246,16 @@ class LlamaModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -289,10 +282,9 @@ class LlamaForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -145,14 +145,12 @@ class MistralAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -193,7 +191,6 @@ class MistralDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -208,7 +205,6 @@ class MistralDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -246,19 +242,16 @@ class MistralModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -285,10 +278,9 @@ class MistralForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -253,14 +253,12 @@ class MixtralAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -297,7 +295,6 @@ class MixtralDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
@ -312,7 +309,6 @@ class MixtralDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -349,16 +345,14 @@ class MixtralModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i], input_metadata,
|
||||
cache_event, residual)
|
||||
residual)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
@ -383,10 +377,9 @@ class MixtralForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -117,7 +117,6 @@ class MPTAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
del position_ids # unused.
|
||||
qkv, _ = self.Wqkv(hidden_states)
|
||||
@ -128,8 +127,7 @@ class MPTAttention(nn.Module):
|
||||
q = self.q_ln(q)
|
||||
k = self.k_ln(k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -187,7 +185,6 @@ class MPTBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
x = self.norm_1(hidden_states)
|
||||
x = self.attn(
|
||||
@ -195,7 +192,6 @@ class MPTBlock(nn.Module):
|
||||
hidden_states=x,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
hidden_states = hidden_states + x
|
||||
x = self.norm_2(hidden_states)
|
||||
@ -235,18 +231,15 @@ class MPTModel(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids)
|
||||
for i in range(len(self.blocks)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
block = self.blocks[i]
|
||||
hidden_states = block(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.norm_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -274,10 +267,9 @@ class MPTForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -98,13 +98,12 @@ class OPTAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata, cache_event)
|
||||
input_metadata)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -154,7 +153,6 @@ class OPTDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
@ -163,8 +161,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event)
|
||||
input_metadata=input_metadata)
|
||||
hidden_states = residual + hidden_states
|
||||
# 350m applies layer norm AFTER attention
|
||||
if not self.do_layer_norm_before:
|
||||
@ -245,7 +242,6 @@ class OPTDecoder(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
pos_embeds = self.embed_positions(positions)
|
||||
@ -254,10 +250,8 @@ class OPTDecoder(nn.Module):
|
||||
hidden_states = inputs_embeds + pos_embeds
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
|
||||
cache_event)
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
|
||||
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
@ -282,10 +276,8 @@ class OPTModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
return self.decoder(input_ids, positions, kv_caches, input_metadata,
|
||||
cache_events)
|
||||
return self.decoder(input_ids, positions, kv_caches, input_metadata)
|
||||
|
||||
|
||||
class OPTForCausalLM(nn.Module):
|
||||
@ -308,10 +300,9 @@ class OPTForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -135,14 +135,12 @@ class PhiAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.Wqkv(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -195,7 +193,6 @@ class PhiLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln(hidden_states)
|
||||
@ -204,7 +201,6 @@ class PhiLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
||||
@ -231,18 +227,15 @@ class PhiModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embd(input_ids)
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
@ -277,10 +270,9 @@ class PhiForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
hidden_states = self.lm_head.ln(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -112,14 +112,12 @@ class QWenAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
|
||||
output, _ = self.c_proj(attn_output)
|
||||
return output
|
||||
@ -156,7 +154,6 @@ class QWenBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -170,7 +167,6 @@ class QWenBlock(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -206,19 +202,16 @@ class QWenModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.h)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
@ -245,10 +238,9 @@ class QWenLMHeadModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -146,14 +146,12 @@ class YiAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -195,7 +193,6 @@ class YiDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
@ -209,7 +206,6 @@ class YiDecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@ -245,19 +241,16 @@ class YiModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -284,10 +277,9 @@ class YiForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
input_metadata)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils import cupy_utils
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_group,
|
||||
is_custom_nccl_enabled_for_all_reduce,
|
||||
)
|
||||
|
||||
|
||||
@ -15,8 +17,12 @@ def tensor_model_parallel_all_reduce(input_):
|
||||
if get_tensor_model_parallel_world_size() == 1:
|
||||
return input_
|
||||
# All-reduce.
|
||||
torch.distributed.all_reduce(input_,
|
||||
group=get_tensor_model_parallel_group())
|
||||
if is_custom_nccl_enabled_for_all_reduce():
|
||||
# TODO: support multiple parallel groups.
|
||||
cupy_utils.all_reduce(input_)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_,
|
||||
group=get_tensor_model_parallel_group())
|
||||
return input_
|
||||
|
||||
|
||||
|
||||
115
vllm/model_executor/parallel_utils/cupy_utils.py
Normal file
115
vllm/model_executor/parallel_utils/cupy_utils.py
Normal file
@ -0,0 +1,115 @@
|
||||
"""CuPy utilities for all-reduce.
|
||||
|
||||
We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
|
||||
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
|
||||
CUDA graphs.
|
||||
|
||||
TODO: Remove this file when torch.distributed.all_reduce is fixed.
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
try:
|
||||
import cupy
|
||||
from cupyx.distributed import NCCLBackend
|
||||
from cupy.cuda import nccl
|
||||
except ImportError as e:
|
||||
cupy = e
|
||||
nccl = None
|
||||
|
||||
class NCCLBackend:
|
||||
...
|
||||
|
||||
|
||||
_OP_MAPPING = {
|
||||
ReduceOp.SUM: "sum",
|
||||
ReduceOp.PRODUCT: "prod",
|
||||
ReduceOp.MIN: "min",
|
||||
ReduceOp.MAX: "max",
|
||||
}
|
||||
|
||||
|
||||
class NCCLBackendWithBFloat16(NCCLBackend):
|
||||
# This is enough to add bfloat16 support for most operations,
|
||||
# but broadcast will fail (will require changes in compiled
|
||||
# cupy code).
|
||||
def _get_nccl_dtype_and_count(self, array, count=None):
|
||||
nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
|
||||
torch_dtype = getattr(array, "_torch_dtype", None)
|
||||
if torch_dtype is torch.bfloat16:
|
||||
nccl_dtype = nccl.NCCL_BFLOAT16
|
||||
return nccl_dtype, count
|
||||
|
||||
|
||||
_NCCL_BACKEND = None
|
||||
_WORLD_SIZE = 0
|
||||
|
||||
|
||||
def is_initialized() -> bool:
|
||||
"""Returns whether the NCCL backend is initialized."""
|
||||
return _NCCL_BACKEND is not None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_cupy_stream(stream: torch.cuda.Stream) -> None:
|
||||
"""Set the cuda stream for communication"""
|
||||
cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
|
||||
stream.device_index)
|
||||
with cupy_stream:
|
||||
yield
|
||||
|
||||
|
||||
def init_process_group(world_size: int, rank: int, host: str,
|
||||
port: int) -> None:
|
||||
"""Initializes the CuPy NCCL backend.
|
||||
|
||||
# TODO: handle NCCL timeouts.
|
||||
"""
|
||||
assert not is_initialized()
|
||||
|
||||
if isinstance(cupy, Exception):
|
||||
raise ImportError(
|
||||
"NCCLBackend is not available. Please install cupy.") from cupy
|
||||
|
||||
# TODO(woosuk): Create TP and PP process groups for CuPy.
|
||||
global _NCCL_BACKEND
|
||||
global _WORLD_SIZE
|
||||
assert world_size > 0, f"{world_size=} should be a positive integer"
|
||||
assert 0 <= rank < world_size, (
|
||||
f"{rank=} should be a integer between [0, {world_size})")
|
||||
|
||||
cupy.cuda.runtime.setDevice(torch.cuda.current_device())
|
||||
_NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
|
||||
_WORLD_SIZE = world_size
|
||||
|
||||
|
||||
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
||||
"""All-reduces the input tensor across the process group."""
|
||||
assert input_.is_cuda, f"{input_} should be a cuda tensor"
|
||||
# Hack to support bfloat16
|
||||
torch_dtype = input_.dtype
|
||||
if torch_dtype is torch.bfloat16:
|
||||
# We need to view as float16, otherwise
|
||||
# cupy will fail. This will not change
|
||||
# the underlying data.
|
||||
input_ = input_.view(torch.float16)
|
||||
cupy_input = cupy.asarray(input_)
|
||||
cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access
|
||||
_NCCL_BACKEND.all_reduce(in_array=cupy_input,
|
||||
out_array=cupy_input,
|
||||
op=_OP_MAPPING[op])
|
||||
|
||||
|
||||
def destroy_process_group() -> None:
|
||||
"""Destroys the NCCL backend."""
|
||||
global _NCCL_BACKEND
|
||||
global _WORLD_SIZE
|
||||
_NCCL_BACKEND = None
|
||||
_WORLD_SIZE = 0
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Returns the world size."""
|
||||
return _WORLD_SIZE
|
||||
@ -3,9 +3,12 @@
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
"""Tensor and pipeline parallel groups."""
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils import cupy_utils
|
||||
|
||||
# Tensor model parallel group that the current rank belongs to.
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
# Pipeline model parallel group that the current rank belongs to.
|
||||
@ -177,3 +180,37 @@ def destroy_model_parallel():
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
global _PIPELINE_GLOBAL_RANKS
|
||||
_PIPELINE_GLOBAL_RANKS = None
|
||||
|
||||
# Destroy the cupy states if any.
|
||||
cupy_utils.destroy_process_group()
|
||||
|
||||
|
||||
# Whether to use cupy for nccl all reduce.
|
||||
# We use cupy for all reduce when using CUDA graph, because torch.distributed
|
||||
# is not well supported by CUDA graph.
|
||||
_ENABLE_CUPY_FOR_ALL_REDUCE = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def with_custom_nccl_for_all_reduce():
|
||||
"""use custom nccl instead of torch.distributed for all reduce"""
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_size == 1:
|
||||
# No-op.
|
||||
# NOTE(woosuk): We don't initialize CuPy when tp_size is 1.
|
||||
yield
|
||||
else:
|
||||
global _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
old = _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
_ENABLE_CUPY_FOR_ALL_REDUCE = True
|
||||
|
||||
stream = torch.cuda.current_stream()
|
||||
with cupy_utils.set_cupy_stream(stream):
|
||||
yield
|
||||
_ENABLE_CUPY_FOR_ALL_REDUCE = old
|
||||
|
||||
|
||||
def is_custom_nccl_enabled_for_all_reduce():
|
||||
"""check if custom nccl is enabled for all reduce"""
|
||||
global _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
return _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import enum
|
||||
import socket
|
||||
import uuid
|
||||
from platform import uname
|
||||
|
||||
@ -52,3 +53,9 @@ def random_uuid() -> str:
|
||||
def in_wsl() -> bool:
|
||||
# Reference: https://github.com/microsoft/WSL/issues/4071
|
||||
return "microsoft" in " ".join(uname()).lower()
|
||||
|
||||
|
||||
def get_open_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
@ -1,16 +1,25 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import time
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
with_custom_nccl_for_all_reduce)
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
_PAD_SLOT_ID = -1
|
||||
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
|
||||
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
@ -32,12 +41,31 @@ class ModelRunner:
|
||||
self.model = None
|
||||
self.block_size = None # Set after initial profiling.
|
||||
|
||||
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||
self.graph_memory_pool = None # Set during graph capture.
|
||||
|
||||
self.max_context_len_to_capture = (
|
||||
self.model_config.max_context_len_to_capture
|
||||
if self.model_config is not None else 0)
|
||||
# When using CUDA graph, the input block tables must be padded to
|
||||
# max_context_len_to_capture. However, creating the block table in
|
||||
# Python can be expensive. To optimize this, we cache the block table
|
||||
# in numpy and only copy the actual input content at every iteration.
|
||||
# The shape of the cached block table will be
|
||||
# (max batch size to capture, max context len to capture / block size).
|
||||
self.graph_block_tables = None # Set after initial profiling.
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(self.model_config)
|
||||
|
||||
def set_block_size(self, block_size: int) -> None:
|
||||
self.block_size = block_size
|
||||
|
||||
max_num_blocks = (self.max_context_len_to_capture + block_size -
|
||||
1) // block_size
|
||||
self.graph_block_tables = np.zeros(
|
||||
(max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
@ -111,6 +139,7 @@ class ModelRunner:
|
||||
max_context_len=None,
|
||||
context_lens=None,
|
||||
block_tables=None,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
return input_tokens, input_positions, input_metadata
|
||||
|
||||
@ -154,27 +183,62 @@ class ModelRunner:
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
block_tables.append(block_table)
|
||||
|
||||
batch_size = len(input_tokens)
|
||||
max_context_len = max(context_lens)
|
||||
use_captured_graph = (
|
||||
not self.model_config.enforce_eager
|
||||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
and max_context_len <= self.max_context_len_to_capture)
|
||||
if use_captured_graph:
|
||||
# Pad the input tokens, positions, and slot mapping to match the
|
||||
# batch size of the captured graph.
|
||||
graph_batch_size = _get_graph_batch_size(batch_size)
|
||||
assert graph_batch_size >= batch_size
|
||||
for _ in range(graph_batch_size - batch_size):
|
||||
input_tokens.append([])
|
||||
input_positions.append([])
|
||||
slot_mapping.append([])
|
||||
context_lens.append(1)
|
||||
block_tables.append([])
|
||||
batch_size = graph_batch_size
|
||||
|
||||
# When using CUDA graph, we don't need to make the tensors on the GPU
|
||||
# because they will be eventually copied to the designated GPU buffer.
|
||||
device = "cpu" if use_captured_graph else "cuda"
|
||||
input_tokens = _make_tensor_with_pad(input_tokens,
|
||||
max_len=1,
|
||||
pad=0,
|
||||
dtype=torch.long)
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
input_positions = _make_tensor_with_pad(input_positions,
|
||||
max_len=1,
|
||||
pad=0,
|
||||
dtype=torch.long)
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
slot_mapping = _make_tensor_with_pad(slot_mapping,
|
||||
max_len=1,
|
||||
pad=_PAD_SLOT_ID,
|
||||
dtype=torch.long)
|
||||
max_context_len = max(context_lens)
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
context_lens = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
max_block_table_len = max([len(t) for t in block_tables])
|
||||
block_tables = _make_tensor_with_pad(block_tables,
|
||||
max_len=max_block_table_len,
|
||||
pad=0,
|
||||
dtype=torch.int)
|
||||
device=device)
|
||||
|
||||
if use_captured_graph:
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
input_block_tables = self.graph_block_tables[:batch_size]
|
||||
for i, block_table in enumerate(block_tables):
|
||||
if block_table:
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
block_tables = torch.from_numpy(input_block_tables).to(device)
|
||||
else:
|
||||
block_tables = _make_tensor_with_pad(
|
||||
block_tables,
|
||||
max_len=max_context_len,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
)
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
prompt_lens=[],
|
||||
@ -182,6 +246,7 @@ class ModelRunner:
|
||||
max_context_len=max_context_len,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
return input_tokens, input_positions, input_metadata
|
||||
|
||||
@ -260,12 +325,11 @@ class ModelRunner:
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
cache_events: Optional[List[torch.cuda.Event]] = None,
|
||||
) -> SamplerOutput:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
# Prepare input tensors.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
inputs = self._prepare_prompt(seq_group_metadata_list)
|
||||
input_tokens, input_positions, input_metadata = inputs
|
||||
@ -276,12 +340,16 @@ class ModelRunner:
|
||||
input_metadata.prompt_lens)
|
||||
|
||||
# Execute the model.
|
||||
hidden_states = self.model(
|
||||
if input_metadata.use_cuda_graph:
|
||||
graph_batch_size = input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[graph_batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
hidden_states = model_executable(
|
||||
input_ids=input_tokens,
|
||||
positions=input_positions,
|
||||
kv_caches=kv_caches,
|
||||
input_metadata=input_metadata,
|
||||
cache_events=cache_events,
|
||||
)
|
||||
|
||||
# Sample the next token.
|
||||
@ -319,8 +387,139 @@ class ModelRunner:
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches = [(None, None)] * num_layers
|
||||
self.execute_model(seqs, kv_caches)
|
||||
torch.cuda.synchronize()
|
||||
return
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, kv_caches: List[KVCache]) -> None:
|
||||
assert not self.model_config.enforce_eager
|
||||
logger.info("Capturing the model for CUDA graphs. This may lead to "
|
||||
"unexpected consequences if the model is not static. To "
|
||||
"run the model in eager mode, set 'enforce_eager=True' or "
|
||||
"use '--enforce-eager' in the CLI.")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Prepare dummy inputs. These will be reused for all batch sizes.
|
||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
||||
input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
|
||||
input_positions = torch.zeros(max_batch_size, 1,
|
||||
dtype=torch.long).cuda()
|
||||
slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
|
||||
slot_mapping.fill_(_PAD_SLOT_ID)
|
||||
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
|
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE):
|
||||
# Create dummy input_metadata.
|
||||
input_metadata = InputMetadata(
|
||||
prompt_lens=[],
|
||||
slot_mapping=slot_mapping[:batch_size],
|
||||
max_context_len=self.max_context_len_to_capture,
|
||||
context_lens=context_lens[:batch_size],
|
||||
block_tables=block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
|
||||
graph_runner = CUDAGraphRunner(self.model)
|
||||
graph_runner.capture(
|
||||
input_tokens[:batch_size],
|
||||
input_positions[:batch_size],
|
||||
kv_caches,
|
||||
input_metadata,
|
||||
memory_pool=self.graph_memory_pool,
|
||||
)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[batch_size] = graph_runner
|
||||
|
||||
end_time = time.perf_counter()
|
||||
elapsed_time = end_time - start_time
|
||||
# This usually takes < 10 seconds.
|
||||
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
|
||||
|
||||
|
||||
class CUDAGraphRunner:
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
self.model = model
|
||||
self.graph = None
|
||||
self.input_buffers: Dict[str, torch.Tensor] = {}
|
||||
self.output_buffers: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def capture(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
memory_pool,
|
||||
) -> None:
|
||||
assert self.graph is None
|
||||
# Run the model once without capturing the graph.
|
||||
# This is to make sure that the captured graph does not include the
|
||||
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
||||
with with_custom_nccl_for_all_reduce():
|
||||
self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
input_metadata,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture the graph.
|
||||
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
|
||||
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
|
||||
with with_custom_nccl_for_all_reduce():
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
input_metadata,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Save the input and output buffers.
|
||||
self.input_buffers = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
"kv_caches": kv_caches,
|
||||
"slot_mapping": input_metadata.slot_mapping,
|
||||
"context_lens": input_metadata.context_lens,
|
||||
"block_tables": input_metadata.block_tables,
|
||||
}
|
||||
self.output_buffers = {"hidden_states": hidden_states}
|
||||
return
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
# KV caches are fixed tensors, so we don't need to copy them.
|
||||
del kv_caches
|
||||
|
||||
# Copy the input tensors to the input buffers.
|
||||
self.input_buffers["input_ids"].copy_(input_ids)
|
||||
self.input_buffers["positions"].copy_(positions)
|
||||
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping)
|
||||
self.input_buffers["context_lens"].copy_(input_metadata.context_lens)
|
||||
self.input_buffers["block_tables"].copy_(input_metadata.block_tables)
|
||||
|
||||
# Run the graph.
|
||||
self.graph.replay()
|
||||
|
||||
# Return the output tensor.
|
||||
return self.output_buffers["hidden_states"]
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
|
||||
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||
assert len(x) <= max_len
|
||||
@ -332,6 +531,16 @@ def _make_tensor_with_pad(
|
||||
max_len: int,
|
||||
pad: int,
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
) -> torch.Tensor:
|
||||
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
|
||||
return torch.tensor(padded_x, dtype=dtype, device="cuda")
|
||||
return torch.tensor(padded_x, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def _get_graph_batch_size(batch_size: int) -> int:
|
||||
if batch_size <= 2:
|
||||
return batch_size
|
||||
elif batch_size <= 4:
|
||||
return 4
|
||||
else:
|
||||
return (batch_size + 7) // 8 * 8
|
||||
|
||||
@ -8,6 +8,7 @@ import torch.distributed
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.parallel_utils import cupy_utils
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
initialize_model_parallel)
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
@ -46,7 +47,7 @@ class Worker:
|
||||
self.cache_events = None
|
||||
self.gpu_cache = None
|
||||
|
||||
def init_model(self):
|
||||
def init_model(self, cupy_port: Optional[int] = None):
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
# Env vars will be set by Ray.
|
||||
@ -62,7 +63,7 @@ class Worker:
|
||||
|
||||
# Initialize the distributed environment.
|
||||
_init_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method)
|
||||
cupy_port, self.distributed_init_method)
|
||||
|
||||
# Initialize the model.
|
||||
set_random_seed(self.model_config.seed)
|
||||
@ -100,10 +101,6 @@ class Worker:
|
||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
||||
@ -114,6 +111,13 @@ class Worker:
|
||||
self.gpu_cache = self.cache_engine.gpu_cache
|
||||
self.model_runner.set_block_size(self.cache_engine.block_size)
|
||||
|
||||
def warm_up_model(self) -> None:
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model(self.gpu_cache)
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@ -136,21 +140,24 @@ class Worker:
|
||||
|
||||
cache_events = self.cache_events if issued_cache_op else None
|
||||
|
||||
# Wait for cache operations to finish.
|
||||
# TODO(woosuk): Profile swapping overhead and optimize if needed.
|
||||
if cache_events is not None:
|
||||
for event in cache_events:
|
||||
event.wait()
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if not seq_group_metadata_list:
|
||||
if cache_events is not None:
|
||||
for event in cache_events:
|
||||
event.wait()
|
||||
return {}
|
||||
|
||||
output = self.model_runner.execute_model(seq_group_metadata_list,
|
||||
self.gpu_cache, cache_events)
|
||||
self.gpu_cache)
|
||||
return output
|
||||
|
||||
|
||||
def _init_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
cupy_port: Optional[int],
|
||||
distributed_init_method: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
@ -173,8 +180,29 @@ def _init_distributed_environment(
|
||||
init_method=distributed_init_method,
|
||||
)
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
if cupy_utils.is_initialized():
|
||||
cupy_world_size = cupy_utils.get_world_size()
|
||||
if cupy_world_size != parallel_config.world_size:
|
||||
raise RuntimeError(
|
||||
"cupy.distributed is already initialized but the cupy world "
|
||||
"size does not match parallel_config.world_size "
|
||||
f"({cupy_world_size} vs. {parallel_config.world_size}).")
|
||||
elif parallel_config.world_size > 1:
|
||||
# NOTE(woosuk): We don't initialize CuPy process group when world size
|
||||
# is 1.
|
||||
# TODO(woosuk): Support multi-node connection.
|
||||
cupy_utils.init_process_group(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
host="localhost",
|
||||
port=cupy_port,
|
||||
)
|
||||
|
||||
if parallel_config.world_size > 1:
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
cupy_utils.all_reduce(torch.zeros(1).cuda())
|
||||
|
||||
initialize_model_parallel(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user