use for loops for fp8 linear layers init in tests

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-07 07:39:08 +00:00
parent 56a05cd818
commit 9ff9b44e0d
2 changed files with 30 additions and 52 deletions

View File

@ -73,45 +73,36 @@ class TestModel(torch.nn.Module):
]
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear_1 = TestFP8Layer(
self.activation_quant_key,
self.weight_quant_key,
self.w[0],
self.wscale[0],
self.scale[0],
)
self.fp8_linear_2 = TestFP8Layer(
self.activation_quant_key,
self.weight_quant_key,
self.w[1],
self.wscale[1],
self.scale[1],
)
self.fp8_linear_3 = TestFP8Layer(
self.activation_quant_key,
self.weight_quant_key,
self.w[2],
self.wscale[2],
self.scale[2],
)
self.fp8_linear_layers = [
TestFP8Layer(
self.activation_quant_key,
self.weight_quant_key,
self.w[i],
self.wscale[i],
input_scale=self.scale[i],
)
for i in range(3)
]
self.enable_rms_norm_custom_op = self.norm[0].enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
0
].is_quant_fp8_enabled()
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
y = self.norm[0](x)
x2 = self.fp8_linear_1(y)
x2 = self.fp8_linear_layers[0](y)
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)
x3 = self.fp8_linear_2(y2)
x3 = self.fp8_linear_layers[1](y2)
y3, resid = self.norm[2](x3, resid) # use resid here
x4 = self.fp8_linear_3(y3)
x4 = self.fp8_linear_layers[2](y3)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4

View File

@ -90,29 +90,16 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
for _ in range(3)
]
self.fp8_linear_1 = TestFP8Layer(
self.quant_key,
self.quant_key,
self.weight[0],
self.wscale[0],
input_scale=self.input_scale[0],
)
self.fp8_linear_2 = TestFP8Layer(
self.quant_key,
self.quant_key,
self.weight[1],
self.wscale[1],
input_scale=self.input_scale[1],
)
self.fp8_linear_3 = TestFP8Layer(
self.quant_key,
self.quant_key,
self.weight[2],
self.wscale[2],
input_scale=self.input_scale[2],
)
self.fp8_linear_layers = [
TestFP8Layer(
self.quant_key,
self.quant_key,
self.weight[i],
self.wscale[i],
input_scale=self.input_scale[i],
)
for i in range(3)
]
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
@ -120,17 +107,17 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
z2 = self.fp8_linear_1(y)
z2 = self.fp8_linear_layers[0](y)
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
z3 = self.fp8_linear_2(y2)
z3 = self.fp8_linear_layers[1](y2)
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
z4 = self.fp8_linear_3(y3)
z4 = self.fp8_linear_layers[2](y3)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
@ -143,7 +130,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default
if self.fp8_linear.is_quant_fp8_enabled()
if self.fp8_linear_layers[0].is_quant_fp8_enabled()
else torch.ops.aten.reciprocal.default,
]