mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
Add Trilu custom op (#4537)
Co-authored-by: neginraoof <neginmr@utexas.edu>
This commit is contained in:
parent
1ce2982f65
commit
ea37a4d89b
13 changed files with 1137 additions and 37 deletions
|
|
@ -34,11 +34,11 @@ If you are adding a symbolic function for a new custom op, add the function to t
|
|||
|
||||
### 2. Extending ONNX Runtime with Custom Ops
|
||||
The next step is to add op schema and kernel implementation in ONNX Runtime.
|
||||
THe exmaple Inverse custom op is added in:
|
||||
Consider the Inverse custom op as an example added in:
|
||||
https://github.com/microsoft/onnxruntime/pull/3485
|
||||
|
||||
|
||||
Custom op schema and shape inference function should be added under ```onnxruntime/core/graph/contrib_ops/contrib_defs.cc ```
|
||||
Custom op schema and shape inference function should be added in ```onnxruntime/core/graph/contrib_ops/contrib_defs.cc ```
|
||||
using ```ONNX_CONTRIB_OPERATOR_SCHEMA```.
|
||||
|
||||
```c++
|
||||
|
|
@ -48,6 +48,17 @@ ONNX_CONTRIB_OPERATOR_SCHEMA(Inverse)
|
|||
...
|
||||
```
|
||||
|
||||
To comply with ONNX guideline for new operators, a new operator should have complete reference implementation tests and
|
||||
shape inference tests.
|
||||
|
||||
Reference implementation python tests should be added in:
|
||||
``onnxruntime/test/python/contrib_ops``
|
||||
E.g.: ``onnxruntime/test/python/contrib_ops/onnx_test_trilu.py``
|
||||
|
||||
Shape inference C++ tests should be added in:
|
||||
``onnxruntime/test/contrib_ops``
|
||||
E.g.: ``onnxruntime/test/contrib_ops/trilu_shape_inference_test.cc``
|
||||
|
||||
The operator kernel should be implemented using ```Compute``` function
|
||||
under contrib namespace in ```onnxruntime/contrib_ops/cpu/<operator>.cc```
|
||||
for CPU and ```onnxruntime/contrib_ops/cuda/<operator>.cc``` for CUDA.
|
||||
|
|
@ -90,7 +101,7 @@ Now you should be able to build and install ONNX Runtime to start using your cus
|
|||
|
||||
##### ONNX Runtime Tests
|
||||
|
||||
ONNX Runtime custom op tests should be added in: ```onnxruntime/test/contrib_ops/<operator>_test.cc ```
|
||||
ONNX Runtime custom op kernel tests should be added in: ```onnxruntime/test/contrib_ops/<operator>_test.cc ```
|
||||
|
||||
```c++
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu);
|
||||
|
||||
Status RegisterNchwcKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
|
|
@ -158,6 +159,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
105
onnxruntime/contrib_ops/cpu/trilu.cc
Normal file
105
onnxruntime/contrib_ops/cpu/trilu.cc
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "Eigen/src/Core/Map.h"
|
||||
#include "trilu.h"
|
||||
#include <functional>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Trilu,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints<float, double, int64_t>()),
|
||||
Trilu);
|
||||
|
||||
template <typename T>
|
||||
static Status TriluImpl(const Tensor* X, Tensor* Y, int64_t k_val, bool up) {
|
||||
const auto& X_shape = X->Shape();
|
||||
int64_t X_num_dims = static_cast<int64_t>(X_shape.NumDimensions());
|
||||
|
||||
const auto* X_data = reinterpret_cast<const T*>(X->DataRaw());
|
||||
int64_t matrix_h = static_cast<int64_t>(X_shape[X_num_dims - 2]);
|
||||
int64_t matrix_w = static_cast<int64_t>(X_shape[X_num_dims - 1]);
|
||||
|
||||
int64_t batch_size = 1;
|
||||
for (int64_t i = 0; i < X_num_dims - 2; ++i) {
|
||||
batch_size *= X_shape[i];
|
||||
}
|
||||
|
||||
int64_t num_matrix_elems = matrix_h * matrix_w;
|
||||
auto* Y_data = reinterpret_cast<T*>(Y->MutableDataRaw());
|
||||
for (int64_t b = 0; b < batch_size; b++) { // can be parallelized if need to
|
||||
auto X_batch_data = X_data + (b * num_matrix_elems);
|
||||
auto Y_batch_data = Y_data + (b * num_matrix_elems);
|
||||
|
||||
auto input_mat = ConstEigenMatrixMapRowMajor<T>(X_batch_data, matrix_h, matrix_w);
|
||||
auto output_mat = EigenMatrixMapRowMajor<T>(Y_batch_data, matrix_h, matrix_w);
|
||||
|
||||
if (X_batch_data != Y_batch_data) {
|
||||
output_mat = input_mat;
|
||||
}
|
||||
|
||||
if (up) {
|
||||
int64_t start_i = k_val > 0 ? 0 : 1 - k_val;
|
||||
for (int64_t i = start_i; i < matrix_h; i++) {
|
||||
for (int64_t j = 0; j < i + k_val && j < matrix_w; j++) {
|
||||
output_mat(i, j) = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int64_t end_i = std::min(matrix_h, matrix_w - k_val);
|
||||
for (int64_t i = 0; i < end_i; i++) {
|
||||
for (int64_t j = std::max(static_cast<int64_t>(0), i + k_val + 1); j < matrix_w; j++) {
|
||||
output_mat(i, j) = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Trilu::Compute(OpKernelContext* ctx) const {
|
||||
Status status;
|
||||
const auto* X = ctx->Input<Tensor>(0);
|
||||
const auto* k = ctx->Input<Tensor>(1);
|
||||
|
||||
bool up = upper_;
|
||||
int64_t k_val = 0;
|
||||
if (k) {
|
||||
ORT_ENFORCE(IsScalarOr1ElementVector(k), "k should be a 1-D or 0-D tensor.");
|
||||
k_val = *(k->template Data<int64_t>());
|
||||
}
|
||||
|
||||
const auto& X_shape = X->Shape();
|
||||
auto* Y = ctx->Output(0, X_shape);
|
||||
|
||||
int64_t X_num_dims = static_cast<int64_t>(X_shape.NumDimensions());
|
||||
// input validation
|
||||
if (X_num_dims < 2) { // this is getting capture by shape inference code as well
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Input tensor should have a rank of at least 2");
|
||||
}
|
||||
|
||||
MLDataType data_type = X->DataType();
|
||||
const auto element_size = data_type->Size();
|
||||
switch (element_size) {
|
||||
case sizeof(float):
|
||||
status = TriluImpl<float>(X, Y, k_val, up);
|
||||
break;
|
||||
case sizeof(double):
|
||||
status = TriluImpl<double>(X, Y, k_val, up);
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Unsupported input data type of ", data_type);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
21
onnxruntime/contrib_ops/cpu/trilu.h
Normal file
21
onnxruntime/contrib_ops/cpu/trilu.h
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
class Trilu final : public OpKernel {
|
||||
public:
|
||||
explicit Trilu(const OpKernelInfo& info) : OpKernel(info) {
|
||||
int64_t temp;
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("upper", &temp).IsOK());
|
||||
upper_ = temp != 0;
|
||||
}
|
||||
Status Compute(OpKernelContext* ctx) const override;
|
||||
|
||||
private:
|
||||
bool upper_;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -2848,6 +2848,76 @@ It's an extension of Gelu. It takes the sum of input A and bias input B as the i
|
|||
}
|
||||
});
|
||||
|
||||
static const char* Trilu_ver1_doc = R"DOC(
|
||||
Returns the upper or lower triangular part of a 2-D matrix, or batches of 2-D matrices. If the attribute "upper" is set to true,
|
||||
the upper triangular matrix is retained. Lower triangular matrix is retained otherwise. Default value for upper is true.
|
||||
Trilu takes one input tensor of shape [*, N, M], where * is zero or more batch dimensions. The upper triangular part consists
|
||||
of the elements on and above the given diagonal (k). The lower triangular part consists of elements on and below the diagonal.
|
||||
All other elements in the matrix are set to zero.
|
||||
If k = 0, the triangular part on and above/below the main diagonal is retained.
|
||||
If upper is set to true, a positive k retains the upper triangular matrix excluding k diagonals above
|
||||
the main diagonal. A negative k value includes as many diagonals below the main diagonal.
|
||||
If upper is set to false, a positive k retains the lower triangular matrix including k diagonals above
|
||||
the main diagonal. A negative k value excludes as many diagonals below the main diagonal.
|
||||
)DOC";
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(Trilu)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(Trilu_ver1_doc)
|
||||
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
|
||||
.Attr("upper",
|
||||
"Boolean. Indicates whether upper or lower part of matrix is retained. Default is true.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(1))
|
||||
.Input(
|
||||
0,
|
||||
"X",
|
||||
"Input tensor of rank 2 or higher.",
|
||||
"T")
|
||||
.Input(
|
||||
1,
|
||||
"k",
|
||||
"A 0-D tensor containing a single value corresponding to the number diagonals above or the main diagonal to exclude or include."
|
||||
"Default value is 0 if it's not specified.",
|
||||
"tensor(int64)",
|
||||
OpSchema::Optional)
|
||||
.Output(
|
||||
0,
|
||||
"Y",
|
||||
"Output tensor of the same type and shape as the input tensor.",
|
||||
"T")
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(float16)",
|
||||
"tensor(float)",
|
||||
"tensor(double)",
|
||||
"tensor(bfloat16)",
|
||||
"tensor(uint8)",
|
||||
"tensor(uint16)",
|
||||
"tensor(uint32)",
|
||||
"tensor(uint64)",
|
||||
"tensor(int8)",
|
||||
"tensor(int16)",
|
||||
"tensor(int32)",
|
||||
"tensor(int64)",
|
||||
"tensor(bool)"},
|
||||
"Constrain input and output types to all numeric tensors and bool tensors.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
using namespace ONNX_NAMESPACE;
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
|
||||
if (hasInputShape(ctx, 0)) {
|
||||
const TensorShapeProto& input_shape =
|
||||
ctx.getInputType(0)->tensor_type().shape();
|
||||
const int rank = static_cast<int>(input_shape.dim_size());
|
||||
if (rank < 2) {
|
||||
fail_shape_inference("Input rank must be >= 2.")
|
||||
}
|
||||
propagateShapeFromInputToOutput(ctx, 0, 0);
|
||||
}
|
||||
});
|
||||
|
||||
RegisterBertSchemas();
|
||||
}
|
||||
} // namespace contrib
|
||||
|
|
|
|||
107
onnxruntime/test/contrib_ops/shape_inference_test_helper.h
Normal file
107
onnxruntime/test/contrib_ops/shape_inference_test_helper.h
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "onnx/shape_inference/implementation.h"
|
||||
#include "onnx/checker.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
|
||||
|
||||
const std::string MS_DOMAIN = "com.microsoft";
|
||||
|
||||
void CheckShapeEquality(ONNX_NAMESPACE::TensorShapeProto* shape1, ONNX_NAMESPACE::TensorShapeProto* shape2) {
|
||||
EXPECT_NE(shape1, nullptr);
|
||||
EXPECT_NE(shape2, nullptr);
|
||||
if ((shape1 != nullptr) && (shape2 != nullptr)) {
|
||||
EXPECT_EQ(shape1->dim_size(), shape2->dim_size()) << "Shapes do not have same rank";
|
||||
auto min_dims = std::min(shape1->dim_size(), shape2->dim_size());
|
||||
for (int i = 0; i < min_dims; ++i) {
|
||||
auto dim1 = shape1->dim(i);
|
||||
auto dim2 = shape2->dim(i);
|
||||
EXPECT_EQ(dim1.has_dim_value(), dim2.has_dim_value());
|
||||
if (dim1.has_dim_value()) {
|
||||
EXPECT_EQ(dim1.dim_value(), dim2.dim_value());
|
||||
}
|
||||
EXPECT_EQ(dim1.has_dim_param(), dim2.has_dim_param());
|
||||
if (dim1.has_dim_param()) {
|
||||
EXPECT_EQ(dim1.dim_param(), dim2.dim_param());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void CreateValueInfo(
|
||||
ONNX_NAMESPACE::ValueInfoProto& value_info,
|
||||
const std::string& name,
|
||||
const ONNX_NAMESPACE::TensorProto_DataType& elem_type,
|
||||
const std::vector<int64_t> shape) {
|
||||
value_info.set_name(name);
|
||||
ONNX_NAMESPACE::TypeProto* type = value_info.mutable_type();
|
||||
ONNX_NAMESPACE::TypeProto_Tensor* tensor_type = type->mutable_tensor_type();
|
||||
tensor_type->set_elem_type(elem_type);
|
||||
ONNX_NAMESPACE::TensorShapeProto* value_info_shape = tensor_type->mutable_shape();
|
||||
|
||||
for (int64_t dim_value : shape) {
|
||||
value_info_shape->add_dim()->set_dim_value(dim_value);
|
||||
}
|
||||
}
|
||||
|
||||
inline void TestShapeInference(
|
||||
const std::string& op_type,
|
||||
const std::vector<ONNX_NAMESPACE::ValueInfoProto>& inputs,
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attributes,
|
||||
ONNX_NAMESPACE::ValueInfoProto& output) {
|
||||
ONNX_NAMESPACE::ModelProto model;
|
||||
// Set opset (domain + version)
|
||||
ONNX_NAMESPACE::OperatorSetIdProto* op_set_id = model.add_opset_import();
|
||||
op_set_id->set_domain(MS_DOMAIN);
|
||||
op_set_id->set_version(1);
|
||||
model.set_ir_version(6);
|
||||
model.set_producer_name("onnx");
|
||||
|
||||
// Set model graph
|
||||
ONNX_NAMESPACE::GraphProto* graph = model.mutable_graph();
|
||||
graph->set_name("test-op");
|
||||
graph->add_value_info();
|
||||
|
||||
// Set add operator node to graph
|
||||
auto& node = *graph->add_node();
|
||||
node.set_op_type(op_type);
|
||||
node.set_domain(MS_DOMAIN);
|
||||
node.set_name("test_node");
|
||||
|
||||
// Add node inputs and graph inputs
|
||||
for (auto const& n_ : inputs) {
|
||||
node.add_input(n_.name());
|
||||
*graph->add_input() = n_;
|
||||
}
|
||||
|
||||
// Add node attributes
|
||||
for (auto const& attr : attributes) {
|
||||
node.add_attribute()->CopyFrom(attr);
|
||||
}
|
||||
|
||||
node.add_output("Output");
|
||||
|
||||
ONNX_NAMESPACE::checker::check_model(model);
|
||||
ONNX_NAMESPACE::shape_inference::InferShapes(model, false, schema_registry);
|
||||
|
||||
auto inferredGraph = model.graph();
|
||||
int index = static_cast<int>(inputs.size()); // index for value_info of output
|
||||
auto inferred_output = inferredGraph.value_info(index);
|
||||
|
||||
auto elem_type = output.mutable_type()->mutable_tensor_type()->elem_type();
|
||||
auto inferred_elem_type = inferred_output.mutable_type()->mutable_tensor_type()->elem_type();
|
||||
EXPECT_EQ(elem_type, inferred_elem_type);
|
||||
|
||||
auto shape = output.mutable_type()->mutable_tensor_type()->mutable_shape();
|
||||
auto inferred_shape = inferred_output.mutable_type()->mutable_tensor_type()->mutable_shape();
|
||||
CheckShapeEquality(shape, inferred_shape);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
121
onnxruntime/test/contrib_ops/trilu_shape_inference_test.cc
Normal file
121
onnxruntime/test/contrib_ops/trilu_shape_inference_test.cc
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "onnx/shape_inference/implementation.h"
|
||||
#include "onnx/checker.h"
|
||||
#include "shape_inference_test_helper.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(ShapeInferenceTests, tri_upper_float) {
|
||||
std::vector<int64_t> shape = {4, 7};
|
||||
ONNX_NAMESPACE::ValueInfoProto input;
|
||||
CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_FLOAT, shape);
|
||||
std::vector<ONNX_NAMESPACE::ValueInfoProto> inputs = {input};
|
||||
|
||||
ONNX_NAMESPACE::AttributeProto upper;
|
||||
upper.set_name("upper");
|
||||
upper.set_type(ONNX_NAMESPACE::AttributeProto::INT);
|
||||
upper.set_i(1); // upper
|
||||
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {upper};
|
||||
|
||||
ONNX_NAMESPACE::ValueInfoProto output;
|
||||
CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_FLOAT, shape);
|
||||
|
||||
TestShapeInference("Trilu", inputs, attributes, output);
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTests, tri_upper_zero_dim_int) {
|
||||
std::vector<int64_t> shape = {4, 7, 0};
|
||||
ONNX_NAMESPACE::ValueInfoProto input;
|
||||
CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_INT32, shape);
|
||||
std::vector<ONNX_NAMESPACE::ValueInfoProto> inputs = {input};
|
||||
|
||||
ONNX_NAMESPACE::AttributeProto upper;
|
||||
upper.set_name("upper");
|
||||
upper.set_type(ONNX_NAMESPACE::AttributeProto::INT);
|
||||
upper.set_i(1); // upper
|
||||
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {upper};
|
||||
|
||||
ONNX_NAMESPACE::ValueInfoProto output;
|
||||
CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_INT32, shape);
|
||||
|
||||
TestShapeInference("Trilu", inputs, attributes, output);
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTests, tri_upper_4d_long) {
|
||||
std::vector<int64_t> shape = {2, 3, 7, 11};
|
||||
ONNX_NAMESPACE::ValueInfoProto input;
|
||||
CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_INT64, shape);
|
||||
std::vector<ONNX_NAMESPACE::ValueInfoProto> inputs = {input};
|
||||
|
||||
ONNX_NAMESPACE::AttributeProto upper;
|
||||
upper.set_name("upper");
|
||||
upper.set_type(ONNX_NAMESPACE::AttributeProto::INT);
|
||||
upper.set_i(1); // upper
|
||||
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {upper};
|
||||
|
||||
ONNX_NAMESPACE::ValueInfoProto output;
|
||||
CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_INT64, shape);
|
||||
|
||||
TestShapeInference("Trilu", inputs, attributes, output);
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTests, tri_lower_float) {
|
||||
std::vector<int64_t> shape = {4, 7};
|
||||
ONNX_NAMESPACE::ValueInfoProto input;
|
||||
CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_FLOAT, shape);
|
||||
std::vector<ONNX_NAMESPACE::ValueInfoProto> inputs = {input};
|
||||
|
||||
ONNX_NAMESPACE::AttributeProto upper;
|
||||
upper.set_name("upper");
|
||||
upper.set_type(ONNX_NAMESPACE::AttributeProto::INT);
|
||||
upper.set_i(0); // lower
|
||||
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {upper};
|
||||
|
||||
ONNX_NAMESPACE::ValueInfoProto output;
|
||||
CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_FLOAT, shape);
|
||||
|
||||
TestShapeInference("Trilu", inputs, attributes, output);
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTests, tri_lower_4d_int) {
|
||||
std::vector<int64_t> shape = {2, 3, 7, 11};
|
||||
ONNX_NAMESPACE::ValueInfoProto input;
|
||||
CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_INT32, shape);
|
||||
std::vector<ONNX_NAMESPACE::ValueInfoProto> inputs = {input};
|
||||
|
||||
ONNX_NAMESPACE::AttributeProto upper;
|
||||
upper.set_name("upper");
|
||||
upper.set_type(ONNX_NAMESPACE::AttributeProto::INT);
|
||||
upper.set_i(0); // lower
|
||||
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {upper};
|
||||
|
||||
ONNX_NAMESPACE::ValueInfoProto output;
|
||||
CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_INT32, shape);
|
||||
|
||||
TestShapeInference("Trilu", inputs, attributes, output);
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTests, tri_lower_zero_dim_long) {
|
||||
std::vector<int64_t> shape = {4, 7, 0};
|
||||
ONNX_NAMESPACE::ValueInfoProto input;
|
||||
CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_INT64, shape);
|
||||
std::vector<ONNX_NAMESPACE::ValueInfoProto> inputs = {input};
|
||||
|
||||
ONNX_NAMESPACE::AttributeProto upper;
|
||||
upper.set_name("upper");
|
||||
upper.set_type(ONNX_NAMESPACE::AttributeProto::INT);
|
||||
upper.set_i(0); // lower
|
||||
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {upper};
|
||||
|
||||
ONNX_NAMESPACE::ValueInfoProto output;
|
||||
CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_INT64, shape);
|
||||
|
||||
TestShapeInference("Trilu", inputs, attributes, output);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
242
onnxruntime/test/contrib_ops/trilu_test.cc
Normal file
242
onnxruntime/test/contrib_ops/trilu_test.cc
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "core/util/math.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(TriluContribOpTest, two_by_two_float_upper) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
int64_t up = 1;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<float>("X", {2, 2}, {4.f, 7.f, 2.f, 6.f});
|
||||
test.AddOutput<float>("Y", {2, 2}, {4.f, 7.f, 0.f, 6.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluContribOpTest, two_by_two_float_lower) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
int64_t up = 0;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<float>("X", {2, 2}, {4.f, 7.f, 2.f, 6.f});
|
||||
test.AddOutput<float>("Y", {2, 2}, {4.f, 0.f, 2.f, 6.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluContribOpTest, two_by_two_double_upper) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
test.AddInput<double>("X", {2, 2}, {4, 7, 2, 6});
|
||||
test.AddInput<int64_t>("k", {1}, {1});
|
||||
test.AddOutput<double>("Y", {2, 2}, {0, 7, 0, 0});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluContribOpTest, two_by_two_double_lower) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
int64_t up = 0;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<double>("X", {2, 2}, {4, 7, 2, 6});
|
||||
test.AddInput<int64_t>("k", {1}, {1});
|
||||
test.AddOutput<double>("Y", {2, 2}, {4, 7, 2, 6});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluContribOpTest, two_by_two_long_upper) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
int64_t up = 1;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<int64_t>("X", {2, 2}, {4, 7, 2, 6});
|
||||
test.AddOutput<int64_t>("Y", {2, 2}, {4, 7, 0, 6});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluContribOpTest, two_by_two_long_lower) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
int64_t up = 0;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<int64_t>("X", {2, 2}, {4, 7, 2, 6});
|
||||
test.AddOutput<int64_t>("Y", {2, 2}, {4, 0, 2, 6});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluContribOpTest, three_dim_float_upper) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
test.AddInput<float>("X", {2, 3, 4},
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
});
|
||||
test.AddInput<int64_t>("k", {1}, {1});
|
||||
test.AddOutput<float>("Y", {2, 3, 4},
|
||||
{0.f, 1.f, 5.f, 8.f,
|
||||
0.f, 0.f, 2.f, 4.f,
|
||||
0.f, 0.f, 0.f, 3.f,
|
||||
0.f, 6.f, 2.f, 1.f,
|
||||
0.f, 0.f, 5.f, 8.f,
|
||||
0.f, 0.f, 0.f, 4.f,
|
||||
});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluContribOpTest, three_dim_float_lower) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
int64_t up = 0;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<float>("X", {2, 3, 4},
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
});
|
||||
test.AddInput<int64_t>("k", {1}, {1});
|
||||
test.AddOutput<float>("Y", {2, 3, 4},
|
||||
{4.f, 1.f, 0.f, 0.f,
|
||||
4.f, 3.f, 2.f, 0.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
1.f, 6.f, 0.f, 0.f,
|
||||
4.f, 1.f, 5.f, 0.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluContribOpTest, neg_k_float_upper) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
int64_t up = 1;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<float>("X", {2, 3, 4},
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
});
|
||||
test.AddInput<int64_t>("k", {1}, {-1});
|
||||
test.AddOutput<float>("Y", {2, 3, 4},
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
0.f, 1.f, 2.f, 3.f,
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
0.f, 3.f, 2.f, 4.f,
|
||||
});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluContribOpTest, neg_k_float_lower) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
int64_t up = 0;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<float>("X", {2, 3, 4},
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
});
|
||||
test.AddInput<int64_t>("k", {1}, {-1});
|
||||
test.AddOutput<float>("Y", {2, 3, 4},
|
||||
{0.f, 0.f, 0.f, 0.f,
|
||||
4.f, 0.f, 0.f, 0.f,
|
||||
6.f, 1.f, 0.f, 0.f,
|
||||
0.f, 0.f, 0.f, 0.f,
|
||||
4.f, 0.f, 0.f, 0.f,
|
||||
4.f, 3.f, 0.f, 0.f,
|
||||
});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluContribOpTest, small_k_float_upper) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
test.AddInput<float>("X", {2, 3, 4},
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
});
|
||||
test.AddInput<int64_t>("k", {1}, {-5});
|
||||
test.AddOutput<float>("Y", {2, 3, 4},
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluContribOpTest, small_k_float_lower) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
int64_t up = 0;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<float>("X", {2, 3, 4},
|
||||
{4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
6.f, 1.f, 2.f, 3.f,
|
||||
1.f, 6.f, 2.f, 1.f,
|
||||
4.f, 1.f, 5.f, 8.f,
|
||||
4.f, 3.f, 2.f, 4.f,
|
||||
});
|
||||
test.AddInput<int64_t>("k", {1}, {-5});
|
||||
test.AddOutput<float>("Y", {2, 3, 4},
|
||||
{0.f, 0.f, 0.f, 0.f,
|
||||
0.f, 0.f, 0.f, 0.f,
|
||||
0.f, 0.f, 0.f, 0.f,
|
||||
0.f, 0.f, 0.f, 0.f,
|
||||
0.f, 0.f, 0.f, 0.f,
|
||||
0.f, 0.f, 0.f, 0.f,
|
||||
});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluContribOpTest, zero_dim_upper) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
test.AddInput<float>("X", {2, 3, 0}, {});
|
||||
test.AddInput<int64_t>("k", {1}, {0});
|
||||
test.AddOutput<float>("Y", {2, 3, 0}, {});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluContribOpTest, zero_dim_lower) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
int64_t up = 0;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<float>("X", {2, 3, 0}, {});
|
||||
test.AddInput<int64_t>("k", {1}, {0});
|
||||
test.AddOutput<float>("Y", {2, 3, 0}, {});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluContribOpTest, zero_dim_2_upper) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
test.AddInput<float>("X", {2, 0, 0}, {});
|
||||
test.AddInput<int64_t>("k", {1}, {-5});
|
||||
test.AddOutput<float>("Y", {2, 0, 0}, {});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(TriluContribOpTest, zero_dim_2_lower) {
|
||||
OpTester test("Trilu", 1, kMSDomain);
|
||||
int64_t up = 0;
|
||||
test.AddAttribute("upper", up);
|
||||
test.AddInput<float>("X", {2, 0, 0}, {});
|
||||
test.AddInput<int64_t>("k", {1}, {-5});
|
||||
test.AddOutput<float>("Y", {2, 0, 0}, {});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
# Helper functions for generating ONNX model and data to test ONNX Runtime contrib ops
|
||||
|
||||
import onnx
|
||||
import os
|
||||
from onnx import numpy_helper
|
||||
import subprocess
|
||||
import shutil
|
||||
|
||||
TOP_DIR = os.path.realpath(os.path.dirname(__file__))
|
||||
DATA_DIR = os.path.join(TOP_DIR, '..', 'testdata/')
|
||||
|
||||
|
||||
def prepare_dir(path):
|
||||
if os.path.exists(path):
|
||||
shutil.rmtree(path)
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def _extract_value_info(arr, name, ele_type=None):
|
||||
return onnx.helper.make_tensor_value_info(
|
||||
name=name,
|
||||
elem_type=ele_type if ele_type else onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[arr.dtype],
|
||||
shape=arr.shape)
|
||||
|
||||
|
||||
def generate_data(graph, inputs, outputs, name):
|
||||
output_dir = os.path.join(DATA_DIR, name)
|
||||
prepare_dir(output_dir)
|
||||
model = onnx.helper.make_model(graph)
|
||||
with open(os.path.join(output_dir, 'model.onnx'), 'wb') as f:
|
||||
f.write(model.SerializeToString())
|
||||
data_set = os.path.join(output_dir, 'test_data_set_0')
|
||||
prepare_dir(data_set)
|
||||
for j, input_np in enumerate(inputs):
|
||||
tensor = numpy_helper.from_array(
|
||||
input_np, model.graph.input[j].name)
|
||||
with open(os.path.join(
|
||||
data_set, 'input_{}.pb'.format(j)), 'wb') as f:
|
||||
f.write(tensor.SerializeToString())
|
||||
for j, output_np in enumerate(outputs):
|
||||
tensor = numpy_helper.from_array(
|
||||
output_np, model.graph.output[j].name)
|
||||
with open(os.path.join(
|
||||
data_set, 'output_{}.pb'.format(j)), 'wb') as f:
|
||||
f.write(tensor.SerializeToString())
|
||||
|
||||
|
||||
def expect(node, # type: onnx.NodeProto
|
||||
inputs,
|
||||
outputs,
|
||||
name,
|
||||
**kwargs
|
||||
): # type: (...) -> None
|
||||
present_inputs = [x for x in node.input if (x != '')]
|
||||
present_outputs = [x for x in node.output if (x != '')]
|
||||
input_types = [None] * len(inputs)
|
||||
if 'input_types' in kwargs:
|
||||
input_types = kwargs[str('input_types')]
|
||||
del kwargs[str('input_types')]
|
||||
output_types = [None] * len(outputs)
|
||||
if 'output_types' in kwargs:
|
||||
output_types = kwargs[str('output_types')]
|
||||
del kwargs[str('output_types')]
|
||||
inputs_vi = [_extract_value_info(arr, arr_name, input_type)
|
||||
for arr, arr_name, input_type in zip(inputs, present_inputs, input_types)]
|
||||
outputs_vi = [_extract_value_info(arr, arr_name, output_type)
|
||||
for arr, arr_name, output_type in zip(outputs, present_outputs, output_types)]
|
||||
graph = onnx.helper.make_graph(
|
||||
nodes=[node],
|
||||
name=name,
|
||||
inputs=inputs_vi,
|
||||
outputs=outputs_vi)
|
||||
|
||||
generate_data(graph, inputs, outputs, name)
|
||||
|
||||
cwd = os.getcwd()
|
||||
onnx_test_runner = os.path.join(cwd, 'onnx_test_runner')
|
||||
subprocess.run([onnx_test_runner, DATA_DIR+name], check=True, cwd=cwd)
|
||||
289
onnxruntime/test/python/contrib_ops/onnx_test_trilu.py
Normal file
289
onnxruntime/test/python/contrib_ops/onnx_test_trilu.py
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
# Test reference implementation and model for ONNX Runtime conrtib op trilu
|
||||
|
||||
import onnx
|
||||
import unittest
|
||||
import numpy as np
|
||||
from onnx_contrib_ops_helper import expect
|
||||
|
||||
|
||||
def triu_reference_implementation(x, k=0):
|
||||
return np.triu(x, k)
|
||||
|
||||
|
||||
def tril_reference_implementation(x, k=0):
|
||||
return np.tril(x, k)
|
||||
|
||||
|
||||
class ONNXReferenceImplementationTest(unittest.TestCase):
|
||||
def test_triu(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x'],
|
||||
outputs=['y'],
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 4, 5).astype(np.float32)
|
||||
y = triu_reference_implementation(x)
|
||||
expect(node, inputs=[x], outputs=[y], name='test_triu')
|
||||
|
||||
def test_triu_neg(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 4, 5).astype(np.float32)
|
||||
k = np.array([-1]).astype(np.int64)
|
||||
y = triu_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_triu_neg')
|
||||
|
||||
def test_triu_out_neg(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 4, 5).astype(np.float32)
|
||||
k = np.array([-7]).astype(np.int64)
|
||||
y = triu_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_triu_out_neg')
|
||||
|
||||
def test_triu_pos(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 4, 5).astype(np.float32)
|
||||
k = np.array([2]).astype(np.int64)
|
||||
y = triu_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_triu_pos')
|
||||
|
||||
def test_triu_out_pos(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 4, 5).astype(np.float32)
|
||||
k = np.array([6]).astype(np.int64)
|
||||
y = triu_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_triu_out_pos')
|
||||
|
||||
def test_triu_square(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x'],
|
||||
outputs=['y'],
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 5, 5).astype(np.float32)
|
||||
y = triu_reference_implementation(x)
|
||||
expect(node, inputs=[x], outputs=[y], name='test_triu_square')
|
||||
|
||||
def test_triu_square_neg(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 5, 5).astype(np.float32)
|
||||
k = np.array([-1]).astype(np.int64)
|
||||
y = triu_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_triu_square_neg')
|
||||
|
||||
def test_triu_one_row_neg(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 1, 5).astype(np.float32)
|
||||
k = np.array([-7]).astype(np.int64)
|
||||
y = triu_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_triu_one_row_neg')
|
||||
|
||||
def test_triu_square_pos(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 5, 5).astype(np.float32)
|
||||
k = np.array([2]).astype(np.int64)
|
||||
y = triu_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_triu_square_pos')
|
||||
|
||||
def test_triu_zero(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 0, 5).astype(np.float32)
|
||||
k = np.array([6]).astype(np.int64)
|
||||
y = triu_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_triu_zero')
|
||||
|
||||
def test_tril(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x'],
|
||||
outputs=['y'],
|
||||
upper=0,
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 4, 5).astype(np.float32)
|
||||
y = tril_reference_implementation(x)
|
||||
expect(node, inputs=[x], outputs=[y], name='test_tril')
|
||||
|
||||
def test_tril_neg(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
upper=0,
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 4, 5).astype(np.float32)
|
||||
k = np.array([-1]).astype(np.int64)
|
||||
y = tril_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_tril_neg')
|
||||
|
||||
def test_tril_out_neg(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
upper=0,
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 4, 5).astype(np.float32)
|
||||
k = np.array([-7]).astype(np.int64)
|
||||
y = tril_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_tril_out_neg')
|
||||
|
||||
def test_tril_pos(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
upper=0,
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 4, 5).astype(np.float32)
|
||||
k = np.array([2]).astype(np.int64)
|
||||
y = tril_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_tril_pos')
|
||||
|
||||
def test_tril_out_pos(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
upper=0,
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 4, 5).astype(np.float32)
|
||||
k = np.array([6]).astype(np.int64)
|
||||
y = tril_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_tril_out_pos')
|
||||
|
||||
def test_tril_square(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x'],
|
||||
outputs=['y'],
|
||||
upper=0,
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 5, 5).astype(np.float32)
|
||||
y = tril_reference_implementation(x)
|
||||
expect(node, inputs=[x], outputs=[y], name='test_tril_square')
|
||||
|
||||
def test_tril_square_neg(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
upper=0,
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 5, 5).astype(np.float32)
|
||||
k = np.array([-1]).astype(np.int64)
|
||||
y = tril_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_tril_square_neg')
|
||||
|
||||
def test_tril_one_row_neg(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
upper=0,
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 1, 5).astype(np.float32)
|
||||
k = np.array([-7]).astype(np.int64)
|
||||
y = tril_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_tril_one_row_neg')
|
||||
|
||||
def test_tril_square_pos(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
upper=0,
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 5, 5).astype(np.float32)
|
||||
k = np.array([2]).astype(np.int64)
|
||||
y = tril_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_tril_square_pos')
|
||||
|
||||
def test_tril_zero(self):
|
||||
node = onnx.helper.make_node(
|
||||
'Trilu',
|
||||
inputs=['x', 'k'],
|
||||
outputs=['y'],
|
||||
upper=0,
|
||||
domain="com.microsoft",
|
||||
)
|
||||
|
||||
x = np.random.randn(3, 0, 5).astype(np.float32)
|
||||
k = np.array([6]).astype(np.int64)
|
||||
y = tril_reference_implementation(x, k)
|
||||
expect(node, inputs=[x, k], outputs=[y], name='test_tril_zero')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(module=__name__, buffer=True)
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
#!/bin/bash
|
||||
|
||||
pip3 install --user --upgrade pip
|
||||
|
||||
|
|
@ -7,3 +8,9 @@ pip3 install --user /build/Release/dist/*.whl
|
|||
export PYTHONPATH=/onnxruntime_src/tools:/usr/local/lib/python3.6/site-packages:$PYTHONPATH
|
||||
|
||||
python3 -m pytest -v /onnxruntime_src/tools/test/test_custom_ops_pytorch_exporter.py || exit 1
|
||||
|
||||
for filename in /onnxruntime_src/onnxruntime/test/python/contrib_ops/onnx_test_* ; do
|
||||
cd /build/Release && python3 -m pytest -v $filename || exit 1
|
||||
done
|
||||
|
||||
cd /build/Release && ./onnxruntime_test_all --gtest_filter=ShapeInferenceTests.* || exit 1
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
# Register pytorch symbolic for export using ONNX Runtime contrib ops
|
||||
|
||||
from torch.onnx import register_custom_op_symbolic
|
||||
|
||||
|
||||
|
|
@ -17,6 +22,14 @@ def register_custom_op():
|
|||
def gelu(g, self):
|
||||
return g.op("com.microsoft::Gelu", self)
|
||||
|
||||
def triu(g, self, diagonal):
|
||||
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=1)
|
||||
|
||||
def tril(g, self, diagonal):
|
||||
return g.op("com.microsoft::Trilu", self, diagonal, upper_i=0)
|
||||
|
||||
# Op Registration
|
||||
register_custom_op_symbolic('::inverse', inverse, _onnx_opset_version)
|
||||
register_custom_op_symbolic('::gelu', gelu, _onnx_opset_version)
|
||||
register_custom_op_symbolic('::triu', triu, _onnx_opset_version)
|
||||
register_custom_op_symbolic('::tril', tril, _onnx_opset_version)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
#
|
||||
# Test export of pytorch operators using ONNX Runtime contrib ops
|
||||
|
||||
import torch
|
||||
import onnxruntime
|
||||
import numpy as np
|
||||
|
|
@ -32,7 +37,7 @@ def ort_test_with_input(ort_sess, input, output, rtol, atol):
|
|||
|
||||
# These set of tests verify ONNX model export and compare onnxruntime outputs to pytorch.
|
||||
# To register custom ops and run the tests, you should set PYTHONPATH as:
|
||||
# PYTHONPATH=<path_to_onnxruntime/tools> pytest -v test_custom_ops_pytorch_exporter.py
|
||||
# PYTHONPATH=<path_to_onnxruntime/tools> python -m pytest -v test_custom_ops_pytorch_exporter.py
|
||||
class ONNXExporterTest(unittest.TestCase):
|
||||
from torch.onnx.symbolic_helper import _export_onnx_opset_version
|
||||
opset_version = _export_onnx_opset_version
|
||||
|
|
@ -105,21 +110,68 @@ class ONNXExporterTest(unittest.TestCase):
|
|||
x = torch.randn(3, 3)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
|
||||
def test_triu(self):
|
||||
for i in range(-5, 5):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return input.triu(diagonal=i)
|
||||
|
||||
# opset 10 tests
|
||||
TestONNXRuntime_opset10 = type(str("TestONNXRuntime_opset10"),
|
||||
(unittest.TestCase,),
|
||||
dict(ONNXExporterTest.__dict__, opset_version=10))
|
||||
model = Module()
|
||||
x = torch.randn(5, 4, 7, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
|
||||
# opset 11 tests
|
||||
ONNXExporterTest_opset11 = type(str("TestONNXRuntime_opset11"),
|
||||
(unittest.TestCase,),
|
||||
dict(ONNXExporterTest.__dict__, opset_version=11))
|
||||
x = torch.randn(5, 4, 0, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
|
||||
x = torch.randn(5, 0, 0, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
|
||||
for i in range(-5, 5):
|
||||
class Module2D(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return input.triu(diagonal=i)
|
||||
|
||||
model = Module2D()
|
||||
x = torch.randn(4, 7, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
|
||||
x = torch.randn(0, 7, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
|
||||
x = torch.randn(0, 0, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
|
||||
def test_tril(self):
|
||||
for i in range(-5, 5):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return input.tril(diagonal=i)
|
||||
|
||||
model = Module()
|
||||
x = torch.randn(5, 4, 7, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
|
||||
x = torch.randn(5, 4, 0, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
|
||||
x = torch.randn(5, 0, 0, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
|
||||
for i in range(-5, 5):
|
||||
class Module2D(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return input.tril(diagonal=i)
|
||||
|
||||
model = Module2D()
|
||||
x = torch.randn(4, 7, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
|
||||
x = torch.randn(0, 7, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
|
||||
x = torch.randn(0, 0, dtype=torch.float32)
|
||||
self.run_test(model, x, custom_opsets={'com.microsoft': 1})
|
||||
|
||||
# opset 12 tests
|
||||
ONNXExporterTest_opset12 = type(str("TestONNXRuntime_opset12"),
|
||||
(unittest.TestCase,),
|
||||
dict(ONNXExporterTest.__dict__, opset_version=12))
|
||||
|
||||
# opset 9 tests, with keep_initializers_as_inputs=False for
|
||||
# IR version 4 style export.
|
||||
|
|
@ -128,27 +180,6 @@ ONNXExporterTest_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"),
|
|||
dict(ONNXExporterTest.__dict__,
|
||||
keep_initializers_as_inputs=False))
|
||||
|
||||
# opset 10 tests, with keep_initializers_as_inputs=False for
|
||||
# IR version 4 style export.
|
||||
ONNXExporterTest_opset10_IRv4 = type(str("TestONNXRuntime_opset10_IRv4"),
|
||||
(unittest.TestCase,),
|
||||
dict(ONNXExporterTest.__dict__, opset_version=10,
|
||||
keep_initializers_as_inputs=False))
|
||||
|
||||
|
||||
# opset 11 tests, with keep_initializers_as_inputs=False for
|
||||
# IR version 4 style export.
|
||||
ONNXExporterTest_opset11_IRv4 = type(str("TestONNXRuntime_opset11_IRv4"),
|
||||
(unittest.TestCase,),
|
||||
dict(ONNXExporterTest.__dict__, opset_version=11,
|
||||
keep_initializers_as_inputs=False))
|
||||
|
||||
# opset 12 tests, with keep_initializers_as_inputs=False for
|
||||
# IR version 4 style export.
|
||||
ONNXExporterTest_opset12_IRv4 = type(str("TestONNXRuntime_opset12_IRv4"),
|
||||
(unittest.TestCase,),
|
||||
dict(ONNXExporterTest.__dict__, opset_version=12,
|
||||
keep_initializers_as_inputs=False))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue