Detect whether the node has been inserted cast nodes twice (#4811)

* check whether the node has been casted before

* check casted node logically

* better naming convention

* nit: extra space

* change to skip for Cast Node

* remove hasNodeBeenCast

* Add a Unit test

* Add test onnx file

* nit: naming convention and comments

* check CI: try to remove test

* move test to existing test file
This commit is contained in:
Chun-Wei Chen 2020-08-24 07:25:41 -07:00 committed by GitHub
parent 47c4144bd1
commit 744809ceae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 29 deletions

View file

@ -8,44 +8,30 @@
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {
class IdGenerator {
public:
int Next() {
return id++;
}
private:
int id = 0;
};
bool InsertCastTransformer::NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const {
//If the node's input is float16 and currently the node is not assigned to any XP.
//we need insert a cast to float, and put the node on CPU for default behavior.
//TODO: a better check is to check does the CPU kernel with float exist or not.
// If the node's input is float16 and currently the node is not assigned to any XP.
// we need insert a cast to float, and put the node on CPU for default behavior.
// TODO: a better check is to check does the CPU kernel with float exist or not.
return input->Type() != nullptr &&
DataTypeImpl::TypeFromProto(*input->TypeAsProto()) == DataTypeImpl::GetTensorType<MLFloat16>() &&
node->GetExecutionProviderType().empty();
}
onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph,
IdGenerator& id_generator,
onnxruntime::NodeArg* old_arg,
TypeProto* new_type,
bool new_on_input,
int64_t to_type,
onnxruntime::ProviderType providerType) {
//insert cast op to cast input
int id = id_generator.Next();
// insert cast op to cast input
std::string node_name = graph.GenerateNodeName("Inserted_Cast");
char str[32];
snprintf(str, 32, "CastDef_%d", id);
auto* new_arg = &graph.GetOrCreateNodeArg(str, new_type);
auto* new_arg = &graph.GetOrCreateNodeArg(node_name, new_type);
std::vector<onnxruntime::NodeArg*> input_defs = {new_on_input ? new_arg : old_arg};
std::vector<onnxruntime::NodeArg*> output_defs = {new_on_input ? old_arg : new_arg};
auto& cast_node = graph.AddNode(str, "Cast", "cast node to cast from float16 to float32 on cpu", input_defs, output_defs);
auto& cast_node = graph.AddNode(node_name, "Cast", "cast node to cast from float16 to float32 on cpu", input_defs, output_defs);
cast_node.AddAttribute("to", to_type);
cast_node.SetExecutionProviderType(providerType);
return new_arg;
@ -84,7 +70,7 @@ Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph) {
}
for (auto& node : graph.Nodes()) {
if (IsSingleInputNodeFloat16Node(node)) {
if (node.OpType() != "Cast" && IsSingleInputNodeFloat16Node(node)) {
node.SetExecutionProviderType("");
}
}
@ -205,8 +191,8 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
TypeProto float_tensor_proto;
float_16_tensor_proto.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT16);
float_tensor_proto.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
IdGenerator id_generator;
std::map<onnxruntime::NodeArg*, onnxruntime::NodeArg*> input_def_updates;
for (onnxruntime::NodeIndex i : order) {
auto node = graph.GetNode(i);
if (!node)
@ -221,9 +207,8 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
if (input_def_updates.count(src_arg)) {
replacement_defs[src_arg] = input_def_updates[src_arg];
} else {
//insert cast op to cast input
// insert cast op to cast input
auto dst_arg = AddCastNode(graph,
id_generator,
src_arg,
&float_tensor_proto,
false,
@ -271,10 +256,9 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
if (output->Type() &&
DataTypeImpl::TypeFromProto(*output->TypeAsProto()) == DataTypeImpl::GetTensorType<MLFloat16>() &&
casted) {
//insert cast op to cast output back to float16
// insert cast op to cast output back to float16
auto dst_arg = output;
auto src_arg = AddCastNode(graph,
id_generator,
dst_arg,
&float_tensor_proto,
true,
@ -283,7 +267,7 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
replacement_defs[dst_arg] = src_arg;
}
}
node->ReplaceDefs(replacement_defs);
modified = modified || casted;

View file

@ -23,7 +23,6 @@ class InsertCastTransformer : public onnxruntime::GraphTransformer {
private:
Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const;
// Currently because we only have very few cpu kernels support float16, place those nodes on float16

View file

@ -174,5 +174,36 @@ TEST(TransformerTest, MultinomialWithFloat16Input) {
EXPECT_TRUE(status.IsOK()) << status;
}
// This test is to test insert_cast_transform the same graph twice
// insert_cast_transform needs to detect existing Cast Node
// Prevent inserting the same Cast node twice
TEST(TransformerTest, InsertCastNodeTwice) {
auto model_uri = MODEL_FOLDER ORT_TSTR("insert_cast_twice.onnx");
std::shared_ptr<Model> model;
auto status = Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(status.IsOK()) << status;
Graph& graph = model->MainGraph();
InsertCastTransformer transformer("Test");
// First insert
bool modified = false;
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(status.IsOK()) << status;
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_TRUE(modified) << "Transformer should have added some Cast nodes";
EXPECT_TRUE(op_to_count["Cast"] == 5) << "Insert 3 more Cast nodes.";
// Second insert
modified = false;
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(status.IsOK()) << status;
op_to_count = CountOpsInGraph(graph);
// Same graph without modification; The number of Cast node remains
EXPECT_TRUE(!modified) << "Transformer should not modify the modfied graph again";
EXPECT_TRUE(op_to_count["Cast"] == 5) << "Remain the same number of Cast node";
}
} // namespace test
} // namespace onnxruntime

Binary file not shown.