Merge branch 'main' into woosuk/input-prep

This commit is contained in:
Woosuk Kwon 2025-08-17 19:28:38 -07:00
commit 699bd7928e
2 changed files with 8 additions and 7 deletions

View File

@ -56,7 +56,7 @@ def get_moe_quant_method(
# Dynamic per module/layer rules may override base config # Dynamic per module/layer rules may override base config
override_config(cloned_config, prefix=prefix) override_config(cloned_config, prefix=prefix)
return moe_method_cls(cloned_config) return moe_method_cls(cloned_config, layer.moe_config)
return None return None

View File

@ -231,8 +231,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# The convention is different. # The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order. # self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order. # The batch sizes in the config are in descending order.
self.cudagraph_batch_sizes = list( if self.compilation_config.cudagraph_capture_sizes and \
reversed(self.compilation_config.cudagraph_capture_sizes)) self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
self.cudagraph_batch_sizes = list(
reversed(self.compilation_config.cudagraph_capture_sizes))
# Cache the device properties. # Cache the device properties.
self._init_device_properties() self._init_device_properties()
@ -1657,7 +1659,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Compute prompt logprobs if needed. # Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict( prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens], hidden_states[:num_scheduled_tokens],
scheduler_output, scheduler_output.num_scheduled_tokens,
) )
# Get the valid generated tokens. # Get the valid generated tokens.
@ -1999,7 +2001,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _get_prompt_logprobs_dict( def _get_prompt_logprobs_dict(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
scheduler_output: "SchedulerOutput", num_scheduled_tokens: dict[str, int],
) -> dict[str, Optional[LogprobsTensors]]: ) -> dict[str, Optional[LogprobsTensors]]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
if not num_prompt_logprobs_dict: if not num_prompt_logprobs_dict:
@ -2012,8 +2014,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# maintainable loop over optimal performance. # maintainable loop over optimal performance.
completed_prefill_reqs = [] completed_prefill_reqs = []
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
num_tokens = num_scheduled_tokens[req_id]
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# Get metadata for this request. # Get metadata for this request.
request = self.requests[req_id] request = self.requests[req_id]