mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Detect whether the node has been inserted cast nodes twice (#4811)
* check whether the node has been casted before * check casted node logically * better naming convention * nit: extra space * change to skip for Cast Node * remove hasNodeBeenCast * Add a Unit test * Add test onnx file * nit: naming convention and comments * check CI: try to remove test * move test to existing test file
This commit is contained in:
parent
47c4144bd1
commit
744809ceae
4 changed files with 43 additions and 29 deletions
|
|
@ -8,44 +8,30 @@
|
|||
using namespace ONNX_NAMESPACE;
|
||||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
class IdGenerator {
|
||||
public:
|
||||
int Next() {
|
||||
return id++;
|
||||
}
|
||||
|
||||
private:
|
||||
int id = 0;
|
||||
};
|
||||
|
||||
bool InsertCastTransformer::NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const {
|
||||
//If the node's input is float16 and currently the node is not assigned to any XP.
|
||||
//we need insert a cast to float, and put the node on CPU for default behavior.
|
||||
//TODO: a better check is to check does the CPU kernel with float exist or not.
|
||||
// If the node's input is float16 and currently the node is not assigned to any XP.
|
||||
// we need insert a cast to float, and put the node on CPU for default behavior.
|
||||
// TODO: a better check is to check does the CPU kernel with float exist or not.
|
||||
return input->Type() != nullptr &&
|
||||
DataTypeImpl::TypeFromProto(*input->TypeAsProto()) == DataTypeImpl::GetTensorType<MLFloat16>() &&
|
||||
node->GetExecutionProviderType().empty();
|
||||
}
|
||||
|
||||
onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph,
|
||||
IdGenerator& id_generator,
|
||||
onnxruntime::NodeArg* old_arg,
|
||||
TypeProto* new_type,
|
||||
bool new_on_input,
|
||||
int64_t to_type,
|
||||
onnxruntime::ProviderType providerType) {
|
||||
//insert cast op to cast input
|
||||
int id = id_generator.Next();
|
||||
// insert cast op to cast input
|
||||
std::string node_name = graph.GenerateNodeName("Inserted_Cast");
|
||||
|
||||
char str[32];
|
||||
snprintf(str, 32, "CastDef_%d", id);
|
||||
|
||||
auto* new_arg = &graph.GetOrCreateNodeArg(str, new_type);
|
||||
auto* new_arg = &graph.GetOrCreateNodeArg(node_name, new_type);
|
||||
|
||||
std::vector<onnxruntime::NodeArg*> input_defs = {new_on_input ? new_arg : old_arg};
|
||||
std::vector<onnxruntime::NodeArg*> output_defs = {new_on_input ? old_arg : new_arg};
|
||||
|
||||
auto& cast_node = graph.AddNode(str, "Cast", "cast node to cast from float16 to float32 on cpu", input_defs, output_defs);
|
||||
auto& cast_node = graph.AddNode(node_name, "Cast", "cast node to cast from float16 to float32 on cpu", input_defs, output_defs);
|
||||
cast_node.AddAttribute("to", to_type);
|
||||
cast_node.SetExecutionProviderType(providerType);
|
||||
return new_arg;
|
||||
|
|
@ -84,7 +70,7 @@ Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph) {
|
|||
}
|
||||
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (IsSingleInputNodeFloat16Node(node)) {
|
||||
if (node.OpType() != "Cast" && IsSingleInputNodeFloat16Node(node)) {
|
||||
node.SetExecutionProviderType("");
|
||||
}
|
||||
}
|
||||
|
|
@ -205,8 +191,8 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
|
|||
TypeProto float_tensor_proto;
|
||||
float_16_tensor_proto.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT16);
|
||||
float_tensor_proto.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
IdGenerator id_generator;
|
||||
std::map<onnxruntime::NodeArg*, onnxruntime::NodeArg*> input_def_updates;
|
||||
|
||||
for (onnxruntime::NodeIndex i : order) {
|
||||
auto node = graph.GetNode(i);
|
||||
if (!node)
|
||||
|
|
@ -221,9 +207,8 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
|
|||
if (input_def_updates.count(src_arg)) {
|
||||
replacement_defs[src_arg] = input_def_updates[src_arg];
|
||||
} else {
|
||||
//insert cast op to cast input
|
||||
// insert cast op to cast input
|
||||
auto dst_arg = AddCastNode(graph,
|
||||
id_generator,
|
||||
src_arg,
|
||||
&float_tensor_proto,
|
||||
false,
|
||||
|
|
@ -271,10 +256,9 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
|
|||
if (output->Type() &&
|
||||
DataTypeImpl::TypeFromProto(*output->TypeAsProto()) == DataTypeImpl::GetTensorType<MLFloat16>() &&
|
||||
casted) {
|
||||
//insert cast op to cast output back to float16
|
||||
// insert cast op to cast output back to float16
|
||||
auto dst_arg = output;
|
||||
auto src_arg = AddCastNode(graph,
|
||||
id_generator,
|
||||
dst_arg,
|
||||
&float_tensor_proto,
|
||||
true,
|
||||
|
|
@ -283,7 +267,7 @@ Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modifie
|
|||
replacement_defs[dst_arg] = src_arg;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
node->ReplaceDefs(replacement_defs);
|
||||
modified = modified || casted;
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ class InsertCastTransformer : public onnxruntime::GraphTransformer {
|
|||
|
||||
private:
|
||||
Status ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
bool NeedInsertCast(const onnxruntime::Node* node, const onnxruntime::NodeArg* input) const;
|
||||
|
||||
// Currently because we only have very few cpu kernels support float16, place those nodes on float16
|
||||
|
|
|
|||
|
|
@ -174,5 +174,36 @@ TEST(TransformerTest, MultinomialWithFloat16Input) {
|
|||
EXPECT_TRUE(status.IsOK()) << status;
|
||||
}
|
||||
|
||||
// This test is to test insert_cast_transform the same graph twice
|
||||
// insert_cast_transform needs to detect existing Cast Node
|
||||
// Prevent inserting the same Cast node twice
|
||||
TEST(TransformerTest, InsertCastNodeTwice) {
|
||||
auto model_uri = MODEL_FOLDER ORT_TSTR("insert_cast_twice.onnx");
|
||||
std::shared_ptr<Model> model;
|
||||
auto status = Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
Graph& graph = model->MainGraph();
|
||||
InsertCastTransformer transformer("Test");
|
||||
|
||||
// First insert
|
||||
bool modified = false;
|
||||
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
EXPECT_TRUE(modified) << "Transformer should have added some Cast nodes";
|
||||
EXPECT_TRUE(op_to_count["Cast"] == 5) << "Insert 3 more Cast nodes.";
|
||||
|
||||
// Second insert
|
||||
modified = false;
|
||||
status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger());
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
// Same graph without modification; The number of Cast node remains
|
||||
EXPECT_TRUE(!modified) << "Transformer should not modify the modfied graph again";
|
||||
EXPECT_TRUE(op_to_count["Cast"] == 5) << "Remain the same number of Cast node";
|
||||
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/insert_cast_twice.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/insert_cast_twice.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue