mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-21 08:37:04 +08:00
added __init__.py
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
parent
66349c33a1
commit
d5b0db449e
@ -58,11 +58,14 @@ class PDController(EngineClient):
|
|||||||
[ Engine ] <---> [ Engine ]
|
[ Engine ] <---> [ Engine ]
|
||||||
|
|
||||||
After PR #12957, we will support xPyD, so we will
|
After PR #12957, we will support xPyD, so we will
|
||||||
also need to implement a scheduler.
|
also need to implement a scheduler and service
|
||||||
we will need to support multiple
|
discovery for the workers.
|
||||||
|
|
||||||
* TODO: actually handle errors and failure.
|
This PDController may be implemented as a K8s
|
||||||
* TODO: support the full API (logprobs, multimodal).
|
controller. This is intended to be a prototype.
|
||||||
|
|
||||||
|
* TODO: better error handling
|
||||||
|
* TODO: support logprobs, multimodal, etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prefill_addr: str, decode_addr: str,
|
def __init__(self, prefill_addr: str, decode_addr: str,
|
||||||
@ -101,17 +104,16 @@ class PDController(EngineClient):
|
|||||||
|
|
||||||
# Dummy: needed for EngineClient Protocol.
|
# Dummy: needed for EngineClient Protocol.
|
||||||
# TODO: refactor OAI Server to avoid needing this.
|
# TODO: refactor OAI Server to avoid needing this.
|
||||||
init_kwargs = dict(
|
self.tokenizer = TokenizerGroup(
|
||||||
tokenizer_id=self.model_config.tokenizer,
|
**dict(tokenizer_id=self.model_config.tokenizer,
|
||||||
enable_lora=False,
|
enable_lora=False,
|
||||||
max_num_seqs=1024,
|
max_num_seqs=1024,
|
||||||
max_loras=0,
|
max_loras=0,
|
||||||
max_input_length=None,
|
max_input_length=None,
|
||||||
tokenizer_mode=self.model_config.tokenizer_mode,
|
tokenizer_mode=self.model_config.tokenizer_mode,
|
||||||
trust_remote_code=self.model_config.trust_remote_code,
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
revision=self.model_config.tokenizer_revision,
|
revision=self.model_config.tokenizer_revision,
|
||||||
truncation_side=self.model_config.truncation_side)
|
truncation_side=self.model_config.truncation_side))
|
||||||
self.tokenizer = TokenizerGroup(**init_kwargs)
|
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
if (ctx := self.ctx) is not None:
|
if (ctx := self.ctx) is not None:
|
||||||
@ -155,7 +157,7 @@ class PDController(EngineClient):
|
|||||||
raise Exception("Unknown response type.")
|
raise Exception("Unknown response type.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# TODO: distinguish between fatal and non-fatal errors.
|
# TODO: distinguish between fatal and non-fatal errors.
|
||||||
for _, q in self.queues.values():
|
for q in self.queues.values():
|
||||||
q.put_nowait(e)
|
q.put_nowait(e)
|
||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
@ -172,17 +174,17 @@ class PDController(EngineClient):
|
|||||||
msg = (PDRequestType.GENERATION, req_bytes)
|
msg = (PDRequestType.GENERATION, req_bytes)
|
||||||
await self.to_prefill.send_multipart(msg, copy=False)
|
await self.to_prefill.send_multipart(msg, copy=False)
|
||||||
|
|
||||||
# Wait for the prefill to be done.
|
# Await completion of the prefill.
|
||||||
response = await q.get()
|
response = await q.get()
|
||||||
if isinstance(response, Exception):
|
if isinstance(response, Exception):
|
||||||
raise response
|
raise response
|
||||||
|
logger.debug("Got Decode Response: %s", request.request_id)
|
||||||
|
|
||||||
async def _run_decode(
|
async def _run_decode(
|
||||||
self,
|
self,
|
||||||
request: PDGenerationRequest,
|
request: PDGenerationRequest,
|
||||||
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
|
q: asyncio.Queue[Union[Exception, PDGenerationResponse]],
|
||||||
) -> AsyncGenerator[PDGenerationResponse]:
|
) -> AsyncGenerator[PDGenerationResponse]:
|
||||||
|
|
||||||
# Send request to the decode instance.
|
# Send request to the decode instance.
|
||||||
req_bytes = self.encoder.encode(request)
|
req_bytes = self.encoder.encode(request)
|
||||||
msg = (PDRequestType.GENERATION, req_bytes)
|
msg = (PDRequestType.GENERATION, req_bytes)
|
||||||
@ -194,6 +196,7 @@ class PDController(EngineClient):
|
|||||||
response = await q.get()
|
response = await q.get()
|
||||||
if isinstance(response, Exception):
|
if isinstance(response, Exception):
|
||||||
raise response
|
raise response
|
||||||
|
logger.debug("Got Decode Response: %s", request.request_id)
|
||||||
finished = response.finish_reason is not None
|
finished = response.finish_reason is not None
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
@ -261,10 +264,10 @@ class PDController(EngineClient):
|
|||||||
|
|
||||||
# (1) Perform the Prefill.
|
# (1) Perform the Prefill.
|
||||||
original_max_tokens = sampling_params.max_tokens
|
original_max_tokens = sampling_params.max_tokens
|
||||||
prompt_token_ids = prompt["prompt_token_ids"]
|
request = PDGenerationRequest(
|
||||||
request = PDGenerationRequest(request_id=request_id,
|
request_id=request_id,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt["prompt_token_ids"],
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params)
|
||||||
request.sampling_params.max_tokens = 1
|
request.sampling_params.max_tokens = 1
|
||||||
logger.debug("Sending Prefill: %s", request.request_id)
|
logger.debug("Sending Prefill: %s", request.request_id)
|
||||||
pd_response = await self._run_prefill(request, q)
|
pd_response = await self._run_prefill(request, q)
|
||||||
@ -273,8 +276,8 @@ class PDController(EngineClient):
|
|||||||
logger.debug("Sending Decode: %s", request.request_id)
|
logger.debug("Sending Decode: %s", request.request_id)
|
||||||
request.sampling_params.max_tokens = original_max_tokens
|
request.sampling_params.max_tokens = original_max_tokens
|
||||||
async for pd_response in self._run_decode(request, q):
|
async for pd_response in self._run_decode(request, q):
|
||||||
logger.debug("Got Response: %s", request.request_id)
|
yield self._to_request_output(pd_response,
|
||||||
yield self._to_request_output(pd_response, prompt_token_ids)
|
prompt["prompt_token_ids"])
|
||||||
|
|
||||||
async def beam_search(
|
async def beam_search(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user