mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 01:26:03 +08:00
[Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734)
This commit is contained in:
parent
4e12131089
commit
e254497b66
17
examples/offline_inference_embedding.py
Normal file
17
examples/offline_inference_embedding.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create an LLM.
|
||||||
|
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True)
|
||||||
|
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||||
|
outputs = model.encode(prompts)
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
print(output.outputs.embedding) # list of 4096 floats
|
||||||
23
examples/openai_embedding_client.py
Normal file
23
examples/openai_embedding_client.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||||
|
openai_api_key = "EMPTY"
|
||||||
|
openai_api_base = "http://localhost:8000/v1"
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
models = client.models.list()
|
||||||
|
model = models.data[0].id
|
||||||
|
|
||||||
|
responses = client.embeddings.create(input=[
|
||||||
|
"Hello my name is",
|
||||||
|
"The best thing about vLLM is that it supports many different models"
|
||||||
|
],
|
||||||
|
model=model)
|
||||||
|
|
||||||
|
for data in responses.data:
|
||||||
|
print(data.embedding) # list of float of len 4096
|
||||||
@ -19,12 +19,15 @@ pytest-forked
|
|||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
pytest-rerunfailures
|
pytest-rerunfailures
|
||||||
pytest-shard
|
pytest-shard
|
||||||
httpx
|
|
||||||
|
# testing utils
|
||||||
|
awscli
|
||||||
einops # required for MPT
|
einops # required for MPT
|
||||||
|
httpx
|
||||||
|
peft
|
||||||
requests
|
requests
|
||||||
ray
|
ray
|
||||||
peft
|
sentence-transformers # required for embedding
|
||||||
awscli
|
|
||||||
|
|
||||||
# Benchmarking
|
# Benchmarking
|
||||||
aiohttp
|
aiohttp
|
||||||
|
|||||||
@ -133,6 +133,10 @@ _VISION_LANGUAGE_MODELS = {
|
|||||||
"llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration,
|
"llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_EMBEDDING_MODELS = [
|
||||||
|
"intfloat/e5-mistral-7b-instruct",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class HfRunner:
|
class HfRunner:
|
||||||
|
|
||||||
@ -145,14 +149,7 @@ class HfRunner:
|
|||||||
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
||||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
if model_name not in _VISION_LANGUAGE_MODELS:
|
if model_name in _VISION_LANGUAGE_MODELS:
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
trust_remote_code=True,
|
|
||||||
).cuda()
|
|
||||||
self.processor = None
|
|
||||||
else:
|
|
||||||
self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained(
|
self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
@ -162,6 +159,20 @@ class HfRunner:
|
|||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
|
elif model_name in _EMBEDDING_MODELS:
|
||||||
|
# Lazy init required for AMD CI
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
self.model = SentenceTransformer(
|
||||||
|
model_name,
|
||||||
|
device="cpu",
|
||||||
|
).to(dtype=torch_dtype).cuda()
|
||||||
|
else:
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=True,
|
||||||
|
).cuda()
|
||||||
|
self.processor = None
|
||||||
if tokenizer_name is None:
|
if tokenizer_name is None:
|
||||||
tokenizer_name = model_name
|
tokenizer_name = model_name
|
||||||
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
|
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
|
||||||
@ -334,6 +345,9 @@ class HfRunner:
|
|||||||
return [(output_ids, output_str, output_logprobs)
|
return [(output_ids, output_str, output_logprobs)
|
||||||
for output_ids, output_str, output_logprobs in outputs]
|
for output_ids, output_str, output_logprobs in outputs]
|
||||||
|
|
||||||
|
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
|
||||||
|
return self.model.encode(prompts)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
del self.model
|
del self.model
|
||||||
cleanup()
|
cleanup()
|
||||||
@ -459,6 +473,14 @@ class VllmRunner:
|
|||||||
outputs = self.generate(prompts, beam_search_params)
|
outputs = self.generate(prompts, beam_search_params)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def encode(self, prompts: List[str]) -> List[List[float]]:
|
||||||
|
req_outputs = self.model.encode(prompts)
|
||||||
|
outputs = []
|
||||||
|
for req_output in req_outputs:
|
||||||
|
embedding = req_output.outputs.embedding
|
||||||
|
outputs.append(embedding)
|
||||||
|
return outputs
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
del self.model
|
del self.model
|
||||||
cleanup()
|
cleanup()
|
||||||
|
|||||||
@ -9,8 +9,8 @@ from vllm.core.scheduler import Scheduler
|
|||||||
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
|
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
|
||||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
|
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||||
SequenceStatus)
|
SequenceOutput, SequenceStatus)
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
from vllm.utils import Counter
|
from vllm.utils import Counter
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
|
|||||||
new_token_ids = list(range(num_new_tokens))
|
new_token_ids = list(range(num_new_tokens))
|
||||||
|
|
||||||
outputs = [
|
outputs = [
|
||||||
SequenceGroupOutput(
|
CompletionSequenceGroupOutput(
|
||||||
samples=[
|
samples=[
|
||||||
SequenceOutput(
|
SequenceOutput(
|
||||||
parent_seq_id=seq.seq_id,
|
parent_seq_id=seq.seq_id,
|
||||||
@ -103,7 +103,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
|
|||||||
new_token_ids = list(range(num_new_tokens))
|
new_token_ids = list(range(num_new_tokens))
|
||||||
|
|
||||||
outputs = [
|
outputs = [
|
||||||
SequenceGroupOutput(
|
CompletionSequenceGroupOutput(
|
||||||
samples=[
|
samples=[
|
||||||
SequenceOutput(
|
SequenceOutput(
|
||||||
parent_seq_id=seq.seq_id,
|
parent_seq_id=seq.seq_id,
|
||||||
@ -170,7 +170,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
|
|||||||
new_token_ids[eos_index] = eos_token_id
|
new_token_ids[eos_index] = eos_token_id
|
||||||
|
|
||||||
outputs = [
|
outputs = [
|
||||||
SequenceGroupOutput(
|
CompletionSequenceGroupOutput(
|
||||||
samples=[
|
samples=[
|
||||||
SequenceOutput(
|
SequenceOutput(
|
||||||
parent_seq_id=seq.seq_id,
|
parent_seq_id=seq.seq_id,
|
||||||
@ -239,7 +239,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
|
|||||||
new_token_ids[eos_index] = eos_token_id
|
new_token_ids[eos_index] = eos_token_id
|
||||||
|
|
||||||
outputs = [
|
outputs = [
|
||||||
SequenceGroupOutput(
|
CompletionSequenceGroupOutput(
|
||||||
samples=[
|
samples=[
|
||||||
SequenceOutput(
|
SequenceOutput(
|
||||||
parent_seq_id=seq.seq_id,
|
parent_seq_id=seq.seq_id,
|
||||||
|
|||||||
@ -14,6 +14,7 @@ class MockModelConfig:
|
|||||||
tokenizer_mode = "auto"
|
tokenizer_mode = "auto"
|
||||||
max_model_len = 100
|
max_model_len = 100
|
||||||
tokenizer_revision = None
|
tokenizer_revision = None
|
||||||
|
embedding_mode = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
|
|||||||
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
||||||
# any model with a chat template should work here
|
# any model with a chat template should work here
|
||||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
|
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
||||||
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
|
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
|
||||||
# generation quality here
|
# generation quality here
|
||||||
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||||
@ -121,7 +122,7 @@ def zephyr_lora_files():
|
|||||||
return snapshot_download(repo_id=LORA_NAME)
|
return snapshot_download(repo_id=LORA_NAME)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="module")
|
||||||
def server(zephyr_lora_files):
|
def server(zephyr_lora_files):
|
||||||
ray.init()
|
ray.init()
|
||||||
server_runner = ServerRunner.remote([
|
server_runner = ServerRunner.remote([
|
||||||
@ -150,6 +151,25 @@ def server(zephyr_lora_files):
|
|||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def embedding_server(zephyr_lora_files):
|
||||||
|
ray.shutdown()
|
||||||
|
ray.init()
|
||||||
|
server_runner = ServerRunner.remote([
|
||||||
|
"--model",
|
||||||
|
EMBEDDING_MODEL_NAME,
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"8192",
|
||||||
|
"--enforce-eager",
|
||||||
|
])
|
||||||
|
ray.get(server_runner.ready.remote())
|
||||||
|
yield server_runner
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def client():
|
def client():
|
||||||
client = openai.AsyncOpenAI(
|
client = openai.AsyncOpenAI(
|
||||||
@ -890,5 +910,79 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
|
|||||||
or "less_than_equal" in exc_info.value.message)
|
or "less_than_equal" in exc_info.value.message)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[EMBEDDING_MODEL_NAME],
|
||||||
|
)
|
||||||
|
async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
|
input = [
|
||||||
|
"The chef prepared a delicious meal.",
|
||||||
|
]
|
||||||
|
|
||||||
|
# test single embedding
|
||||||
|
embeddings = await client.embeddings.create(
|
||||||
|
model=model_name,
|
||||||
|
input=input,
|
||||||
|
encoding_format="float",
|
||||||
|
)
|
||||||
|
assert embeddings.id is not None
|
||||||
|
assert embeddings.data is not None and len(embeddings.data) == 1
|
||||||
|
assert len(embeddings.data[0].embedding) == 4096
|
||||||
|
assert embeddings.usage.completion_tokens == 0
|
||||||
|
assert embeddings.usage.prompt_tokens == 9
|
||||||
|
assert embeddings.usage.total_tokens == 9
|
||||||
|
|
||||||
|
# test using token IDs
|
||||||
|
input = [1, 1, 1, 1, 1]
|
||||||
|
embeddings = await client.embeddings.create(
|
||||||
|
model=model_name,
|
||||||
|
input=input,
|
||||||
|
encoding_format="float",
|
||||||
|
)
|
||||||
|
assert embeddings.id is not None
|
||||||
|
assert embeddings.data is not None and len(embeddings.data) == 1
|
||||||
|
assert len(embeddings.data[0].embedding) == 4096
|
||||||
|
assert embeddings.usage.completion_tokens == 0
|
||||||
|
assert embeddings.usage.prompt_tokens == 5
|
||||||
|
assert embeddings.usage.total_tokens == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[EMBEDDING_MODEL_NAME],
|
||||||
|
)
|
||||||
|
async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
|
# test List[str]
|
||||||
|
inputs = [
|
||||||
|
"The cat sat on the mat.", "A feline was resting on a rug.",
|
||||||
|
"Stars twinkle brightly in the night sky."
|
||||||
|
]
|
||||||
|
embeddings = await client.embeddings.create(
|
||||||
|
model=model_name,
|
||||||
|
input=inputs,
|
||||||
|
encoding_format="float",
|
||||||
|
)
|
||||||
|
assert embeddings.id is not None
|
||||||
|
assert embeddings.data is not None and len(embeddings.data) == 3
|
||||||
|
assert len(embeddings.data[0].embedding) == 4096
|
||||||
|
|
||||||
|
# test List[List[int]]
|
||||||
|
inputs = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
||||||
|
[25, 32, 64, 77]]
|
||||||
|
embeddings = await client.embeddings.create(
|
||||||
|
model=model_name,
|
||||||
|
input=inputs,
|
||||||
|
encoding_format="float",
|
||||||
|
)
|
||||||
|
assert embeddings.id is not None
|
||||||
|
assert embeddings.data is not None and len(embeddings.data) == 4
|
||||||
|
assert len(embeddings.data[0].embedding) == 4096
|
||||||
|
assert embeddings.usage.completion_tokens == 0
|
||||||
|
assert embeddings.usage.prompt_tokens == 17
|
||||||
|
assert embeddings.usage.total_tokens == 17
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|||||||
44
tests/models/test_embedding.py
Normal file
44
tests/models/test_embedding.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
|
||||||
|
|
||||||
|
Run `pytest tests/models/test_llama_embedding.py`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"intfloat/e5-mistral-7b-instruct",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def compare_embeddings(embeddings1, embeddings2):
|
||||||
|
similarities = [
|
||||||
|
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0)
|
||||||
|
for e1, e2 in zip(embeddings1, embeddings2)
|
||||||
|
]
|
||||||
|
return similarities
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
def test_models(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
) -> None:
|
||||||
|
hf_model = hf_runner(model, dtype=dtype)
|
||||||
|
hf_outputs = hf_model.encode(example_prompts)
|
||||||
|
del hf_model
|
||||||
|
|
||||||
|
vllm_model = vllm_runner(model, dtype=dtype)
|
||||||
|
vllm_outputs = vllm_model.encode(example_prompts)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
similarities = compare_embeddings(hf_outputs, vllm_outputs)
|
||||||
|
all_similarities = torch.stack(similarities)
|
||||||
|
tolerance = 1e-2
|
||||||
|
assert torch.all((all_similarities <= 1.0 + tolerance)
|
||||||
|
& (all_similarities >= 1.0 - tolerance)
|
||||||
|
), f"Not all values are within {tolerance} of 1.0"
|
||||||
@ -36,14 +36,14 @@ def test_logits_processor_force_generate(
|
|||||||
# test logits_processors when prompt_logprobs is not None
|
# test logits_processors when prompt_logprobs is not None
|
||||||
vllm_model.model._add_request(
|
vllm_model.model._add_request(
|
||||||
prompt=example_prompts[0],
|
prompt=example_prompts[0],
|
||||||
sampling_params=params_with_logprobs,
|
params=params_with_logprobs,
|
||||||
prompt_token_ids=None,
|
prompt_token_ids=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# test prompt_logprobs is not None
|
# test prompt_logprobs is not None
|
||||||
vllm_model.model._add_request(
|
vllm_model.model._add_request(
|
||||||
prompt=example_prompts[1],
|
prompt=example_prompts[1],
|
||||||
sampling_params=SamplingParams(
|
params=SamplingParams(
|
||||||
prompt_logprobs=3,
|
prompt_logprobs=3,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
),
|
),
|
||||||
@ -53,7 +53,7 @@ def test_logits_processor_force_generate(
|
|||||||
# test grouped requests
|
# test grouped requests
|
||||||
vllm_model.model._add_request(
|
vllm_model.model._add_request(
|
||||||
prompt=example_prompts[2],
|
prompt=example_prompts[2],
|
||||||
sampling_params=SamplingParams(max_tokens=max_tokens),
|
params=SamplingParams(max_tokens=max_tokens),
|
||||||
prompt_token_ids=None,
|
prompt_token_ids=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -60,7 +60,7 @@ def test_random_sample_with_seed(
|
|||||||
llm._add_request(
|
llm._add_request(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_token_ids=None,
|
prompt_token_ids=None,
|
||||||
sampling_params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
results = llm._run_engine(use_tqdm=False)
|
results = llm._run_engine(use_tqdm=False)
|
||||||
|
|||||||
@ -7,8 +7,8 @@ import torch
|
|||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import (Logprob, SamplerOutput, SequenceData,
|
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||||
SequenceGroupMetadata, SequenceGroupOutput,
|
SamplerOutput, SequenceData, SequenceGroupMetadata,
|
||||||
SequenceOutput)
|
SequenceOutput)
|
||||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
@ -170,7 +170,7 @@ def create_sampler_output_list(
|
|||||||
|
|
||||||
return [
|
return [
|
||||||
SamplerOutput(outputs=[
|
SamplerOutput(outputs=[
|
||||||
SequenceGroupOutput(
|
CompletionSequenceGroupOutput(
|
||||||
samples=[
|
samples=[
|
||||||
SequenceOutput(
|
SequenceOutput(
|
||||||
output_token=token_id,
|
output_token=token_id,
|
||||||
|
|||||||
@ -1,14 +1,14 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.core.utils import create_dummy_prompt
|
from tests.core.utils import create_dummy_prompt
|
||||||
from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput,
|
from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput,
|
||||||
SequenceOutput)
|
SequenceData, SequenceOutput)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_outputs():
|
def sample_outputs():
|
||||||
return [
|
return [
|
||||||
SequenceGroupOutput(samples=[
|
CompletionSequenceGroupOutput(samples=[
|
||||||
SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
|
SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
|
||||||
],
|
],
|
||||||
prompt_logprobs=None) for i in range(5)
|
prompt_logprobs=None) for i in range(5)
|
||||||
@ -32,7 +32,7 @@ def test_sampler_output_getitem(sampler_output, sample_outputs):
|
|||||||
|
|
||||||
|
|
||||||
def test_sampler_output_setitem(sampler_output):
|
def test_sampler_output_setitem(sampler_output):
|
||||||
new_output = SequenceGroupOutput(samples=[
|
new_output = CompletionSequenceGroupOutput(samples=[
|
||||||
SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
|
SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
|
||||||
],
|
],
|
||||||
prompt_logprobs=None)
|
prompt_logprobs=None)
|
||||||
|
|||||||
@ -6,7 +6,9 @@ from vllm.engine.llm_engine import LLMEngine
|
|||||||
from vllm.entrypoints.llm import LLM
|
from vllm.entrypoints.llm import LLM
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
|
||||||
|
EmbeddingRequestOutput, RequestOutput)
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
__version__ = "0.4.2"
|
__version__ = "0.4.2"
|
||||||
@ -17,9 +19,12 @@ __all__ = [
|
|||||||
"SamplingParams",
|
"SamplingParams",
|
||||||
"RequestOutput",
|
"RequestOutput",
|
||||||
"CompletionOutput",
|
"CompletionOutput",
|
||||||
|
"EmbeddingOutput",
|
||||||
|
"EmbeddingRequestOutput",
|
||||||
"LLMEngine",
|
"LLMEngine",
|
||||||
"EngineArgs",
|
"EngineArgs",
|
||||||
"AsyncLLMEngine",
|
"AsyncLLMEngine",
|
||||||
"AsyncEngineArgs",
|
"AsyncEngineArgs",
|
||||||
"initialize_ray_cluster",
|
"initialize_ray_cluster",
|
||||||
|
"PoolingParams",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from transformers import PretrainedConfig
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
||||||
get_quantization_config)
|
get_quantization_config)
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||||
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron
|
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron
|
||||||
|
|
||||||
@ -22,6 +23,7 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_GB = 1 << 30
|
_GB = 1 << 30
|
||||||
|
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
@ -126,6 +128,7 @@ class ModelConfig:
|
|||||||
served_model_name)
|
served_model_name)
|
||||||
if not self.skip_tokenizer_init:
|
if not self.skip_tokenizer_init:
|
||||||
self._verify_tokenizer_mode()
|
self._verify_tokenizer_mode()
|
||||||
|
self._verify_embedding_mode()
|
||||||
self._verify_quantization()
|
self._verify_quantization()
|
||||||
self._verify_cuda_graph()
|
self._verify_cuda_graph()
|
||||||
|
|
||||||
@ -137,6 +140,11 @@ class ModelConfig:
|
|||||||
"either 'auto' or 'slow'.")
|
"either 'auto' or 'slow'.")
|
||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
|
|
||||||
|
def _verify_embedding_mode(self) -> None:
|
||||||
|
architectures = getattr(self.hf_config, "architectures", [])
|
||||||
|
self.embedding_mode = any(
|
||||||
|
ModelRegistry.is_embedding_model(arch) for arch in architectures)
|
||||||
|
|
||||||
def _verify_quantization(self) -> None:
|
def _verify_quantization(self) -> None:
|
||||||
supported_quantization = [*QUANTIZATION_METHODS]
|
supported_quantization = [*QUANTIZATION_METHODS]
|
||||||
rocm_supported_quantization = ["gptq", "squeezellm"]
|
rocm_supported_quantization = ["gptq", "squeezellm"]
|
||||||
@ -591,6 +599,7 @@ class SchedulerConfig:
|
|||||||
prompt latency) before scheduling next prompt.
|
prompt latency) before scheduling next prompt.
|
||||||
enable_chunked_prefill: If True, prefill requests can be chunked based
|
enable_chunked_prefill: If True, prefill requests can be chunked based
|
||||||
on the remaining max_num_batched_tokens.
|
on the remaining max_num_batched_tokens.
|
||||||
|
embedding_mode: Whether the running model is for embedding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -602,6 +611,7 @@ class SchedulerConfig:
|
|||||||
num_lookahead_slots: int = 0,
|
num_lookahead_slots: int = 0,
|
||||||
delay_factor: float = 0.0,
|
delay_factor: float = 0.0,
|
||||||
enable_chunked_prefill: bool = False,
|
enable_chunked_prefill: bool = False,
|
||||||
|
embedding_mode: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if max_num_batched_tokens is not None:
|
if max_num_batched_tokens is not None:
|
||||||
self.max_num_batched_tokens = max_num_batched_tokens
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
@ -610,6 +620,10 @@ class SchedulerConfig:
|
|||||||
# It is the values that have the best balance between ITL
|
# It is the values that have the best balance between ITL
|
||||||
# and TTFT on A100. Note it is not optimized for throughput.
|
# and TTFT on A100. Note it is not optimized for throughput.
|
||||||
self.max_num_batched_tokens = 512
|
self.max_num_batched_tokens = 512
|
||||||
|
elif embedding_mode:
|
||||||
|
# For embedding, choose specific value for higher throughput
|
||||||
|
self.max_num_batched_tokens = max(
|
||||||
|
max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS)
|
||||||
else:
|
else:
|
||||||
# If max_model_len is too short, use 2048 as the default value
|
# If max_model_len is too short, use 2048 as the default value
|
||||||
# for higher throughput.
|
# for higher throughput.
|
||||||
@ -623,6 +637,7 @@ class SchedulerConfig:
|
|||||||
self.num_lookahead_slots = num_lookahead_slots
|
self.num_lookahead_slots = num_lookahead_slots
|
||||||
self.delay_factor = delay_factor
|
self.delay_factor = delay_factor
|
||||||
self.chunked_prefill_enabled = enable_chunked_prefill
|
self.chunked_prefill_enabled = enable_chunked_prefill
|
||||||
|
self.embedding_mode = embedding_mode
|
||||||
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
|
|||||||
84
vllm/core/embedding_model_block_manager.py
Normal file
84
vllm/core/embedding_model_block_manager.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||||
|
from vllm.sequence import Sequence, SequenceGroup
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
|
||||||
|
"""An embedding version of BlockSpaceManager for use in environments
|
||||||
|
with embedding models where block management is not required.
|
||||||
|
|
||||||
|
This class provides the same interface as BlockSpaceManager, but its
|
||||||
|
methods perform no actions or return simple values like True in specific
|
||||||
|
actions. It's designed to be used in scenarios where the overhead of
|
||||||
|
block management is unnecessary, such as in an embedding environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||||
|
# Always return OK for dummy purposes
|
||||||
|
return AllocStatus.OK
|
||||||
|
|
||||||
|
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||||
|
# No actual allocation logic needed
|
||||||
|
pass
|
||||||
|
|
||||||
|
def can_append_slots(self, seq_group: SequenceGroup,
|
||||||
|
num_lookahead_slots: int) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def append_slots(
|
||||||
|
self,
|
||||||
|
seq: Sequence,
|
||||||
|
num_lookahead_slots: int,
|
||||||
|
) -> List[Tuple[int, int]]:
|
||||||
|
return None # type: ignore
|
||||||
|
|
||||||
|
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def can_swap_in(self, seq_group: SequenceGroup,
|
||||||
|
num_lookahead_slots: int) -> AllocStatus:
|
||||||
|
return AllocStatus.OK
|
||||||
|
|
||||||
|
def swap_in(self, seq_group: SequenceGroup,
|
||||||
|
num_lookahead_slots: int) -> List[Tuple[int, int]]:
|
||||||
|
return None # type: ignore
|
||||||
|
|
||||||
|
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||||
|
return None # type: ignore
|
||||||
|
|
||||||
|
def free(self, seq: Sequence) -> None:
|
||||||
|
# No operation on free
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||||
|
return None # type: ignore
|
||||||
|
|
||||||
|
def get_num_free_gpu_blocks(self) -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def get_num_free_cpu_blocks(self) -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def access_all_blocks_in_seq(
|
||||||
|
self,
|
||||||
|
seq: Sequence,
|
||||||
|
access_time: float,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_common_computed_block_ids(self,
|
||||||
|
seq_group: SequenceGroup) -> List[int]:
|
||||||
|
return None # type: ignore
|
||||||
|
|
||||||
|
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||||
|
pass
|
||||||
@ -35,6 +35,11 @@ class BlockSpaceManager(ABC):
|
|||||||
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
|
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
|
||||||
return BlockSpaceManagerV2
|
return BlockSpaceManagerV2
|
||||||
|
|
||||||
|
if version == "embedding":
|
||||||
|
from vllm.core.embedding_model_block_manager import (
|
||||||
|
EmbeddingModelBlockSpaceManager)
|
||||||
|
return EmbeddingModelBlockSpaceManager
|
||||||
|
|
||||||
raise ValueError(f"Unknown version {version=}")
|
raise ValueError(f"Unknown version {version=}")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@ -270,9 +270,14 @@ class Scheduler:
|
|||||||
self.scheduler_config.max_model_len,
|
self.scheduler_config.max_model_len,
|
||||||
self.scheduler_config.max_num_batched_tokens)
|
self.scheduler_config.max_num_batched_tokens)
|
||||||
|
|
||||||
|
version = "v1"
|
||||||
|
if self.scheduler_config.use_v2_block_manager:
|
||||||
|
version = "v2"
|
||||||
|
if self.scheduler_config.embedding_mode:
|
||||||
|
version = "embedding"
|
||||||
|
|
||||||
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
|
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
|
||||||
version="v2" if self.scheduler_config.
|
version)
|
||||||
use_v2_block_manager else "v1")
|
|
||||||
|
|
||||||
# Create the block space manager.
|
# Create the block space manager.
|
||||||
self.block_manager = BlockSpaceManagerImpl(
|
self.block_manager = BlockSpaceManagerImpl(
|
||||||
@ -968,6 +973,7 @@ class Scheduler:
|
|||||||
sampling_params=seq_group.sampling_params,
|
sampling_params=seq_group.sampling_params,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
|
pooling_params=seq_group.pooling_params,
|
||||||
token_chunk_size=token_chunk_size,
|
token_chunk_size=token_chunk_size,
|
||||||
lora_request=seq_group.lora_request,
|
lora_request=seq_group.lora_request,
|
||||||
computed_block_nums=common_computed_block_nums,
|
computed_block_nums=common_computed_block_nums,
|
||||||
|
|||||||
@ -574,6 +574,7 @@ class EngineArgs:
|
|||||||
speculative_config.num_lookahead_slots),
|
speculative_config.num_lookahead_slots),
|
||||||
delay_factor=self.scheduler_delay_factor,
|
delay_factor=self.scheduler_delay_factor,
|
||||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||||
|
embedding_mode=model_config.embedding_mode,
|
||||||
)
|
)
|
||||||
lora_config = LoRAConfig(
|
lora_config = LoRAConfig(
|
||||||
max_lora_rank=self.max_lora_rank,
|
max_lora_rank=self.max_lora_rank,
|
||||||
|
|||||||
@ -14,7 +14,8 @@ from vllm.engine.llm_engine import LLMEngine
|
|||||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
@ -47,15 +48,16 @@ def _raise_exception_on_finish(
|
|||||||
|
|
||||||
|
|
||||||
class AsyncStream:
|
class AsyncStream:
|
||||||
"""A stream of RequestOutputs for a request that can be
|
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
|
||||||
iterated over asynchronously."""
|
that can be iterated over asynchronously."""
|
||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
def __init__(self, request_id: str) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self._queue: asyncio.Queue = asyncio.Queue()
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
self._finished = False
|
self._finished = False
|
||||||
|
|
||||||
def put(self, item: Union[RequestOutput, Exception]) -> None:
|
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
|
||||||
|
Exception]) -> None:
|
||||||
if self._finished:
|
if self._finished:
|
||||||
return
|
return
|
||||||
self._queue.put_nowait(item)
|
self._queue.put_nowait(item)
|
||||||
@ -71,7 +73,7 @@ class AsyncStream:
|
|||||||
def __aiter__(self):
|
def __aiter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __anext__(self) -> RequestOutput:
|
async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
|
||||||
result = await self._queue.get()
|
result = await self._queue.get()
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
raise result
|
raise result
|
||||||
@ -108,7 +110,8 @@ class RequestTracker:
|
|||||||
self.abort_request(rid)
|
self.abort_request(rid)
|
||||||
|
|
||||||
def process_request_output(self,
|
def process_request_output(self,
|
||||||
request_output: RequestOutput,
|
request_output: Union[RequestOutput,
|
||||||
|
EmbeddingRequestOutput],
|
||||||
*,
|
*,
|
||||||
verbose: bool = False) -> None:
|
verbose: bool = False) -> None:
|
||||||
"""Process a request output from the engine."""
|
"""Process a request output from the engine."""
|
||||||
@ -196,7 +199,8 @@ class RequestTracker:
|
|||||||
class _AsyncLLMEngine(LLMEngine):
|
class _AsyncLLMEngine(LLMEngine):
|
||||||
"""Extension of LLMEngine to add async methods."""
|
"""Extension of LLMEngine to add async methods."""
|
||||||
|
|
||||||
async def step_async(self) -> List[RequestOutput]:
|
async def step_async(
|
||||||
|
self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||||
"""Performs one decoding iteration and returns newly generated results.
|
"""Performs one decoding iteration and returns newly generated results.
|
||||||
The workers are ran asynchronously if possible.
|
The workers are ran asynchronously if possible.
|
||||||
|
|
||||||
@ -251,7 +255,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
prompt: Optional[str],
|
prompt: Optional[str],
|
||||||
sampling_params: SamplingParams,
|
params: Union[SamplingParams, PoolingParams],
|
||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -270,8 +274,8 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
|
|
||||||
return self.add_request(request_id,
|
return self.add_request(request_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
params=params,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
sampling_params=sampling_params,
|
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
multi_modal_data=multi_modal_data)
|
multi_modal_data=multi_modal_data)
|
||||||
@ -511,7 +515,7 @@ class AsyncLLMEngine:
|
|||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
prompt: Optional[str],
|
prompt: Optional[str],
|
||||||
sampling_params: SamplingParams,
|
params: Union[SamplingParams, PoolingParams],
|
||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -528,9 +532,9 @@ class AsyncLLMEngine:
|
|||||||
max_log_len]
|
max_log_len]
|
||||||
logger.info(
|
logger.info(
|
||||||
"Received request %s: prompt: %r, "
|
"Received request %s: prompt: %r, "
|
||||||
"sampling_params: %s, prompt_token_ids: %s, "
|
"params: %s, prompt_token_ids: %s, "
|
||||||
"lora_request: %s.", request_id, shortened_prompt,
|
"lora_request: %s.", request_id, shortened_prompt, params,
|
||||||
sampling_params, shortened_token_ids, lora_request)
|
shortened_token_ids, lora_request)
|
||||||
|
|
||||||
if not self.is_running:
|
if not self.is_running:
|
||||||
if self.start_engine_loop:
|
if self.start_engine_loop:
|
||||||
@ -562,7 +566,7 @@ class AsyncLLMEngine:
|
|||||||
stream = self._request_tracker.add_request(
|
stream = self._request_tracker.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
sampling_params=sampling_params,
|
params=params,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
@ -597,8 +601,8 @@ class AsyncLLMEngine:
|
|||||||
multi_modal_data: Multi modal data per request.
|
multi_modal_data: Multi modal data per request.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
The output `RequestOutput` objects from the LLMEngine for the
|
The output `RequestOutput` objects from the LLMEngine
|
||||||
request.
|
for the request.
|
||||||
|
|
||||||
Details:
|
Details:
|
||||||
- If the engine is not running, start the background loop,
|
- If the engine is not running, start the background loop,
|
||||||
@ -643,25 +647,123 @@ class AsyncLLMEngine:
|
|||||||
>>> # Process and return the final output
|
>>> # Process and return the final output
|
||||||
>>> ...
|
>>> ...
|
||||||
"""
|
"""
|
||||||
# Preprocess the request.
|
async for output in self.process_request(
|
||||||
arrival_time = time.time()
|
|
||||||
|
|
||||||
try:
|
|
||||||
stream = await self.add_request(
|
|
||||||
request_id,
|
request_id,
|
||||||
prompt,
|
prompt,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
|
prompt_token_ids,
|
||||||
|
lora_request,
|
||||||
|
multi_modal_data,
|
||||||
|
):
|
||||||
|
yield output
|
||||||
|
|
||||||
|
async def encode(
|
||||||
|
self,
|
||||||
|
prompt: Optional[str],
|
||||||
|
pooling_params: PoolingParams,
|
||||||
|
request_id: str,
|
||||||
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None
|
||||||
|
) -> AsyncIterator[EmbeddingRequestOutput]:
|
||||||
|
"""Generate outputs for a request from an embedding model.
|
||||||
|
|
||||||
|
Generate outputs for a request. This method is a coroutine. It adds the
|
||||||
|
request into the waiting queue of the LLMEngine and streams the outputs
|
||||||
|
from the LLMEngine to the caller.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||||
|
provided.
|
||||||
|
pooling_params: The pooling parameters of the request.
|
||||||
|
request_id: The unique id of the request.
|
||||||
|
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||||
|
use the tokenizer to convert the prompts to token IDs.
|
||||||
|
lora_request: LoRA request to use for generation, if any.
|
||||||
|
multi_modal_data: Multi modal data per request.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
||||||
|
for the request.
|
||||||
|
|
||||||
|
Details:
|
||||||
|
- If the engine is not running, start the background loop,
|
||||||
|
which iteratively invokes
|
||||||
|
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
|
||||||
|
to process the waiting requests.
|
||||||
|
- Add the request to the engine's `RequestTracker`.
|
||||||
|
On the next background loop, this request will be sent to
|
||||||
|
the underlying engine.
|
||||||
|
Also, a corresponding `AsyncStream` will be created.
|
||||||
|
- Wait for the request outputs from `AsyncStream` and yield them.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # Please refer to entrypoints/api_server.py for
|
||||||
|
>>> # the complete example.
|
||||||
|
>>>
|
||||||
|
>>> # initialize the engine and the example input
|
||||||
|
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
>>> example_input = {
|
||||||
|
>>> "input": "What is LLM?",
|
||||||
|
>>> "request_id": 0,
|
||||||
|
>>> }
|
||||||
|
>>>
|
||||||
|
>>> # start the generation
|
||||||
|
>>> results_generator = engine.encode(
|
||||||
|
>>> example_input["input"],
|
||||||
|
>>> PoolingParams(),
|
||||||
|
>>> example_input["request_id"])
|
||||||
|
>>>
|
||||||
|
>>> # get the results
|
||||||
|
>>> final_output = None
|
||||||
|
>>> async for request_output in results_generator:
|
||||||
|
>>> if await request.is_disconnected():
|
||||||
|
>>> # Abort the request if the client disconnects.
|
||||||
|
>>> await engine.abort(request_id)
|
||||||
|
>>> # Return or raise an error
|
||||||
|
>>> ...
|
||||||
|
>>> final_output = request_output
|
||||||
|
>>>
|
||||||
|
>>> # Process and return the final output
|
||||||
|
>>> ...
|
||||||
|
"""
|
||||||
|
async for output in self.process_request(
|
||||||
|
request_id,
|
||||||
|
prompt,
|
||||||
|
pooling_params,
|
||||||
|
prompt_token_ids,
|
||||||
|
lora_request,
|
||||||
|
multi_modal_data,
|
||||||
|
):
|
||||||
|
yield output
|
||||||
|
|
||||||
|
async def process_request(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: Optional[str],
|
||||||
|
params: Union[SamplingParams, PoolingParams],
|
||||||
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
|
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||||
|
"""Common logic to process requests with SamplingParams or
|
||||||
|
PoolingParams."""
|
||||||
|
arrival_time = time.time()
|
||||||
|
|
||||||
|
stream = await self.add_request(
|
||||||
|
request_id,
|
||||||
|
prompt,
|
||||||
|
params,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
multi_modal_data=multi_modal_data,
|
multi_modal_data=multi_modal_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
async for request_output in stream:
|
async for request_output in stream:
|
||||||
yield request_output
|
yield request_output
|
||||||
except (Exception, asyncio.CancelledError) as e:
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
# If there is an exception or coroutine is cancelled, abort the
|
|
||||||
# request.
|
|
||||||
self._abort(request_id)
|
self._abort(request_id)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|||||||
@ -20,9 +20,12 @@ from vllm.executor.executor_base import ExecutorBase
|
|||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||||
|
RequestOutputFactory)
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput,
|
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||||
|
MultiModalData, PoolerOutput, SamplerOutput,
|
||||||
Sequence, SequenceGroup, SequenceGroupMetadata,
|
Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||||
SequenceStatus)
|
SequenceStatus)
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
@ -169,6 +172,7 @@ class LLMEngine:
|
|||||||
load_config=load_config,
|
load_config=load_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.model_config.embedding_mode:
|
||||||
self._initialize_kv_caches()
|
self._initialize_kv_caches()
|
||||||
|
|
||||||
# If usage stat is enabled, collect relevant info.
|
# If usage stat is enabled, collect relevant info.
|
||||||
@ -354,7 +358,7 @@ class LLMEngine:
|
|||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
prompt: Optional[str],
|
prompt: Optional[str],
|
||||||
sampling_params: SamplingParams,
|
params: Union[SamplingParams, PoolingParams],
|
||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -370,7 +374,8 @@ class LLMEngine:
|
|||||||
request_id: The unique ID of the request.
|
request_id: The unique ID of the request.
|
||||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||||
provided.
|
provided.
|
||||||
sampling_params: The sampling parameters for text generation.
|
params: Parameters for sampling or pooling. SamplingParams
|
||||||
|
for text generation. PoolingParams for pooling.
|
||||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||||
use the tokenizer to convert the prompts to token IDs.
|
use the tokenizer to convert the prompts to token IDs.
|
||||||
arrival_time: The arrival time of the request. If None, we use
|
arrival_time: The arrival time of the request. If None, we use
|
||||||
@ -404,13 +409,6 @@ class LLMEngine:
|
|||||||
if lora_request is not None and not self.lora_config:
|
if lora_request is not None and not self.lora_config:
|
||||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||||
"not enabled!")
|
"not enabled!")
|
||||||
max_logprobs = self.get_model_config().max_logprobs
|
|
||||||
if (sampling_params.logprobs
|
|
||||||
and sampling_params.logprobs > max_logprobs) or (
|
|
||||||
sampling_params.prompt_logprobs
|
|
||||||
and sampling_params.prompt_logprobs > max_logprobs):
|
|
||||||
raise ValueError(f"Cannot request more than "
|
|
||||||
f"{max_logprobs} logprobs.")
|
|
||||||
if arrival_time is None:
|
if arrival_time is None:
|
||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
prompt_token_ids = self.encode_request(
|
prompt_token_ids = self.encode_request(
|
||||||
@ -432,6 +430,50 @@ class LLMEngine:
|
|||||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
||||||
eos_token_id, lora_request)
|
eos_token_id, lora_request)
|
||||||
|
|
||||||
|
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
||||||
|
if isinstance(params, SamplingParams):
|
||||||
|
seq_group = self._create_sequence_group_with_sampling(
|
||||||
|
request_id,
|
||||||
|
seq,
|
||||||
|
params,
|
||||||
|
arrival_time,
|
||||||
|
lora_request,
|
||||||
|
multi_modal_data,
|
||||||
|
)
|
||||||
|
elif isinstance(params, PoolingParams):
|
||||||
|
seq_group = self._create_sequence_group_with_pooling(
|
||||||
|
request_id,
|
||||||
|
seq,
|
||||||
|
params,
|
||||||
|
arrival_time,
|
||||||
|
lora_request,
|
||||||
|
multi_modal_data,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Either SamplingParams or PoolingParams must be provided.")
|
||||||
|
|
||||||
|
# Add the sequence group to the scheduler.
|
||||||
|
self.scheduler.add_seq_group(seq_group)
|
||||||
|
|
||||||
|
def _create_sequence_group_with_sampling(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
seq: Sequence,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
arrival_time: Optional[float] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
|
) -> SequenceGroup:
|
||||||
|
"""Creates a SequenceGroup with SamplingParams."""
|
||||||
|
max_logprobs = self.get_model_config().max_logprobs
|
||||||
|
if (sampling_params.logprobs
|
||||||
|
and sampling_params.logprobs > max_logprobs) or (
|
||||||
|
sampling_params.prompt_logprobs
|
||||||
|
and sampling_params.prompt_logprobs > max_logprobs):
|
||||||
|
raise ValueError(f"Cannot request more than "
|
||||||
|
f"{max_logprobs} logprobs.")
|
||||||
|
|
||||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||||
# this doesn't deep-copy LogitsProcessor objects
|
# this doesn't deep-copy LogitsProcessor objects
|
||||||
sampling_params = sampling_params.clone()
|
sampling_params = sampling_params.clone()
|
||||||
@ -443,11 +485,35 @@ class LLMEngine:
|
|||||||
self.generation_config_fields)
|
self.generation_config_fields)
|
||||||
|
|
||||||
# Create the sequence group.
|
# Create the sequence group.
|
||||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
seq_group = SequenceGroup(request_id=request_id,
|
||||||
arrival_time, lora_request, multi_modal_data)
|
seqs=[seq],
|
||||||
|
arrival_time=arrival_time,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
multi_modal_data=multi_modal_data)
|
||||||
|
|
||||||
# Add the sequence group to the scheduler.
|
return seq_group
|
||||||
self.scheduler.add_seq_group(seq_group)
|
|
||||||
|
def _create_sequence_group_with_pooling(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
seq: Sequence,
|
||||||
|
pooling_params: PoolingParams,
|
||||||
|
arrival_time: Optional[float] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
|
) -> SequenceGroup:
|
||||||
|
"""Creates a SequenceGroup with PoolingParams."""
|
||||||
|
# Defensive copy of PoolingParams, which are used by the pooler
|
||||||
|
pooling_params = pooling_params.clone()
|
||||||
|
# Create the sequence group.
|
||||||
|
seq_group = SequenceGroup(request_id=request_id,
|
||||||
|
seqs=[seq],
|
||||||
|
arrival_time=arrival_time,
|
||||||
|
lora_request=lora_request,
|
||||||
|
multi_modal_data=multi_modal_data,
|
||||||
|
pooling_params=pooling_params)
|
||||||
|
return seq_group
|
||||||
|
|
||||||
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||||
"""Aborts a request(s) with the given ID.
|
"""Aborts a request(s) with the given ID.
|
||||||
@ -484,13 +550,25 @@ class LLMEngine:
|
|||||||
"""Returns True if there are unfinished requests."""
|
"""Returns True if there are unfinished requests."""
|
||||||
return self.scheduler.has_unfinished_seqs()
|
return self.scheduler.has_unfinished_seqs()
|
||||||
|
|
||||||
|
def _process_sequence_group_outputs(
|
||||||
|
self,
|
||||||
|
seq_group: SequenceGroup,
|
||||||
|
outputs: List[EmbeddingSequenceGroupOutput],
|
||||||
|
) -> None:
|
||||||
|
seq_group.embeddings = outputs[0].embeddings
|
||||||
|
|
||||||
|
for seq in seq_group.get_seqs():
|
||||||
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
def _process_model_outputs(
|
def _process_model_outputs(
|
||||||
self,
|
self,
|
||||||
output: List[SamplerOutput],
|
output: List[Union[SamplerOutput, PoolerOutput]],
|
||||||
scheduled_seq_groups: List[ScheduledSequenceGroup],
|
scheduled_seq_groups: List[ScheduledSequenceGroup],
|
||||||
ignored_seq_groups: List[SequenceGroup],
|
ignored_seq_groups: List[SequenceGroup],
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> List[RequestOutput]:
|
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||||
|
|
||||||
Returns RequestOutputs that can be returned to the client.
|
Returns RequestOutputs that can be returned to the client.
|
||||||
@ -510,6 +588,9 @@ class LLMEngine:
|
|||||||
seq_group = scheduled_seq_group.seq_group
|
seq_group = scheduled_seq_group.seq_group
|
||||||
seq_group.update_num_computed_tokens(
|
seq_group.update_num_computed_tokens(
|
||||||
scheduled_seq_group.token_chunk_size)
|
scheduled_seq_group.token_chunk_size)
|
||||||
|
if self.model_config.embedding_mode:
|
||||||
|
self._process_sequence_group_outputs(seq_group, outputs)
|
||||||
|
continue
|
||||||
|
|
||||||
self.output_processor.process_prompt_logprob(seq_group, outputs)
|
self.output_processor.process_prompt_logprob(seq_group, outputs)
|
||||||
if seq_group_meta.do_sample:
|
if seq_group_meta.do_sample:
|
||||||
@ -519,18 +600,19 @@ class LLMEngine:
|
|||||||
self.scheduler.free_finished_seq_groups()
|
self.scheduler.free_finished_seq_groups()
|
||||||
|
|
||||||
# Create the outputs.
|
# Create the outputs.
|
||||||
request_outputs: List[RequestOutput] = []
|
request_outputs: List[Union[RequestOutput,
|
||||||
|
EmbeddingRequestOutput]] = []
|
||||||
for scheduled_seq_group in scheduled_seq_groups:
|
for scheduled_seq_group in scheduled_seq_groups:
|
||||||
seq_group = scheduled_seq_group.seq_group
|
seq_group = scheduled_seq_group.seq_group
|
||||||
seq_group.maybe_set_first_token_time(now)
|
seq_group.maybe_set_first_token_time(now)
|
||||||
request_output = RequestOutput.from_seq_group(seq_group)
|
request_output = RequestOutputFactory.create(seq_group)
|
||||||
request_outputs.append(request_output)
|
request_outputs.append(request_output)
|
||||||
for seq_group in ignored_seq_groups:
|
for seq_group in ignored_seq_groups:
|
||||||
request_output = RequestOutput.from_seq_group(seq_group)
|
request_output = RequestOutputFactory.create(seq_group)
|
||||||
request_outputs.append(request_output)
|
request_outputs.append(request_output)
|
||||||
return request_outputs
|
return request_outputs
|
||||||
|
|
||||||
def step(self) -> List[RequestOutput]:
|
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||||
"""Performs one decoding iteration and returns newly generated results.
|
"""Performs one decoding iteration and returns newly generated results.
|
||||||
|
|
||||||
.. figure:: https://i.imgur.com/sv2HssD.png
|
.. figure:: https://i.imgur.com/sv2HssD.png
|
||||||
@ -570,7 +652,7 @@ class LLMEngine:
|
|||||||
>>> while True:
|
>>> while True:
|
||||||
>>> if example_inputs:
|
>>> if example_inputs:
|
||||||
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
|
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
|
||||||
>>> engine.add_request(str(req_id), prompt, sampling_params)
|
>>> engine.add_request(str(req_id),prompt,sampling_params)
|
||||||
>>>
|
>>>
|
||||||
>>> # continue the request processing
|
>>> # continue the request processing
|
||||||
>>> request_outputs = engine.step()
|
>>> request_outputs = engine.step()
|
||||||
@ -637,12 +719,15 @@ class LLMEngine:
|
|||||||
|
|
||||||
# KV Cache Usage in %
|
# KV Cache Usage in %
|
||||||
num_total_gpu = self.cache_config.num_gpu_blocks
|
num_total_gpu = self.cache_config.num_gpu_blocks
|
||||||
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
|
gpu_cache_usage_sys = 0.
|
||||||
|
if num_total_gpu is not None:
|
||||||
|
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks(
|
||||||
|
)
|
||||||
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
|
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
|
||||||
|
|
||||||
num_total_cpu = self.cache_config.num_cpu_blocks
|
num_total_cpu = self.cache_config.num_cpu_blocks
|
||||||
cpu_cache_usage_sys = 0.
|
cpu_cache_usage_sys = 0.
|
||||||
if num_total_cpu > 0:
|
if num_total_cpu is not None and num_total_cpu > 0:
|
||||||
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
|
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
|
||||||
)
|
)
|
||||||
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
|
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
|
||||||
@ -716,7 +801,9 @@ class LLMEngine:
|
|||||||
seq.get_output_len()
|
seq.get_output_len()
|
||||||
for seq in seq_group.get_finished_seqs()
|
for seq in seq_group.get_finished_seqs()
|
||||||
])
|
])
|
||||||
best_of_requests.append(seq_group.sampling_params.best_of)
|
if seq_group.sampling_params is not None:
|
||||||
|
best_of_requests.append(
|
||||||
|
seq_group.sampling_params.best_of)
|
||||||
n_requests.append(seq_group.sampling_params.n)
|
n_requests.append(seq_group.sampling_params.n)
|
||||||
finished_reason_requests.extend([
|
finished_reason_requests.extend([
|
||||||
SequenceStatus.get_finished_reason(seq.status)
|
SequenceStatus.get_finished_reason(seq.status)
|
||||||
|
|||||||
@ -6,13 +6,17 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
|||||||
|
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import MultiModalData
|
from vllm.sequence import MultiModalData
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import Counter
|
from vllm.utils import Counter
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
"""An LLM for generating texts from given prompts and sampling parameters.
|
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||||
@ -164,8 +168,89 @@ class LLM:
|
|||||||
multi_modal_data: Multi modal data.
|
multi_modal_data: Multi modal data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of `RequestOutput` objects containing the generated
|
A list of `RequestOutput` objects containing the
|
||||||
completions in the same order as the input prompts.
|
generated completions in the same order as the input prompts.
|
||||||
|
"""
|
||||||
|
if sampling_params is None:
|
||||||
|
# Use default sampling params.
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
|
||||||
|
requests_data = self._validate_and_prepare_requests(
|
||||||
|
prompts,
|
||||||
|
sampling_params,
|
||||||
|
prompt_token_ids,
|
||||||
|
lora_request,
|
||||||
|
multi_modal_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add requests to the engine and run the engine
|
||||||
|
for request_data in requests_data:
|
||||||
|
self._add_request(**request_data)
|
||||||
|
|
||||||
|
return self._run_engine(use_tqdm)
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
prompts: Optional[Union[str, List[str]]] = None,
|
||||||
|
pooling_params: Optional[Union[PoolingParams,
|
||||||
|
List[PoolingParams]]] = None,
|
||||||
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||||
|
use_tqdm: bool = True,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
|
) -> List[EmbeddingRequestOutput]:
|
||||||
|
"""Generates the completions for the input prompts.
|
||||||
|
|
||||||
|
NOTE: This class automatically batches the given prompts, considering
|
||||||
|
the memory constraint. For the best performance, put all of your prompts
|
||||||
|
into a single list and pass it to this method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: A list of prompts to generate completions for.
|
||||||
|
pooling_params: The pooling parameters for pooling. If None, we
|
||||||
|
use the default pooling parameters.
|
||||||
|
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
||||||
|
use the tokenizer to convert the prompts to token IDs.
|
||||||
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||||
|
lora_request: LoRA request to use for generation, if any.
|
||||||
|
multi_modal_data: Multi modal data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of `EmbeddingRequestOutput` objects containing the
|
||||||
|
generated embeddings in the same order as the input prompts.
|
||||||
|
"""
|
||||||
|
if pooling_params is None:
|
||||||
|
# Use default pooling params.
|
||||||
|
pooling_params = PoolingParams()
|
||||||
|
|
||||||
|
requests_data = self._validate_and_prepare_requests(
|
||||||
|
prompts,
|
||||||
|
pooling_params,
|
||||||
|
prompt_token_ids,
|
||||||
|
lora_request,
|
||||||
|
multi_modal_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add requests to the engine and run the engine
|
||||||
|
for request_data in requests_data:
|
||||||
|
self._add_request(**request_data)
|
||||||
|
|
||||||
|
return self._run_engine(use_tqdm)
|
||||||
|
|
||||||
|
def _validate_and_prepare_requests(
|
||||||
|
self,
|
||||||
|
prompts: Optional[Union[str, List[str]]],
|
||||||
|
params: Union[Union[SamplingParams, PoolingParams],
|
||||||
|
List[Union[SamplingParams,
|
||||||
|
PoolingParams]]], # Unified parameter
|
||||||
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
|
) -> List[dict]:
|
||||||
|
"""Validates and prepares request data for adding to the engine.
|
||||||
|
|
||||||
|
Ensures prompts and token IDs are consistent, and returns a list of
|
||||||
|
dictionaries with request data for further processing.
|
||||||
"""
|
"""
|
||||||
if prompts is None and prompt_token_ids is None:
|
if prompts is None and prompt_token_ids is None:
|
||||||
raise ValueError("Either prompts or prompt_token_ids must be "
|
raise ValueError("Either prompts or prompt_token_ids must be "
|
||||||
@ -188,40 +273,43 @@ class LLM:
|
|||||||
assert prompt_token_ids is not None
|
assert prompt_token_ids is not None
|
||||||
num_requests = len(prompt_token_ids)
|
num_requests = len(prompt_token_ids)
|
||||||
|
|
||||||
if sampling_params is None:
|
if isinstance(params, list) and len(params) != num_requests:
|
||||||
# Use default sampling params.
|
raise ValueError("The lengths of prompts and params "
|
||||||
sampling_params = SamplingParams()
|
|
||||||
|
|
||||||
elif isinstance(sampling_params,
|
|
||||||
list) and len(sampling_params) != num_requests:
|
|
||||||
raise ValueError("The lengths of prompts and sampling_params "
|
|
||||||
"must be the same.")
|
"must be the same.")
|
||||||
if multi_modal_data:
|
if multi_modal_data:
|
||||||
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
|
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
|
||||||
|
|
||||||
# Add requests to the engine.
|
# Add requests to the engine.
|
||||||
|
requests_data = []
|
||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
prompt = prompts[i] if prompts is not None else None
|
prompt = prompts[i] if prompts is not None else None
|
||||||
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
||||||
i]
|
i]
|
||||||
self._add_request(
|
|
||||||
prompt,
|
multi_modal_item = MultiModalData(
|
||||||
sampling_params[i]
|
|
||||||
if isinstance(sampling_params, list) else sampling_params,
|
|
||||||
token_ids,
|
|
||||||
lora_request=lora_request,
|
|
||||||
# Get ith image while maintaining the batch dim.
|
|
||||||
multi_modal_data=MultiModalData(
|
|
||||||
type=multi_modal_data.type,
|
type=multi_modal_data.type,
|
||||||
data=multi_modal_data.data[i].unsqueeze(0))
|
data=multi_modal_data.data[i].unsqueeze(0),
|
||||||
if multi_modal_data else None,
|
) if multi_modal_data else None
|
||||||
)
|
|
||||||
return self._run_engine(use_tqdm)
|
requests_data.append({
|
||||||
|
"prompt":
|
||||||
|
prompt,
|
||||||
|
"params":
|
||||||
|
params[i] if isinstance(params, list) else params,
|
||||||
|
"prompt_token_ids":
|
||||||
|
token_ids,
|
||||||
|
"lora_request":
|
||||||
|
lora_request,
|
||||||
|
"multi_modal_data":
|
||||||
|
multi_modal_item,
|
||||||
|
})
|
||||||
|
|
||||||
|
return requests_data
|
||||||
|
|
||||||
def _add_request(
|
def _add_request(
|
||||||
self,
|
self,
|
||||||
prompt: Optional[str],
|
prompt: Optional[str],
|
||||||
sampling_params: SamplingParams,
|
params: Union[SamplingParams, PoolingParams],
|
||||||
prompt_token_ids: Optional[List[int]],
|
prompt_token_ids: Optional[List[int]],
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
multi_modal_data: Optional[MultiModalData] = None,
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
@ -229,12 +317,14 @@ class LLM:
|
|||||||
request_id = str(next(self.request_counter))
|
request_id = str(next(self.request_counter))
|
||||||
self.llm_engine.add_request(request_id,
|
self.llm_engine.add_request(request_id,
|
||||||
prompt,
|
prompt,
|
||||||
sampling_params,
|
params,
|
||||||
prompt_token_ids,
|
prompt_token_ids,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
multi_modal_data=multi_modal_data)
|
multi_modal_data=multi_modal_data)
|
||||||
|
|
||||||
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
def _run_engine(
|
||||||
|
self, use_tqdm: bool
|
||||||
|
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||||
# Initialize tqdm.
|
# Initialize tqdm.
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
num_requests = self.llm_engine.get_num_unfinished_requests()
|
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||||
@ -245,7 +335,7 @@ class LLM:
|
|||||||
postfix=f"Generation Speed: {0:.2f} toks/s",
|
postfix=f"Generation Speed: {0:.2f} toks/s",
|
||||||
)
|
)
|
||||||
# Run the engine.
|
# Run the engine.
|
||||||
outputs: List[RequestOutput] = []
|
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
|
||||||
total_toks = 0
|
total_toks = 0
|
||||||
while self.llm_engine.has_unfinished_requests():
|
while self.llm_engine.has_unfinished_requests():
|
||||||
step_outputs = self.llm_engine.step()
|
step_outputs = self.llm_engine.step()
|
||||||
@ -253,8 +343,10 @@ class LLM:
|
|||||||
if output.finished:
|
if output.finished:
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
total_toks += (sum(
|
if isinstance(output, RequestOutput):
|
||||||
len(stp.token_ids) for stp in output.outputs))
|
# Calculate tokens only for RequestOutput
|
||||||
|
total_toks += sum(
|
||||||
|
len(stp.token_ids) for stp in output.outputs)
|
||||||
spd = total_toks / pbar.format_dict["elapsed"]
|
spd = total_toks / pbar.format_dict["elapsed"]
|
||||||
pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
|
pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|||||||
@ -22,9 +22,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|||||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
CompletionRequest, ErrorResponse)
|
CompletionRequest,
|
||||||
|
EmbeddingRequest, ErrorResponse)
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
|
||||||
@ -32,6 +34,8 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
|
|||||||
|
|
||||||
openai_serving_chat: OpenAIServingChat
|
openai_serving_chat: OpenAIServingChat
|
||||||
openai_serving_completion: OpenAIServingCompletion
|
openai_serving_completion: OpenAIServingCompletion
|
||||||
|
openai_serving_embedding: OpenAIServingEmbedding
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_running_tasks: Set[asyncio.Task] = set()
|
_running_tasks: Set[asyncio.Task] = set()
|
||||||
@ -123,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
return JSONResponse(content=generator.model_dump())
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/embeddings")
|
||||||
|
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||||
|
generator = await openai_serving_embedding.create_embedding(
|
||||||
|
request, raw_request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
else:
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
@ -190,7 +205,8 @@ if __name__ == "__main__":
|
|||||||
args.chat_template)
|
args.chat_template)
|
||||||
openai_serving_completion = OpenAIServingCompletion(
|
openai_serving_completion = OpenAIServingCompletion(
|
||||||
engine, model_config, served_model_names, args.lora_modules)
|
engine, model_config, served_model_names, args.lora_modules)
|
||||||
|
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
|
||||||
|
served_model_names)
|
||||||
app.root_path = args.root_path
|
app.root_path = args.root_path
|
||||||
uvicorn.run(app,
|
uvicorn.run(app,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
|
|||||||
@ -1,13 +1,14 @@
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
@ -363,6 +364,24 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingRequest(BaseModel):
|
||||||
|
# Ordered by official OpenAI API documentation
|
||||||
|
# https://platform.openai.com/docs/api-reference/embeddings
|
||||||
|
model: str
|
||||||
|
input: Union[List[int], List[List[int]], str, List[str]]
|
||||||
|
encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
|
||||||
|
dimensions: Optional[int] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
# doc: begin-embedding-pooling-params
|
||||||
|
additional_data: Optional[Any] = None
|
||||||
|
|
||||||
|
# doc: end-embedding-pooling-params
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(additional_data=self.additional_data)
|
||||||
|
|
||||||
|
|
||||||
class LogProbs(OpenAIBaseModel):
|
class LogProbs(OpenAIBaseModel):
|
||||||
text_offset: List[int] = Field(default_factory=list)
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||||
@ -416,6 +435,21 @@ class CompletionStreamResponse(OpenAIBaseModel):
|
|||||||
usage: Optional[UsageInfo] = Field(default=None)
|
usage: Optional[UsageInfo] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingResponseData(BaseModel):
|
||||||
|
index: int
|
||||||
|
object: str = "embedding"
|
||||||
|
embedding: List[float]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingResponse(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||||
|
object: str = "list"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
data: List[EmbeddingResponseData]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(OpenAIBaseModel):
|
class ChatMessage(OpenAIBaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
|
|||||||
134
vllm/entrypoints/openai/serving_embedding.py
Normal file
134
vllm/entrypoints/openai/serving_embedding.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
import time
|
||||||
|
from typing import AsyncIterator, List, Tuple
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
||||||
|
EmbeddingResponse,
|
||||||
|
EmbeddingResponseData, UsageInfo)
|
||||||
|
from vllm.entrypoints.openai.serving_completion import parse_prompt_format
|
||||||
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.outputs import EmbeddingRequestOutput
|
||||||
|
from vllm.utils import merge_async_iterators, random_uuid
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
TypeTokenIDs = List[int]
|
||||||
|
|
||||||
|
|
||||||
|
def request_output_to_embedding_response(
|
||||||
|
final_res_batch: List[EmbeddingRequestOutput],
|
||||||
|
request_id: str,
|
||||||
|
created_time: int,
|
||||||
|
model_name: str,
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
data = []
|
||||||
|
num_prompt_tokens = 0
|
||||||
|
for idx, final_res in enumerate(final_res_batch):
|
||||||
|
assert final_res is not None
|
||||||
|
prompt_token_ids = final_res.prompt_token_ids
|
||||||
|
|
||||||
|
embedding_data = EmbeddingResponseData(
|
||||||
|
index=idx, embedding=final_res.outputs.embedding)
|
||||||
|
data.append(embedding_data)
|
||||||
|
|
||||||
|
num_prompt_tokens += len(prompt_token_ids)
|
||||||
|
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
total_tokens=num_prompt_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return EmbeddingResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created_time,
|
||||||
|
model=model_name,
|
||||||
|
data=data,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingEmbedding(OpenAIServing):
|
||||||
|
|
||||||
|
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
|
||||||
|
served_model_names: List[str]):
|
||||||
|
super().__init__(engine=engine,
|
||||||
|
model_config=model_config,
|
||||||
|
served_model_names=served_model_names,
|
||||||
|
lora_modules=None)
|
||||||
|
self._check_embedding_mode(model_config.embedding_mode)
|
||||||
|
|
||||||
|
async def create_embedding(self, request: EmbeddingRequest,
|
||||||
|
raw_request: Request):
|
||||||
|
"""Completion API similar to OpenAI's API.
|
||||||
|
|
||||||
|
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||||
|
for the API specification. This API mimics the OpenAI Embedding API.
|
||||||
|
"""
|
||||||
|
error_check_ret = await self._check_model(request)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
# Return error for unsupported features.
|
||||||
|
if request.encoding_format == "base64":
|
||||||
|
return self.create_error_response(
|
||||||
|
"base64 encoding is not currently supported")
|
||||||
|
if request.dimensions is not None:
|
||||||
|
return self.create_error_response(
|
||||||
|
"dimensions is currently not supported")
|
||||||
|
|
||||||
|
model_name = request.model
|
||||||
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
|
created_time = int(time.monotonic())
|
||||||
|
|
||||||
|
# Schedule the request and get the result generator.
|
||||||
|
generators = []
|
||||||
|
try:
|
||||||
|
prompt_is_tokens, prompts = parse_prompt_format(request.input)
|
||||||
|
pooling_params = request.to_pooling_params()
|
||||||
|
|
||||||
|
for i, prompt in enumerate(prompts):
|
||||||
|
if prompt_is_tokens:
|
||||||
|
prompt_formats = self._validate_prompt_and_tokenize(
|
||||||
|
request, prompt_ids=prompt)
|
||||||
|
else:
|
||||||
|
prompt_formats = self._validate_prompt_and_tokenize(
|
||||||
|
request, prompt=prompt)
|
||||||
|
|
||||||
|
prompt_ids, prompt_text = prompt_formats
|
||||||
|
|
||||||
|
generators.append(
|
||||||
|
self.engine.generate(prompt_text,
|
||||||
|
pooling_params,
|
||||||
|
f"{request_id}-{i}",
|
||||||
|
prompt_token_ids=prompt_ids))
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
result_generator: AsyncIterator[Tuple[
|
||||||
|
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
|
||||||
|
|
||||||
|
# Non-streaming response
|
||||||
|
final_res_batch: EmbeddingRequestOutput = [None] * len(prompts)
|
||||||
|
async for i, res in result_generator:
|
||||||
|
if await raw_request.is_disconnected():
|
||||||
|
# Abort the request if the client disconnects.
|
||||||
|
await self.engine.abort(f"{request_id}-{i}")
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response("Client disconnected")
|
||||||
|
final_res_batch[i] = res
|
||||||
|
response = request_output_to_embedding_response(
|
||||||
|
final_res_batch, request_id, created_time, model_name)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _check_embedding_mode(self, embedding_mode: bool):
|
||||||
|
if not embedding_mode:
|
||||||
|
logger.warning(
|
||||||
|
"embedding_mode is False. Embedding API will not work.")
|
||||||
|
else:
|
||||||
|
logger.info("Activating the server engine with embedding enabled.")
|
||||||
@ -9,7 +9,8 @@ from typing_extensions import Annotated
|
|||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
CompletionRequest, ErrorResponse,
|
CompletionRequest,
|
||||||
|
EmbeddingRequest, ErrorResponse,
|
||||||
LogProbs, ModelCard, ModelList,
|
LogProbs, ModelCard, ModelList,
|
||||||
ModelPermission)
|
ModelPermission)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -165,7 +166,8 @@ class OpenAIServing:
|
|||||||
|
|
||||||
def _validate_prompt_and_tokenize(
|
def _validate_prompt_and_tokenize(
|
||||||
self,
|
self,
|
||||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
request: Union[ChatCompletionRequest, CompletionRequest,
|
||||||
|
EmbeddingRequest],
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
prompt_ids: Optional[List[int]] = None,
|
prompt_ids: Optional[List[int]] = None,
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||||
@ -191,6 +193,16 @@ class OpenAIServing:
|
|||||||
prompt_ids)
|
prompt_ids)
|
||||||
token_num = len(input_ids)
|
token_num = len(input_ids)
|
||||||
|
|
||||||
|
# Note: EmbeddingRequest doesn't have max_tokens
|
||||||
|
if isinstance(request, EmbeddingRequest):
|
||||||
|
if token_num > self.max_model_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"This model's maximum context length is "
|
||||||
|
f"{self.max_model_len} tokens. However, you requested "
|
||||||
|
f"{token_num} tokens in the input for embedding "
|
||||||
|
f"generation. Please reduce the length of the input.", )
|
||||||
|
return input_ids, input_text
|
||||||
|
|
||||||
if request.max_tokens is None:
|
if request.max_tokens is None:
|
||||||
if token_num >= self.max_model_len:
|
if token_num >= self.max_model_len:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
|
||||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||||
make_async)
|
make_async)
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
@ -123,8 +123,8 @@ class GPUExecutor(ExecutorBase):
|
|||||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self, execute_model_req: ExecuteModelRequest
|
||||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
) -> List[Union[SamplerOutput, PoolerOutput]]:
|
||||||
output = self.driver_worker.execute_model(execute_model_req)
|
output = self.driver_worker.execute_model(execute_model_req)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -150,7 +150,7 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
|
|||||||
async def execute_model_async(
|
async def execute_model_async(
|
||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
) -> List[SamplerOutput]:
|
) -> List[Union[SamplerOutput, PoolerOutput]]:
|
||||||
output = await make_async(self.driver_worker.execute_model
|
output = await make_async(self.driver_worker.execute_model
|
||||||
)(execute_model_req=execute_model_req, )
|
)(execute_model_req=execute_model_req, )
|
||||||
return output
|
return output
|
||||||
|
|||||||
56
vllm/model_executor/layers/pooler.py
Normal file
56
vllm/model_executor/layers/pooler.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||||
|
PoolingTensors)
|
||||||
|
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingType(IntEnum):
|
||||||
|
"""Enumeration for different types of pooling methods."""
|
||||||
|
LAST = 0
|
||||||
|
|
||||||
|
|
||||||
|
class Pooler(nn.Module):
|
||||||
|
"""A layer that pools specific information from hidden states.
|
||||||
|
|
||||||
|
This layer does the following:
|
||||||
|
1. Extracts specific tokens or aggregates data based on pooling method.
|
||||||
|
2. Normalizes output if specified.
|
||||||
|
3. Returns structured results as `PoolerOutput`.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
|
||||||
|
normalize: Whether to normalize the pooled data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pooling_type: PoolingType, normalize: bool):
|
||||||
|
super().__init__()
|
||||||
|
self.pooling_type = pooling_type
|
||||||
|
self.normalize = normalize
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> PoolerOutput:
|
||||||
|
"""Pools specific information from hidden states based on metadata."""
|
||||||
|
prompt_lens = PoolingTensors.from_pooling_metadata(
|
||||||
|
pooling_metadata, hidden_states.device).prompt_lens
|
||||||
|
|
||||||
|
if self.pooling_type == PoolingType.LAST:
|
||||||
|
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
|
||||||
|
pooled_data = hidden_states[last_token_flat_indices]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
|
||||||
|
|
||||||
|
if self.normalize:
|
||||||
|
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
|
||||||
|
|
||||||
|
pooled_outputs = [
|
||||||
|
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
|
||||||
|
]
|
||||||
|
|
||||||
|
return PoolerOutput(outputs=pooled_outputs)
|
||||||
@ -10,8 +10,9 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
|||||||
SamplingTensors,
|
SamplingTensors,
|
||||||
SequenceGroupToSample)
|
SequenceGroupToSample)
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||||
SamplerOutput, SequenceGroupOutput, SequenceOutput)
|
PromptLogprobs, SampleLogprobs, SamplerOutput,
|
||||||
|
SequenceOutput)
|
||||||
|
|
||||||
# (num_token_ids, num_parent_ids) per sequence group.
|
# (num_token_ids, num_parent_ids) per sequence group.
|
||||||
SampleResultType = List[Tuple[List[int], List[int]]]
|
SampleResultType = List[Tuple[List[int], List[int]]]
|
||||||
@ -1019,7 +1020,7 @@ def _build_sampler_output(
|
|||||||
seq_outputs.append(
|
seq_outputs.append(
|
||||||
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
|
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
|
||||||
sampler_output.append(
|
sampler_output.append(
|
||||||
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
||||||
|
|
||||||
# If not specified, store None values in SamplerOutput.
|
# If not specified, store None values in SamplerOutput.
|
||||||
if on_device_tensors is not None:
|
if on_device_tensors is not None:
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from vllm.utils import is_hip
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# Architecture -> (module, class).
|
# Architecture -> (module, class).
|
||||||
_MODELS = {
|
_GENERATION_MODELS = {
|
||||||
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
||||||
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
|
||||||
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
|
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
|
||||||
@ -58,6 +58,12 @@ _MODELS = {
|
|||||||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_EMBEDDING_MODELS = {
|
||||||
|
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
|
||||||
|
}
|
||||||
|
|
||||||
|
_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS}
|
||||||
|
|
||||||
# Architecture -> type.
|
# Architecture -> type.
|
||||||
# out of tree models
|
# out of tree models
|
||||||
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
|
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
|
||||||
@ -114,6 +120,10 @@ class ModelRegistry:
|
|||||||
global _OOT_MODELS
|
global _OOT_MODELS
|
||||||
_OOT_MODELS[model_arch] = model_cls
|
_OOT_MODELS[model_arch] = model_cls
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_embedding_model(model_arch: str) -> bool:
|
||||||
|
return model_arch in _EMBEDDING_MODELS
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModelRegistry",
|
"ModelRegistry",
|
||||||
|
|||||||
87
vllm/model_executor/models/llama_embedding.py
Normal file
87
vllm/model_executor/models/llama_embedding.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.llama import LlamaModel
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
|
from vllm.sequence import PoolerOutput
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaEmbeddingModel(nn.Module):
|
||||||
|
"""A model that uses Llama with additional embedding functionalities.
|
||||||
|
|
||||||
|
This class encapsulates the LlamaModel and provides an interface for
|
||||||
|
embedding operations and customized pooling functions.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model: An instance of LlamaModel used for forward operations.
|
||||||
|
_pooler: An instance of Pooler used for pooling operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.model = LlamaModel(**kwargs)
|
||||||
|
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.model.forward(input_ids, positions, kv_caches,
|
||||||
|
attn_metadata, inputs_embeds)
|
||||||
|
|
||||||
|
def pooler(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Optional[PoolerOutput]:
|
||||||
|
return self._pooler(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.model.named_parameters())
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if ("rotary_emb.cos_cached" in name
|
||||||
|
or "rotary_emb.sin_cached" in name):
|
||||||
|
# Models trained using ColossalAI may include these tensors in
|
||||||
|
# the checkpoint. Skip them.
|
||||||
|
continue
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
69
vllm/model_executor/pooling_metadata.py
Normal file
69
vllm/model_executor/pooling_metadata.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingMetadata:
|
||||||
|
"""Metadata for pooling operations in the Pooler layer.
|
||||||
|
|
||||||
|
This class holds the necessary information for pooling operations,
|
||||||
|
providing context for how to perform pooling and other related operations.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
seq_groups: List of (seq_ids, pooling_params).
|
||||||
|
seq_data: A mapping of sequence ID to additional sequence data.
|
||||||
|
prompt_lens: List of the lengths of each prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
seq_groups: List[Tuple[List[int], PoolingParams]],
|
||||||
|
seq_data: Dict[int, Any], # Specific data related to sequences
|
||||||
|
prompt_lens: List[int],
|
||||||
|
) -> None:
|
||||||
|
self.seq_groups = seq_groups
|
||||||
|
self.seq_data = seq_data
|
||||||
|
self.prompt_lens = prompt_lens
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return ("PoolingMetadata("
|
||||||
|
f"seq_groups={self.seq_groups}, "
|
||||||
|
f"seq_data={self.seq_data}, "
|
||||||
|
f"prompt_lens={self.prompt_lens})")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PoolingTensors:
|
||||||
|
"""Tensors for pooling."""
|
||||||
|
|
||||||
|
prompt_lens: torch.Tensor
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pooling_metadata(
|
||||||
|
cls,
|
||||||
|
pooling_metadata: "PoolingMetadata",
|
||||||
|
device: torch.device,
|
||||||
|
) -> "PoolingTensors":
|
||||||
|
"""
|
||||||
|
Create PoolingTensors from PoolingMetadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pooling_metadata: PoolingMetadata instance to convert.
|
||||||
|
device: Device to store the tensors.
|
||||||
|
"""
|
||||||
|
# Convert prompt lengths to tensor
|
||||||
|
pin_memory = is_pin_memory_available()
|
||||||
|
|
||||||
|
prompt_lens_t = torch.tensor(
|
||||||
|
pooling_metadata.prompt_lens,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.long,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(prompt_lens=prompt_lens_t.to(device=device,
|
||||||
|
non_blocking=True), )
|
||||||
@ -57,8 +57,27 @@ class CompletionOutput:
|
|||||||
f"stop_reason={self.stop_reason})")
|
f"stop_reason={self.stop_reason})")
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingOutput:
|
||||||
|
"""The output data of one completion output of a request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: The embedding vector, which is a list of floats. The
|
||||||
|
length of vector depends on the model as listed in the embedding guide.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
) -> None:
|
||||||
|
self.embedding = embedding
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"EmbeddingOutput("
|
||||||
|
f"embedding={len(self.embedding)}")
|
||||||
|
|
||||||
|
|
||||||
class RequestOutput:
|
class RequestOutput:
|
||||||
"""The output data of a request to the LLM.
|
"""The output data of a completion request to the LLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request_id: The unique ID of the request.
|
request_id: The unique ID of the request.
|
||||||
@ -93,6 +112,9 @@ class RequestOutput:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
||||||
|
if seq_group.sampling_params is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Sampling parameters are missing for a CompletionRequest.")
|
||||||
seqs = seq_group.get_seqs()
|
seqs = seq_group.get_seqs()
|
||||||
if len(seqs) == 1:
|
if len(seqs) == 1:
|
||||||
top_n_seqs = seqs
|
top_n_seqs = seqs
|
||||||
@ -148,3 +170,61 @@ class RequestOutput:
|
|||||||
f"finished={self.finished}, "
|
f"finished={self.finished}, "
|
||||||
f"metrics={self.metrics}, "
|
f"metrics={self.metrics}, "
|
||||||
f"lora_request={self.lora_request})")
|
f"lora_request={self.lora_request})")
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingRequestOutput:
|
||||||
|
"""
|
||||||
|
The output data of an embedding request to the LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id (str): A unique identifier for the embedding request.
|
||||||
|
outputs (EmbeddingOutput): The embedding results for the given input.
|
||||||
|
prompt_token_ids (List[int]): A list of token IDs used in the prompt.
|
||||||
|
finished (bool): A flag indicating whether the embedding is completed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str, outputs: 'EmbeddingOutput',
|
||||||
|
prompt_token_ids: List[int], finished: bool):
|
||||||
|
self.request_id = request_id
|
||||||
|
self.prompt_token_ids = prompt_token_ids
|
||||||
|
self.finished = finished
|
||||||
|
self.outputs = outputs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_seq_group(cls,
|
||||||
|
seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput":
|
||||||
|
if seq_group.embeddings is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Embeddings are missing in seq_group for EmbeddingRequest.")
|
||||||
|
output = EmbeddingOutput(seq_group.embeddings)
|
||||||
|
prompt_token_ids = seq_group.prompt_token_ids
|
||||||
|
finished = seq_group.is_finished()
|
||||||
|
|
||||||
|
return cls(seq_group.request_id, output, prompt_token_ids, finished)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""
|
||||||
|
Returns a string representation of an EmbeddingRequestOutput instance.
|
||||||
|
|
||||||
|
The representation includes the request_id and the number of outputs,
|
||||||
|
providing a quick overview of the embedding request's results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A string representation of the EmbeddingRequestOutput instance.
|
||||||
|
"""
|
||||||
|
return (f"EmbeddingRequestOutput(request_id='{self.request_id}', "
|
||||||
|
f"outputs={repr(self.outputs)}, "
|
||||||
|
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||||
|
f"finished={self.finished})")
|
||||||
|
|
||||||
|
|
||||||
|
class RequestOutputFactory:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(seq_group):
|
||||||
|
# Determine the type based on a condition, for example:
|
||||||
|
if hasattr(seq_group,
|
||||||
|
'embeddings') and seq_group.embeddings is not None:
|
||||||
|
return EmbeddingRequestOutput.from_seq_group(seq_group)
|
||||||
|
else:
|
||||||
|
return RequestOutput.from_seq_group(seq_group)
|
||||||
|
|||||||
20
vllm/pooling_params.py
Normal file
20
vllm/pooling_params.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingParams:
|
||||||
|
"""Pooling parameters for pooling.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
additional_data: Any additional data needed for pooling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, additional_data: Optional[Any] = None):
|
||||||
|
self.additional_data = additional_data
|
||||||
|
|
||||||
|
def clone(self) -> "PoolingParams":
|
||||||
|
"""Returns a deep copy of the PoolingParams instance."""
|
||||||
|
return PoolingParams(additional_data=self.additional_data, )
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"PoolingParams("
|
||||||
|
f"additional_metadata={self.additional_data})")
|
||||||
@ -1,11 +1,13 @@
|
|||||||
"""Sequence and its related classes."""
|
"""Sequence and its related classes."""
|
||||||
import copy
|
import copy
|
||||||
import enum
|
import enum
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from vllm.block import LogicalTokenBlock
|
from vllm.block import LogicalTokenBlock
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -402,16 +404,22 @@ class SequenceGroup:
|
|||||||
arrival_time: The arrival time of the request.
|
arrival_time: The arrival time of the request.
|
||||||
lora_request: LoRA request.
|
lora_request: LoRA request.
|
||||||
multi_modal_data: Multi modal data associated with the request.
|
multi_modal_data: Multi modal data associated with the request.
|
||||||
|
embeddings: The embeddings vectors of the prompt of the sequence group
|
||||||
|
for an embedding model.
|
||||||
|
pooling_params: The pooling parameters used to generate the pooling
|
||||||
|
for an embedding model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
seqs: List[Sequence],
|
seqs: List[Sequence],
|
||||||
sampling_params: SamplingParams,
|
|
||||||
arrival_time: float,
|
arrival_time: float,
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
multi_modal_data: Optional[MultiModalData] = None,
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
|
embeddings: Optional[List[float]] = None,
|
||||||
|
pooling_params: Optional[PoolingParams] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||||
@ -425,6 +433,8 @@ class SequenceGroup:
|
|||||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||||
self.state = SequenceGroupState()
|
self.state = SequenceGroupState()
|
||||||
self.multi_modal_data = multi_modal_data
|
self.multi_modal_data = multi_modal_data
|
||||||
|
self.embeddings = embeddings
|
||||||
|
self.pooling_params = pooling_params
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt(self) -> str:
|
def prompt(self) -> str:
|
||||||
@ -479,12 +489,13 @@ class SequenceGroup:
|
|||||||
def get_max_num_running_seqs(self) -> int:
|
def get_max_num_running_seqs(self) -> int:
|
||||||
"""The maximum number of sequences running in parallel in the remaining
|
"""The maximum number of sequences running in parallel in the remaining
|
||||||
lifetime of the request."""
|
lifetime of the request."""
|
||||||
if self.sampling_params.use_beam_search:
|
if self.sampling_params and self.sampling_params.use_beam_search:
|
||||||
# For beam search, maximally there will always be `best_of` beam
|
# For beam search, maximally there will always be `best_of` beam
|
||||||
# candidates running in the future.
|
# candidates running in the future.
|
||||||
return self.sampling_params.best_of
|
return self.sampling_params.best_of
|
||||||
else:
|
else:
|
||||||
if self.sampling_params.best_of > self.num_seqs():
|
if (self.sampling_params
|
||||||
|
and self.sampling_params.best_of > self.num_seqs()):
|
||||||
# At prompt stage, the sequence group is not yet filled up
|
# At prompt stage, the sequence group is not yet filled up
|
||||||
# and only have one sequence running. However, in the
|
# and only have one sequence running. However, in the
|
||||||
# generation stage, we will have `best_of` sequences running.
|
# generation stage, we will have `best_of` sequences running.
|
||||||
@ -555,7 +566,7 @@ class SequenceGroup:
|
|||||||
return all(seq.is_finished() for seq in self.get_seqs())
|
return all(seq.is_finished() for seq in self.get_seqs())
|
||||||
|
|
||||||
def is_prefill(self) -> bool:
|
def is_prefill(self) -> bool:
|
||||||
# Every sequences should be in the same stage.
|
# Every sequence should be in the same stage.
|
||||||
return self.get_seqs()[0].is_prefill()
|
return self.get_seqs()[0].is_prefill()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@ -594,6 +605,7 @@ class SequenceGroupMetadata:
|
|||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
block_tables: Dict[int, List[int]],
|
block_tables: Dict[int, List[int]],
|
||||||
do_sample: bool = True,
|
do_sample: bool = True,
|
||||||
|
pooling_params: Optional[PoolingParams] = None,
|
||||||
token_chunk_size: Optional[int] = None,
|
token_chunk_size: Optional[int] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
computed_block_nums: Optional[List[int]] = None,
|
computed_block_nums: Optional[List[int]] = None,
|
||||||
@ -605,6 +617,7 @@ class SequenceGroupMetadata:
|
|||||||
self.seq_data = seq_data
|
self.seq_data = seq_data
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
self.block_tables = block_tables
|
self.block_tables = block_tables
|
||||||
|
self.pooling_params = pooling_params
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.computed_block_nums = computed_block_nums
|
self.computed_block_nums = computed_block_nums
|
||||||
self.multi_modal_data = multi_modal_data
|
self.multi_modal_data = multi_modal_data
|
||||||
@ -669,8 +682,20 @@ class SequenceOutput:
|
|||||||
return equal and log_probs_equal
|
return equal and log_probs_equal
|
||||||
|
|
||||||
|
|
||||||
class SequenceGroupOutput:
|
class SequenceGroupOutput(ABC):
|
||||||
"""The model output associated with a sequence group."""
|
"""The base class for model outputs associated with a sequence group."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionSequenceGroupOutput(SequenceGroupOutput):
|
||||||
|
"""The model output associated with a completion sequence group."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -682,26 +707,45 @@ class SequenceGroupOutput:
|
|||||||
self.prompt_logprobs = prompt_logprobs
|
self.prompt_logprobs = prompt_logprobs
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"SequenceGroupOutput(samples={self.samples}, "
|
return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
|
||||||
f"prompt_logprobs={self.prompt_logprobs})")
|
f"prompt_logprobs={self.prompt_logprobs})")
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if not isinstance(other, SequenceGroupOutput):
|
if not isinstance(other, CompletionSequenceGroupOutput):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
return (self.samples == other.samples
|
return (self.samples == other.samples
|
||||||
and self.prompt_logprobs == other.prompt_logprobs)
|
and self.prompt_logprobs == other.prompt_logprobs)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
|
||||||
|
"""The model output associated with an embedding sequence group."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embeddings: List[float],
|
||||||
|
) -> None:
|
||||||
|
self.embeddings = embeddings
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"EmbeddingSequenceGroupOutput("
|
||||||
|
f"embeddings_shape={len(self.embeddings)})")
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, EmbeddingSequenceGroupOutput):
|
||||||
|
raise NotImplementedError()
|
||||||
|
return self.embeddings == other.embeddings
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SamplerOutput:
|
class SamplerOutput:
|
||||||
"""For each sequence group, we generate a list of SequenceOutput object,
|
"""For each sequence group, we generate a list of SequenceOutput object,
|
||||||
each of which contains one possible candidate for the next token.
|
each of which contains one possible candidate for the next token.
|
||||||
|
|
||||||
This datastructure implements methods so it can be used like a list, but
|
This data structure implements methods, so it can be used like a list, but
|
||||||
also has optional fields for device tensors.
|
also has optional fields for device tensors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
outputs: List[SequenceGroupOutput]
|
outputs: List[CompletionSequenceGroupOutput]
|
||||||
|
|
||||||
# On-device tensor containing probabilities of each token.
|
# On-device tensor containing probabilities of each token.
|
||||||
sampled_token_probs: Optional["torch.Tensor"] = None
|
sampled_token_probs: Optional["torch.Tensor"] = None
|
||||||
@ -742,6 +786,27 @@ class SamplerOutput:
|
|||||||
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PoolerOutput:
|
||||||
|
"""The output from a pooling operation in the embedding model."""
|
||||||
|
outputs: List[EmbeddingSequenceGroupOutput]
|
||||||
|
|
||||||
|
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
return self.outputs[idx]
|
||||||
|
|
||||||
|
def __setitem__(self, idx: int, value):
|
||||||
|
self.outputs[idx] = value
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.outputs)
|
||||||
|
|
||||||
|
def __eq__(self, other: object):
|
||||||
|
return isinstance(other,
|
||||||
|
self.__class__) and self.outputs == other.outputs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExecuteModelRequest:
|
class ExecuteModelRequest:
|
||||||
"""The model execution request."""
|
"""The model execution request."""
|
||||||
|
|||||||
@ -4,7 +4,8 @@ from typing import Dict, List, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
|
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||||
|
SamplerOutput, SequenceGroupMetadata,
|
||||||
SequenceGroupOutput, SequenceOutput)
|
SequenceGroupOutput, SequenceOutput)
|
||||||
|
|
||||||
SeqId = int
|
SeqId = int
|
||||||
@ -94,7 +95,7 @@ def create_sequence_group_output(
|
|||||||
for topk_logprob_index, _ in enumerate(topk_token_ids)
|
for topk_logprob_index, _ in enumerate(topk_token_ids)
|
||||||
})
|
})
|
||||||
|
|
||||||
return SequenceGroupOutput(
|
return CompletionSequenceGroupOutput(
|
||||||
samples=[
|
samples=[
|
||||||
SequenceOutput(parent_seq_id=seq_id,
|
SequenceOutput(parent_seq_id=seq_id,
|
||||||
output_token=token_id,
|
output_token=token_id,
|
||||||
|
|||||||
266
vllm/worker/embedding_model_runner.py
Normal file
266
vllm/worker/embedding_model_runner.py
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
|
VisionLanguageConfig)
|
||||||
|
from vllm.distributed import broadcast_tensor_dict
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.layers import LoRAMapping
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
|
||||||
|
from vllm.worker.model_runner import BatchType, ModelRunner
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingModelRunner(ModelRunner):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
load_config: LoadConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||||
|
):
|
||||||
|
super().__init__(model_config,
|
||||||
|
parallel_config,
|
||||||
|
scheduler_config,
|
||||||
|
device_config,
|
||||||
|
cache_config,
|
||||||
|
load_config,
|
||||||
|
lora_config=lora_config,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
is_driver_worker=is_driver_worker,
|
||||||
|
vision_language_config=vision_language_config)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
) -> Optional[PoolerOutput]:
|
||||||
|
(input_tokens, input_positions, attn_metadata, pooling_metadata,
|
||||||
|
lora_requests, lora_mapping, multi_modal_input
|
||||||
|
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
self.set_active_loras(lora_requests, lora_mapping)
|
||||||
|
|
||||||
|
# Currently cuda graph is only supported by the decode phase.
|
||||||
|
prefill_meta = attn_metadata.prefill_metadata
|
||||||
|
decode_meta = attn_metadata.decode_metadata
|
||||||
|
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||||
|
graph_batch_size = input_tokens.shape[0]
|
||||||
|
model_executable = self.graph_runners[graph_batch_size]
|
||||||
|
else:
|
||||||
|
model_executable = self.model
|
||||||
|
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
kv_caches = [None] * num_layers
|
||||||
|
|
||||||
|
execute_model_kwargs = {
|
||||||
|
"input_ids": input_tokens,
|
||||||
|
"positions": input_positions,
|
||||||
|
"kv_caches": kv_caches,
|
||||||
|
"attn_metadata": attn_metadata,
|
||||||
|
}
|
||||||
|
if self.vision_language_config:
|
||||||
|
execute_model_kwargs.update({"image_input": multi_modal_input})
|
||||||
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
|
return self.model.pooler(hidden_states=hidden_states,
|
||||||
|
pooling_metadata=pooling_metadata)
|
||||||
|
|
||||||
|
def prepare_input_tensors(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
|
||||||
|
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
||||||
|
if self.is_driver_worker:
|
||||||
|
prefill_reqs = []
|
||||||
|
decode_reqs = []
|
||||||
|
for seq_group_meta in seq_group_metadata_list:
|
||||||
|
if seq_group_meta.is_prompt:
|
||||||
|
prefill_reqs.append(seq_group_meta)
|
||||||
|
else:
|
||||||
|
decode_reqs.append(seq_group_meta)
|
||||||
|
|
||||||
|
# Prepare input tensors.
|
||||||
|
(
|
||||||
|
input_tokens,
|
||||||
|
input_positions,
|
||||||
|
prefill_attn_metadata,
|
||||||
|
prompt_lens,
|
||||||
|
subquery_lens,
|
||||||
|
lora_index_mapping,
|
||||||
|
lora_prompt_mapping,
|
||||||
|
lora_requests,
|
||||||
|
multi_modal_input,
|
||||||
|
slot_mapping,
|
||||||
|
) = self._prepare_prompt(prefill_reqs)
|
||||||
|
(
|
||||||
|
decode_input_tokens,
|
||||||
|
decode_input_positions,
|
||||||
|
decode_attn_metadata,
|
||||||
|
decode_lora_index_mapping,
|
||||||
|
decode_lora_prompt_mapping,
|
||||||
|
decode_lora_requests,
|
||||||
|
decode_slot_mapping,
|
||||||
|
) = self._prepare_decode(decode_reqs)
|
||||||
|
|
||||||
|
# Prepare PoolingMetadata
|
||||||
|
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
||||||
|
prompt_lens)
|
||||||
|
|
||||||
|
if not self.scheduler_config.chunked_prefill_enabled:
|
||||||
|
assert (len(prefill_reqs) and len(decode_reqs)) == 0
|
||||||
|
|
||||||
|
num_prefills = len(prompt_lens)
|
||||||
|
num_prefill_tokens = len(input_tokens)
|
||||||
|
num_decode_tokens = len(decode_input_tokens)
|
||||||
|
|
||||||
|
# Coalesce tensors. Note that attn_metadata is currently not
|
||||||
|
# coalesced for simplicity.
|
||||||
|
input_tokens.extend(decode_input_tokens)
|
||||||
|
input_positions.extend(decode_input_positions)
|
||||||
|
slot_mapping.extend(decode_slot_mapping)
|
||||||
|
lora_index_mapping.extend(decode_lora_index_mapping)
|
||||||
|
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
|
||||||
|
lora_requests.update(decode_lora_requests)
|
||||||
|
|
||||||
|
input_tokens = torch.tensor(input_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
input_positions = torch.tensor(input_positions,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
slot_mapping = torch.tensor(slot_mapping,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
lora_mapping = LoRAMapping(
|
||||||
|
lora_index_mapping,
|
||||||
|
lora_prompt_mapping,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lora_mapping = None
|
||||||
|
|
||||||
|
# Broadcast the metadata.
|
||||||
|
# If batch contains both prefill and decode, it sends 2 broadcasts.
|
||||||
|
# If it only contains 1 type, it triggers a single broadcast.
|
||||||
|
if (prefill_attn_metadata is not None
|
||||||
|
and decode_attn_metadata is not None):
|
||||||
|
batch_type = BatchType.MIXED
|
||||||
|
elif prefill_attn_metadata is not None:
|
||||||
|
batch_type = BatchType.PREFILL
|
||||||
|
else:
|
||||||
|
batch_type = BatchType.DECODE
|
||||||
|
|
||||||
|
metadata_dict = {
|
||||||
|
"input_tokens": input_tokens,
|
||||||
|
"input_positions": input_positions,
|
||||||
|
"lora_requests": lora_requests,
|
||||||
|
"lora_mapping": lora_mapping,
|
||||||
|
"multi_modal_input": multi_modal_input,
|
||||||
|
"num_prefill_tokens": num_prefill_tokens,
|
||||||
|
"num_decode_tokens": num_decode_tokens,
|
||||||
|
"slot_mapping": slot_mapping,
|
||||||
|
"num_prefills": num_prefills,
|
||||||
|
"batch_type": batch_type,
|
||||||
|
}
|
||||||
|
if prefill_attn_metadata is not None:
|
||||||
|
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
|
||||||
|
else:
|
||||||
|
assert decode_attn_metadata is not None
|
||||||
|
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
|
||||||
|
broadcast_tensor_dict(metadata_dict, src=0)
|
||||||
|
|
||||||
|
# Broadcast decode attn metadata for mixed batch type.
|
||||||
|
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
|
||||||
|
# We can potentially reduce the overhead by coelescing tensors.
|
||||||
|
if batch_type == BatchType.MIXED:
|
||||||
|
assert decode_attn_metadata is not None
|
||||||
|
metadata_dict = decode_attn_metadata.asdict_zerocopy()
|
||||||
|
broadcast_tensor_dict(metadata_dict, src=0)
|
||||||
|
else:
|
||||||
|
metadata_dict = broadcast_tensor_dict(src=0)
|
||||||
|
input_tokens = metadata_dict.pop("input_tokens")
|
||||||
|
input_positions = metadata_dict.pop("input_positions")
|
||||||
|
slot_mapping = metadata_dict.pop("slot_mapping")
|
||||||
|
num_prefills = metadata_dict.pop("num_prefills")
|
||||||
|
lora_mapping = metadata_dict.pop("lora_mapping")
|
||||||
|
lora_requests = metadata_dict.pop("lora_requests")
|
||||||
|
multi_modal_input = metadata_dict.pop("multi_modal_input")
|
||||||
|
num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
|
||||||
|
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
|
||||||
|
batch_type = metadata_dict.pop("batch_type")
|
||||||
|
|
||||||
|
# Create an attention metadata.
|
||||||
|
prefill_attn_metadata = None
|
||||||
|
decode_attn_metadata = None
|
||||||
|
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
|
||||||
|
prefill_attn_metadata = self.attn_backend.make_metadata(
|
||||||
|
**metadata_dict)
|
||||||
|
else:
|
||||||
|
decode_attn_metadata = self.attn_backend.make_metadata(
|
||||||
|
**metadata_dict)
|
||||||
|
|
||||||
|
pooling_metadata = PoolingMetadata(seq_groups=None,
|
||||||
|
seq_data=None,
|
||||||
|
prompt_lens=None)
|
||||||
|
|
||||||
|
# if it is a mixed batch, decode attn_metadata is broadcasted
|
||||||
|
# separately.
|
||||||
|
if batch_type == BatchType.MIXED:
|
||||||
|
metadata_dict = broadcast_tensor_dict(src=0)
|
||||||
|
decode_attn_metadata = self.attn_backend.make_metadata(
|
||||||
|
**metadata_dict)
|
||||||
|
|
||||||
|
attn_metadata = AttentionMetadata(
|
||||||
|
num_prefills=num_prefills,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
prefill_metadata=prefill_attn_metadata,
|
||||||
|
decode_metadata=decode_attn_metadata,
|
||||||
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
|
||||||
|
lora_requests, lora_mapping, multi_modal_input)
|
||||||
|
|
||||||
|
def _prepare_pooling(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
prompt_lens: List[int],
|
||||||
|
) -> PoolingMetadata:
|
||||||
|
"""Prepare PoolingMetadata for the sequence group metadata list."""
|
||||||
|
seq_groups: List[Tuple[List[int], PoolingParams]] = []
|
||||||
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
pooling_params = seq_group_metadata.pooling_params
|
||||||
|
seq_groups.append((seq_ids, pooling_params))
|
||||||
|
|
||||||
|
seq_data: Dict[int, SequenceData] = {}
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
seq_data.update(seq_group_metadata.seq_data)
|
||||||
|
|
||||||
|
pooling_metadata = PoolingMetadata(
|
||||||
|
seq_groups=seq_groups,
|
||||||
|
seq_data=seq_data,
|
||||||
|
prompt_lens=prompt_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pooling_metadata
|
||||||
@ -1,6 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
|
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -287,18 +287,18 @@ class ModelRunner:
|
|||||||
lora_requests.add(seq_group_metadata.lora_request)
|
lora_requests.add(seq_group_metadata.lora_request)
|
||||||
|
|
||||||
lora_index_mapping += [lora_id] * (seq_len - context_len)
|
lora_index_mapping += [lora_id] * (seq_len - context_len)
|
||||||
lora_prompt_mapping.extend(
|
lora_prompt_mapping.extend([lora_id] * (
|
||||||
[lora_id] *
|
seq_len - context_len if seq_group_metadata.sampling_params
|
||||||
(seq_len - context_len
|
and seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
||||||
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
|
||||||
|
|
||||||
if seq_group_metadata.multi_modal_data:
|
if seq_group_metadata.multi_modal_data:
|
||||||
multi_modal_input_list.append(
|
multi_modal_input_list.append(
|
||||||
seq_group_metadata.multi_modal_data.data)
|
seq_group_metadata.multi_modal_data.data)
|
||||||
|
|
||||||
if seq_group_metadata.block_tables is None:
|
if _is_block_tables_empty(seq_group_metadata.block_tables):
|
||||||
# During memory profiling, the block tables are not initialized
|
# During memory profiling, the block tables are not initialized
|
||||||
# yet. In this case, we just use a dummy slot mapping.
|
# yet. In this case, we just use a dummy slot mapping.
|
||||||
|
# In embeddings, the block tables are {seq_id: None}.
|
||||||
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -813,7 +813,6 @@ class ModelRunner:
|
|||||||
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||||
|
|
||||||
# This represents the maximum number of different requests
|
# This represents the maximum number of different requests
|
||||||
# that will have unique loras, an therefore the max amount of memory
|
# that will have unique loras, an therefore the max amount of memory
|
||||||
# consumption create dummy lora request copies from the lora request
|
# consumption create dummy lora request copies from the lora request
|
||||||
@ -1139,3 +1138,15 @@ def _prepare_fake_inputs(
|
|||||||
prompt_tokens = [0] * seq_len
|
prompt_tokens = [0] * seq_len
|
||||||
fake_image_input = None
|
fake_image_input = None
|
||||||
return SequenceData(prompt_tokens), fake_image_input
|
return SequenceData(prompt_tokens), fake_image_input
|
||||||
|
|
||||||
|
|
||||||
|
def _is_block_tables_empty(block_tables: Union[None, Dict]):
|
||||||
|
"""
|
||||||
|
Check if block_tables is None or a dictionary with all None values.
|
||||||
|
"""
|
||||||
|
if block_tables is None:
|
||||||
|
return True
|
||||||
|
if isinstance(block_tables, dict) and all(
|
||||||
|
value is None for value in block_tables.values()):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""A GPU worker class."""
|
"""A GPU worker class."""
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -16,8 +16,9 @@ from vllm.distributed.device_communicators.custom_all_reduce import (
|
|||||||
init_custom_ar)
|
init_custom_ar)
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
|
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
|
||||||
from vllm.worker.model_runner import ModelRunner
|
from vllm.worker.model_runner import ModelRunner
|
||||||
from vllm.worker.worker_base import WorkerBase
|
from vllm.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
@ -68,7 +69,9 @@ class Worker(WorkerBase):
|
|||||||
assert not self.lora_config, (
|
assert not self.lora_config, (
|
||||||
"To be tested: vision language model with LoRA settings.")
|
"To be tested: vision language model with LoRA settings.")
|
||||||
|
|
||||||
self.model_runner = ModelRunner(
|
ModelRunnerClass = (EmbeddingModelRunner if
|
||||||
|
self.model_config.embedding_mode else ModelRunner)
|
||||||
|
self.model_runner = ModelRunnerClass(
|
||||||
model_config,
|
model_config,
|
||||||
parallel_config,
|
parallel_config,
|
||||||
scheduler_config,
|
scheduler_config,
|
||||||
@ -83,7 +86,8 @@ class Worker(WorkerBase):
|
|||||||
# Uninitialized cache engine. Will be initialized by
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# initialize_cache.
|
# initialize_cache.
|
||||||
self.cache_engine: CacheEngine
|
self.cache_engine: CacheEngine
|
||||||
self.gpu_cache: List[torch.Tensor]
|
# Initialize gpu_cache as embedding models don't initialize kv_caches
|
||||||
|
self.gpu_cache: Optional[List[torch.tensor]] = None
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
if self.device_config.device.type == "cuda":
|
if self.device_config.device.type == "cuda":
|
||||||
@ -209,7 +213,7 @@ class Worker(WorkerBase):
|
|||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
) -> List[SamplerOutput]:
|
) -> List[Union[SamplerOutput, PoolerOutput]]:
|
||||||
|
|
||||||
if execute_model_req is None:
|
if execute_model_req is None:
|
||||||
seq_group_metadata_list = None
|
seq_group_metadata_list = None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user