Add function-body to SoftmaxGrad (#6988)

* Add function body to SoftmaxGrad schema

* Add type context and cleanup

* Add test case with symbolic dimensions

* Add opset specification to function

* handle opset dependence

* Exclude from minimal build
This commit is contained in:
G. Ramalingam 2021-03-25 11:34:06 -07:00 committed by GitHub
parent 53c123dcee
commit cc0e7bee76
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 427 additions and 7 deletions

View file

@ -17,7 +17,8 @@ if (onnxruntime_MINIMAL_BUILD)
"${ONNXRUNTIME_ROOT}/core/graph/schema_registry.cc"
"${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/*defs.h"
"${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/*defs.cc"
"${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.h"
"${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.cc"
)
# no Function support initially

View file

@ -0,0 +1,52 @@
#include "core/graph/contrib_ops/onnx_function_util.h"
#include "core/util/math.h"
namespace ONNX_NAMESPACE {
TensorProto ToTensor(double value, TensorProto_DataType elem_type) {
TensorProto t;
t.set_data_type(elem_type);
switch (elem_type) {
case TensorProto_DataType::TensorProto_DataType_FLOAT:
t.add_float_data((float)value);
break;
case TensorProto_DataType::TensorProto_DataType_DOUBLE:
t.add_double_data(value);
break;
case TensorProto_DataType::TensorProto_DataType_FLOAT16:
t.add_int32_data(onnxruntime::math::floatToHalf((float)value));
break;
default:
assert(false);
}
return t;
}
void BuildNodes(FunctionProto& functionProto, const std::vector<FunctionBodyHelper::NodeDef>& node_defs) {
for (size_t i = 0; i < node_defs.size(); i++) {
const FunctionBodyHelper::NodeDef& node = node_defs[i];
auto* np = functionProto.add_node();
np->set_op_type(node.op_type);
for (const auto& inp : node.inputs) {
np->add_input(inp);
}
for (const auto& o : node.outputs) {
np->add_output(o);
}
for (const auto& attr : node.attributes) {
*(np->add_attribute()) = attr.proto;
}
}
}
bool BuildFunctionProto(FunctionProto& functionProto, const OpSchema& schema,
const std::vector<FunctionBodyHelper::NodeDef>& node_defs,
const std::vector<OperatorSetIdProto>& relied_opsets) {
BuildNodes(functionProto, node_defs);
schema.BuildFunction(functionProto, relied_opsets);
return true;
}
} // namespace ONNX_NAMESPACE

View file

@ -0,0 +1,25 @@
#pragma once
// Utility functions for building the body of a context-dependent function.
// Temporary placeholder for utilities to be moved into ONNX repo. TODO.
#include <string>
#include <vector>
#include "onnx/onnx-operators_pb.h"
#include "onnx/defs/schema.h"
#include "onnx/defs/function.h"
namespace ONNX_NAMESPACE {
// For floating-value constants of different precision:
TensorProto ToTensor(double value, TensorProto_DataType elem_type);
// Utility function to construct a FunctionProto from an opschema (for the signature information),
// a sequence of NodeDefs (for the function body), and the relied opsets.
bool BuildFunctionProto(FunctionProto& functionProto,
const OpSchema& schema,
const std::vector<FunctionBodyHelper::NodeDef>& node_defs,
const std::vector<OperatorSetIdProto>& relied_opsets = {});
} // namespace ONNX_NAMESPACE

View file

@ -180,11 +180,12 @@ static std::unordered_map<std::string, int> CreateOpsetImportsForFunction(const
std::unordered_map<std::string, int> function_opset_imports{graph_opset_imports};
// merge with opset imports in function proto
for (const auto& opset_import : func_proto.opset_import()) {
auto result = function_opset_imports.insert({opset_import.domain(), static_cast<int>(opset_import.version())});
ORT_ENFORCE(result.second,
auto opset_version = static_cast<int>(opset_import.version());
auto result = function_opset_imports.insert({opset_import.domain(), opset_version});
ORT_ENFORCE((result.first->second == opset_version),
"ONNX model does not support multiple opset versions for a domain. Model imports opset version ",
result.first->second, " for domain ", result.first->first, " and function is trying to import opset version ",
opset_import.version(), " for the same domain");
opset_version, " for the same domain");
}
return function_opset_imports;

View file

@ -2371,12 +2371,29 @@ void Graph::InitFunctionBodyForNode(Node& node) {
if (node.op_->HasContextDependentFunction()) {
NodeProto node_proto;
node.ToProto(node_proto);
onnx::FunctionBodyBuildContextImpl function_body_ctx(node_proto);
std::vector<TypeProto> input_types;
for (size_t i = 0, n = node.InputDefs().size(); i < n; i++) {
auto p_node_arg = node.InputDefs().at(i);
if ((nullptr != p_node_arg) && p_node_arg->Exists()) {
auto& type = *(p_node_arg->TypeAsProto());
input_types.emplace_back(type);
} else
input_types.emplace_back();
}
onnx::FunctionBodyBuildContextImpl function_body_ctx(node_proto, input_types);
node.op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto);
} else {
onnx_function_proto = *(node.op_->GetFunction());
}
// Check function's opset requirements are compatible with model's opset.
auto& graphImports = DomainToVersionMap();
for (const auto& fn_import : onnx_function_proto.opset_import()) {
auto it = graphImports.find(fn_import.domain());
if ((it != graphImports.end()) && (it->second != fn_import.version()))
return; // Incompatible. Do not use this function expansion.
}
auto func_ptr = onnxruntime::make_unique<onnxruntime::FunctionImpl>(*this, node.Index(), onnx_function_proto,
logger_);

View file

@ -3,6 +3,7 @@
#include "core/graph/op.h"
#include "core/graph/contrib_ops/contrib_defs.h"
#include "core/graph/contrib_ops/onnx_function_util.h"
#include "core/providers/common.h"
#include "orttraining/core/graph/training_op_defs.h"
#include "orttraining/core/framework/distributed_run_context.h"
@ -343,7 +344,7 @@ void RegisterTrainingOpSchemas() {
.SetDomain(kMSDomain)
.SinceVersion(1)
.Input(0, "dY", "Gradient of output Y", "T")
.Input(1, "X", "Input tensor", "T")
.Input(1, "Y", "Input tensor", "T")
.Output(0, "dX", "Gradient of input X", "T")
.Attr(
"axis",
@ -356,7 +357,54 @@ void RegisterTrainingOpSchemas() {
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput);
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
.SetContextDependentFunctionBodyBuilder(
[](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
// SoftmaxGrad computes dX = Y * ( dY - dot(Y, dY))
// ONNX does not have a dot product, which can be simulated as a pointwise-multiplication ("Mul"),
// followed by a "ReduceSum". Unfortunately, the treatment of "axis" is different in "SoftmaxGrad"
// and "ReduceSum". If axis=k for SoftmaxGrad, we need to specify [k, ..., n-1] as the axes of
// reduction for "ReduceSum", after accounting for negative-axis specification.
// An alternative solution would be to Flatten inputs to 2D and then reshape output back to original shape.
// Hopefully, many of these ops can be optimized away in the common-case of statically-known shapes.
auto* axis_attr = ctx.getAttribute("axis");
int64_t axis = (axis_attr != nullptr) ? axis_attr->i() : 1;
auto zero1d = ToTensor(std::vector<int64_t>({0}));
zero1d.add_dims(1);
// nodes: {outputs, op, inputs, attributes}
// First, convert axis specification k to reduction axes [k, k+1, ..., n-1]
std::vector<FunctionBodyHelper::NodeDef> body{
FunctionBodyHelper::Const<int64_t>("one", 1),
FunctionBodyHelper::Const<int64_t>("k", axis),
{{"axis_zero"}, "Constant", {}, {{"value", zero1d}}},
{{"shape"}, "Shape", {"dY"}},
{{"n_as_vector"}, "Shape", {"shape"}},
{{"n"}, "Squeeze", {"n_as_vector", "axis_zero"}},
};
// For negative axis, add n to axis-value k; then use Range(...).
if (axis >= 0) {
body.push_back({{"reduction_axes"}, "Range", {"k", "n", "one"}});
} else {
body.push_back({{"n_plus_k"}, "Add", {"n", "k"}});
body.push_back({{"reduction_axes"}, "Range", {"n_plus_k", "n", "one"}});
}
// compute dX = Y * ( dY - dot(Y, dY)) = Y * ( dY - ReduceSum(Y * dY))
body.push_back({{"a"}, "Mul", {"Y", "dY"}});
body.push_back({{"b"}, "ReduceSum", {"a", "reduction_axes"}});
body.push_back({{"c"}, "Sub", {"dY", "b"}});
body.push_back({{"dX"}, "Mul", {"Y", "c"}});
OperatorSetIdProto onnx_opset_13;
onnx_opset_13.set_domain("");
onnx_opset_13.set_version(13);
return ONNX_NAMESPACE::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
});
ONNX_CONTRIB_OPERATOR_SCHEMA(LogSoftmaxGrad)
.SetDomain(kMSDomain)

View file

@ -0,0 +1,276 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <iostream>
#include <sstream>
#include <memory>
#include "gtest/gtest.h"
#include "core/graph/model.h"
#include "core/graph/contrib_ops/contrib_defs.h"
#include "orttraining/core/graph/training_op_defs.h"
#include "test/test_environment.h"
#include "core/session/inference_session.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "test/framework/test_utils.h"
using namespace ::onnxruntime::common;
namespace onnxruntime {
namespace test {
typedef std::vector<onnxruntime::NodeArg*> ArgMap;
static void RegisterSchemas() {
static bool registered = false;
if (!registered) {
onnxruntime::training::RegisterTrainingOpSchemas();
registered = true;
}
}
static ONNX_NAMESPACE::TypeProto TensorType(int32_t elem_type, std::vector<int64_t> dims) {
ONNX_NAMESPACE::TypeProto typeProto;
typeProto.mutable_tensor_type()->set_elem_type(elem_type);
auto* shape = typeProto.mutable_tensor_type()->mutable_shape();
for (auto dim : dims)
shape->add_dim()->set_dim_value(dim);
return typeProto;
}
static ONNX_NAMESPACE::TypeProto TensorType(int32_t elem_type, std::vector<std::string> dims) {
ONNX_NAMESPACE::TypeProto typeProto;
typeProto.mutable_tensor_type()->set_elem_type(elem_type);
auto* shape = typeProto.mutable_tensor_type()->mutable_shape();
for (auto dim : dims) {
uint64_t dimval;
std::istringstream s(dim);
if (s >> dimval) {
shape->add_dim()->set_dim_value(dimval);
} else {
shape->add_dim()->set_dim_param(dim);
}
}
return typeProto;
}
static std::vector<OrtValue>
Run(onnxruntime::Model& model, NameMLValMap& feeds, std::vector<std::string> output_names) {
SessionOptions session_options;
InferenceSession session_object{session_options, GetEnvironment()};
std::string serialized_model;
const bool serialization_status = model.ToProto().SerializeToString(&serialized_model);
EXPECT_TRUE(serialization_status) << "Failed to serialize proto to string";
std::stringstream sstr(serialized_model);
auto status = session_object.Load(sstr);
EXPECT_TRUE(status.IsOK());
status = session_object.Initialize();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
RunOptions run_options;
run_options.run_tag = session_options.session_logid;
std::vector<OrtValue> fetches;
status = session_object.Run(run_options, feeds, output_names, &fetches);
EXPECT_TRUE(status.IsOK()) << "Session Run failed.";
return fetches;
}
// Restricted to float tensors
static void AssertEqual(const Tensor& tensor1, const Tensor& tensor2) {
auto size = tensor1.Shape().Size();
auto* data1 = tensor1.template Data<float>();
auto* data2 = tensor2.template Data<float>();
float threshold = 0.001f;
for (int i = 0; i < size; ++i) {
ASSERT_NEAR(data1[i], data2[i], threshold) << "as position i:" << i;
}
}
static void AssertEqual(const std::vector<OrtValue>& results1, const std::vector<OrtValue>& results2) {
ASSERT_EQ(results1.size(), results2.size());
for (int i = 0; i < results1.size(); i++) {
auto& value1 = results1[i].Get<Tensor>();
auto& value2 = results2[i].Get<Tensor>();
AssertEqual(value1, value2);
}
}
struct FunctionTestCase {
const char* opname;
std::vector<NodeArg> input_args;
std::vector<std::pair<std::string, OrtValue>> input_values;
NameMLValMap input_value_map;
std::vector<std::string> output_names;
std::vector<NodeArg> output_args;
NodeAttributes attributes;
std::unique_ptr<IExecutionProvider> provider;
std::unordered_map<std::string, int> opsets;
FunctionTestCase(const char* _opname) : opname(_opname), provider(new CPUExecutionProvider(CPUExecutionProviderInfo())) {}
void AddInput(std::string input_name, std::vector<int64_t> shape, std::vector<float> data, std::vector<std::string> symshape = {}) {
auto arg_type = (symshape.size() > 0) ? TensorType(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, symshape) : TensorType(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, shape);
input_args.emplace_back(input_name, &arg_type);
OrtValue ort_value;
CreateMLValue<float>(provider->GetAllocator(0, OrtMemTypeDefault), shape, data, &ort_value);
input_values.push_back(std::make_pair(input_name, ort_value));
input_value_map.insert(std::make_pair(input_name, ort_value));
}
void AddOutput(std::string output_name) {
output_names.emplace_back(output_name);
output_args.emplace_back(output_name, nullptr);
}
void AddAttribute(const char* attr_name, int64_t attr_val) {
ONNX_NAMESPACE::AttributeProto axis_attr;
axis_attr.set_name(attr_name);
axis_attr.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT);
axis_attr.set_i(attr_val);
attributes[attr_name] = axis_attr;
}
onnxruntime::Node& AddCallNodeTo(onnxruntime::Graph& graph) {
std::vector<NodeArg*> input_arg_ptrs;
for (auto& arg : input_args)
input_arg_ptrs.push_back(&arg);
std::vector<NodeArg*> output_arg_ptrs;
for (auto& arg : output_args)
output_arg_ptrs.push_back(&arg);
return graph.AddNode("fncallnode", opname, "function call node", input_arg_ptrs, output_arg_ptrs, &attributes, onnxruntime::kMSDomain);
}
std::unique_ptr<Model> CreateModel(bool inline_call = false) {
RegisterSchemas();
if (opsets.size() == 0) {
// Default opsets
opsets[kOnnxDomain] = 13;
opsets[kMSDomain] = 1;
}
std::unique_ptr<Model> model(new Model("test", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
opsets, {}, DefaultLoggingManager().DefaultLogger()));
onnxruntime::Graph& graph = model->MainGraph();
auto& call_node = AddCallNodeTo(graph);
auto status = graph.Resolve();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
if (inline_call) {
graph.InlineFunction(call_node);
status = graph.Resolve();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
}
return model;
}
void RunTest() {
auto model1 = CreateModel(false);
auto results1 = Run(*model1, input_value_map, output_names);
auto model2 = CreateModel(true);
auto results2 = Run(*model2, input_value_map, output_names);
AssertEqual(results1, results2);
}
};
static void InitSoftmaxGradTestCase(FunctionTestCase& testCase, std::vector<int64_t> shape) {
int64_t size = 1;
for (auto dim : shape)
size *= dim;
std::vector<float> value(size);
for (int64_t i = 0; i < size; i++)
value[i] = float(i);
testCase.AddInput("dY", shape, value);
testCase.AddInput("Y", shape, value);
testCase.AddOutput("dX");
}
TEST(SoftmaxGradExpansionTest, DefaultAxis) {
FunctionTestCase testCase("SoftmaxGrad");
InitSoftmaxGradTestCase(testCase, {3, 2});
testCase.RunTest();
}
TEST(SoftmaxGradExpansionTest, NegativeAxis) {
FunctionTestCase testCase("SoftmaxGrad");
InitSoftmaxGradTestCase(testCase, {3, 2});
testCase.AddAttribute("axis", -1);
testCase.RunTest();
}
TEST(SoftmaxGradExpansionTest, PositiveAxis) {
FunctionTestCase testCase("SoftmaxGrad");
InitSoftmaxGradTestCase(testCase, {3, 2});
testCase.AddAttribute("axis", 1);
testCase.RunTest();
}
TEST(SoftmaxGradExpansionTest, 3D) {
FunctionTestCase testCase("SoftmaxGrad");
InitSoftmaxGradTestCase(testCase, {3, 2, 2});
testCase.RunTest();
}
TEST(SoftmaxGradExpansionTest, SymbolicShape) {
FunctionTestCase testCase("SoftmaxGrad");
std::vector<int64_t> shape{3, 2, 2};
std::vector<std::string> sym_shape{"BatchSize", "SeqSize", "2"};
int size = 12;
std::vector<float> value(size);
for (int64_t i = 0; i < size; i++)
value[i] = float(i);
testCase.AddInput("dY", shape, value, sym_shape);
testCase.AddInput("Y", shape, value, sym_shape);
testCase.AddOutput("dX");
testCase.RunTest();
}
// Test (unexpanded) versions for both opset 12 and opset 13 models to ensure
// function-schema does not impact handling of opset 12 models. The current
// expansion requires opset 13, and no expansion should happen in opset 12
// models. Test is required since ORT currently generates function-expansion
// even when op is dispatched to a kernel.
TEST(SoftmaxGradExpansionTest, OpsetTest) {
FunctionTestCase testCase("SoftmaxGrad");
testCase.opsets[kOnnxDomain] = 12;
testCase.opsets[kMSDomain] = 1;
InitSoftmaxGradTestCase(testCase, {3, 2, 2});
auto model1 = testCase.CreateModel();
auto results1 = onnxruntime::test::Run(*model1, testCase.input_value_map, testCase.output_names);
testCase.opsets[kOnnxDomain] = 13;
testCase.opsets[kMSDomain] = 1;
auto model2 = testCase.CreateModel();
auto results2 = onnxruntime::test::Run(*model1, testCase.input_value_map, testCase.output_names);
AssertEqual(results1, results2);
}
} // namespace test
} // namespace onnxruntime