mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add function-body to SoftmaxGrad (#6988)
* Add function body to SoftmaxGrad schema * Add type context and cleanup * Add test case with symbolic dimensions * Add opset specification to function * handle opset dependence * Exclude from minimal build
This commit is contained in:
parent
53c123dcee
commit
cc0e7bee76
7 changed files with 427 additions and 7 deletions
|
|
@ -17,7 +17,8 @@ if (onnxruntime_MINIMAL_BUILD)
|
|||
"${ONNXRUNTIME_ROOT}/core/graph/schema_registry.cc"
|
||||
"${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/*defs.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/*defs.cc"
|
||||
|
||||
"${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/graph/contrib_ops/onnx_function_util.cc"
|
||||
)
|
||||
|
||||
# no Function support initially
|
||||
|
|
|
|||
52
onnxruntime/core/graph/contrib_ops/onnx_function_util.cc
Normal file
52
onnxruntime/core/graph/contrib_ops/onnx_function_util.cc
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
#include "core/graph/contrib_ops/onnx_function_util.h"
|
||||
#include "core/util/math.h"
|
||||
|
||||
namespace ONNX_NAMESPACE {
|
||||
|
||||
TensorProto ToTensor(double value, TensorProto_DataType elem_type) {
|
||||
TensorProto t;
|
||||
t.set_data_type(elem_type);
|
||||
switch (elem_type) {
|
||||
case TensorProto_DataType::TensorProto_DataType_FLOAT:
|
||||
t.add_float_data((float)value);
|
||||
break;
|
||||
case TensorProto_DataType::TensorProto_DataType_DOUBLE:
|
||||
t.add_double_data(value);
|
||||
break;
|
||||
case TensorProto_DataType::TensorProto_DataType_FLOAT16:
|
||||
t.add_int32_data(onnxruntime::math::floatToHalf((float)value));
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
|
||||
return t;
|
||||
}
|
||||
|
||||
void BuildNodes(FunctionProto& functionProto, const std::vector<FunctionBodyHelper::NodeDef>& node_defs) {
|
||||
for (size_t i = 0; i < node_defs.size(); i++) {
|
||||
const FunctionBodyHelper::NodeDef& node = node_defs[i];
|
||||
auto* np = functionProto.add_node();
|
||||
|
||||
np->set_op_type(node.op_type);
|
||||
for (const auto& inp : node.inputs) {
|
||||
np->add_input(inp);
|
||||
}
|
||||
for (const auto& o : node.outputs) {
|
||||
np->add_output(o);
|
||||
}
|
||||
for (const auto& attr : node.attributes) {
|
||||
*(np->add_attribute()) = attr.proto;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool BuildFunctionProto(FunctionProto& functionProto, const OpSchema& schema,
|
||||
const std::vector<FunctionBodyHelper::NodeDef>& node_defs,
|
||||
const std::vector<OperatorSetIdProto>& relied_opsets) {
|
||||
BuildNodes(functionProto, node_defs);
|
||||
schema.BuildFunction(functionProto, relied_opsets);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace ONNX_NAMESPACE
|
||||
25
onnxruntime/core/graph/contrib_ops/onnx_function_util.h
Normal file
25
onnxruntime/core/graph/contrib_ops/onnx_function_util.h
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
|
||||
// Utility functions for building the body of a context-dependent function.
|
||||
// Temporary placeholder for utilities to be moved into ONNX repo. TODO.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "onnx/onnx-operators_pb.h"
|
||||
#include "onnx/defs/schema.h"
|
||||
#include "onnx/defs/function.h"
|
||||
|
||||
namespace ONNX_NAMESPACE {
|
||||
|
||||
// For floating-value constants of different precision:
|
||||
TensorProto ToTensor(double value, TensorProto_DataType elem_type);
|
||||
|
||||
// Utility function to construct a FunctionProto from an opschema (for the signature information),
|
||||
// a sequence of NodeDefs (for the function body), and the relied opsets.
|
||||
bool BuildFunctionProto(FunctionProto& functionProto,
|
||||
const OpSchema& schema,
|
||||
const std::vector<FunctionBodyHelper::NodeDef>& node_defs,
|
||||
const std::vector<OperatorSetIdProto>& relied_opsets = {});
|
||||
|
||||
} // namespace ONNX_NAMESPACE
|
||||
|
|
@ -180,11 +180,12 @@ static std::unordered_map<std::string, int> CreateOpsetImportsForFunction(const
|
|||
std::unordered_map<std::string, int> function_opset_imports{graph_opset_imports};
|
||||
// merge with opset imports in function proto
|
||||
for (const auto& opset_import : func_proto.opset_import()) {
|
||||
auto result = function_opset_imports.insert({opset_import.domain(), static_cast<int>(opset_import.version())});
|
||||
ORT_ENFORCE(result.second,
|
||||
auto opset_version = static_cast<int>(opset_import.version());
|
||||
auto result = function_opset_imports.insert({opset_import.domain(), opset_version});
|
||||
ORT_ENFORCE((result.first->second == opset_version),
|
||||
"ONNX model does not support multiple opset versions for a domain. Model imports opset version ",
|
||||
result.first->second, " for domain ", result.first->first, " and function is trying to import opset version ",
|
||||
opset_import.version(), " for the same domain");
|
||||
opset_version, " for the same domain");
|
||||
}
|
||||
|
||||
return function_opset_imports;
|
||||
|
|
|
|||
|
|
@ -2371,12 +2371,29 @@ void Graph::InitFunctionBodyForNode(Node& node) {
|
|||
if (node.op_->HasContextDependentFunction()) {
|
||||
NodeProto node_proto;
|
||||
node.ToProto(node_proto);
|
||||
onnx::FunctionBodyBuildContextImpl function_body_ctx(node_proto);
|
||||
std::vector<TypeProto> input_types;
|
||||
for (size_t i = 0, n = node.InputDefs().size(); i < n; i++) {
|
||||
auto p_node_arg = node.InputDefs().at(i);
|
||||
if ((nullptr != p_node_arg) && p_node_arg->Exists()) {
|
||||
auto& type = *(p_node_arg->TypeAsProto());
|
||||
input_types.emplace_back(type);
|
||||
} else
|
||||
input_types.emplace_back();
|
||||
}
|
||||
onnx::FunctionBodyBuildContextImpl function_body_ctx(node_proto, input_types);
|
||||
node.op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto);
|
||||
} else {
|
||||
onnx_function_proto = *(node.op_->GetFunction());
|
||||
}
|
||||
|
||||
// Check function's opset requirements are compatible with model's opset.
|
||||
auto& graphImports = DomainToVersionMap();
|
||||
for (const auto& fn_import : onnx_function_proto.opset_import()) {
|
||||
auto it = graphImports.find(fn_import.domain());
|
||||
if ((it != graphImports.end()) && (it->second != fn_import.version()))
|
||||
return; // Incompatible. Do not use this function expansion.
|
||||
}
|
||||
|
||||
auto func_ptr = onnxruntime::make_unique<onnxruntime::FunctionImpl>(*this, node.Index(), onnx_function_proto,
|
||||
logger_);
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "core/graph/op.h"
|
||||
#include "core/graph/contrib_ops/contrib_defs.h"
|
||||
#include "core/graph/contrib_ops/onnx_function_util.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "orttraining/core/graph/training_op_defs.h"
|
||||
#include "orttraining/core/framework/distributed_run_context.h"
|
||||
|
|
@ -343,7 +344,7 @@ void RegisterTrainingOpSchemas() {
|
|||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.Input(0, "dY", "Gradient of output Y", "T")
|
||||
.Input(1, "X", "Input tensor", "T")
|
||||
.Input(1, "Y", "Input tensor", "T")
|
||||
.Output(0, "dX", "Gradient of input X", "T")
|
||||
.Attr(
|
||||
"axis",
|
||||
|
|
@ -356,7 +357,54 @@ void RegisterTrainingOpSchemas() {
|
|||
"T",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
|
||||
"Constrain input and output types to float tensors.")
|
||||
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput);
|
||||
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
|
||||
.SetContextDependentFunctionBodyBuilder(
|
||||
[](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
|
||||
// SoftmaxGrad computes dX = Y * ( dY - dot(Y, dY))
|
||||
// ONNX does not have a dot product, which can be simulated as a pointwise-multiplication ("Mul"),
|
||||
// followed by a "ReduceSum". Unfortunately, the treatment of "axis" is different in "SoftmaxGrad"
|
||||
// and "ReduceSum". If axis=k for SoftmaxGrad, we need to specify [k, ..., n-1] as the axes of
|
||||
// reduction for "ReduceSum", after accounting for negative-axis specification.
|
||||
// An alternative solution would be to Flatten inputs to 2D and then reshape output back to original shape.
|
||||
// Hopefully, many of these ops can be optimized away in the common-case of statically-known shapes.
|
||||
|
||||
auto* axis_attr = ctx.getAttribute("axis");
|
||||
int64_t axis = (axis_attr != nullptr) ? axis_attr->i() : 1;
|
||||
auto zero1d = ToTensor(std::vector<int64_t>({0}));
|
||||
zero1d.add_dims(1);
|
||||
|
||||
// nodes: {outputs, op, inputs, attributes}
|
||||
|
||||
// First, convert axis specification k to reduction axes [k, k+1, ..., n-1]
|
||||
std::vector<FunctionBodyHelper::NodeDef> body{
|
||||
FunctionBodyHelper::Const<int64_t>("one", 1),
|
||||
FunctionBodyHelper::Const<int64_t>("k", axis),
|
||||
{{"axis_zero"}, "Constant", {}, {{"value", zero1d}}},
|
||||
{{"shape"}, "Shape", {"dY"}},
|
||||
{{"n_as_vector"}, "Shape", {"shape"}},
|
||||
{{"n"}, "Squeeze", {"n_as_vector", "axis_zero"}},
|
||||
};
|
||||
|
||||
// For negative axis, add n to axis-value k; then use Range(...).
|
||||
if (axis >= 0) {
|
||||
body.push_back({{"reduction_axes"}, "Range", {"k", "n", "one"}});
|
||||
} else {
|
||||
body.push_back({{"n_plus_k"}, "Add", {"n", "k"}});
|
||||
body.push_back({{"reduction_axes"}, "Range", {"n_plus_k", "n", "one"}});
|
||||
}
|
||||
|
||||
// compute dX = Y * ( dY - dot(Y, dY)) = Y * ( dY - ReduceSum(Y * dY))
|
||||
body.push_back({{"a"}, "Mul", {"Y", "dY"}});
|
||||
body.push_back({{"b"}, "ReduceSum", {"a", "reduction_axes"}});
|
||||
body.push_back({{"c"}, "Sub", {"dY", "b"}});
|
||||
body.push_back({{"dX"}, "Mul", {"Y", "c"}});
|
||||
|
||||
OperatorSetIdProto onnx_opset_13;
|
||||
onnx_opset_13.set_domain("");
|
||||
onnx_opset_13.set_version(13);
|
||||
|
||||
return ONNX_NAMESPACE::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(LogSoftmaxGrad)
|
||||
.SetDomain(kMSDomain)
|
||||
|
|
|
|||
276
orttraining/orttraining/test/gradient/function_ops_test.cc
Normal file
276
orttraining/orttraining/test/gradient/function_ops_test.cc
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/graph/contrib_ops/contrib_defs.h"
|
||||
#include "orttraining/core/graph/training_op_defs.h"
|
||||
#include "test/test_environment.h"
|
||||
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
|
||||
#include "test/framework/test_utils.h"
|
||||
|
||||
using namespace ::onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
typedef std::vector<onnxruntime::NodeArg*> ArgMap;
|
||||
|
||||
static void RegisterSchemas() {
|
||||
static bool registered = false;
|
||||
if (!registered) {
|
||||
onnxruntime::training::RegisterTrainingOpSchemas();
|
||||
registered = true;
|
||||
}
|
||||
}
|
||||
|
||||
static ONNX_NAMESPACE::TypeProto TensorType(int32_t elem_type, std::vector<int64_t> dims) {
|
||||
ONNX_NAMESPACE::TypeProto typeProto;
|
||||
typeProto.mutable_tensor_type()->set_elem_type(elem_type);
|
||||
auto* shape = typeProto.mutable_tensor_type()->mutable_shape();
|
||||
for (auto dim : dims)
|
||||
shape->add_dim()->set_dim_value(dim);
|
||||
return typeProto;
|
||||
}
|
||||
|
||||
static ONNX_NAMESPACE::TypeProto TensorType(int32_t elem_type, std::vector<std::string> dims) {
|
||||
ONNX_NAMESPACE::TypeProto typeProto;
|
||||
typeProto.mutable_tensor_type()->set_elem_type(elem_type);
|
||||
auto* shape = typeProto.mutable_tensor_type()->mutable_shape();
|
||||
for (auto dim : dims) {
|
||||
uint64_t dimval;
|
||||
std::istringstream s(dim);
|
||||
if (s >> dimval) {
|
||||
shape->add_dim()->set_dim_value(dimval);
|
||||
} else {
|
||||
shape->add_dim()->set_dim_param(dim);
|
||||
}
|
||||
}
|
||||
return typeProto;
|
||||
}
|
||||
|
||||
static std::vector<OrtValue>
|
||||
Run(onnxruntime::Model& model, NameMLValMap& feeds, std::vector<std::string> output_names) {
|
||||
SessionOptions session_options;
|
||||
InferenceSession session_object{session_options, GetEnvironment()};
|
||||
|
||||
std::string serialized_model;
|
||||
const bool serialization_status = model.ToProto().SerializeToString(&serialized_model);
|
||||
EXPECT_TRUE(serialization_status) << "Failed to serialize proto to string";
|
||||
std::stringstream sstr(serialized_model);
|
||||
auto status = session_object.Load(sstr);
|
||||
EXPECT_TRUE(status.IsOK());
|
||||
status = session_object.Initialize();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
RunOptions run_options;
|
||||
run_options.run_tag = session_options.session_logid;
|
||||
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
status = session_object.Run(run_options, feeds, output_names, &fetches);
|
||||
EXPECT_TRUE(status.IsOK()) << "Session Run failed.";
|
||||
|
||||
return fetches;
|
||||
}
|
||||
|
||||
// Restricted to float tensors
|
||||
static void AssertEqual(const Tensor& tensor1, const Tensor& tensor2) {
|
||||
auto size = tensor1.Shape().Size();
|
||||
auto* data1 = tensor1.template Data<float>();
|
||||
auto* data2 = tensor2.template Data<float>();
|
||||
|
||||
float threshold = 0.001f;
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
ASSERT_NEAR(data1[i], data2[i], threshold) << "as position i:" << i;
|
||||
}
|
||||
}
|
||||
|
||||
static void AssertEqual(const std::vector<OrtValue>& results1, const std::vector<OrtValue>& results2) {
|
||||
ASSERT_EQ(results1.size(), results2.size());
|
||||
for (int i = 0; i < results1.size(); i++) {
|
||||
auto& value1 = results1[i].Get<Tensor>();
|
||||
auto& value2 = results2[i].Get<Tensor>();
|
||||
AssertEqual(value1, value2);
|
||||
}
|
||||
}
|
||||
|
||||
struct FunctionTestCase {
|
||||
const char* opname;
|
||||
|
||||
std::vector<NodeArg> input_args;
|
||||
std::vector<std::pair<std::string, OrtValue>> input_values;
|
||||
NameMLValMap input_value_map;
|
||||
|
||||
std::vector<std::string> output_names;
|
||||
std::vector<NodeArg> output_args;
|
||||
|
||||
NodeAttributes attributes;
|
||||
std::unique_ptr<IExecutionProvider> provider;
|
||||
|
||||
std::unordered_map<std::string, int> opsets;
|
||||
|
||||
FunctionTestCase(const char* _opname) : opname(_opname), provider(new CPUExecutionProvider(CPUExecutionProviderInfo())) {}
|
||||
|
||||
void AddInput(std::string input_name, std::vector<int64_t> shape, std::vector<float> data, std::vector<std::string> symshape = {}) {
|
||||
auto arg_type = (symshape.size() > 0) ? TensorType(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, symshape) : TensorType(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, shape);
|
||||
input_args.emplace_back(input_name, &arg_type);
|
||||
|
||||
OrtValue ort_value;
|
||||
CreateMLValue<float>(provider->GetAllocator(0, OrtMemTypeDefault), shape, data, &ort_value);
|
||||
input_values.push_back(std::make_pair(input_name, ort_value));
|
||||
input_value_map.insert(std::make_pair(input_name, ort_value));
|
||||
}
|
||||
|
||||
void AddOutput(std::string output_name) {
|
||||
output_names.emplace_back(output_name);
|
||||
output_args.emplace_back(output_name, nullptr);
|
||||
}
|
||||
|
||||
void AddAttribute(const char* attr_name, int64_t attr_val) {
|
||||
ONNX_NAMESPACE::AttributeProto axis_attr;
|
||||
axis_attr.set_name(attr_name);
|
||||
axis_attr.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT);
|
||||
axis_attr.set_i(attr_val);
|
||||
attributes[attr_name] = axis_attr;
|
||||
}
|
||||
|
||||
onnxruntime::Node& AddCallNodeTo(onnxruntime::Graph& graph) {
|
||||
std::vector<NodeArg*> input_arg_ptrs;
|
||||
|
||||
for (auto& arg : input_args)
|
||||
input_arg_ptrs.push_back(&arg);
|
||||
|
||||
std::vector<NodeArg*> output_arg_ptrs;
|
||||
for (auto& arg : output_args)
|
||||
output_arg_ptrs.push_back(&arg);
|
||||
|
||||
return graph.AddNode("fncallnode", opname, "function call node", input_arg_ptrs, output_arg_ptrs, &attributes, onnxruntime::kMSDomain);
|
||||
}
|
||||
|
||||
std::unique_ptr<Model> CreateModel(bool inline_call = false) {
|
||||
RegisterSchemas();
|
||||
if (opsets.size() == 0) {
|
||||
// Default opsets
|
||||
opsets[kOnnxDomain] = 13;
|
||||
opsets[kMSDomain] = 1;
|
||||
}
|
||||
|
||||
std::unique_ptr<Model> model(new Model("test", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
opsets, {}, DefaultLoggingManager().DefaultLogger()));
|
||||
|
||||
onnxruntime::Graph& graph = model->MainGraph();
|
||||
auto& call_node = AddCallNodeTo(graph);
|
||||
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
|
||||
if (inline_call) {
|
||||
graph.InlineFunction(call_node);
|
||||
status = graph.Resolve();
|
||||
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
|
||||
}
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
void RunTest() {
|
||||
auto model1 = CreateModel(false);
|
||||
auto results1 = Run(*model1, input_value_map, output_names);
|
||||
|
||||
auto model2 = CreateModel(true);
|
||||
auto results2 = Run(*model2, input_value_map, output_names);
|
||||
|
||||
AssertEqual(results1, results2);
|
||||
}
|
||||
};
|
||||
|
||||
static void InitSoftmaxGradTestCase(FunctionTestCase& testCase, std::vector<int64_t> shape) {
|
||||
int64_t size = 1;
|
||||
for (auto dim : shape)
|
||||
size *= dim;
|
||||
|
||||
std::vector<float> value(size);
|
||||
for (int64_t i = 0; i < size; i++)
|
||||
value[i] = float(i);
|
||||
|
||||
testCase.AddInput("dY", shape, value);
|
||||
testCase.AddInput("Y", shape, value);
|
||||
testCase.AddOutput("dX");
|
||||
}
|
||||
|
||||
TEST(SoftmaxGradExpansionTest, DefaultAxis) {
|
||||
FunctionTestCase testCase("SoftmaxGrad");
|
||||
InitSoftmaxGradTestCase(testCase, {3, 2});
|
||||
testCase.RunTest();
|
||||
}
|
||||
|
||||
TEST(SoftmaxGradExpansionTest, NegativeAxis) {
|
||||
FunctionTestCase testCase("SoftmaxGrad");
|
||||
InitSoftmaxGradTestCase(testCase, {3, 2});
|
||||
testCase.AddAttribute("axis", -1);
|
||||
testCase.RunTest();
|
||||
}
|
||||
|
||||
TEST(SoftmaxGradExpansionTest, PositiveAxis) {
|
||||
FunctionTestCase testCase("SoftmaxGrad");
|
||||
InitSoftmaxGradTestCase(testCase, {3, 2});
|
||||
testCase.AddAttribute("axis", 1);
|
||||
testCase.RunTest();
|
||||
}
|
||||
|
||||
TEST(SoftmaxGradExpansionTest, 3D) {
|
||||
FunctionTestCase testCase("SoftmaxGrad");
|
||||
InitSoftmaxGradTestCase(testCase, {3, 2, 2});
|
||||
testCase.RunTest();
|
||||
}
|
||||
|
||||
TEST(SoftmaxGradExpansionTest, SymbolicShape) {
|
||||
FunctionTestCase testCase("SoftmaxGrad");
|
||||
std::vector<int64_t> shape{3, 2, 2};
|
||||
std::vector<std::string> sym_shape{"BatchSize", "SeqSize", "2"};
|
||||
int size = 12;
|
||||
std::vector<float> value(size);
|
||||
for (int64_t i = 0; i < size; i++)
|
||||
value[i] = float(i);
|
||||
|
||||
testCase.AddInput("dY", shape, value, sym_shape);
|
||||
testCase.AddInput("Y", shape, value, sym_shape);
|
||||
testCase.AddOutput("dX");
|
||||
testCase.RunTest();
|
||||
}
|
||||
|
||||
// Test (unexpanded) versions for both opset 12 and opset 13 models to ensure
|
||||
// function-schema does not impact handling of opset 12 models. The current
|
||||
// expansion requires opset 13, and no expansion should happen in opset 12
|
||||
// models. Test is required since ORT currently generates function-expansion
|
||||
// even when op is dispatched to a kernel.
|
||||
|
||||
TEST(SoftmaxGradExpansionTest, OpsetTest) {
|
||||
FunctionTestCase testCase("SoftmaxGrad");
|
||||
testCase.opsets[kOnnxDomain] = 12;
|
||||
testCase.opsets[kMSDomain] = 1;
|
||||
InitSoftmaxGradTestCase(testCase, {3, 2, 2});
|
||||
|
||||
auto model1 = testCase.CreateModel();
|
||||
auto results1 = onnxruntime::test::Run(*model1, testCase.input_value_map, testCase.output_names);
|
||||
|
||||
testCase.opsets[kOnnxDomain] = 13;
|
||||
testCase.opsets[kMSDomain] = 1;
|
||||
|
||||
auto model2 = testCase.CreateModel();
|
||||
auto results2 = onnxruntime::test::Run(*model1, testCase.input_value_map, testCase.output_names);
|
||||
|
||||
AssertEqual(results1, results2);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue