Optimize model execution with CUDA graph (#1926)

Co-authored-by: Chen Shen <scv119@gmail.com>
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Woosuk Kwon 2023-12-16 21:12:08 -08:00 committed by GitHub
parent eed74a558f
commit 37ca558103
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 557 additions and 254 deletions

View File

@ -23,6 +23,7 @@ def main(args: argparse.Namespace):
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
enforce_eager=args.enforce_eager,
)
sampling_params = SamplingParams(
@ -111,6 +112,9 @@ if __name__ == '__main__':
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--enforce-eager',
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument(
'--profile',
action='store_true',

View File

@ -69,7 +69,8 @@ def run_vllm(
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int] = None,
max_model_len: Optional[int],
enforce_eager: bool,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
@ -81,6 +82,7 @@ def run_vllm(
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
enforce_eager=enforce_eager,
)
# Add the requests to the engine.
@ -204,7 +206,7 @@ def main(args: argparse.Namespace):
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len)
args.max_model_len, args.enforce_eager)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@ -279,6 +281,9 @@ if __name__ == "__main__":
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model

View File

@ -12,3 +12,4 @@ fastapi
uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server.
aioprometheus[starlette]
cupy-cuda12x # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. # FIXME: Fix this in setup.py.

View File

@ -49,6 +49,12 @@ class ModelConfig:
output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
"""
def __init__(
@ -65,6 +71,8 @@ class ModelConfig:
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
@ -76,6 +84,8 @@ class ModelConfig:
self.revision = revision
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
# download model from ModelScope hub,
@ -95,6 +105,7 @@ class ModelConfig:
self._verify_load_format()
self._verify_tokenizer_mode()
self._verify_quantization()
self._verify_cuda_graph()
def _verify_load_format(self) -> None:
load_format = self.load_format.lower()
@ -169,6 +180,12 @@ class ModelConfig:
"optimized yet. The speed can be slower than "
"non-quantized models.")
def _verify_cuda_graph(self) -> None:
if self.max_context_len_to_capture is None:
self.max_context_len_to_capture = self.max_model_len
self.max_context_len_to_capture = min(self.max_context_len_to_capture,
self.max_model_len)
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",

View File

@ -33,6 +33,8 @@ class EngineArgs:
revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: bool = False
max_context_len_to_capture: int = 8192
def __post_init__(self):
if self.tokenizer is None:
@ -182,6 +184,17 @@ class EngineArgs:
choices=['awq', 'gptq', 'squeezellm', None],
default=None,
help='Method used to quantize the weights')
parser.add_argument('--enforce-eager',
action='store_true',
help='Always use eager-mode PyTorch. If False, '
'will use eager mode and CUDA graph in hybrid '
'for maximal performance and flexibility.')
parser.add_argument('--max-context-len-to-capture',
type=int,
default=EngineArgs.max_context_len_to_capture,
help='maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.')
return parser
@classmethod
@ -200,7 +213,8 @@ class EngineArgs:
self.download_dir, self.load_format,
self.dtype, self.seed, self.revision,
self.tokenizer_revision, self.max_model_len,
self.quantization)
self.quantization, self.enforce_eager,
self.max_context_len_to_capture)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space,

View File

@ -17,7 +17,7 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
from vllm.utils import Counter
from vllm.utils import Counter, get_open_port
if ray:
from ray.air.util.torch_dist import init_torch_dist_process_group
@ -84,6 +84,7 @@ class LLMEngine:
f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"quantization={model_config.quantization}, "
f"enforce_eager={model_config.enforce_eager}, "
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.
@ -189,6 +190,7 @@ class LLMEngine:
))
self._run_workers(
"init_model",
cupy_port=get_open_port(),
get_all_outputs=True,
)
self._run_workers(
@ -232,6 +234,9 @@ class LLMEngine:
# Initialize the cache.
self._run_workers("init_cache_engine", cache_config=self.cache_config)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self._run_workers("warm_up_model")
@classmethod
def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":

View File

@ -1,9 +1,8 @@
import socket
from typing import Optional, Tuple, TYPE_CHECKING
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import is_hip
from vllm.utils import get_open_port, is_hip
logger = init_logger(__name__)
@ -43,12 +42,6 @@ if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def initialize_cluster(
parallel_config: ParallelConfig,
engine_use_ray: bool = False,

View File

@ -56,6 +56,12 @@ class LLM:
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
"""
def __init__(
@ -72,6 +78,8 @@ class LLM:
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: int = 8192,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
@ -89,6 +97,8 @@ class LLM:
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(engine_args)

View File

@ -21,12 +21,14 @@ class InputMetadata:
max_context_len: Optional[int],
context_lens: Optional[torch.Tensor],
block_tables: Optional[torch.Tensor],
use_cuda_graph: bool,
) -> None:
self.prompt_lens = prompt_lens
self.max_context_len = max_context_len
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.block_tables = block_tables
self.use_cuda_graph = use_cuda_graph
self.is_prompt = len(prompt_lens) > 0
# Set during the execution of the first attention op.
@ -39,4 +41,5 @@ class InputMetadata:
f"max_context_len={self.max_context_len}, "
f"slot_mapping={self.slot_mapping}, "
f"context_lens={self.context_lens}, "
f"block_tables={self.block_tables})")
f"block_tables={self.block_tables}, "
f"use_cuda_graph={self.use_cuda_graph})")

View File

@ -24,13 +24,10 @@ class PagedAttention(nn.Module):
can either contain prompt tokens or generation tokens.
The class does the following:
1. Wait for the cache operations (e.g., swap, copy) to finish. The cache
operations are issued by the cache engine before executing the forward
pass of the model, and they are executed asynchronously.
2. Reshape and store the input key and value tensors in the KV cache.
3. Perform (multi-head/multi-query/grouped-query) attention using either
1. Reshape and store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention using either
xformers or the PagedAttention custom op.
4. Return the output tensor.
3. Return the output tensor.
"""
def __init__(
@ -67,7 +64,6 @@ class PagedAttention(nn.Module):
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
"""PagedAttention forward pass.
@ -80,7 +76,6 @@ class PagedAttention(nn.Module):
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
@ -89,10 +84,6 @@ class PagedAttention(nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
slot_mapping = input_metadata.slot_mapping.flatten()
if cache_event is not None:
cache_event.wait()
# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
@ -104,7 +95,7 @@ class PagedAttention(nn.Module):
value,
key_cache,
value_cache,
slot_mapping,
input_metadata.slot_mapping.flatten(),
)
if input_metadata.is_prompt:
@ -165,15 +156,20 @@ class PagedAttention(nn.Module):
output = out.view_as(query)
else:
# Decoding run.
output = _paged_attention(
query,
key_cache,
value_cache,
input_metadata,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)
if key_cache is not None and value_cache is not None:
output = _paged_attention(
query,
key_cache,
value_cache,
input_metadata,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)
else:
# This happens during the initial memory profiling run for
# CUDA graphs.
output = torch.zeros_like(query)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

View File

@ -158,14 +158,12 @@ class AquilaAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
@ -209,7 +207,6 @@ class AquilaDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
@ -219,7 +216,6 @@ class AquilaDecoderLayer(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
@ -258,18 +254,15 @@ class AquilaModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
@ -296,10 +289,9 @@ class AquilaForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -172,15 +172,13 @@ class BaiChuanAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
@ -221,7 +219,6 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@ -236,7 +233,6 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# Fully Connected
@ -273,19 +269,16 @@ class BaiChuanModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -311,10 +304,9 @@ class BaiChuanBaseForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -118,14 +118,12 @@ class BloomAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
del position_ids # Unused.
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.dense(attn_output)
return output
@ -184,7 +182,6 @@ class BloomBlock(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
@ -201,7 +198,6 @@ class BloomBlock(nn.Module):
hidden_states=layernorm_output,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
attention_output = attention_output + residual
layernorm_output = self.post_attention_layernorm(attention_output)
@ -250,19 +246,16 @@ class BloomModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
@ -288,10 +281,9 @@ class BloomForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -100,7 +100,6 @@ class GLMAttention(nn.Module):
position_ids: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@ -113,7 +112,6 @@ class GLMAttention(nn.Module):
key_cache,
value_cache,
input_metadata,
cache_event,
)
attn_output, _ = self.dense(context_layer)
return attn_output
@ -203,7 +201,6 @@ class GLMBlock(nn.Module):
position_ids: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
@ -214,7 +211,6 @@ class GLMBlock(nn.Module):
position_ids=position_ids,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# Residual connection.
@ -269,17 +265,14 @@ class GLMTransformer(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
for i in range(self.num_layers):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
position_ids=position_ids,
kv_cache=kv_caches[i],
input_metadata=input_metadata,
cache_event=cache_event,
)
# Final layer norm.
if self.post_layer_norm:
@ -314,8 +307,7 @@ class ChatGLMModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
):
) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids)
# Run encoder.
@ -324,9 +316,7 @@ class ChatGLMModel(nn.Module):
position_ids=position_ids,
kv_caches=kv_caches,
input_metadata=input_metadata,
cache_events=cache_events,
)
return hidden_states
@ -350,10 +340,9 @@ class ChatGLMForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -178,7 +178,6 @@ class FalconAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, bias = self.query_key_value(hidden_states)
if bias is not None:
@ -187,8 +186,7 @@ class FalconAttention(nn.Module):
if self.use_rotary:
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
attn_output, bias = self.dense(attn_output)
return attn_output, bias
@ -266,8 +264,7 @@ class FalconDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
):
) -> torch.Tensor:
residual = hidden_states
if self.config.new_decoder_architecture:
@ -282,7 +279,6 @@ class FalconDecoderLayer(nn.Module):
hidden_states=attention_layernorm_out,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
if self.reduce_row_parallel_results and attention_bias is not None:
attention_output += attention_bias
@ -311,7 +307,6 @@ class FalconDecoderLayer(nn.Module):
mlp_output += mlp_bias
output = mlp_output + residual
return output
@ -349,18 +344,15 @@ class FalconModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
@ -389,14 +381,12 @@ class FalconForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
positions,
kv_caches,
input_metadata,
cache_events,
)
return hidden_states

View File

@ -82,13 +82,12 @@ class GPT2Attention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
input_metadata)
attn_output, _ = self.c_proj(attn_output)
return attn_output
@ -148,7 +147,6 @@ class GPT2Block(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
@ -156,7 +154,6 @@ class GPT2Block(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# residual connection
hidden_states = attn_output + residual
@ -196,17 +193,14 @@ class GPT2Model(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
cache_event)
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
hidden_states = self.ln_f(hidden_states)
return hidden_states
@ -232,10 +226,9 @@ class GPT2LMHeadModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -95,7 +95,6 @@ class GPTBigCodeAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split(
@ -107,7 +106,7 @@ class GPTBigCodeAttention(nn.Module):
)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
input_metadata)
attn_output, _ = self.c_proj(attn_output)
return attn_output
@ -167,7 +166,6 @@ class GPTBigCodeBlock(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
@ -175,7 +173,6 @@ class GPTBigCodeBlock(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# residual connection
hidden_states = attn_output + residual
@ -215,17 +212,14 @@ class GPTBigCodeModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
cache_event)
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
hidden_states = self.ln_f(hidden_states)
return hidden_states
@ -251,10 +245,9 @@ class GPTBigCodeForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -94,14 +94,12 @@ class GPTJAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
attn_output, _ = self.out_proj(attn_output)
return attn_output
@ -156,7 +154,6 @@ class GPTJBlock(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
@ -165,7 +162,6 @@ class GPTJBlock(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
mlp_output = self.mlp(hidden_states)
hidden_states = attn_output + mlp_output + residual
@ -196,18 +192,15 @@ class GPTJModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
@ -238,10 +231,9 @@ class GPTJForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -92,14 +92,12 @@ class GPTNeoXAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.dense(attn_output)
return output
@ -155,7 +153,6 @@ class GPTNeoXLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
attn_input = self.input_layernorm(hidden_states)
attn_output = self.attention(
@ -163,7 +160,6 @@ class GPTNeoXLayer(nn.Module):
hidden_states=attn_input,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
if self.use_parallel_residual:
@ -210,18 +206,15 @@ class GPTNeoXModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
@ -250,10 +243,9 @@ class GPTNeoXForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -110,14 +110,12 @@ class InternLMAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
@ -160,7 +158,6 @@ class InternLMDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@ -175,7 +172,6 @@ class InternLMDecoderLayer(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# Fully Connected
@ -214,19 +210,16 @@ class InternLMModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -253,10 +246,9 @@ class InternLMForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -147,14 +147,12 @@ class LlamaAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
@ -198,7 +196,6 @@ class LlamaDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@ -213,7 +210,6 @@ class LlamaDecoderLayer(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# Fully Connected
@ -250,19 +246,16 @@ class LlamaModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -289,10 +282,9 @@ class LlamaForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -145,14 +145,12 @@ class MistralAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
@ -193,7 +191,6 @@ class MistralDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@ -208,7 +205,6 @@ class MistralDecoderLayer(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# Fully Connected
@ -246,19 +242,16 @@ class MistralModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -285,10 +278,9 @@ class MistralForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -253,14 +253,12 @@ class MixtralAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
@ -297,7 +295,6 @@ class MixtralDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
@ -312,7 +309,6 @@ class MixtralDecoderLayer(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# Fully Connected
@ -349,16 +345,14 @@ class MixtralModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], input_metadata,
cache_event, residual)
residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@ -383,10 +377,9 @@ class MixtralForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -117,7 +117,6 @@ class MPTAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
del position_ids # unused.
qkv, _ = self.Wqkv(hidden_states)
@ -128,8 +127,7 @@ class MPTAttention(nn.Module):
q = self.q_ln(q)
k = self.k_ln(k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.out_proj(attn_output)
return output
@ -187,7 +185,6 @@ class MPTBlock(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
x = self.norm_1(hidden_states)
x = self.attn(
@ -195,7 +192,6 @@ class MPTBlock(nn.Module):
hidden_states=x,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = hidden_states + x
x = self.norm_2(hidden_states)
@ -235,18 +231,15 @@ class MPTModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.blocks)):
cache_event = None if cache_events is None else cache_events[i]
block = self.blocks[i]
hidden_states = block(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm_f(hidden_states)
return hidden_states
@ -274,10 +267,9 @@ class MPTForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -98,13 +98,12 @@ class OPTAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
input_metadata)
output, _ = self.out_proj(attn_output)
return output
@ -154,7 +153,6 @@ class OPTDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
@ -163,8 +161,7 @@ class OPTDecoderLayer(nn.Module):
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event)
input_metadata=input_metadata)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
@ -245,7 +242,6 @@ class OPTDecoder(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions)
@ -254,10 +250,8 @@ class OPTDecoder(nn.Module):
hidden_states = inputs_embeds + pos_embeds
for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
cache_event)
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
@ -282,10 +276,8 @@ class OPTModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
return self.decoder(input_ids, positions, kv_caches, input_metadata,
cache_events)
return self.decoder(input_ids, positions, kv_caches, input_metadata)
class OPTForCausalLM(nn.Module):
@ -308,10 +300,9 @@ class OPTForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -135,14 +135,12 @@ class PhiAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.out_proj(attn_output)
return output
@ -195,7 +193,6 @@ class PhiLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln(hidden_states)
@ -204,7 +201,6 @@ class PhiLayer(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = attn_outputs + feed_forward_hidden_states + residual
@ -231,18 +227,15 @@ class PhiModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embd(input_ids)
for i in range(self.config.num_hidden_layers):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
return hidden_states
@ -277,10 +270,9 @@ class PhiForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
hidden_states = self.lm_head.ln(hidden_states)
return hidden_states

View File

@ -112,14 +112,12 @@ class QWenAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.c_proj(attn_output)
return output
@ -156,7 +154,6 @@ class QWenBlock(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@ -170,7 +167,6 @@ class QWenBlock(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# Fully Connected
@ -206,19 +202,16 @@ class QWenModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
residual = None
for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
residual,
)
hidden_states, _ = self.ln_f(hidden_states, residual)
@ -245,10 +238,9 @@ class QWenLMHeadModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -146,14 +146,12 @@ class YiAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
@ -195,7 +193,6 @@ class YiDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@ -209,7 +206,6 @@ class YiDecoderLayer(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# Fully Connected
@ -245,19 +241,16 @@ class YiModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -284,10 +277,9 @@ class YiForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(

View File

@ -1,8 +1,10 @@
import torch
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
is_custom_nccl_enabled_for_all_reduce,
)
@ -15,8 +17,12 @@ def tensor_model_parallel_all_reduce(input_):
if get_tensor_model_parallel_world_size() == 1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
if is_custom_nccl_enabled_for_all_reduce():
# TODO: support multiple parallel groups.
cupy_utils.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
return input_

View File

@ -0,0 +1,115 @@
"""CuPy utilities for all-reduce.
We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
CUDA graphs.
TODO: Remove this file when torch.distributed.all_reduce is fixed.
"""
import contextlib
import torch
from torch.distributed import ReduceOp
try:
import cupy
from cupyx.distributed import NCCLBackend
from cupy.cuda import nccl
except ImportError as e:
cupy = e
nccl = None
class NCCLBackend:
...
_OP_MAPPING = {
ReduceOp.SUM: "sum",
ReduceOp.PRODUCT: "prod",
ReduceOp.MIN: "min",
ReduceOp.MAX: "max",
}
class NCCLBackendWithBFloat16(NCCLBackend):
# This is enough to add bfloat16 support for most operations,
# but broadcast will fail (will require changes in compiled
# cupy code).
def _get_nccl_dtype_and_count(self, array, count=None):
nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
torch_dtype = getattr(array, "_torch_dtype", None)
if torch_dtype is torch.bfloat16:
nccl_dtype = nccl.NCCL_BFLOAT16
return nccl_dtype, count
_NCCL_BACKEND = None
_WORLD_SIZE = 0
def is_initialized() -> bool:
"""Returns whether the NCCL backend is initialized."""
return _NCCL_BACKEND is not None
@contextlib.contextmanager
def set_cupy_stream(stream: torch.cuda.Stream) -> None:
"""Set the cuda stream for communication"""
cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
stream.device_index)
with cupy_stream:
yield
def init_process_group(world_size: int, rank: int, host: str,
port: int) -> None:
"""Initializes the CuPy NCCL backend.
# TODO: handle NCCL timeouts.
"""
assert not is_initialized()
if isinstance(cupy, Exception):
raise ImportError(
"NCCLBackend is not available. Please install cupy.") from cupy
# TODO(woosuk): Create TP and PP process groups for CuPy.
global _NCCL_BACKEND
global _WORLD_SIZE
assert world_size > 0, f"{world_size=} should be a positive integer"
assert 0 <= rank < world_size, (
f"{rank=} should be a integer between [0, {world_size})")
cupy.cuda.runtime.setDevice(torch.cuda.current_device())
_NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
_WORLD_SIZE = world_size
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor"
# Hack to support bfloat16
torch_dtype = input_.dtype
if torch_dtype is torch.bfloat16:
# We need to view as float16, otherwise
# cupy will fail. This will not change
# the underlying data.
input_ = input_.view(torch.float16)
cupy_input = cupy.asarray(input_)
cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access
_NCCL_BACKEND.all_reduce(in_array=cupy_input,
out_array=cupy_input,
op=_OP_MAPPING[op])
def destroy_process_group() -> None:
"""Destroys the NCCL backend."""
global _NCCL_BACKEND
global _WORLD_SIZE
_NCCL_BACKEND = None
_WORLD_SIZE = 0
def get_world_size() -> int:
"""Returns the world size."""
return _WORLD_SIZE

View File

@ -3,9 +3,12 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
import contextlib
import torch
from vllm.model_executor.parallel_utils import cupy_utils
# Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Pipeline model parallel group that the current rank belongs to.
@ -177,3 +180,37 @@ def destroy_model_parallel():
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None
# Destroy the cupy states if any.
cupy_utils.destroy_process_group()
# Whether to use cupy for nccl all reduce.
# We use cupy for all reduce when using CUDA graph, because torch.distributed
# is not well supported by CUDA graph.
_ENABLE_CUPY_FOR_ALL_REDUCE = False
@contextlib.contextmanager
def with_custom_nccl_for_all_reduce():
"""use custom nccl instead of torch.distributed for all reduce"""
tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1:
# No-op.
# NOTE(woosuk): We don't initialize CuPy when tp_size is 1.
yield
else:
global _ENABLE_CUPY_FOR_ALL_REDUCE
old = _ENABLE_CUPY_FOR_ALL_REDUCE
_ENABLE_CUPY_FOR_ALL_REDUCE = True
stream = torch.cuda.current_stream()
with cupy_utils.set_cupy_stream(stream):
yield
_ENABLE_CUPY_FOR_ALL_REDUCE = old
def is_custom_nccl_enabled_for_all_reduce():
"""check if custom nccl is enabled for all reduce"""
global _ENABLE_CUPY_FOR_ALL_REDUCE
return _ENABLE_CUPY_FOR_ALL_REDUCE

View File

@ -1,4 +1,5 @@
import enum
import socket
import uuid
from platform import uname
@ -52,3 +53,9 @@ def random_uuid() -> str:
def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071
return "microsoft" in " ".join(uname()).lower()
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]

View File

@ -1,16 +1,25 @@
from typing import Dict, List, Optional, Tuple
import time
from typing import Dict, List, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils.parallel_state import (
with_custom_nccl_for_all_reduce)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor]
_PAD_SLOT_ID = -1
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class ModelRunner:
@ -32,12 +41,31 @@ class ModelRunner:
self.model = None
self.block_size = None # Set after initial profiling.
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
self.max_context_len_to_capture = (
self.model_config.max_context_len_to_capture
if self.model_config is not None else 0)
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables = None # Set after initial profiling.
def load_model(self) -> None:
self.model = get_model(self.model_config)
def set_block_size(self, block_size: int) -> None:
self.block_size = block_size
max_num_blocks = (self.max_context_len_to_capture + block_size -
1) // block_size
self.graph_block_tables = np.zeros(
(max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
@ -111,6 +139,7 @@ class ModelRunner:
max_context_len=None,
context_lens=None,
block_tables=None,
use_cuda_graph=False,
)
return input_tokens, input_positions, input_metadata
@ -154,27 +183,62 @@ class ModelRunner:
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
batch_size = len(input_tokens)
max_context_len = max(context_lens)
use_captured_graph = (
not self.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_context_len <= self.max_context_len_to_capture)
if use_captured_graph:
# Pad the input tokens, positions, and slot mapping to match the
# batch size of the captured graph.
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
for _ in range(graph_batch_size - batch_size):
input_tokens.append([])
input_positions.append([])
slot_mapping.append([])
context_lens.append(1)
block_tables.append([])
batch_size = graph_batch_size
# When using CUDA graph, we don't need to make the tensors on the GPU
# because they will be eventually copied to the designated GPU buffer.
device = "cpu" if use_captured_graph else "cuda"
input_tokens = _make_tensor_with_pad(input_tokens,
max_len=1,
pad=0,
dtype=torch.long)
dtype=torch.long,
device=device)
input_positions = _make_tensor_with_pad(input_positions,
max_len=1,
pad=0,
dtype=torch.long)
dtype=torch.long,
device=device)
slot_mapping = _make_tensor_with_pad(slot_mapping,
max_len=1,
pad=_PAD_SLOT_ID,
dtype=torch.long)
max_context_len = max(context_lens)
dtype=torch.long,
device=device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device="cuda")
max_block_table_len = max([len(t) for t in block_tables])
block_tables = _make_tensor_with_pad(block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int)
device=device)
if use_captured_graph:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.graph_block_tables[:batch_size]
for i, block_table in enumerate(block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.from_numpy(input_block_tables).to(device)
else:
block_tables = _make_tensor_with_pad(
block_tables,
max_len=max_context_len,
pad=0,
dtype=torch.int,
)
input_metadata = InputMetadata(
prompt_lens=[],
@ -182,6 +246,7 @@ class ModelRunner:
max_context_len=max_context_len,
context_lens=context_lens,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
return input_tokens, input_positions, input_metadata
@ -260,12 +325,11 @@ class ModelRunner:
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
cache_events: Optional[List[torch.cuda.Event]] = None,
) -> SamplerOutput:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
# Prepare input tensors.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
inputs = self._prepare_prompt(seq_group_metadata_list)
input_tokens, input_positions, input_metadata = inputs
@ -276,12 +340,16 @@ class ModelRunner:
input_metadata.prompt_lens)
# Execute the model.
hidden_states = self.model(
if input_metadata.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
model_executable = self.model
hidden_states = model_executable(
input_ids=input_tokens,
positions=input_positions,
kv_caches=kv_caches,
input_metadata=input_metadata,
cache_events=cache_events,
)
# Sample the next token.
@ -319,8 +387,139 @@ class ModelRunner:
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [(None, None)] * num_layers
self.execute_model(seqs, kv_caches)
torch.cuda.synchronize()
return
@torch.inference_mode()
def capture_model(self, kv_caches: List[KVCache]) -> None:
assert not self.model_config.enforce_eager
logger.info("Capturing the model for CUDA graphs. This may lead to "
"unexpected consequences if the model is not static. To "
"run the model in eager mode, set 'enforce_eager=True' or "
"use '--enforce-eager' in the CLI.")
start_time = time.perf_counter()
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, 1,
dtype=torch.long).cuda()
slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
slot_mapping.fill_(_PAD_SLOT_ID)
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE):
# Create dummy input_metadata.
input_metadata = InputMetadata(
prompt_lens=[],
slot_mapping=slot_mapping[:batch_size],
max_context_len=self.max_context_len_to_capture,
context_lens=context_lens[:batch_size],
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens[:batch_size],
input_positions[:batch_size],
kv_caches,
input_metadata,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
end_time = time.perf_counter()
elapsed_time = end_time - start_time
# This usually takes < 10 seconds.
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
class CUDAGraphRunner:
def __init__(self, model: nn.Module):
self.model = model
self.graph = None
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
def capture(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
memory_pool,
) -> None:
assert self.graph is None
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
with with_custom_nccl_for_all_reduce():
self.model(
input_ids,
positions,
kv_caches,
input_metadata,
)
torch.cuda.synchronize()
# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
with with_custom_nccl_for_all_reduce():
hidden_states = self.model(
input_ids,
positions,
kv_caches,
input_metadata,
)
torch.cuda.synchronize()
# Save the input and output buffers.
self.input_buffers = {
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": input_metadata.slot_mapping,
"context_lens": input_metadata.context_lens,
"block_tables": input_metadata.block_tables,
}
self.output_buffers = {"hidden_states": hidden_states}
return
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
input_metadata: InputMetadata,
) -> torch.Tensor:
# KV caches are fixed tensors, so we don't need to copy them.
del kv_caches
# Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids)
self.input_buffers["positions"].copy_(positions)
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping)
self.input_buffers["context_lens"].copy_(input_metadata.context_lens)
self.input_buffers["block_tables"].copy_(input_metadata.block_tables)
# Run the graph.
self.graph.replay()
# Return the output tensor.
return self.output_buffers["hidden_states"]
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len
@ -332,6 +531,16 @@ def _make_tensor_with_pad(
max_len: int,
pad: int,
dtype: torch.dtype,
device: Union[str, torch.device] = "cuda",
) -> torch.Tensor:
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, dtype=dtype, device="cuda")
return torch.tensor(padded_x, dtype=dtype, device=device)
def _get_graph_batch_size(batch_size: int) -> int:
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
else:
return (batch_size + 7) // 8 * 8

View File

@ -8,6 +8,7 @@ import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
@ -46,7 +47,7 @@ class Worker:
self.cache_events = None
self.gpu_cache = None
def init_model(self):
def init_model(self, cupy_port: Optional[int] = None):
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
# Env vars will be set by Ray.
@ -62,7 +63,7 @@ class Worker:
# Initialize the distributed environment.
_init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method)
cupy_port, self.distributed_init_method)
# Initialize the model.
set_random_seed(self.model_config.seed)
@ -100,10 +101,6 @@ class Worker:
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
torch.cuda.empty_cache()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
return num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig) -> None:
@ -114,6 +111,13 @@ class Worker:
self.gpu_cache = self.cache_engine.gpu_cache
self.model_runner.set_block_size(self.cache_engine.block_size)
def warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
@torch.inference_mode()
def execute_model(
self,
@ -136,21 +140,24 @@ class Worker:
cache_events = self.cache_events if issued_cache_op else None
# Wait for cache operations to finish.
# TODO(woosuk): Profile swapping overhead and optimize if needed.
if cache_events is not None:
for event in cache_events:
event.wait()
# If there is no input, we don't need to execute the model.
if not seq_group_metadata_list:
if cache_events is not None:
for event in cache_events:
event.wait()
return {}
output = self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache, cache_events)
self.gpu_cache)
return output
def _init_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
cupy_port: Optional[int],
distributed_init_method: Optional[str] = None,
) -> None:
"""Initialize the distributed environment."""
@ -173,8 +180,29 @@ def _init_distributed_environment(
init_method=distributed_init_method,
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
if cupy_utils.is_initialized():
cupy_world_size = cupy_utils.get_world_size()
if cupy_world_size != parallel_config.world_size:
raise RuntimeError(
"cupy.distributed is already initialized but the cupy world "
"size does not match parallel_config.world_size "
f"({cupy_world_size} vs. {parallel_config.world_size}).")
elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize CuPy process group when world size
# is 1.
# TODO(woosuk): Support multi-node connection.
cupy_utils.init_process_group(
world_size=parallel_config.world_size,
rank=rank,
host="localhost",
port=cupy_port,
)
if parallel_config.world_size > 1:
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
cupy_utils.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)