Merge branch 'main' into woosuk/input-prep

This commit is contained in:
Woosuk Kwon 2025-08-17 19:28:38 -07:00
commit 699bd7928e
2 changed files with 8 additions and 7 deletions

View File

@ -56,7 +56,7 @@ def get_moe_quant_method(
# Dynamic per module/layer rules may override base config
override_config(cloned_config, prefix=prefix)
return moe_method_cls(cloned_config)
return moe_method_cls(cloned_config, layer.moe_config)
return None

View File

@ -231,8 +231,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order.
self.cudagraph_batch_sizes = list(
reversed(self.compilation_config.cudagraph_capture_sizes))
if self.compilation_config.cudagraph_capture_sizes and \
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
self.cudagraph_batch_sizes = list(
reversed(self.compilation_config.cudagraph_capture_sizes))
# Cache the device properties.
self._init_device_properties()
@ -1657,7 +1659,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens],
scheduler_output,
scheduler_output.num_scheduled_tokens,
)
# Get the valid generated tokens.
@ -1999,7 +2001,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _get_prompt_logprobs_dict(
self,
hidden_states: torch.Tensor,
scheduler_output: "SchedulerOutput",
num_scheduled_tokens: dict[str, int],
) -> dict[str, Optional[LogprobsTensors]]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
if not num_prompt_logprobs_dict:
@ -2012,8 +2014,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# maintainable loop over optimal performance.
completed_prefill_reqs = []
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_tokens = num_scheduled_tokens[req_id]
# Get metadata for this request.
request = self.requests[req_id]