mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 18:27:04 +08:00
more fixes
This commit is contained in:
parent
5b0249b86e
commit
62da375465
@ -33,6 +33,7 @@ from time import sleep
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils import get_open_port, FlexibleArgumentParser
|
||||
from vllm import LLM, EngineArgs
|
||||
import torch
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
@ -82,7 +83,7 @@ def main(args, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
] * 100
|
||||
]
|
||||
|
||||
# with DP, each rank should process different prompts.
|
||||
# usually all the DP ranks process a full dataset,
|
||||
|
||||
@ -972,7 +972,9 @@ def pplx_finalize():
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
_all_to_all_cache)
|
||||
_all_to_all_cache.destroy()
|
||||
for cache in _all_to_all_cache:
|
||||
cache.destroy()
|
||||
# _all_to_all_cache.destroy()
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
|
||||
@ -453,7 +453,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
if isinstance(prepare_finalize,
|
||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||
print("BatchedTritonExperts %s", self.moe)
|
||||
# print("BatchedTritonExperts %s", self.moe)
|
||||
|
||||
experts = BatchedTritonExperts(
|
||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||
@ -466,7 +466,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
block_shape=None,
|
||||
)
|
||||
else:
|
||||
print("TritonExperts %s", self.moe)
|
||||
# print("TritonExperts %s", self.moe)
|
||||
experts = TritonExperts(
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
|
||||
@ -135,7 +135,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
|
||||
torch.cuda.synchronize()
|
||||
return expert_x, expert_x_scale, expert_num_tokens
|
||||
|
||||
def finalize(
|
||||
@ -179,6 +179,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
|
||||
combine(False)
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
# torch.cuda.synchronize()
|
||||
@ -481,7 +481,7 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
device = self.runner.device
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[req_slice]
|
||||
print(f"num_reqs: {num_reqs} bloc_table_shape: {block_table_tensor.shape}")
|
||||
# print(f"num_reqs: {num_reqs} bloc_table_shape: {block_table_tensor.shape}")
|
||||
slot_mapping = block_table.slot_mapping_cpu[token_slice].to(
|
||||
device, non_blocking=True).long()
|
||||
|
||||
|
||||
@ -1301,16 +1301,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(token_slice, use_dummy_input)
|
||||
with context:
|
||||
if isinstance(context, UBatchContext):
|
||||
print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape} use_dummy_input {use_dummy_input} token_slice {token_slice}")
|
||||
# if isinstance(context, UBatchContext):
|
||||
# print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape} use_dummy_input {use_dummy_input} token_slice {token_slice}")
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if isinstance(context, UBatchContext):
|
||||
print(f"Ran ubatch {context.id}putput {model_output.shape}")
|
||||
# if isinstance(context, UBatchContext):
|
||||
# print(f"Ran ubatch {context.id}putput {model_output.shape}")
|
||||
if isinstance(context, UBatchContext):
|
||||
# Clone before we leave the ubatch context
|
||||
model_output = model_output.clone()
|
||||
@ -1342,7 +1342,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
|
||||
assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run
|
||||
|
||||
print("ubatch", i, "tokens slice", tokens_slice, "is dummy ubatch", is_dummy_ubatch, "is dummy run", is_dummy_run)
|
||||
# print("ubatch", i, "tokens slice", tokens_slice, "is dummy ubatch", is_dummy_ubatch, "is dummy run", is_dummy_run)
|
||||
|
||||
num_tokens = num_dummy_tokens if is_dummy_ubatch or is_dummy_run else (tokens_slice.stop - tokens_slice.start)
|
||||
ubatch_ctxs[i].forward_context = create_forward_context(
|
||||
@ -1363,6 +1363,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.set_stream(root_stream)
|
||||
return torch.cat(results, dim=0)
|
||||
|
||||
|
||||
@ -289,8 +289,10 @@ class Worker(WorkerBase):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
if is_start:
|
||||
assert False
|
||||
self.profiler.start()
|
||||
else:
|
||||
assert False
|
||||
self.profiler.stop()
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
|
||||
@ -70,9 +70,10 @@ class UBatchContext:
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
|
||||
def ctx_valid_state(self):
|
||||
assert forward_context._forward_context == self.forward_context
|
||||
assert current_stream() == self.current_stream
|
||||
assert not self.cpu_wait_event.is_set()
|
||||
# assert forward_context._forward_context == self.forward_context
|
||||
# assert current_stream() == self.current_stream
|
||||
# assert not self.cpu_wait_event.is_set()
|
||||
pass
|
||||
def _signal_comm_done(self):
|
||||
self.ctx_valid_state()
|
||||
self.gpu_comm_done_event.record(self.comm_stream)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user