diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index a8ac8eb576da2..6d27d10f687ab 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -73,45 +73,36 @@ class TestModel(torch.nn.Module): ] with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear_1 = TestFP8Layer( - self.activation_quant_key, - self.weight_quant_key, - self.w[0], - self.wscale[0], - self.scale[0], - ) - self.fp8_linear_2 = TestFP8Layer( - self.activation_quant_key, - self.weight_quant_key, - self.w[1], - self.wscale[1], - self.scale[1], - ) - self.fp8_linear_3 = TestFP8Layer( - self.activation_quant_key, - self.weight_quant_key, - self.w[2], - self.wscale[2], - self.scale[2], - ) + self.fp8_linear_layers = [ + TestFP8Layer( + self.activation_quant_key, + self.weight_quant_key, + self.w[i], + self.wscale[i], + input_scale=self.scale[i], + ) + for i in range(3) + ] self.enable_rms_norm_custom_op = self.norm[0].enabled() - self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear_layers[ + 0 + ].is_quant_fp8_enabled() def forward(self, x): # avoid having graph input be an arg to a pattern directly x = resid = torch.relu(x) y = self.norm[0](x) - x2 = self.fp8_linear_1(y) + x2 = self.fp8_linear_layers[0](y) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = self.fp8_linear_2(y2) + x3 = self.fp8_linear_layers[1](y2) y3, resid = self.norm[2](x3, resid) # use resid here - x4 = self.fp8_linear_3(y3) + x4 = self.fp8_linear_layers[2](y3) y4, resid = self.norm[3](x4, resid) # use resid here return y4 diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index bda2620d3e2fe..a539f4a160384 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -90,29 +90,16 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): for _ in range(3) ] - self.fp8_linear_1 = TestFP8Layer( - self.quant_key, - self.quant_key, - self.weight[0], - self.wscale[0], - input_scale=self.input_scale[0], - ) - - self.fp8_linear_2 = TestFP8Layer( - self.quant_key, - self.quant_key, - self.weight[1], - self.wscale[1], - input_scale=self.input_scale[1], - ) - - self.fp8_linear_3 = TestFP8Layer( - self.quant_key, - self.quant_key, - self.weight[2], - self.wscale[2], - input_scale=self.input_scale[2], - ) + self.fp8_linear_layers = [ + TestFP8Layer( + self.quant_key, + self.quant_key, + self.weight[i], + self.wscale[i], + input_scale=self.input_scale[i], + ) + for i in range(3) + ] def forward(self, hidden_states): # avoid having graph input be an arg to a pattern directly @@ -120,17 +107,17 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): x = resid = tensor_model_parallel_all_reduce(z) y = self.norm[0](x) - z2 = self.fp8_linear_1(y) + z2 = self.fp8_linear_layers[0](y) x2 = tensor_model_parallel_all_reduce(z2) y2, resid = self.norm[1](x2, resid) - z3 = self.fp8_linear_2(y2) + z3 = self.fp8_linear_layers[1](y2) x3 = tensor_model_parallel_all_reduce(z3) y3, resid = self.norm[2](x3, resid) # use resid here - z4 = self.fp8_linear_3(y3) + z4 = self.fp8_linear_layers[2](y3) x4 = tensor_model_parallel_all_reduce(z4) y4, resid = self.norm[3](x4, resid) # use resid here @@ -143,7 +130,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): return [ torch.ops.vllm.all_reduce.default, torch.ops._C.static_scaled_fp8_quant.default - if self.fp8_linear.is_quant_fp8_enabled() + if self.fp8_linear_layers[0].is_quant_fp8_enabled() else torch.ops.aten.reciprocal.default, ]