Add Trilu custom op (#4537)

Co-authored-by: neginraoof <neginmr@utexas.edu>
This commit is contained in:
Ksenija Stanojevic 2020-08-17 14:42:26 -07:00 committed by GitHub
parent 1ce2982f65
commit ea37a4d89b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 1137 additions and 37 deletions

View file

@ -34,11 +34,11 @@ If you are adding a symbolic function for a new custom op, add the function to t
### 2. Extending ONNX Runtime with Custom Ops
The next step is to add op schema and kernel implementation in ONNX Runtime.
THe exmaple Inverse custom op is added in:
Consider the Inverse custom op as an example added in:
https://github.com/microsoft/onnxruntime/pull/3485
Custom op schema and shape inference function should be added under ```onnxruntime/core/graph/contrib_ops/contrib_defs.cc ```
Custom op schema and shape inference function should be added in ```onnxruntime/core/graph/contrib_ops/contrib_defs.cc ```
using ```ONNX_CONTRIB_OPERATOR_SCHEMA```.
```c++
@ -48,6 +48,17 @@ ONNX_CONTRIB_OPERATOR_SCHEMA(Inverse)
...
```
To comply with ONNX guideline for new operators, a new operator should have complete reference implementation tests and
shape inference tests.
Reference implementation python tests should be added in:
``onnxruntime/test/python/contrib_ops``
E.g.: ``onnxruntime/test/python/contrib_ops/onnx_test_trilu.py``
Shape inference C++ tests should be added in:
``onnxruntime/test/contrib_ops``
E.g.: ``onnxruntime/test/contrib_ops/trilu_shape_inference_test.cc``
The operator kernel should be implemented using ```Compute``` function
under contrib namespace in ```onnxruntime/contrib_ops/cpu/<operator>.cc```
for CPU and ```onnxruntime/contrib_ops/cuda/<operator>.cc``` for CUDA.
@ -90,7 +101,7 @@ Now you should be able to build and install ONNX Runtime to start using your cus
##### ONNX Runtime Tests
ONNX Runtime custom op tests should be added in: ```onnxruntime/test/contrib_ops/<operator>_test.cc ```
ONNX Runtime custom op kernel tests should be added in: ```onnxruntime/test/contrib_ops/<operator>_test.cc ```
```c++
namespace onnxruntime {

View file

@ -74,6 +74,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu);
Status RegisterNchwcKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
@ -158,6 +159,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -0,0 +1,105 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/common.h"
#include "core/framework/op_kernel.h"
#include "core/util/math_cpuonly.h"
#include "Eigen/src/Core/Map.h"
#include "trilu.h"
#include <functional>
namespace onnxruntime {
namespace contrib {
ONNX_OPERATOR_KERNEL_EX(
Trilu,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints<float, double, int64_t>()),
Trilu);
template <typename T>
static Status TriluImpl(const Tensor* X, Tensor* Y, int64_t k_val, bool up) {
const auto& X_shape = X->Shape();
int64_t X_num_dims = static_cast<int64_t>(X_shape.NumDimensions());
const auto* X_data = reinterpret_cast<const T*>(X->DataRaw());
int64_t matrix_h = static_cast<int64_t>(X_shape[X_num_dims - 2]);
int64_t matrix_w = static_cast<int64_t>(X_shape[X_num_dims - 1]);
int64_t batch_size = 1;
for (int64_t i = 0; i < X_num_dims - 2; ++i) {
batch_size *= X_shape[i];
}
int64_t num_matrix_elems = matrix_h * matrix_w;
auto* Y_data = reinterpret_cast<T*>(Y->MutableDataRaw());
for (int64_t b = 0; b < batch_size; b++) { // can be parallelized if need to
auto X_batch_data = X_data + (b * num_matrix_elems);
auto Y_batch_data = Y_data + (b * num_matrix_elems);
auto input_mat = ConstEigenMatrixMapRowMajor<T>(X_batch_data, matrix_h, matrix_w);
auto output_mat = EigenMatrixMapRowMajor<T>(Y_batch_data, matrix_h, matrix_w);
if (X_batch_data != Y_batch_data) {
output_mat = input_mat;
}
if (up) {
int64_t start_i = k_val > 0 ? 0 : 1 - k_val;
for (int64_t i = start_i; i < matrix_h; i++) {
for (int64_t j = 0; j < i + k_val && j < matrix_w; j++) {
output_mat(i, j) = 0;
}
}
} else {
int64_t end_i = std::min(matrix_h, matrix_w - k_val);
for (int64_t i = 0; i < end_i; i++) {
for (int64_t j = std::max(static_cast<int64_t>(0), i + k_val + 1); j < matrix_w; j++) {
output_mat(i, j) = 0;
}
}
}
}
return Status::OK();
}
Status Trilu::Compute(OpKernelContext* ctx) const {
Status status;
const auto* X = ctx->Input<Tensor>(0);
const auto* k = ctx->Input<Tensor>(1);
bool up = upper_;
int64_t k_val = 0;
if (k) {
ORT_ENFORCE(IsScalarOr1ElementVector(k), "k should be a 1-D or 0-D tensor.");
k_val = *(k->template Data<int64_t>());
}
const auto& X_shape = X->Shape();
auto* Y = ctx->Output(0, X_shape);
int64_t X_num_dims = static_cast<int64_t>(X_shape.NumDimensions());
// input validation
if (X_num_dims < 2) { // this is getting capture by shape inference code as well
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Input tensor should have a rank of at least 2");
}
MLDataType data_type = X->DataType();
const auto element_size = data_type->Size();
switch (element_size) {
case sizeof(float):
status = TriluImpl<float>(X, Y, k_val, up);
break;
case sizeof(double):
status = TriluImpl<double>(X, Y, k_val, up);
break;
default:
ORT_THROW("Unsupported input data type of ", data_type);
}
return status;
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
namespace onnxruntime {
namespace contrib {
class Trilu final : public OpKernel {
public:
explicit Trilu(const OpKernelInfo& info) : OpKernel(info) {
int64_t temp;
ORT_ENFORCE(info.GetAttr<int64_t>("upper", &temp).IsOK());
upper_ = temp != 0;
}
Status Compute(OpKernelContext* ctx) const override;
private:
bool upper_;
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -2848,6 +2848,76 @@ It's an extension of Gelu. It takes the sum of input A and bias input B as the i
}
});
static const char* Trilu_ver1_doc = R"DOC(
Returns the upper or lower triangular part of a 2-D matrix, or batches of 2-D matrices. If the attribute "upper" is set to true,
the upper triangular matrix is retained. Lower triangular matrix is retained otherwise. Default value for upper is true.
Trilu takes one input tensor of shape [*, N, M], where * is zero or more batch dimensions. The upper triangular part consists
of the elements on and above the given diagonal (k). The lower triangular part consists of elements on and below the diagonal.
All other elements in the matrix are set to zero.
If k = 0, the triangular part on and above/below the main diagonal is retained.
If upper is set to true, a positive k retains the upper triangular matrix excluding k diagonals above
the main diagonal. A negative k value includes as many diagonals below the main diagonal.
If upper is set to false, a positive k retains the lower triangular matrix including k diagonals above
the main diagonal. A negative k value excludes as many diagonals below the main diagonal.
)DOC";
ONNX_CONTRIB_OPERATOR_SCHEMA(Trilu)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(Trilu_ver1_doc)
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
.Attr("upper",
"Boolean. Indicates whether upper or lower part of matrix is retained. Default is true.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Input(
0,
"X",
"Input tensor of rank 2 or higher.",
"T")
.Input(
1,
"k",
"A 0-D tensor containing a single value corresponding to the number diagonals above or the main diagonal to exclude or include."
"Default value is 0 if it's not specified.",
"tensor(int64)",
OpSchema::Optional)
.Output(
0,
"Y",
"Output tensor of the same type and shape as the input tensor.",
"T")
.TypeConstraint(
"T",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(bfloat16)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(bool)"},
"Constrain input and output types to all numeric tensors and bool tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
using namespace ONNX_NAMESPACE;
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (hasInputShape(ctx, 0)) {
const TensorShapeProto& input_shape =
ctx.getInputType(0)->tensor_type().shape();
const int rank = static_cast<int>(input_shape.dim_size());
if (rank < 2) {
fail_shape_inference("Input rank must be >= 2.")
}
propagateShapeFromInputToOutput(ctx, 0, 0);
}
});
RegisterBertSchemas();
}
} // namespace contrib

View file

@ -0,0 +1,107 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "onnx/shape_inference/implementation.h"
#include "onnx/checker.h"
namespace onnxruntime {
namespace test {
auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
const std::string MS_DOMAIN = "com.microsoft";
void CheckShapeEquality(ONNX_NAMESPACE::TensorShapeProto* shape1, ONNX_NAMESPACE::TensorShapeProto* shape2) {
EXPECT_NE(shape1, nullptr);
EXPECT_NE(shape2, nullptr);
if ((shape1 != nullptr) && (shape2 != nullptr)) {
EXPECT_EQ(shape1->dim_size(), shape2->dim_size()) << "Shapes do not have same rank";
auto min_dims = std::min(shape1->dim_size(), shape2->dim_size());
for (int i = 0; i < min_dims; ++i) {
auto dim1 = shape1->dim(i);
auto dim2 = shape2->dim(i);
EXPECT_EQ(dim1.has_dim_value(), dim2.has_dim_value());
if (dim1.has_dim_value()) {
EXPECT_EQ(dim1.dim_value(), dim2.dim_value());
}
EXPECT_EQ(dim1.has_dim_param(), dim2.has_dim_param());
if (dim1.has_dim_param()) {
EXPECT_EQ(dim1.dim_param(), dim2.dim_param());
}
}
}
}
inline void CreateValueInfo(
ONNX_NAMESPACE::ValueInfoProto& value_info,
const std::string& name,
const ONNX_NAMESPACE::TensorProto_DataType& elem_type,
const std::vector<int64_t> shape) {
value_info.set_name(name);
ONNX_NAMESPACE::TypeProto* type = value_info.mutable_type();
ONNX_NAMESPACE::TypeProto_Tensor* tensor_type = type->mutable_tensor_type();
tensor_type->set_elem_type(elem_type);
ONNX_NAMESPACE::TensorShapeProto* value_info_shape = tensor_type->mutable_shape();
for (int64_t dim_value : shape) {
value_info_shape->add_dim()->set_dim_value(dim_value);
}
}
inline void TestShapeInference(
const std::string& op_type,
const std::vector<ONNX_NAMESPACE::ValueInfoProto>& inputs,
const std::vector<ONNX_NAMESPACE::AttributeProto>& attributes,
ONNX_NAMESPACE::ValueInfoProto& output) {
ONNX_NAMESPACE::ModelProto model;
// Set opset (domain + version)
ONNX_NAMESPACE::OperatorSetIdProto* op_set_id = model.add_opset_import();
op_set_id->set_domain(MS_DOMAIN);
op_set_id->set_version(1);
model.set_ir_version(6);
model.set_producer_name("onnx");
// Set model graph
ONNX_NAMESPACE::GraphProto* graph = model.mutable_graph();
graph->set_name("test-op");
graph->add_value_info();
// Set add operator node to graph
auto& node = *graph->add_node();
node.set_op_type(op_type);
node.set_domain(MS_DOMAIN);
node.set_name("test_node");
// Add node inputs and graph inputs
for (auto const& n_ : inputs) {
node.add_input(n_.name());
*graph->add_input() = n_;
}
// Add node attributes
for (auto const& attr : attributes) {
node.add_attribute()->CopyFrom(attr);
}
node.add_output("Output");
ONNX_NAMESPACE::checker::check_model(model);
ONNX_NAMESPACE::shape_inference::InferShapes(model, false, schema_registry);
auto inferredGraph = model.graph();
int index = static_cast<int>(inputs.size()); // index for value_info of output
auto inferred_output = inferredGraph.value_info(index);
auto elem_type = output.mutable_type()->mutable_tensor_type()->elem_type();
auto inferred_elem_type = inferred_output.mutable_type()->mutable_tensor_type()->elem_type();
EXPECT_EQ(elem_type, inferred_elem_type);
auto shape = output.mutable_type()->mutable_tensor_type()->mutable_shape();
auto inferred_shape = inferred_output.mutable_type()->mutable_tensor_type()->mutable_shape();
CheckShapeEquality(shape, inferred_shape);
}
} // namespace test
} // namespace onnxruntime

View file

@ -0,0 +1,121 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "test/providers/provider_test_utils.h"
#include "onnx/shape_inference/implementation.h"
#include "onnx/checker.h"
#include "shape_inference_test_helper.h"
namespace onnxruntime {
namespace test {
TEST(ShapeInferenceTests, tri_upper_float) {
std::vector<int64_t> shape = {4, 7};
ONNX_NAMESPACE::ValueInfoProto input;
CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_FLOAT, shape);
std::vector<ONNX_NAMESPACE::ValueInfoProto> inputs = {input};
ONNX_NAMESPACE::AttributeProto upper;
upper.set_name("upper");
upper.set_type(ONNX_NAMESPACE::AttributeProto::INT);
upper.set_i(1); // upper
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {upper};
ONNX_NAMESPACE::ValueInfoProto output;
CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_FLOAT, shape);
TestShapeInference("Trilu", inputs, attributes, output);
}
TEST(ShapeInferenceTests, tri_upper_zero_dim_int) {
std::vector<int64_t> shape = {4, 7, 0};
ONNX_NAMESPACE::ValueInfoProto input;
CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_INT32, shape);
std::vector<ONNX_NAMESPACE::ValueInfoProto> inputs = {input};
ONNX_NAMESPACE::AttributeProto upper;
upper.set_name("upper");
upper.set_type(ONNX_NAMESPACE::AttributeProto::INT);
upper.set_i(1); // upper
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {upper};
ONNX_NAMESPACE::ValueInfoProto output;
CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_INT32, shape);
TestShapeInference("Trilu", inputs, attributes, output);
}
TEST(ShapeInferenceTests, tri_upper_4d_long) {
std::vector<int64_t> shape = {2, 3, 7, 11};
ONNX_NAMESPACE::ValueInfoProto input;
CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_INT64, shape);
std::vector<ONNX_NAMESPACE::ValueInfoProto> inputs = {input};
ONNX_NAMESPACE::AttributeProto upper;
upper.set_name("upper");
upper.set_type(ONNX_NAMESPACE::AttributeProto::INT);
upper.set_i(1); // upper
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {upper};
ONNX_NAMESPACE::ValueInfoProto output;
CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_INT64, shape);
TestShapeInference("Trilu", inputs, attributes, output);
}
TEST(ShapeInferenceTests, tri_lower_float) {
std::vector<int64_t> shape = {4, 7};
ONNX_NAMESPACE::ValueInfoProto input;
CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_FLOAT, shape);
std::vector<ONNX_NAMESPACE::ValueInfoProto> inputs = {input};
ONNX_NAMESPACE::AttributeProto upper;
upper.set_name("upper");
upper.set_type(ONNX_NAMESPACE::AttributeProto::INT);
upper.set_i(0); // lower
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {upper};
ONNX_NAMESPACE::ValueInfoProto output;
CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_FLOAT, shape);
TestShapeInference("Trilu", inputs, attributes, output);
}
TEST(ShapeInferenceTests, tri_lower_4d_int) {
std::vector<int64_t> shape = {2, 3, 7, 11};
ONNX_NAMESPACE::ValueInfoProto input;
CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_INT32, shape);
std::vector<ONNX_NAMESPACE::ValueInfoProto> inputs = {input};
ONNX_NAMESPACE::AttributeProto upper;
upper.set_name("upper");
upper.set_type(ONNX_NAMESPACE::AttributeProto::INT);
upper.set_i(0); // lower
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {upper};
ONNX_NAMESPACE::ValueInfoProto output;
CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_INT32, shape);
TestShapeInference("Trilu", inputs, attributes, output);
}
TEST(ShapeInferenceTests, tri_lower_zero_dim_long) {
std::vector<int64_t> shape = {4, 7, 0};
ONNX_NAMESPACE::ValueInfoProto input;
CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_INT64, shape);
std::vector<ONNX_NAMESPACE::ValueInfoProto> inputs = {input};
ONNX_NAMESPACE::AttributeProto upper;
upper.set_name("upper");
upper.set_type(ONNX_NAMESPACE::AttributeProto::INT);
upper.set_i(0); // lower
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {upper};
ONNX_NAMESPACE::ValueInfoProto output;
CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_INT64, shape);
TestShapeInference("Trilu", inputs, attributes, output);
}
} // namespace test
} // namespace onnxruntime

View file

@ -0,0 +1,242 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "core/util/math.h"
namespace onnxruntime {
namespace test {
TEST(TriluContribOpTest, two_by_two_float_upper) {
OpTester test("Trilu", 1, kMSDomain);
int64_t up = 1;
test.AddAttribute("upper", up);
test.AddInput<float>("X", {2, 2}, {4.f, 7.f, 2.f, 6.f});
test.AddOutput<float>("Y", {2, 2}, {4.f, 7.f, 0.f, 6.f});
test.Run();
}
TEST(TriluContribOpTest, two_by_two_float_lower) {
OpTester test("Trilu", 1, kMSDomain);
int64_t up = 0;
test.AddAttribute("upper", up);
test.AddInput<float>("X", {2, 2}, {4.f, 7.f, 2.f, 6.f});
test.AddOutput<float>("Y", {2, 2}, {4.f, 0.f, 2.f, 6.f});
test.Run();
}
TEST(TriluContribOpTest, two_by_two_double_upper) {
OpTester test("Trilu", 1, kMSDomain);
test.AddInput<double>("X", {2, 2}, {4, 7, 2, 6});
test.AddInput<int64_t>("k", {1}, {1});
test.AddOutput<double>("Y", {2, 2}, {0, 7, 0, 0});
test.Run();
}
TEST(TriluContribOpTest, two_by_two_double_lower) {
OpTester test("Trilu", 1, kMSDomain);
int64_t up = 0;
test.AddAttribute("upper", up);
test.AddInput<double>("X", {2, 2}, {4, 7, 2, 6});
test.AddInput<int64_t>("k", {1}, {1});
test.AddOutput<double>("Y", {2, 2}, {4, 7, 2, 6});
test.Run();
}
TEST(TriluContribOpTest, two_by_two_long_upper) {
OpTester test("Trilu", 1, kMSDomain);
int64_t up = 1;
test.AddAttribute("upper", up);
test.AddInput<int64_t>("X", {2, 2}, {4, 7, 2, 6});
test.AddOutput<int64_t>("Y", {2, 2}, {4, 7, 0, 6});
test.Run();
}
TEST(TriluContribOpTest, two_by_two_long_lower) {
OpTester test("Trilu", 1, kMSDomain);
int64_t up = 0;
test.AddAttribute("upper", up);
test.AddInput<int64_t>("X", {2, 2}, {4, 7, 2, 6});
test.AddOutput<int64_t>("Y", {2, 2}, {4, 0, 2, 6});
test.Run();
}
TEST(TriluContribOpTest, three_dim_float_upper) {
OpTester test("Trilu", 1, kMSDomain);
test.AddInput<float>("X", {2, 3, 4},
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
});
test.AddInput<int64_t>("k", {1}, {1});
test.AddOutput<float>("Y", {2, 3, 4},
{0.f, 1.f, 5.f, 8.f,
0.f, 0.f, 2.f, 4.f,
0.f, 0.f, 0.f, 3.f,
0.f, 6.f, 2.f, 1.f,
0.f, 0.f, 5.f, 8.f,
0.f, 0.f, 0.f, 4.f,
});
test.Run();
}
TEST(TriluContribOpTest, three_dim_float_lower) {
OpTester test("Trilu", 1, kMSDomain);
int64_t up = 0;
test.AddAttribute("upper", up);
test.AddInput<float>("X", {2, 3, 4},
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
});
test.AddInput<int64_t>("k", {1}, {1});
test.AddOutput<float>("Y", {2, 3, 4},
{4.f, 1.f, 0.f, 0.f,
4.f, 3.f, 2.f, 0.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 0.f, 0.f,
4.f, 1.f, 5.f, 0.f,
4.f, 3.f, 2.f, 4.f,
});
test.Run();
}
TEST(TriluContribOpTest, neg_k_float_upper) {
OpTester test("Trilu", 1, kMSDomain);
int64_t up = 1;
test.AddAttribute("upper", up);
test.AddInput<float>("X", {2, 3, 4},
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
});
test.AddInput<int64_t>("k", {1}, {-1});
test.AddOutput<float>("Y", {2, 3, 4},
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
0.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
0.f, 3.f, 2.f, 4.f,
});
test.Run();
}
TEST(TriluContribOpTest, neg_k_float_lower) {
OpTester test("Trilu", 1, kMSDomain);
int64_t up = 0;
test.AddAttribute("upper", up);
test.AddInput<float>("X", {2, 3, 4},
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
});
test.AddInput<int64_t>("k", {1}, {-1});
test.AddOutput<float>("Y", {2, 3, 4},
{0.f, 0.f, 0.f, 0.f,
4.f, 0.f, 0.f, 0.f,
6.f, 1.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f,
4.f, 0.f, 0.f, 0.f,
4.f, 3.f, 0.f, 0.f,
});
test.Run();
}
TEST(TriluContribOpTest, small_k_float_upper) {
OpTester test("Trilu", 1, kMSDomain);
test.AddInput<float>("X", {2, 3, 4},
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
});
test.AddInput<int64_t>("k", {1}, {-5});
test.AddOutput<float>("Y", {2, 3, 4},
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
});
test.Run();
}
TEST(TriluContribOpTest, small_k_float_lower) {
OpTester test("Trilu", 1, kMSDomain);
int64_t up = 0;
test.AddAttribute("upper", up);
test.AddInput<float>("X", {2, 3, 4},
{4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
6.f, 1.f, 2.f, 3.f,
1.f, 6.f, 2.f, 1.f,
4.f, 1.f, 5.f, 8.f,
4.f, 3.f, 2.f, 4.f,
});
test.AddInput<int64_t>("k", {1}, {-5});
test.AddOutput<float>("Y", {2, 3, 4},
{0.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f,
});
test.Run();
}
TEST(TriluContribOpTest, zero_dim_upper) {
OpTester test("Trilu", 1, kMSDomain);
test.AddInput<float>("X", {2, 3, 0}, {});
test.AddInput<int64_t>("k", {1}, {0});
test.AddOutput<float>("Y", {2, 3, 0}, {});
test.Run();
}
TEST(TriluContribOpTest, zero_dim_lower) {
OpTester test("Trilu", 1, kMSDomain);
int64_t up = 0;
test.AddAttribute("upper", up);
test.AddInput<float>("X", {2, 3, 0}, {});
test.AddInput<int64_t>("k", {1}, {0});
test.AddOutput<float>("Y", {2, 3, 0}, {});
test.Run();
}
TEST(TriluContribOpTest, zero_dim_2_upper) {
OpTester test("Trilu", 1, kMSDomain);
test.AddInput<float>("X", {2, 0, 0}, {});
test.AddInput<int64_t>("k", {1}, {-5});
test.AddOutput<float>("Y", {2, 0, 0}, {});
test.Run();
}
TEST(TriluContribOpTest, zero_dim_2_lower) {
OpTester test("Trilu", 1, kMSDomain);
int64_t up = 0;
test.AddAttribute("upper", up);
test.AddInput<float>("X", {2, 0, 0}, {});
test.AddInput<int64_t>("k", {1}, {-5});
test.AddOutput<float>("Y", {2, 0, 0}, {});
test.Run();
}
} // namespace test
} // namespace onnxruntime

View file

@ -0,0 +1,81 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# Helper functions for generating ONNX model and data to test ONNX Runtime contrib ops
import onnx
import os
from onnx import numpy_helper
import subprocess
import shutil
TOP_DIR = os.path.realpath(os.path.dirname(__file__))
DATA_DIR = os.path.join(TOP_DIR, '..', 'testdata/')
def prepare_dir(path):
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path)
def _extract_value_info(arr, name, ele_type=None):
return onnx.helper.make_tensor_value_info(
name=name,
elem_type=ele_type if ele_type else onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[arr.dtype],
shape=arr.shape)
def generate_data(graph, inputs, outputs, name):
output_dir = os.path.join(DATA_DIR, name)
prepare_dir(output_dir)
model = onnx.helper.make_model(graph)
with open(os.path.join(output_dir, 'model.onnx'), 'wb') as f:
f.write(model.SerializeToString())
data_set = os.path.join(output_dir, 'test_data_set_0')
prepare_dir(data_set)
for j, input_np in enumerate(inputs):
tensor = numpy_helper.from_array(
input_np, model.graph.input[j].name)
with open(os.path.join(
data_set, 'input_{}.pb'.format(j)), 'wb') as f:
f.write(tensor.SerializeToString())
for j, output_np in enumerate(outputs):
tensor = numpy_helper.from_array(
output_np, model.graph.output[j].name)
with open(os.path.join(
data_set, 'output_{}.pb'.format(j)), 'wb') as f:
f.write(tensor.SerializeToString())
def expect(node, # type: onnx.NodeProto
inputs,
outputs,
name,
**kwargs
): # type: (...) -> None
present_inputs = [x for x in node.input if (x != '')]
present_outputs = [x for x in node.output if (x != '')]
input_types = [None] * len(inputs)
if 'input_types' in kwargs:
input_types = kwargs[str('input_types')]
del kwargs[str('input_types')]
output_types = [None] * len(outputs)
if 'output_types' in kwargs:
output_types = kwargs[str('output_types')]
del kwargs[str('output_types')]
inputs_vi = [_extract_value_info(arr, arr_name, input_type)
for arr, arr_name, input_type in zip(inputs, present_inputs, input_types)]
outputs_vi = [_extract_value_info(arr, arr_name, output_type)
for arr, arr_name, output_type in zip(outputs, present_outputs, output_types)]
graph = onnx.helper.make_graph(
nodes=[node],
name=name,
inputs=inputs_vi,
outputs=outputs_vi)
generate_data(graph, inputs, outputs, name)
cwd = os.getcwd()
onnx_test_runner = os.path.join(cwd, 'onnx_test_runner')
subprocess.run([onnx_test_runner, DATA_DIR+name], check=True, cwd=cwd)

View file

@ -0,0 +1,289 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# Test reference implementation and model for ONNX Runtime conrtib op trilu
import onnx
import unittest
import numpy as np
from onnx_contrib_ops_helper import expect
def triu_reference_implementation(x, k=0):
return np.triu(x, k)
def tril_reference_implementation(x, k=0):
return np.tril(x, k)
class ONNXReferenceImplementationTest(unittest.TestCase):
def test_triu(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x'],
outputs=['y'],
domain="com.microsoft",
)
x = np.random.randn(3, 4, 5).astype(np.float32)
y = triu_reference_implementation(x)
expect(node, inputs=[x], outputs=[y], name='test_triu')
def test_triu_neg(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
domain="com.microsoft",
)
x = np.random.randn(3, 4, 5).astype(np.float32)
k = np.array([-1]).astype(np.int64)
y = triu_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_triu_neg')
def test_triu_out_neg(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
domain="com.microsoft",
)
x = np.random.randn(3, 4, 5).astype(np.float32)
k = np.array([-7]).astype(np.int64)
y = triu_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_triu_out_neg')
def test_triu_pos(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
domain="com.microsoft",
)
x = np.random.randn(3, 4, 5).astype(np.float32)
k = np.array([2]).astype(np.int64)
y = triu_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_triu_pos')
def test_triu_out_pos(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
domain="com.microsoft",
)
x = np.random.randn(3, 4, 5).astype(np.float32)
k = np.array([6]).astype(np.int64)
y = triu_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_triu_out_pos')
def test_triu_square(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x'],
outputs=['y'],
domain="com.microsoft",
)
x = np.random.randn(3, 5, 5).astype(np.float32)
y = triu_reference_implementation(x)
expect(node, inputs=[x], outputs=[y], name='test_triu_square')
def test_triu_square_neg(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
domain="com.microsoft",
)
x = np.random.randn(3, 5, 5).astype(np.float32)
k = np.array([-1]).astype(np.int64)
y = triu_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_triu_square_neg')
def test_triu_one_row_neg(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
domain="com.microsoft",
)
x = np.random.randn(3, 1, 5).astype(np.float32)
k = np.array([-7]).astype(np.int64)
y = triu_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_triu_one_row_neg')
def test_triu_square_pos(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
domain="com.microsoft",
)
x = np.random.randn(3, 5, 5).astype(np.float32)
k = np.array([2]).astype(np.int64)
y = triu_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_triu_square_pos')
def test_triu_zero(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
domain="com.microsoft",
)
x = np.random.randn(3, 0, 5).astype(np.float32)
k = np.array([6]).astype(np.int64)
y = triu_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_triu_zero')
def test_tril(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x'],
outputs=['y'],
upper=0,
domain="com.microsoft",
)
x = np.random.randn(3, 4, 5).astype(np.float32)
y = tril_reference_implementation(x)
expect(node, inputs=[x], outputs=[y], name='test_tril')
def test_tril_neg(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0,
domain="com.microsoft",
)
x = np.random.randn(3, 4, 5).astype(np.float32)
k = np.array([-1]).astype(np.int64)
y = tril_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_tril_neg')
def test_tril_out_neg(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0,
domain="com.microsoft",
)
x = np.random.randn(3, 4, 5).astype(np.float32)
k = np.array([-7]).astype(np.int64)
y = tril_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_tril_out_neg')
def test_tril_pos(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0,
domain="com.microsoft",
)
x = np.random.randn(3, 4, 5).astype(np.float32)
k = np.array([2]).astype(np.int64)
y = tril_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_tril_pos')
def test_tril_out_pos(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0,
domain="com.microsoft",
)
x = np.random.randn(3, 4, 5).astype(np.float32)
k = np.array([6]).astype(np.int64)
y = tril_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_tril_out_pos')
def test_tril_square(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x'],
outputs=['y'],
upper=0,
domain="com.microsoft",
)
x = np.random.randn(3, 5, 5).astype(np.float32)
y = tril_reference_implementation(x)
expect(node, inputs=[x], outputs=[y], name='test_tril_square')
def test_tril_square_neg(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0,
domain="com.microsoft",
)
x = np.random.randn(3, 5, 5).astype(np.float32)
k = np.array([-1]).astype(np.int64)
y = tril_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_tril_square_neg')
def test_tril_one_row_neg(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0,
domain="com.microsoft",
)
x = np.random.randn(3, 1, 5).astype(np.float32)
k = np.array([-7]).astype(np.int64)
y = tril_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_tril_one_row_neg')
def test_tril_square_pos(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0,
domain="com.microsoft",
)
x = np.random.randn(3, 5, 5).astype(np.float32)
k = np.array([2]).astype(np.int64)
y = tril_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_tril_square_pos')
def test_tril_zero(self):
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0,
domain="com.microsoft",
)
x = np.random.randn(3, 0, 5).astype(np.float32)
k = np.array([6]).astype(np.int64)
y = tril_reference_implementation(x, k)
expect(node, inputs=[x, k], outputs=[y], name='test_tril_zero')
if __name__ == '__main__':
unittest.main(module=__name__, buffer=True)

View file

@ -1,3 +1,4 @@
#!/bin/bash
pip3 install --user --upgrade pip
@ -7,3 +8,9 @@ pip3 install --user /build/Release/dist/*.whl
export PYTHONPATH=/onnxruntime_src/tools:/usr/local/lib/python3.6/site-packages:$PYTHONPATH
python3 -m pytest -v /onnxruntime_src/tools/test/test_custom_ops_pytorch_exporter.py || exit 1
for filename in /onnxruntime_src/onnxruntime/test/python/contrib_ops/onnx_test_* ; do
cd /build/Release && python3 -m pytest -v $filename || exit 1
done
cd /build/Release && ./onnxruntime_test_all --gtest_filter=ShapeInferenceTests.* || exit 1

View file

@ -1,3 +1,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# Register pytorch symbolic for export using ONNX Runtime contrib ops
from torch.onnx import register_custom_op_symbolic
@ -17,6 +22,14 @@ def register_custom_op():
def gelu(g, self):
return g.op("com.microsoft::Gelu", self)
def triu(g, self, diagonal):
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=1)
def tril(g, self, diagonal):
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=0)
# Op Registration
register_custom_op_symbolic('::inverse', inverse, _onnx_opset_version)
register_custom_op_symbolic('::gelu', gelu, _onnx_opset_version)
register_custom_op_symbolic('::triu', triu, _onnx_opset_version)
register_custom_op_symbolic('::tril', tril, _onnx_opset_version)

View file

@ -1,3 +1,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
# Test export of pytorch operators using ONNX Runtime contrib ops
import torch
import onnxruntime
import numpy as np
@ -32,7 +37,7 @@ def ort_test_with_input(ort_sess, input, output, rtol, atol):
# These set of tests verify ONNX model export and compare onnxruntime outputs to pytorch.
# To register custom ops and run the tests, you should set PYTHONPATH as:
# PYTHONPATH=<path_to_onnxruntime/tools> pytest -v test_custom_ops_pytorch_exporter.py
# PYTHONPATH=<path_to_onnxruntime/tools> python -m pytest -v test_custom_ops_pytorch_exporter.py
class ONNXExporterTest(unittest.TestCase):
from torch.onnx.symbolic_helper import _export_onnx_opset_version
opset_version = _export_onnx_opset_version
@ -105,21 +110,68 @@ class ONNXExporterTest(unittest.TestCase):
x = torch.randn(3, 3)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
def test_triu(self):
for i in range(-5, 5):
class Module(torch.nn.Module):
def forward(self, input):
return input.triu(diagonal=i)
# opset 10 tests
TestONNXRuntime_opset10 = type(str("TestONNXRuntime_opset10"),
(unittest.TestCase,),
dict(ONNXExporterTest.__dict__, opset_version=10))
model = Module()
x = torch.randn(5, 4, 7, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
# opset 11 tests
ONNXExporterTest_opset11 = type(str("TestONNXRuntime_opset11"),
(unittest.TestCase,),
dict(ONNXExporterTest.__dict__, opset_version=11))
x = torch.randn(5, 4, 0, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
x = torch.randn(5, 0, 0, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
for i in range(-5, 5):
class Module2D(torch.nn.Module):
def forward(self, input):
return input.triu(diagonal=i)
model = Module2D()
x = torch.randn(4, 7, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
x = torch.randn(0, 7, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
x = torch.randn(0, 0, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
def test_tril(self):
for i in range(-5, 5):
class Module(torch.nn.Module):
def forward(self, input):
return input.tril(diagonal=i)
model = Module()
x = torch.randn(5, 4, 7, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
x = torch.randn(5, 4, 0, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
x = torch.randn(5, 0, 0, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
for i in range(-5, 5):
class Module2D(torch.nn.Module):
def forward(self, input):
return input.tril(diagonal=i)
model = Module2D()
x = torch.randn(4, 7, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
x = torch.randn(0, 7, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
x = torch.randn(0, 0, dtype=torch.float32)
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
# opset 12 tests
ONNXExporterTest_opset12 = type(str("TestONNXRuntime_opset12"),
(unittest.TestCase,),
dict(ONNXExporterTest.__dict__, opset_version=12))
# opset 9 tests, with keep_initializers_as_inputs=False for
# IR version 4 style export.
@ -128,27 +180,6 @@ ONNXExporterTest_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"),
dict(ONNXExporterTest.__dict__,
keep_initializers_as_inputs=False))
# opset 10 tests, with keep_initializers_as_inputs=False for
# IR version 4 style export.
ONNXExporterTest_opset10_IRv4 = type(str("TestONNXRuntime_opset10_IRv4"),
(unittest.TestCase,),
dict(ONNXExporterTest.__dict__, opset_version=10,
keep_initializers_as_inputs=False))
# opset 11 tests, with keep_initializers_as_inputs=False for
# IR version 4 style export.
ONNXExporterTest_opset11_IRv4 = type(str("TestONNXRuntime_opset11_IRv4"),
(unittest.TestCase,),
dict(ONNXExporterTest.__dict__, opset_version=11,
keep_initializers_as_inputs=False))
# opset 12 tests, with keep_initializers_as_inputs=False for
# IR version 4 style export.
ONNXExporterTest_opset12_IRv4 = type(str("TestONNXRuntime_opset12_IRv4"),
(unittest.TestCase,),
dict(ONNXExporterTest.__dict__, opset_version=12,
keep_initializers_as_inputs=False))
if __name__ == '__main__':
unittest.main()