[Chore] code lint

This commit is contained in:
i-yuanyukun 2025-12-18 15:56:43 +08:00
parent cd16bcff1e
commit f74bb82909
3 changed files with 16 additions and 7 deletions

View File

@ -130,8 +130,10 @@ class GPUFFNModelRunner(LoRAModelRunnerMixin):
try:
hidden_states, recv_metadata = self.connector.recv_attn_output()
if hasattr(self.connector, 'dp_metadata_list'):
dp_metadata = self.connector.dp_metadata_list.get(recv_metadata.stage_idx, None)
if hasattr(self.connector, "dp_metadata_list"):
dp_metadata = self.connector.dp_metadata_list.get(
recv_metadata.stage_idx, None
)
else:
dp_metadata = None
current_layer_idx = recv_metadata.layer_idx

View File

@ -3192,7 +3192,9 @@ class GPUModelRunner(
# Mark KV scales as calculated after the first forward pass
self.calculate_kv_scales = False
afd_metadata = self._build_afd_metadata(ubatch_slices_padded, num_tokens_unpadded)
afd_metadata = self._build_afd_metadata(
ubatch_slices_padded, num_tokens_unpadded
)
self.profiler.step()
# Run the model.
@ -4326,7 +4328,9 @@ class GPUModelRunner(
if num_tokens_across_dp is not None:
num_tokens_across_dp[:] = num_tokens_padded
afd_metadata = self._build_afd_metadata(ubatch_slices_padded, num_tokens_unpadded)
afd_metadata = self._build_afd_metadata(
ubatch_slices_padded, num_tokens_unpadded
)
with (
self.maybe_randomize_inputs(input_ids, inputs_embeds),

View File

@ -405,9 +405,12 @@ class UBatchWrapper:
afd_metadata.input_ids_list.append(sliced_input_ids)
afd_metadata.positions_list.append(sliced_positions)
afd_metadata.inputs_embeds_list.append(sliced_inputs_embeds)
afd_metadata.intermediate_tensors_list.append(sliced_intermediate_tensors)
afd_metadata.intermediate_tensors_list.append(
sliced_intermediate_tensors
)
afd_metadata.attn_metadata_list.append(
attn_metadata[i] if attn_metadata is not None else None)
attn_metadata[i] if attn_metadata is not None else None
)
afd_metadata.dp_metadata_list.append(ubatch_dp_metadata)
return afd_metadata
@ -481,7 +484,7 @@ class UBatchWrapper:
# num_tokens, we don't have a non-ubatched one. Without this
# check, the cudagraph wrapper will try to capture a cudagraph
# for this shape during a normal run.
if cudagraph_runtime_mode is CUDAGraphMode.FULL:
assert batch_descriptor is not None
if batch_descriptor.num_tokens in self.cudagraphs: