mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 20:37:46 +08:00
use for loops for fp8 linear layers init in tests
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
56a05cd818
commit
9ff9b44e0d
@ -73,45 +73,36 @@ class TestModel(torch.nn.Module):
|
||||
]
|
||||
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
self.fp8_linear_1 = TestFP8Layer(
|
||||
self.activation_quant_key,
|
||||
self.weight_quant_key,
|
||||
self.w[0],
|
||||
self.wscale[0],
|
||||
self.scale[0],
|
||||
)
|
||||
self.fp8_linear_2 = TestFP8Layer(
|
||||
self.activation_quant_key,
|
||||
self.weight_quant_key,
|
||||
self.w[1],
|
||||
self.wscale[1],
|
||||
self.scale[1],
|
||||
)
|
||||
self.fp8_linear_3 = TestFP8Layer(
|
||||
self.activation_quant_key,
|
||||
self.weight_quant_key,
|
||||
self.w[2],
|
||||
self.wscale[2],
|
||||
self.scale[2],
|
||||
)
|
||||
self.fp8_linear_layers = [
|
||||
TestFP8Layer(
|
||||
self.activation_quant_key,
|
||||
self.weight_quant_key,
|
||||
self.w[i],
|
||||
self.wscale[i],
|
||||
input_scale=self.scale[i],
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled()
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
|
||||
0
|
||||
].is_quant_fp8_enabled()
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
x = resid = torch.relu(x)
|
||||
y = self.norm[0](x)
|
||||
|
||||
x2 = self.fp8_linear_1(y)
|
||||
x2 = self.fp8_linear_layers[0](y)
|
||||
# make sure resid is used for replacement to work
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
x3 = self.fp8_linear_2(y2)
|
||||
x3 = self.fp8_linear_layers[1](y2)
|
||||
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
x4 = self.fp8_linear_3(y3)
|
||||
x4 = self.fp8_linear_layers[2](y3)
|
||||
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
@ -90,29 +90,16 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.fp8_linear_1 = TestFP8Layer(
|
||||
self.quant_key,
|
||||
self.quant_key,
|
||||
self.weight[0],
|
||||
self.wscale[0],
|
||||
input_scale=self.input_scale[0],
|
||||
)
|
||||
|
||||
self.fp8_linear_2 = TestFP8Layer(
|
||||
self.quant_key,
|
||||
self.quant_key,
|
||||
self.weight[1],
|
||||
self.wscale[1],
|
||||
input_scale=self.input_scale[1],
|
||||
)
|
||||
|
||||
self.fp8_linear_3 = TestFP8Layer(
|
||||
self.quant_key,
|
||||
self.quant_key,
|
||||
self.weight[2],
|
||||
self.wscale[2],
|
||||
input_scale=self.input_scale[2],
|
||||
)
|
||||
self.fp8_linear_layers = [
|
||||
TestFP8Layer(
|
||||
self.quant_key,
|
||||
self.quant_key,
|
||||
self.weight[i],
|
||||
self.wscale[i],
|
||||
input_scale=self.input_scale[i],
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
@ -120,17 +107,17 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = self.fp8_linear_1(y)
|
||||
z2 = self.fp8_linear_layers[0](y)
|
||||
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = self.fp8_linear_2(y2)
|
||||
z3 = self.fp8_linear_layers[1](y2)
|
||||
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
z4 = self.fp8_linear_3(y3)
|
||||
z4 = self.fp8_linear_layers[2](y3)
|
||||
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
@ -143,7 +130,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
return [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
torch.ops._C.static_scaled_fp8_quant.default
|
||||
if self.fp8_linear.is_quant_fp8_enabled()
|
||||
if self.fp8_linear_layers[0].is_quant_fp8_enabled()
|
||||
else torch.ops.aten.reciprocal.default,
|
||||
]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user