Fix cast propagation to not change casts from bool type. (#8925)

* Added new models to test bool->float and bool->float16 casts

* Fixed bool casts. Added new test cases.
This commit is contained in:
satyajandhyala 2021-09-01 15:15:37 -07:00 committed by GitHub
parent 6299a60bf8
commit 9e661b64ae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 42 additions and 5 deletions

View file

@ -127,12 +127,16 @@ static bool IsFP16Allow(const std::string& op_type, size_t level) {
return fp16_allow;
}
// Check whether the node is cast operation to the specified data type
// Check whether the node is a cast operation from float16/float to the specified data_type.
static bool IsCastTo(const Node* node, TensorProto_DataType data_type) {
if (node->OpType() == "Cast") {
const NodeAttributes& attributes = node->GetAttributes();
ORT_ENFORCE(attributes.find("to") != attributes.end());
return attributes.at("to").i() == static_cast<int64_t>(data_type);
const NodeArg* input = node->InputDefs()[0];
auto input_data_type = static_cast<TensorProto_DataType>(input->TypeAsProto()->tensor_type().elem_type());
// Allow cast nodes with same input and output type float/float16 to eliminate such casts.
return (input_data_type == TensorProto::FLOAT16 || input_data_type == TensorProto::FLOAT) &&
attributes.at("to").i() == static_cast<int64_t>(data_type);
}
return false;
}
@ -1132,9 +1136,9 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node,
* V V
*/
static bool RemoveInputOutputUpDownCasts(Graph& graph, Node* node,
std::deque<NodeIndex>& removed_nodes,
size_t level,
const logging::Logger& logger) {
std::deque<NodeIndex>& removed_nodes,
size_t level,
const logging::Logger& logger) {
bool modified = false;
bool has_float_outputs = false;
bool has_float_inputs = false;

View file

@ -4170,6 +4170,9 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) {
std::vector<std::string> allow_matmul_transpose = {"MatMul", "Transpose"};
std::vector<std::string> allow_matmul_transpose_add = {"Add", "MatMul", "Transpose"};
const std::vector<PropagateCastOpsTestSpecs> test_cases = {
// Negative testcase to test that the transformer will not move cast bool to float/float16.
{MODEL_FOLDER "propagate_cast/negative_test_case_bool_fp_cast.onnx", {{insertAndReduce0, 2}, {floodFill1, 2}, {floodFill2, 2}}, {"Add"}},
{MODEL_FOLDER "propagate_cast/negative_test_case_bool_fp16_cast.onnx", {{insertAndReduce0, 2}, {floodFill1, 2}, {floodFill2, 2}}, {"Add"}},
// Test fusing back to back casts functionality
{MODEL_FOLDER "propagate_cast/fuse_back2back_casts_float16_float16.onnx", {{insertAndReduce0, 1}, {floodFill1, 1}, {floodFill2, 1}}},
{MODEL_FOLDER "propagate_cast/fuse_back2back_casts_float16_float.onnx", {{insertAndReduce0, 2}, {floodFill1, 2}, {floodFill2, 2}}},

View file

@ -351,6 +351,32 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second
model_path += "_add_products" if add_products else ""
save(model_path, nodes, inputs, outputs, [])
def gen_bool_to_float16_cast(model_path):
X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [1, 1])
X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [1, 1])
X3 = helper.make_tensor_value_info('x3', TensorProto.FLOAT, [1, 1])
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 1])
less1 = helper.make_node('Less', ['x1', 'x2'], ['less1'], name='less1')
cast1 = helper.make_node('Cast', ['less1'], ['cast1'], name='cast1', to=TensorProto.FLOAT16)
cast2 = helper.make_node('Cast', ['x3'], ['cast2'], name='cast2', to=TensorProto.FLOAT16)
add1 = helper.make_node('Add', ['cast1', 'cast2'], ['output'])
save(model_path, [less1, cast1, cast2, add1], [X1, X2, X3], [Y], [])
def gen_bool_to_float_cast(model_path):
X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [1, 1])
X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [1, 1])
X3 = helper.make_tensor_value_info('x3', TensorProto.FLOAT16, [1, 1])
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 1])
less1 = helper.make_node('Less', ['x1', 'x2'], ['less1'], name='less1')
cast1 = helper.make_node('Cast', ['less1'], ['cast1'], name='cast1', to=TensorProto.FLOAT)
cast2 = helper.make_node('Cast', ['x3'], ['cast2'], name='cast2', to=TensorProto.FLOAT)
add1 = helper.make_node('Add', ['cast1', 'cast2'], ['add1'])
cast3 = helper.make_node('Cast', ['add1'], ['output'], name='cast3', to=TensorProto.FLOAT16)
save(model_path, [less1, cast1, cast2, cast3, add1], [X1, X2, X3], [Y], [])
for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2) in list(itertools.product([False, True], repeat=7)):
if not insert_add and (cast_sum or cast_input2):
@ -369,3 +395,7 @@ for (transpose, transpose_before_cast, second_matmul, add_products, cast_inputs)
continue
gen_matmul_two_products("matmul_two_outputs", transpose,
transpose_before_cast, second_matmul, cast_inputs)
gen_bool_to_float16_cast("negative_test_case_bool_fp16_cast")
gen_bool_to_float_cast("negative_test_case_bool_fp_cast")