diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake index 315a66edd1..37ab9ac9ff 100644 --- a/cmake/onnxruntime_graph.cmake +++ b/cmake/onnxruntime_graph.cmake @@ -17,7 +17,8 @@ if (onnxruntime_MINIMAL_BUILD) "${ONNXRUNTIME_ROOT}/core/graph/schema_registry.cc" "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/*defs.h" "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/*defs.cc" - + "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.h" + "${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.cc" ) # no Function support initially diff --git a/onnxruntime/core/graph/contrib_ops/onnx_function_util.cc b/onnxruntime/core/graph/contrib_ops/onnx_function_util.cc new file mode 100644 index 0000000000..dc4925fb0d --- /dev/null +++ b/onnxruntime/core/graph/contrib_ops/onnx_function_util.cc @@ -0,0 +1,52 @@ +#include "core/graph/contrib_ops/onnx_function_util.h" +#include "core/util/math.h" + +namespace ONNX_NAMESPACE { + +TensorProto ToTensor(double value, TensorProto_DataType elem_type) { + TensorProto t; + t.set_data_type(elem_type); + switch (elem_type) { + case TensorProto_DataType::TensorProto_DataType_FLOAT: + t.add_float_data((float)value); + break; + case TensorProto_DataType::TensorProto_DataType_DOUBLE: + t.add_double_data(value); + break; + case TensorProto_DataType::TensorProto_DataType_FLOAT16: + t.add_int32_data(onnxruntime::math::floatToHalf((float)value)); + break; + default: + assert(false); + } + + return t; +} + +void BuildNodes(FunctionProto& functionProto, const std::vector& node_defs) { + for (size_t i = 0; i < node_defs.size(); i++) { + const FunctionBodyHelper::NodeDef& node = node_defs[i]; + auto* np = functionProto.add_node(); + + np->set_op_type(node.op_type); + for (const auto& inp : node.inputs) { + np->add_input(inp); + } + for (const auto& o : node.outputs) { + np->add_output(o); + } + for (const auto& attr : node.attributes) { + *(np->add_attribute()) = attr.proto; + } + } +} + +bool BuildFunctionProto(FunctionProto& functionProto, const OpSchema& schema, + const std::vector& node_defs, + const std::vector& relied_opsets) { + BuildNodes(functionProto, node_defs); + schema.BuildFunction(functionProto, relied_opsets); + return true; +} + +} // namespace ONNX_NAMESPACE \ No newline at end of file diff --git a/onnxruntime/core/graph/contrib_ops/onnx_function_util.h b/onnxruntime/core/graph/contrib_ops/onnx_function_util.h new file mode 100644 index 0000000000..babe1f6537 --- /dev/null +++ b/onnxruntime/core/graph/contrib_ops/onnx_function_util.h @@ -0,0 +1,25 @@ +#pragma once + +// Utility functions for building the body of a context-dependent function. +// Temporary placeholder for utilities to be moved into ONNX repo. TODO. + +#include +#include + +#include "onnx/onnx-operators_pb.h" +#include "onnx/defs/schema.h" +#include "onnx/defs/function.h" + +namespace ONNX_NAMESPACE { + +// For floating-value constants of different precision: +TensorProto ToTensor(double value, TensorProto_DataType elem_type); + +// Utility function to construct a FunctionProto from an opschema (for the signature information), +// a sequence of NodeDefs (for the function body), and the relied opsets. +bool BuildFunctionProto(FunctionProto& functionProto, + const OpSchema& schema, + const std::vector& node_defs, + const std::vector& relied_opsets = {}); + +} // namespace ONNX_NAMESPACE \ No newline at end of file diff --git a/onnxruntime/core/graph/function.cc b/onnxruntime/core/graph/function.cc index 4da98bf0de..6f538d3c58 100644 --- a/onnxruntime/core/graph/function.cc +++ b/onnxruntime/core/graph/function.cc @@ -180,11 +180,12 @@ static std::unordered_map CreateOpsetImportsForFunction(const std::unordered_map function_opset_imports{graph_opset_imports}; // merge with opset imports in function proto for (const auto& opset_import : func_proto.opset_import()) { - auto result = function_opset_imports.insert({opset_import.domain(), static_cast(opset_import.version())}); - ORT_ENFORCE(result.second, + auto opset_version = static_cast(opset_import.version()); + auto result = function_opset_imports.insert({opset_import.domain(), opset_version}); + ORT_ENFORCE((result.first->second == opset_version), "ONNX model does not support multiple opset versions for a domain. Model imports opset version ", result.first->second, " for domain ", result.first->first, " and function is trying to import opset version ", - opset_import.version(), " for the same domain"); + opset_version, " for the same domain"); } return function_opset_imports; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 70eddde7cf..5f3de37128 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2371,12 +2371,29 @@ void Graph::InitFunctionBodyForNode(Node& node) { if (node.op_->HasContextDependentFunction()) { NodeProto node_proto; node.ToProto(node_proto); - onnx::FunctionBodyBuildContextImpl function_body_ctx(node_proto); + std::vector input_types; + for (size_t i = 0, n = node.InputDefs().size(); i < n; i++) { + auto p_node_arg = node.InputDefs().at(i); + if ((nullptr != p_node_arg) && p_node_arg->Exists()) { + auto& type = *(p_node_arg->TypeAsProto()); + input_types.emplace_back(type); + } else + input_types.emplace_back(); + } + onnx::FunctionBodyBuildContextImpl function_body_ctx(node_proto, input_types); node.op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto); } else { onnx_function_proto = *(node.op_->GetFunction()); } + // Check function's opset requirements are compatible with model's opset. + auto& graphImports = DomainToVersionMap(); + for (const auto& fn_import : onnx_function_proto.opset_import()) { + auto it = graphImports.find(fn_import.domain()); + if ((it != graphImports.end()) && (it->second != fn_import.version())) + return; // Incompatible. Do not use this function expansion. + } + auto func_ptr = onnxruntime::make_unique(*this, node.Index(), onnx_function_proto, logger_); diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index cf78dc7e50..c2bcce3eef 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -3,6 +3,7 @@ #include "core/graph/op.h" #include "core/graph/contrib_ops/contrib_defs.h" +#include "core/graph/contrib_ops/onnx_function_util.h" #include "core/providers/common.h" #include "orttraining/core/graph/training_op_defs.h" #include "orttraining/core/framework/distributed_run_context.h" @@ -343,7 +344,7 @@ void RegisterTrainingOpSchemas() { .SetDomain(kMSDomain) .SinceVersion(1) .Input(0, "dY", "Gradient of output Y", "T") - .Input(1, "X", "Input tensor", "T") + .Input(1, "Y", "Input tensor", "T") .Output(0, "dX", "Gradient of input X", "T") .Attr( "axis", @@ -356,7 +357,54 @@ void RegisterTrainingOpSchemas() { "T", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") - .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput); + .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput) + .SetContextDependentFunctionBodyBuilder( + [](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) { + // SoftmaxGrad computes dX = Y * ( dY - dot(Y, dY)) + // ONNX does not have a dot product, which can be simulated as a pointwise-multiplication ("Mul"), + // followed by a "ReduceSum". Unfortunately, the treatment of "axis" is different in "SoftmaxGrad" + // and "ReduceSum". If axis=k for SoftmaxGrad, we need to specify [k, ..., n-1] as the axes of + // reduction for "ReduceSum", after accounting for negative-axis specification. + // An alternative solution would be to Flatten inputs to 2D and then reshape output back to original shape. + // Hopefully, many of these ops can be optimized away in the common-case of statically-known shapes. + + auto* axis_attr = ctx.getAttribute("axis"); + int64_t axis = (axis_attr != nullptr) ? axis_attr->i() : 1; + auto zero1d = ToTensor(std::vector({0})); + zero1d.add_dims(1); + + // nodes: {outputs, op, inputs, attributes} + + // First, convert axis specification k to reduction axes [k, k+1, ..., n-1] + std::vector body{ + FunctionBodyHelper::Const("one", 1), + FunctionBodyHelper::Const("k", axis), + {{"axis_zero"}, "Constant", {}, {{"value", zero1d}}}, + {{"shape"}, "Shape", {"dY"}}, + {{"n_as_vector"}, "Shape", {"shape"}}, + {{"n"}, "Squeeze", {"n_as_vector", "axis_zero"}}, + }; + + // For negative axis, add n to axis-value k; then use Range(...). + if (axis >= 0) { + body.push_back({{"reduction_axes"}, "Range", {"k", "n", "one"}}); + } else { + body.push_back({{"n_plus_k"}, "Add", {"n", "k"}}); + body.push_back({{"reduction_axes"}, "Range", {"n_plus_k", "n", "one"}}); + } + + // compute dX = Y * ( dY - dot(Y, dY)) = Y * ( dY - ReduceSum(Y * dY)) + body.push_back({{"a"}, "Mul", {"Y", "dY"}}); + body.push_back({{"b"}, "ReduceSum", {"a", "reduction_axes"}}); + body.push_back({{"c"}, "Sub", {"dY", "b"}}); + body.push_back({{"dX"}, "Mul", {"Y", "c"}}); + + OperatorSetIdProto onnx_opset_13; + onnx_opset_13.set_domain(""); + onnx_opset_13.set_version(13); + + return ONNX_NAMESPACE::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13}); + }); ONNX_CONTRIB_OPERATOR_SCHEMA(LogSoftmaxGrad) .SetDomain(kMSDomain) diff --git a/orttraining/orttraining/test/gradient/function_ops_test.cc b/orttraining/orttraining/test/gradient/function_ops_test.cc new file mode 100644 index 0000000000..6eb06b2820 --- /dev/null +++ b/orttraining/orttraining/test/gradient/function_ops_test.cc @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "gtest/gtest.h" +#include "core/graph/model.h" +#include "core/graph/contrib_ops/contrib_defs.h" +#include "orttraining/core/graph/training_op_defs.h" +#include "test/test_environment.h" + +#include "core/session/inference_session.h" +#include "core/providers/cpu/cpu_execution_provider.h" + +#include "test/framework/test_utils.h" + +using namespace ::onnxruntime::common; + +namespace onnxruntime { +namespace test { + +typedef std::vector ArgMap; + +static void RegisterSchemas() { + static bool registered = false; + if (!registered) { + onnxruntime::training::RegisterTrainingOpSchemas(); + registered = true; + } +} + +static ONNX_NAMESPACE::TypeProto TensorType(int32_t elem_type, std::vector dims) { + ONNX_NAMESPACE::TypeProto typeProto; + typeProto.mutable_tensor_type()->set_elem_type(elem_type); + auto* shape = typeProto.mutable_tensor_type()->mutable_shape(); + for (auto dim : dims) + shape->add_dim()->set_dim_value(dim); + return typeProto; +} + +static ONNX_NAMESPACE::TypeProto TensorType(int32_t elem_type, std::vector dims) { + ONNX_NAMESPACE::TypeProto typeProto; + typeProto.mutable_tensor_type()->set_elem_type(elem_type); + auto* shape = typeProto.mutable_tensor_type()->mutable_shape(); + for (auto dim : dims) { + uint64_t dimval; + std::istringstream s(dim); + if (s >> dimval) { + shape->add_dim()->set_dim_value(dimval); + } else { + shape->add_dim()->set_dim_param(dim); + } + } + return typeProto; +} + +static std::vector +Run(onnxruntime::Model& model, NameMLValMap& feeds, std::vector output_names) { + SessionOptions session_options; + InferenceSession session_object{session_options, GetEnvironment()}; + + std::string serialized_model; + const bool serialization_status = model.ToProto().SerializeToString(&serialized_model); + EXPECT_TRUE(serialization_status) << "Failed to serialize proto to string"; + std::stringstream sstr(serialized_model); + auto status = session_object.Load(sstr); + EXPECT_TRUE(status.IsOK()); + status = session_object.Initialize(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + RunOptions run_options; + run_options.run_tag = session_options.session_logid; + + std::vector fetches; + + status = session_object.Run(run_options, feeds, output_names, &fetches); + EXPECT_TRUE(status.IsOK()) << "Session Run failed."; + + return fetches; +} + +// Restricted to float tensors +static void AssertEqual(const Tensor& tensor1, const Tensor& tensor2) { + auto size = tensor1.Shape().Size(); + auto* data1 = tensor1.template Data(); + auto* data2 = tensor2.template Data(); + + float threshold = 0.001f; + + for (int i = 0; i < size; ++i) { + ASSERT_NEAR(data1[i], data2[i], threshold) << "as position i:" << i; + } +} + +static void AssertEqual(const std::vector& results1, const std::vector& results2) { + ASSERT_EQ(results1.size(), results2.size()); + for (int i = 0; i < results1.size(); i++) { + auto& value1 = results1[i].Get(); + auto& value2 = results2[i].Get(); + AssertEqual(value1, value2); + } +} + +struct FunctionTestCase { + const char* opname; + + std::vector input_args; + std::vector> input_values; + NameMLValMap input_value_map; + + std::vector output_names; + std::vector output_args; + + NodeAttributes attributes; + std::unique_ptr provider; + + std::unordered_map opsets; + + FunctionTestCase(const char* _opname) : opname(_opname), provider(new CPUExecutionProvider(CPUExecutionProviderInfo())) {} + + void AddInput(std::string input_name, std::vector shape, std::vector data, std::vector symshape = {}) { + auto arg_type = (symshape.size() > 0) ? TensorType(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, symshape) : TensorType(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, shape); + input_args.emplace_back(input_name, &arg_type); + + OrtValue ort_value; + CreateMLValue(provider->GetAllocator(0, OrtMemTypeDefault), shape, data, &ort_value); + input_values.push_back(std::make_pair(input_name, ort_value)); + input_value_map.insert(std::make_pair(input_name, ort_value)); + } + + void AddOutput(std::string output_name) { + output_names.emplace_back(output_name); + output_args.emplace_back(output_name, nullptr); + } + + void AddAttribute(const char* attr_name, int64_t attr_val) { + ONNX_NAMESPACE::AttributeProto axis_attr; + axis_attr.set_name(attr_name); + axis_attr.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT); + axis_attr.set_i(attr_val); + attributes[attr_name] = axis_attr; + } + + onnxruntime::Node& AddCallNodeTo(onnxruntime::Graph& graph) { + std::vector input_arg_ptrs; + + for (auto& arg : input_args) + input_arg_ptrs.push_back(&arg); + + std::vector output_arg_ptrs; + for (auto& arg : output_args) + output_arg_ptrs.push_back(&arg); + + return graph.AddNode("fncallnode", opname, "function call node", input_arg_ptrs, output_arg_ptrs, &attributes, onnxruntime::kMSDomain); + } + + std::unique_ptr CreateModel(bool inline_call = false) { + RegisterSchemas(); + if (opsets.size() == 0) { + // Default opsets + opsets[kOnnxDomain] = 13; + opsets[kMSDomain] = 1; + } + + std::unique_ptr model(new Model("test", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + opsets, {}, DefaultLoggingManager().DefaultLogger())); + + onnxruntime::Graph& graph = model->MainGraph(); + auto& call_node = AddCallNodeTo(graph); + + auto status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + if (inline_call) { + graph.InlineFunction(call_node); + status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + } + + return model; + } + + void RunTest() { + auto model1 = CreateModel(false); + auto results1 = Run(*model1, input_value_map, output_names); + + auto model2 = CreateModel(true); + auto results2 = Run(*model2, input_value_map, output_names); + + AssertEqual(results1, results2); + } +}; + +static void InitSoftmaxGradTestCase(FunctionTestCase& testCase, std::vector shape) { + int64_t size = 1; + for (auto dim : shape) + size *= dim; + + std::vector value(size); + for (int64_t i = 0; i < size; i++) + value[i] = float(i); + + testCase.AddInput("dY", shape, value); + testCase.AddInput("Y", shape, value); + testCase.AddOutput("dX"); +} + +TEST(SoftmaxGradExpansionTest, DefaultAxis) { + FunctionTestCase testCase("SoftmaxGrad"); + InitSoftmaxGradTestCase(testCase, {3, 2}); + testCase.RunTest(); +} + +TEST(SoftmaxGradExpansionTest, NegativeAxis) { + FunctionTestCase testCase("SoftmaxGrad"); + InitSoftmaxGradTestCase(testCase, {3, 2}); + testCase.AddAttribute("axis", -1); + testCase.RunTest(); +} + +TEST(SoftmaxGradExpansionTest, PositiveAxis) { + FunctionTestCase testCase("SoftmaxGrad"); + InitSoftmaxGradTestCase(testCase, {3, 2}); + testCase.AddAttribute("axis", 1); + testCase.RunTest(); +} + +TEST(SoftmaxGradExpansionTest, 3D) { + FunctionTestCase testCase("SoftmaxGrad"); + InitSoftmaxGradTestCase(testCase, {3, 2, 2}); + testCase.RunTest(); +} + +TEST(SoftmaxGradExpansionTest, SymbolicShape) { + FunctionTestCase testCase("SoftmaxGrad"); + std::vector shape{3, 2, 2}; + std::vector sym_shape{"BatchSize", "SeqSize", "2"}; + int size = 12; + std::vector value(size); + for (int64_t i = 0; i < size; i++) + value[i] = float(i); + + testCase.AddInput("dY", shape, value, sym_shape); + testCase.AddInput("Y", shape, value, sym_shape); + testCase.AddOutput("dX"); + testCase.RunTest(); +} + +// Test (unexpanded) versions for both opset 12 and opset 13 models to ensure +// function-schema does not impact handling of opset 12 models. The current +// expansion requires opset 13, and no expansion should happen in opset 12 +// models. Test is required since ORT currently generates function-expansion +// even when op is dispatched to a kernel. + +TEST(SoftmaxGradExpansionTest, OpsetTest) { + FunctionTestCase testCase("SoftmaxGrad"); + testCase.opsets[kOnnxDomain] = 12; + testCase.opsets[kMSDomain] = 1; + InitSoftmaxGradTestCase(testCase, {3, 2, 2}); + + auto model1 = testCase.CreateModel(); + auto results1 = onnxruntime::test::Run(*model1, testCase.input_value_map, testCase.output_names); + + testCase.opsets[kOnnxDomain] = 13; + testCase.opsets[kMSDomain] = 1; + + auto model2 = testCase.CreateModel(); + auto results2 = onnxruntime::test::Run(*model1, testCase.input_value_map, testCase.output_names); + + AssertEqual(results1, results2); +} + +} // namespace test +} // namespace onnxruntime \ No newline at end of file