From 9e661b64ae59c976d72650faebec38bd3d0fdde9 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Wed, 1 Sep 2021 15:15:37 -0700 Subject: [PATCH] Fix cast propagation to not change casts from bool type. (#8925) * Added new models to test bool->float and bool->float16 casts * Fixed bool casts. Added new test cases. --- .../core/optimizer/propagate_cast_ops.cc | 14 +++++--- .../test/optimizer/graph_transform_test.cc | 3 ++ .../propagate_cast/gen_propagate_cast.py | 30 ++++++++++++++++++ .../negative_test_case_bool_fp16_cast.onnx | Bin 0 -> 295 bytes .../negative_test_case_bool_fp_cast.onnx | Bin 0 -> 333 bytes 5 files changed, 42 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/negative_test_case_bool_fp16_cast.onnx create mode 100644 onnxruntime/test/testdata/transform/propagate_cast/negative_test_case_bool_fp_cast.onnx diff --git a/onnxruntime/core/optimizer/propagate_cast_ops.cc b/onnxruntime/core/optimizer/propagate_cast_ops.cc index 2a41c6315f..5bbfa1585b 100644 --- a/onnxruntime/core/optimizer/propagate_cast_ops.cc +++ b/onnxruntime/core/optimizer/propagate_cast_ops.cc @@ -127,12 +127,16 @@ static bool IsFP16Allow(const std::string& op_type, size_t level) { return fp16_allow; } -// Check whether the node is cast operation to the specified data type +// Check whether the node is a cast operation from float16/float to the specified data_type. static bool IsCastTo(const Node* node, TensorProto_DataType data_type) { if (node->OpType() == "Cast") { const NodeAttributes& attributes = node->GetAttributes(); ORT_ENFORCE(attributes.find("to") != attributes.end()); - return attributes.at("to").i() == static_cast(data_type); + const NodeArg* input = node->InputDefs()[0]; + auto input_data_type = static_cast(input->TypeAsProto()->tensor_type().elem_type()); + // Allow cast nodes with same input and output type float/float16 to eliminate such casts. + return (input_data_type == TensorProto::FLOAT16 || input_data_type == TensorProto::FLOAT) && + attributes.at("to").i() == static_cast(data_type); } return false; } @@ -1132,9 +1136,9 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node, * V V */ static bool RemoveInputOutputUpDownCasts(Graph& graph, Node* node, - std::deque& removed_nodes, - size_t level, - const logging::Logger& logger) { + std::deque& removed_nodes, + size_t level, + const logging::Logger& logger) { bool modified = false; bool has_float_outputs = false; bool has_float_inputs = false; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index a62b754b4d..839fe3adbb 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4170,6 +4170,9 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) { std::vector allow_matmul_transpose = {"MatMul", "Transpose"}; std::vector allow_matmul_transpose_add = {"Add", "MatMul", "Transpose"}; const std::vector test_cases = { + // Negative testcase to test that the transformer will not move cast bool to float/float16. + {MODEL_FOLDER "propagate_cast/negative_test_case_bool_fp_cast.onnx", {{insertAndReduce0, 2}, {floodFill1, 2}, {floodFill2, 2}}, {"Add"}}, + {MODEL_FOLDER "propagate_cast/negative_test_case_bool_fp16_cast.onnx", {{insertAndReduce0, 2}, {floodFill1, 2}, {floodFill2, 2}}, {"Add"}}, // Test fusing back to back casts functionality {MODEL_FOLDER "propagate_cast/fuse_back2back_casts_float16_float16.onnx", {{insertAndReduce0, 1}, {floodFill1, 1}, {floodFill2, 1}}}, {MODEL_FOLDER "propagate_cast/fuse_back2back_casts_float16_float.onnx", {{insertAndReduce0, 2}, {floodFill1, 2}, {floodFill2, 2}}}, diff --git a/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py b/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py index 53195a0e54..5160ac9900 100644 --- a/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py +++ b/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py @@ -351,6 +351,32 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second model_path += "_add_products" if add_products else "" save(model_path, nodes, inputs, outputs, []) +def gen_bool_to_float16_cast(model_path): + X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [1, 1]) + X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [1, 1]) + X3 = helper.make_tensor_value_info('x3', TensorProto.FLOAT, [1, 1]) + Y = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 1]) + + less1 = helper.make_node('Less', ['x1', 'x2'], ['less1'], name='less1') + cast1 = helper.make_node('Cast', ['less1'], ['cast1'], name='cast1', to=TensorProto.FLOAT16) + cast2 = helper.make_node('Cast', ['x3'], ['cast2'], name='cast2', to=TensorProto.FLOAT16) + add1 = helper.make_node('Add', ['cast1', 'cast2'], ['output']) + + save(model_path, [less1, cast1, cast2, add1], [X1, X2, X3], [Y], []) + +def gen_bool_to_float_cast(model_path): + X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [1, 1]) + X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [1, 1]) + X3 = helper.make_tensor_value_info('x3', TensorProto.FLOAT16, [1, 1]) + Y = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 1]) + + less1 = helper.make_node('Less', ['x1', 'x2'], ['less1'], name='less1') + cast1 = helper.make_node('Cast', ['less1'], ['cast1'], name='cast1', to=TensorProto.FLOAT) + cast2 = helper.make_node('Cast', ['x3'], ['cast2'], name='cast2', to=TensorProto.FLOAT) + add1 = helper.make_node('Add', ['cast1', 'cast2'], ['add1']) + cast3 = helper.make_node('Cast', ['add1'], ['output'], name='cast3', to=TensorProto.FLOAT16) + + save(model_path, [less1, cast1, cast2, cast3, add1], [X1, X2, X3], [Y], []) for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2) in list(itertools.product([False, True], repeat=7)): if not insert_add and (cast_sum or cast_input2): @@ -369,3 +395,7 @@ for (transpose, transpose_before_cast, second_matmul, add_products, cast_inputs) continue gen_matmul_two_products("matmul_two_outputs", transpose, transpose_before_cast, second_matmul, cast_inputs) + + +gen_bool_to_float16_cast("negative_test_case_bool_fp16_cast") +gen_bool_to_float_cast("negative_test_case_bool_fp_cast") \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/propagate_cast/negative_test_case_bool_fp16_cast.onnx b/onnxruntime/test/testdata/transform/propagate_cast/negative_test_case_bool_fp16_cast.onnx new file mode 100644 index 0000000000000000000000000000000000000000..635670035df5b19bb8fe5867801f40d945a4efbd GIT binary patch literal 295 zcmZurJqyAx5H+#Yc&OCiAnH`XK~OMmZtCjb;N(_o6u~y7xk7)A|I*bo5gc?l9`EDc z9b-IDX#q0O3{0cE5b ntBSws%Btj>JrCe>!#ePFwjR`i_O=I|SS5?(9_LD{D+c5ZYVZL0{jDa{r9GJ+q~v=PKV!|{0U ze!0t0%50Hkx1!AOk|uLZBrfj+CO~e!D$ZKhL`r|G#WlCefd|^MRuoE9&01V@qui_u z1jW#J7zCzjLZBh8A^r<_V1{w*v$iNPFHv*B2M5)#5$H8qIICZ|%~c#Tr&_!#L>`_) zO!n%A!=O5V^}(VYMuAYHsygL=)1Xe_%^rvF`q6i1Z9>Dg)(-R{kuK6pR0t_fXd6#O CbVro{ literal 0 HcmV?d00001