mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Fix cast propagation to not change casts from bool type. (#8925)
* Added new models to test bool->float and bool->float16 casts * Fixed bool casts. Added new test cases.
This commit is contained in:
parent
6299a60bf8
commit
9e661b64ae
5 changed files with 42 additions and 5 deletions
|
|
@ -127,12 +127,16 @@ static bool IsFP16Allow(const std::string& op_type, size_t level) {
|
|||
return fp16_allow;
|
||||
}
|
||||
|
||||
// Check whether the node is cast operation to the specified data type
|
||||
// Check whether the node is a cast operation from float16/float to the specified data_type.
|
||||
static bool IsCastTo(const Node* node, TensorProto_DataType data_type) {
|
||||
if (node->OpType() == "Cast") {
|
||||
const NodeAttributes& attributes = node->GetAttributes();
|
||||
ORT_ENFORCE(attributes.find("to") != attributes.end());
|
||||
return attributes.at("to").i() == static_cast<int64_t>(data_type);
|
||||
const NodeArg* input = node->InputDefs()[0];
|
||||
auto input_data_type = static_cast<TensorProto_DataType>(input->TypeAsProto()->tensor_type().elem_type());
|
||||
// Allow cast nodes with same input and output type float/float16 to eliminate such casts.
|
||||
return (input_data_type == TensorProto::FLOAT16 || input_data_type == TensorProto::FLOAT) &&
|
||||
attributes.at("to").i() == static_cast<int64_t>(data_type);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
@ -1132,9 +1136,9 @@ static bool PropagateFP16CastsFromOutputsToInputs(Graph& graph, Node* node,
|
|||
* V V
|
||||
*/
|
||||
static bool RemoveInputOutputUpDownCasts(Graph& graph, Node* node,
|
||||
std::deque<NodeIndex>& removed_nodes,
|
||||
size_t level,
|
||||
const logging::Logger& logger) {
|
||||
std::deque<NodeIndex>& removed_nodes,
|
||||
size_t level,
|
||||
const logging::Logger& logger) {
|
||||
bool modified = false;
|
||||
bool has_float_outputs = false;
|
||||
bool has_float_inputs = false;
|
||||
|
|
|
|||
|
|
@ -4170,6 +4170,9 @@ TEST_F(GraphTransformationTests, PropagateCastOpsTests) {
|
|||
std::vector<std::string> allow_matmul_transpose = {"MatMul", "Transpose"};
|
||||
std::vector<std::string> allow_matmul_transpose_add = {"Add", "MatMul", "Transpose"};
|
||||
const std::vector<PropagateCastOpsTestSpecs> test_cases = {
|
||||
// Negative testcase to test that the transformer will not move cast bool to float/float16.
|
||||
{MODEL_FOLDER "propagate_cast/negative_test_case_bool_fp_cast.onnx", {{insertAndReduce0, 2}, {floodFill1, 2}, {floodFill2, 2}}, {"Add"}},
|
||||
{MODEL_FOLDER "propagate_cast/negative_test_case_bool_fp16_cast.onnx", {{insertAndReduce0, 2}, {floodFill1, 2}, {floodFill2, 2}}, {"Add"}},
|
||||
// Test fusing back to back casts functionality
|
||||
{MODEL_FOLDER "propagate_cast/fuse_back2back_casts_float16_float16.onnx", {{insertAndReduce0, 1}, {floodFill1, 1}, {floodFill2, 1}}},
|
||||
{MODEL_FOLDER "propagate_cast/fuse_back2back_casts_float16_float.onnx", {{insertAndReduce0, 2}, {floodFill1, 2}, {floodFill2, 2}}},
|
||||
|
|
|
|||
|
|
@ -351,6 +351,32 @@ def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second
|
|||
model_path += "_add_products" if add_products else ""
|
||||
save(model_path, nodes, inputs, outputs, [])
|
||||
|
||||
def gen_bool_to_float16_cast(model_path):
|
||||
X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [1, 1])
|
||||
X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [1, 1])
|
||||
X3 = helper.make_tensor_value_info('x3', TensorProto.FLOAT, [1, 1])
|
||||
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 1])
|
||||
|
||||
less1 = helper.make_node('Less', ['x1', 'x2'], ['less1'], name='less1')
|
||||
cast1 = helper.make_node('Cast', ['less1'], ['cast1'], name='cast1', to=TensorProto.FLOAT16)
|
||||
cast2 = helper.make_node('Cast', ['x3'], ['cast2'], name='cast2', to=TensorProto.FLOAT16)
|
||||
add1 = helper.make_node('Add', ['cast1', 'cast2'], ['output'])
|
||||
|
||||
save(model_path, [less1, cast1, cast2, add1], [X1, X2, X3], [Y], [])
|
||||
|
||||
def gen_bool_to_float_cast(model_path):
|
||||
X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [1, 1])
|
||||
X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [1, 1])
|
||||
X3 = helper.make_tensor_value_info('x3', TensorProto.FLOAT16, [1, 1])
|
||||
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 1])
|
||||
|
||||
less1 = helper.make_node('Less', ['x1', 'x2'], ['less1'], name='less1')
|
||||
cast1 = helper.make_node('Cast', ['less1'], ['cast1'], name='cast1', to=TensorProto.FLOAT)
|
||||
cast2 = helper.make_node('Cast', ['x3'], ['cast2'], name='cast2', to=TensorProto.FLOAT)
|
||||
add1 = helper.make_node('Add', ['cast1', 'cast2'], ['add1'])
|
||||
cast3 = helper.make_node('Cast', ['add1'], ['output'], name='cast3', to=TensorProto.FLOAT16)
|
||||
|
||||
save(model_path, [less1, cast1, cast2, cast3, add1], [X1, X2, X3], [Y], [])
|
||||
|
||||
for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2) in list(itertools.product([False, True], repeat=7)):
|
||||
if not insert_add and (cast_sum or cast_input2):
|
||||
|
|
@ -369,3 +395,7 @@ for (transpose, transpose_before_cast, second_matmul, add_products, cast_inputs)
|
|||
continue
|
||||
gen_matmul_two_products("matmul_two_outputs", transpose,
|
||||
transpose_before_cast, second_matmul, cast_inputs)
|
||||
|
||||
|
||||
gen_bool_to_float16_cast("negative_test_case_bool_fp16_cast")
|
||||
gen_bool_to_float_cast("negative_test_case_bool_fp_cast")
|
||||
BIN
onnxruntime/test/testdata/transform/propagate_cast/negative_test_case_bool_fp16_cast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/propagate_cast/negative_test_case_bool_fp16_cast.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/propagate_cast/negative_test_case_bool_fp_cast.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/propagate_cast/negative_test_case_bool_fp_cast.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue