[Chore] code lint

This commit is contained in:
i-yuanyukun 2025-12-18 15:56:43 +08:00
parent cd16bcff1e
commit f74bb82909
3 changed files with 16 additions and 7 deletions

View File

@ -130,8 +130,10 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin):
try: try:
hidden_states, recv_metadata = self.connector.recv_attn_output() hidden_states, recv_metadata = self.connector.recv_attn_output()
if hasattr(self.connector, 'dp_metadata_list'): if hasattr(self.connector, "dp_metadata_list"):
dp_metadata = self.connector.dp_metadata_list.get(recv_metadata.stage_idx, None) dp_metadata = self.connector.dp_metadata_list.get(
recv_metadata.stage_idx, None
)
else: else:
dp_metadata = None dp_metadata = None
current_layer_idx = recv_metadata.layer_idx current_layer_idx = recv_metadata.layer_idx

View File

@ -3192,7 +3192,9 @@ class GPUModelRunner(
# Mark KV scales as calculated after the first forward pass # Mark KV scales as calculated after the first forward pass
self.calculate_kv_scales = False self.calculate_kv_scales = False
afd_metadata = self._build_afd_metadata(ubatch_slices_padded, num_tokens_unpadded) afd_metadata = self._build_afd_metadata(
ubatch_slices_padded, num_tokens_unpadded
)
self.profiler.step() self.profiler.step()
# Run the model. # Run the model.
@ -4326,7 +4328,9 @@ class GPUModelRunner(
if num_tokens_across_dp is not None: if num_tokens_across_dp is not None:
num_tokens_across_dp[:] = num_tokens_padded num_tokens_across_dp[:] = num_tokens_padded
afd_metadata = self._build_afd_metadata(ubatch_slices_padded, num_tokens_unpadded) afd_metadata = self._build_afd_metadata(
ubatch_slices_padded, num_tokens_unpadded
)
with ( with (
self.maybe_randomize_inputs(input_ids, inputs_embeds), self.maybe_randomize_inputs(input_ids, inputs_embeds),

View File

@ -405,9 +405,12 @@ class UBatchWrapper:
afd_metadata.input_ids_list.append(sliced_input_ids) afd_metadata.input_ids_list.append(sliced_input_ids)
afd_metadata.positions_list.append(sliced_positions) afd_metadata.positions_list.append(sliced_positions)
afd_metadata.inputs_embeds_list.append(sliced_inputs_embeds) afd_metadata.inputs_embeds_list.append(sliced_inputs_embeds)
afd_metadata.intermediate_tensors_list.append(sliced_intermediate_tensors) afd_metadata.intermediate_tensors_list.append(
sliced_intermediate_tensors
)
afd_metadata.attn_metadata_list.append( afd_metadata.attn_metadata_list.append(
attn_metadata[i] if attn_metadata is not None else None) attn_metadata[i] if attn_metadata is not None else None
)
afd_metadata.dp_metadata_list.append(ubatch_dp_metadata) afd_metadata.dp_metadata_list.append(ubatch_dp_metadata)
return afd_metadata return afd_metadata
@ -481,7 +484,7 @@ class UBatchWrapper:
# num_tokens, we don't have a non-ubatched one. Without this # num_tokens, we don't have a non-ubatched one. Without this
# check, the cudagraph wrapper will try to capture a cudagraph # check, the cudagraph wrapper will try to capture a cudagraph
# for this shape during a normal run. # for this shape during a normal run.
if cudagraph_runtime_mode is CUDAGraphMode.FULL: if cudagraph_runtime_mode is CUDAGraphMode.FULL:
assert batch_descriptor is not None assert batch_descriptor is not None
if batch_descriptor.num_tokens in self.cudagraphs: if batch_descriptor.num_tokens in self.cudagraphs: