diff --git a/onnxruntime/core/optimizer/propagate_cast_ops.cc b/onnxruntime/core/optimizer/propagate_cast_ops.cc index 2a41c6315f..5bbfa1585b 100644 --- a/onnxruntime/core/optimizer/propagate_cast_ops.cc +++ b/onnxruntime/core/optimizer/propagate_cast_ops.cc @@ -127,12 +127,16 @@ static bool IsFP16Allow(const std::string& op_type, size_t level) { return fp16_allow; } -// Check whether the node is cast operation to the specified data type +// Check whether the node is a cast operation from float16/float to the specified data_type. static bool IsCastTo(const Node* node, TensorProto_DataType data_type) { if (node->OpType() == "Cast") { const NodeAttributes& attributes = node->GetAttributes(); ORT_ENFORCE(attributes.find("to") != attributes.end()); - return attributes.at("to").i() == static_cast(data_type); + const NodeArg* input = node->InputDefs()[0]; + auto input_data_type = static_cast(input->TypeAsProto()->tensor_type().elem_type()); + // Allow cast nodes with same input and output type float/float16 to eliminate such casts. + return (input_data_type == TensorProto::FLOAT16 || input_data_type == TensorProto::FLOAT) && + attributes.at("to").i() == static_cast(data_type); } return false; } @@ -1132,9 +1136,9 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node, * V V */ static bool RemoveInputOutputUpDownCasts(Graph& graph, Node* node, - std::deque& removed_nodes, - size_t level, - const logging::Logger& logger) { + std::deque& removed_nodes, + size_t level, + const logging::Logger& logger) { bool modified = false; bool has_float_outputs = false; bool has_float_inputs = false; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index a62b754b4d..839fe3adbb 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4170,6 +4170,9 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) { std::vector allow_matmul_transpose = {"MatMul", "Transpose"}; std::vector allow_matmul_transpose_add = {"Add", "MatMul", "Transpose"}; const std::vector test_cases = { + // Negative testcase to test that the transformer will not move cast bool to float/float16. + {MODEL_FOLDER "propagate_cast/negative_test_case_bool_fp_cast.onnx", {{insertAndReduce0, 2}, {floodFill1, 2}, {floodFill2, 2}}, {"Add"}}, + {MODEL_FOLDER "propagate_cast/negative_test_case_bool_fp16_cast.onnx", {{insertAndReduce0, 2}, {floodFill1, 2}, {floodFill2, 2}}, {"Add"}}, // Test fusing back to back casts functionality {MODEL_FOLDER "propagate_cast/fuse_back2back_casts_float16_float16.onnx", {{insertAndReduce0, 1}, {floodFill1, 1}, {floodFill2, 1}}}, {MODEL_FOLDER "propagate_cast/fuse_back2back_casts_float16_float.onnx", {{insertAndReduce0, 2}, {floodFill1, 2}, {floodFill2, 2}}}, diff --git a/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py b/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py index 53195a0e54..5160ac9900 100644 --- a/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py +++ b/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py @@ -351,6 +351,32 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second model_path += "_add_products" if add_products else "" save(model_path, nodes, inputs, outputs, []) +def gen_bool_to_float16_cast(model_path): + X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [1, 1]) + X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [1, 1]) + X3 = helper.make_tensor_value_info('x3', TensorProto.FLOAT, [1, 1]) + Y = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 1]) + + less1 = helper.make_node('Less', ['x1', 'x2'], ['less1'], name='less1') + cast1 = helper.make_node('Cast', ['less1'], ['cast1'], name='cast1', to=TensorProto.FLOAT16) + cast2 = helper.make_node('Cast', ['x3'], ['cast2'], name='cast2', to=TensorProto.FLOAT16) + add1 = helper.make_node('Add', ['cast1', 'cast2'], ['output']) + + save(model_path, [less1, cast1, cast2, add1], [X1, X2, X3], [Y], []) + +def gen_bool_to_float_cast(model_path): + X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [1, 1]) + X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [1, 1]) + X3 = helper.make_tensor_value_info('x3', TensorProto.FLOAT16, [1, 1]) + Y = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 1]) + + less1 = helper.make_node('Less', ['x1', 'x2'], ['less1'], name='less1') + cast1 = helper.make_node('Cast', ['less1'], ['cast1'], name='cast1', to=TensorProto.FLOAT) + cast2 = helper.make_node('Cast', ['x3'], ['cast2'], name='cast2', to=TensorProto.FLOAT) + add1 = helper.make_node('Add', ['cast1', 'cast2'], ['add1']) + cast3 = helper.make_node('Cast', ['add1'], ['output'], name='cast3', to=TensorProto.FLOAT16) + + save(model_path, [less1, cast1, cast2, cast3, add1], [X1, X2, X3], [Y], []) for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2) in list(itertools.product([False, True], repeat=7)): if not insert_add and (cast_sum or cast_input2): @@ -369,3 +395,7 @@ for (transpose, transpose_before_cast, second_matmul, add_products, cast_inputs) continue gen_matmul_two_products("matmul_two_outputs", transpose, transpose_before_cast, second_matmul, cast_inputs) + + +gen_bool_to_float16_cast("negative_test_case_bool_fp16_cast") +gen_bool_to_float_cast("negative_test_case_bool_fp_cast") \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/propagate_cast/negative_test_case_bool_fp16_cast.onnx b/onnxruntime/test/testdata/transform/propagate_cast/negative_test_case_bool_fp16_cast.onnx new file mode 100644 index 0000000000..635670035d Binary files /dev/null and b/onnxruntime/test/testdata/transform/propagate_cast/negative_test_case_bool_fp16_cast.onnx differ diff --git a/onnxruntime/test/testdata/transform/propagate_cast/negative_test_case_bool_fp_cast.onnx b/onnxruntime/test/testdata/transform/propagate_cast/negative_test_case_bool_fp_cast.onnx new file mode 100644 index 0000000000..73504cb7f1 Binary files /dev/null and b/onnxruntime/test/testdata/transform/propagate_cast/negative_test_case_bool_fp_cast.onnx differ