mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
move test case to util
This commit is contained in:
parent
c1a8f0d81e
commit
7a32847761
3 changed files with 51 additions and 49 deletions
|
|
@ -210,7 +210,7 @@ static bool IsQuantizedIOSupported(const InitializedTensorSet& initializers, con
|
|||
const auto& op_type = node_unit.OpType();
|
||||
auto quant_op_type = GetQuantizedOpType(node_unit);
|
||||
|
||||
ORT_ENFORCE(quant_op_type != QuantizedOpType::QLinearMatMul, "[", op_type, "] is not a quantized op");
|
||||
ORT_ENFORCE(quant_op_type != QuantizedOpType::Unknown, "[", op_type, "] is not a quantized op");
|
||||
|
||||
bool is_quant_conv = IsQuantizedConv(quant_op_type);
|
||||
bool is_quant_matmul = (quant_op_type == QuantizedOpType::QLinearMatMul);
|
||||
|
|
|
|||
|
|
@ -83,5 +83,53 @@ GetQDQTestCaseFn BuildQDQResizeTestCase(const std::vector<int64_t>& input_shape,
|
|||
const std::string& mode = "nearest",
|
||||
const std::string& coordinate_transformation_mode = "half_pixel");
|
||||
|
||||
template <typename Input1Type, typename Input2Type, typename OutputType>
|
||||
GetQDQTestCaseFn BuildBinaryOpTestCase(const std::vector<int64_t>& input_shape,
|
||||
const std::string& op_type) {
|
||||
return [input_shape, op_type](ModelTestBuilder& builder) {
|
||||
auto* input1_arg = builder.MakeInput<float>(input_shape, -1.f, 1.f);
|
||||
auto* input2_arg = builder.MakeInput<float>(input_shape, -1.f, 1.f);
|
||||
auto* output_arg = builder.MakeOutput();
|
||||
|
||||
// add QDQ 1
|
||||
auto* q1_output = builder.MakeIntermediate();
|
||||
auto* dq1_output = builder.MakeIntermediate();
|
||||
builder.AddQuantizeLinearNode<Input1Type>(input1_arg,
|
||||
.004f,
|
||||
std::numeric_limits<Input1Type>::max() / 2,
|
||||
q1_output);
|
||||
builder.AddDequantizeLinearNode<Input1Type>(q1_output,
|
||||
.0039f,
|
||||
std::numeric_limits<Input1Type>::max() / 2,
|
||||
dq1_output);
|
||||
|
||||
// add QDQ 2
|
||||
auto* q2_output = builder.MakeIntermediate();
|
||||
auto* dq2_output = builder.MakeIntermediate();
|
||||
builder.AddQuantizeLinearNode<Input2Type>(input2_arg,
|
||||
.004f,
|
||||
std::numeric_limits<Input2Type>::max() / 2,
|
||||
q2_output);
|
||||
builder.AddDequantizeLinearNode<Input2Type>(q2_output,
|
||||
.0039f,
|
||||
std::numeric_limits<Input2Type>::max() / 2,
|
||||
dq2_output);
|
||||
|
||||
// add binary operator
|
||||
auto* binary_op_output = builder.MakeIntermediate();
|
||||
builder.AddNode(op_type, {dq1_output, dq2_output}, {binary_op_output});
|
||||
|
||||
// add QDQ output
|
||||
auto* q3_output = builder.MakeIntermediate();
|
||||
builder.AddQuantizeLinearNode<OutputType>(binary_op_output,
|
||||
.0038f,
|
||||
std::numeric_limits<OutputType>::max() / 2,
|
||||
q3_output);
|
||||
builder.AddDequantizeLinearNode<OutputType>(q3_output,
|
||||
.0039f,
|
||||
std::numeric_limits<OutputType>::max() / 2,
|
||||
output_arg);
|
||||
};
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -290,51 +290,6 @@ TEST(QDQTransformerTests, AveragePool_U8S8) {
|
|||
template <typename Input1Type, typename Input2Type, typename OutputType>
|
||||
void QDQTransformerBinaryOpTests(const std::string& op_type) {
|
||||
auto test_case = [&](const std::vector<int64_t>& input_shape) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input1_arg = builder.MakeInput<float>(input_shape, -1.f, 1.f);
|
||||
auto* input2_arg = builder.MakeInput<float>(input_shape, -1.f, 1.f);
|
||||
auto* output_arg = builder.MakeOutput();
|
||||
|
||||
// add QDQ 1
|
||||
auto* q1_output = builder.MakeIntermediate();
|
||||
auto* dq1_output = builder.MakeIntermediate();
|
||||
builder.AddQuantizeLinearNode<Input1Type>(input1_arg,
|
||||
.004f,
|
||||
std::numeric_limits<Input1Type>::max() / 2,
|
||||
q1_output);
|
||||
builder.AddDequantizeLinearNode<Input1Type>(q1_output,
|
||||
.0039f,
|
||||
std::numeric_limits<Input1Type>::max() / 2,
|
||||
dq1_output);
|
||||
|
||||
// add QDQ 2
|
||||
auto* q2_output = builder.MakeIntermediate();
|
||||
auto* dq2_output = builder.MakeIntermediate();
|
||||
builder.AddQuantizeLinearNode<Input2Type>(input2_arg,
|
||||
.004f,
|
||||
std::numeric_limits<Input2Type>::max() / 2,
|
||||
q2_output);
|
||||
builder.AddDequantizeLinearNode<Input2Type>(q2_output,
|
||||
.0039f,
|
||||
std::numeric_limits<Input2Type>::max() / 2,
|
||||
dq2_output);
|
||||
|
||||
// add binary operator
|
||||
auto* binary_op_output = builder.MakeIntermediate();
|
||||
builder.AddNode(op_type, {dq1_output, dq2_output}, {binary_op_output});
|
||||
|
||||
// add QDQ output
|
||||
auto* q3_output = builder.MakeIntermediate();
|
||||
builder.AddQuantizeLinearNode<OutputType>(binary_op_output,
|
||||
.0038f,
|
||||
std::numeric_limits<OutputType>::max() / 2,
|
||||
q3_output);
|
||||
builder.AddDequantizeLinearNode<OutputType>(q3_output,
|
||||
.0039f,
|
||||
std::numeric_limits<OutputType>::max() / 2,
|
||||
output_arg);
|
||||
};
|
||||
|
||||
auto check_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
if (std::is_same<Input1Type, Input2Type>::value &&
|
||||
|
|
@ -351,7 +306,7 @@ void QDQTransformerBinaryOpTests(const std::string& op_type) {
|
|||
}
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case,
|
||||
TransformerTester(BuildBinaryOpTestCase<Input1Type, Input2Type, OutputType>(input_shape, op_type),
|
||||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2,
|
||||
|
|
@ -614,8 +569,7 @@ void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one
|
|||
|
||||
auto check_binary_op_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
if ((!has_output_q || std::is_same_v<Input1Type, OutputType>)&&
|
||||
(!has_bias || (std::is_same_v<BiasType, int32_t> && !beta_not_one)) &&
|
||||
if ((!has_output_q || std::is_same_v<Input1Type, OutputType>)&&(!has_bias || (std::is_same_v<BiasType, int32_t> && !beta_not_one)) &&
|
||||
(std::is_same_v<Input1Type, uint8_t> || std::is_same_v<Input2Type, int8_t>)) {
|
||||
EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1);
|
||||
EXPECT_EQ(op_to_count["Gemm"], 0);
|
||||
|
|
|
|||
Loading…
Reference in a new issue