diff --git a/docs/ExportPyTorchCustomOps.md b/docs/ExportPyTorchCustomOps.md index 49a10e86bc..20835ad76f 100644 --- a/docs/ExportPyTorchCustomOps.md +++ b/docs/ExportPyTorchCustomOps.md @@ -34,11 +34,11 @@ If you are adding a symbolic function for a new custom op, add the function to t ### 2. Extending ONNX Runtime with Custom Ops The next step is to add op schema and kernel implementation in ONNX Runtime. -THe exmaple Inverse custom op is added in: +Consider the Inverse custom op as an example added in: https://github.com/microsoft/onnxruntime/pull/3485 -Custom op schema and shape inference function should be added under ```onnxruntime/core/graph/contrib_ops/contrib_defs.cc ``` +Custom op schema and shape inference function should be added in ```onnxruntime/core/graph/contrib_ops/contrib_defs.cc ``` using ```ONNX_CONTRIB_OPERATOR_SCHEMA```. ```c++ @@ -48,6 +48,17 @@ ONNX_CONTRIB_OPERATOR_SCHEMA(Inverse) ... ``` +To comply with ONNX guideline for new operators, a new operator should have complete reference implementation tests and +shape inference tests. + +Reference implementation python tests should be added in: +``onnxruntime/test/python/contrib_ops`` +E.g.: ``onnxruntime/test/python/contrib_ops/onnx_test_trilu.py`` + +Shape inference C++ tests should be added in: +``onnxruntime/test/contrib_ops`` +E.g.: ``onnxruntime/test/contrib_ops/trilu_shape_inference_test.cc`` + The operator kernel should be implemented using ```Compute``` function under contrib namespace in ```onnxruntime/contrib_ops/cpu/.cc``` for CPU and ```onnxruntime/contrib_ops/cuda/.cc``` for CUDA. @@ -90,7 +101,7 @@ Now you should be able to build and install ONNX Runtime to start using your cus ##### ONNX Runtime Tests -ONNX Runtime custom op tests should be added in: ```onnxruntime/test/contrib_ops/_test.cc ``` +ONNX Runtime custom op kernel tests should be added in: ```onnxruntime/test/contrib_ops/_test.cc ``` ```c++ namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index d75686bfcc..d11cd2f6a0 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -74,6 +74,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu); Status RegisterNchwcKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -158,6 +159,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cpu/trilu.cc b/onnxruntime/contrib_ops/cpu/trilu.cc new file mode 100644 index 0000000000..35cf14afcd --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/trilu.cc @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/framework/op_kernel.h" +#include "core/util/math_cpuonly.h" +#include "Eigen/src/Core/Map.h" +#include "trilu.h" +#include + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_KERNEL_EX( + Trilu, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints()), + Trilu); + +template +static Status TriluImpl(const Tensor* X, Tensor* Y, int64_t k_val, bool up) { + const auto& X_shape = X->Shape(); + int64_t X_num_dims = static_cast(X_shape.NumDimensions()); + + const auto* X_data = reinterpret_cast(X->DataRaw()); + int64_t matrix_h = static_cast(X_shape[X_num_dims - 2]); + int64_t matrix_w = static_cast(X_shape[X_num_dims - 1]); + + int64_t batch_size = 1; + for (int64_t i = 0; i < X_num_dims - 2; ++i) { + batch_size *= X_shape[i]; + } + + int64_t num_matrix_elems = matrix_h * matrix_w; + auto* Y_data = reinterpret_cast(Y->MutableDataRaw()); + for (int64_t b = 0; b < batch_size; b++) { // can be parallelized if need to + auto X_batch_data = X_data + (b * num_matrix_elems); + auto Y_batch_data = Y_data + (b * num_matrix_elems); + + auto input_mat = ConstEigenMatrixMapRowMajor(X_batch_data, matrix_h, matrix_w); + auto output_mat = EigenMatrixMapRowMajor(Y_batch_data, matrix_h, matrix_w); + + if (X_batch_data != Y_batch_data) { + output_mat = input_mat; + } + + if (up) { + int64_t start_i = k_val > 0 ? 0 : 1 - k_val; + for (int64_t i = start_i; i < matrix_h; i++) { + for (int64_t j = 0; j < i + k_val && j < matrix_w; j++) { + output_mat(i, j) = 0; + } + } + } else { + int64_t end_i = std::min(matrix_h, matrix_w - k_val); + for (int64_t i = 0; i < end_i; i++) { + for (int64_t j = std::max(static_cast(0), i + k_val + 1); j < matrix_w; j++) { + output_mat(i, j) = 0; + } + } + } + } + return Status::OK(); +} + +Status Trilu::Compute(OpKernelContext* ctx) const { + Status status; + const auto* X = ctx->Input(0); + const auto* k = ctx->Input(1); + + bool up = upper_; + int64_t k_val = 0; + if (k) { + ORT_ENFORCE(IsScalarOr1ElementVector(k), "k should be a 1-D or 0-D tensor."); + k_val = *(k->template Data()); + } + + const auto& X_shape = X->Shape(); + auto* Y = ctx->Output(0, X_shape); + + int64_t X_num_dims = static_cast(X_shape.NumDimensions()); + // input validation + if (X_num_dims < 2) { // this is getting capture by shape inference code as well + return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Input tensor should have a rank of at least 2"); + } + + MLDataType data_type = X->DataType(); + const auto element_size = data_type->Size(); + switch (element_size) { + case sizeof(float): + status = TriluImpl(X, Y, k_val, up); + break; + case sizeof(double): + status = TriluImpl(X, Y, k_val, up); + break; + default: + ORT_THROW("Unsupported input data type of ", data_type); + } + return status; +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/trilu.h b/onnxruntime/contrib_ops/cpu/trilu.h new file mode 100644 index 0000000000..bc767c239a --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/trilu.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace onnxruntime { +namespace contrib { + +class Trilu final : public OpKernel { + public: + explicit Trilu(const OpKernelInfo& info) : OpKernel(info) { + int64_t temp; + ORT_ENFORCE(info.GetAttr("upper", &temp).IsOK()); + upper_ = temp != 0; + } + Status Compute(OpKernelContext* ctx) const override; + + private: + bool upper_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index ba89d76ec8..e514e21aaf 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2848,6 +2848,76 @@ It's an extension of Gelu. It takes the sum of input A and bias input B as the i } }); + static const char* Trilu_ver1_doc = R"DOC( + Returns the upper or lower triangular part of a 2-D matrix, or batches of 2-D matrices. If the attribute "upper" is set to true, + the upper triangular matrix is retained. Lower triangular matrix is retained otherwise. Default value for upper is true. + Trilu takes one input tensor of shape [*, N, M], where * is zero or more batch dimensions. The upper triangular part consists + of the elements on and above the given diagonal (k). The lower triangular part consists of elements on and below the diagonal. + All other elements in the matrix are set to zero. + If k = 0, the triangular part on and above/below the main diagonal is retained. + If upper is set to true, a positive k retains the upper triangular matrix excluding k diagonals above + the main diagonal. A negative k value includes as many diagonals below the main diagonal. + If upper is set to false, a positive k retains the lower triangular matrix including k diagonals above + the main diagonal. A negative k value excludes as many diagonals below the main diagonal. + )DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(Trilu) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(Trilu_ver1_doc) + .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) + .Attr("upper", + "Boolean. Indicates whether upper or lower part of matrix is retained. Default is true.", + AttributeProto::INT, + static_cast(1)) + .Input( + 0, + "X", + "Input tensor of rank 2 or higher.", + "T") + .Input( + 1, + "k", + "A 0-D tensor containing a single value corresponding to the number diagonals above or the main diagonal to exclude or include." + "Default value is 0 if it's not specified.", + "tensor(int64)", + OpSchema::Optional) + .Output( + 0, + "Y", + "Output tensor of the same type and shape as the input tensor.", + "T") + .TypeConstraint( + "T", + {"tensor(float16)", + "tensor(float)", + "tensor(double)", + "tensor(bfloat16)", + "tensor(uint8)", + "tensor(uint16)", + "tensor(uint32)", + "tensor(uint64)", + "tensor(int8)", + "tensor(int16)", + "tensor(int32)", + "tensor(int64)", + "tensor(bool)"}, + "Constrain input and output types to all numeric tensors and bool tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + using namespace ONNX_NAMESPACE; + propagateElemTypeFromInputToOutput(ctx, 0, 0); + + if (hasInputShape(ctx, 0)) { + const TensorShapeProto& input_shape = + ctx.getInputType(0)->tensor_type().shape(); + const int rank = static_cast(input_shape.dim_size()); + if (rank < 2) { + fail_shape_inference("Input rank must be >= 2.") + } + propagateShapeFromInputToOutput(ctx, 0, 0); + } + }); + RegisterBertSchemas(); } } // namespace contrib diff --git a/onnxruntime/test/contrib_ops/shape_inference_test_helper.h b/onnxruntime/test/contrib_ops/shape_inference_test_helper.h new file mode 100644 index 0000000000..23691051b2 --- /dev/null +++ b/onnxruntime/test/contrib_ops/shape_inference_test_helper.h @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "onnx/shape_inference/implementation.h" +#include "onnx/checker.h" + +namespace onnxruntime { +namespace test { + +auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance(); + +const std::string MS_DOMAIN = "com.microsoft"; + +void CheckShapeEquality(ONNX_NAMESPACE::TensorShapeProto* shape1, ONNX_NAMESPACE::TensorShapeProto* shape2) { + EXPECT_NE(shape1, nullptr); + EXPECT_NE(shape2, nullptr); + if ((shape1 != nullptr) && (shape2 != nullptr)) { + EXPECT_EQ(shape1->dim_size(), shape2->dim_size()) << "Shapes do not have same rank"; + auto min_dims = std::min(shape1->dim_size(), shape2->dim_size()); + for (int i = 0; i < min_dims; ++i) { + auto dim1 = shape1->dim(i); + auto dim2 = shape2->dim(i); + EXPECT_EQ(dim1.has_dim_value(), dim2.has_dim_value()); + if (dim1.has_dim_value()) { + EXPECT_EQ(dim1.dim_value(), dim2.dim_value()); + } + EXPECT_EQ(dim1.has_dim_param(), dim2.has_dim_param()); + if (dim1.has_dim_param()) { + EXPECT_EQ(dim1.dim_param(), dim2.dim_param()); + } + } + } +} + +inline void CreateValueInfo( + ONNX_NAMESPACE::ValueInfoProto& value_info, + const std::string& name, + const ONNX_NAMESPACE::TensorProto_DataType& elem_type, + const std::vector shape) { + value_info.set_name(name); + ONNX_NAMESPACE::TypeProto* type = value_info.mutable_type(); + ONNX_NAMESPACE::TypeProto_Tensor* tensor_type = type->mutable_tensor_type(); + tensor_type->set_elem_type(elem_type); + ONNX_NAMESPACE::TensorShapeProto* value_info_shape = tensor_type->mutable_shape(); + + for (int64_t dim_value : shape) { + value_info_shape->add_dim()->set_dim_value(dim_value); + } +} + +inline void TestShapeInference( + const std::string& op_type, + const std::vector& inputs, + const std::vector& attributes, + ONNX_NAMESPACE::ValueInfoProto& output) { + ONNX_NAMESPACE::ModelProto model; + // Set opset (domain + version) + ONNX_NAMESPACE::OperatorSetIdProto* op_set_id = model.add_opset_import(); + op_set_id->set_domain(MS_DOMAIN); + op_set_id->set_version(1); + model.set_ir_version(6); + model.set_producer_name("onnx"); + + // Set model graph + ONNX_NAMESPACE::GraphProto* graph = model.mutable_graph(); + graph->set_name("test-op"); + graph->add_value_info(); + + // Set add operator node to graph + auto& node = *graph->add_node(); + node.set_op_type(op_type); + node.set_domain(MS_DOMAIN); + node.set_name("test_node"); + + // Add node inputs and graph inputs + for (auto const& n_ : inputs) { + node.add_input(n_.name()); + *graph->add_input() = n_; + } + + // Add node attributes + for (auto const& attr : attributes) { + node.add_attribute()->CopyFrom(attr); + } + + node.add_output("Output"); + + ONNX_NAMESPACE::checker::check_model(model); + ONNX_NAMESPACE::shape_inference::InferShapes(model, false, schema_registry); + + auto inferredGraph = model.graph(); + int index = static_cast(inputs.size()); // index for value_info of output + auto inferred_output = inferredGraph.value_info(index); + + auto elem_type = output.mutable_type()->mutable_tensor_type()->elem_type(); + auto inferred_elem_type = inferred_output.mutable_type()->mutable_tensor_type()->elem_type(); + EXPECT_EQ(elem_type, inferred_elem_type); + + auto shape = output.mutable_type()->mutable_tensor_type()->mutable_shape(); + auto inferred_shape = inferred_output.mutable_type()->mutable_tensor_type()->mutable_shape(); + CheckShapeEquality(shape, inferred_shape); +} + +} // namespace test +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/contrib_ops/trilu_shape_inference_test.cc b/onnxruntime/test/contrib_ops/trilu_shape_inference_test.cc new file mode 100644 index 0000000000..561f0bb0cc --- /dev/null +++ b/onnxruntime/test/contrib_ops/trilu_shape_inference_test.cc @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/providers/provider_test_utils.h" +#include "onnx/shape_inference/implementation.h" +#include "onnx/checker.h" +#include "shape_inference_test_helper.h" + +namespace onnxruntime { +namespace test { + +TEST(ShapeInferenceTests, tri_upper_float) { + std::vector shape = {4, 7}; + ONNX_NAMESPACE::ValueInfoProto input; + CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_FLOAT, shape); + std::vector inputs = {input}; + + ONNX_NAMESPACE::AttributeProto upper; + upper.set_name("upper"); + upper.set_type(ONNX_NAMESPACE::AttributeProto::INT); + upper.set_i(1); // upper + std::vector attributes = {upper}; + + ONNX_NAMESPACE::ValueInfoProto output; + CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_FLOAT, shape); + + TestShapeInference("Trilu", inputs, attributes, output); +} + +TEST(ShapeInferenceTests, tri_upper_zero_dim_int) { + std::vector shape = {4, 7, 0}; + ONNX_NAMESPACE::ValueInfoProto input; + CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_INT32, shape); + std::vector inputs = {input}; + + ONNX_NAMESPACE::AttributeProto upper; + upper.set_name("upper"); + upper.set_type(ONNX_NAMESPACE::AttributeProto::INT); + upper.set_i(1); // upper + std::vector attributes = {upper}; + + ONNX_NAMESPACE::ValueInfoProto output; + CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_INT32, shape); + + TestShapeInference("Trilu", inputs, attributes, output); +} + +TEST(ShapeInferenceTests, tri_upper_4d_long) { + std::vector shape = {2, 3, 7, 11}; + ONNX_NAMESPACE::ValueInfoProto input; + CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_INT64, shape); + std::vector inputs = {input}; + + ONNX_NAMESPACE::AttributeProto upper; + upper.set_name("upper"); + upper.set_type(ONNX_NAMESPACE::AttributeProto::INT); + upper.set_i(1); // upper + std::vector attributes = {upper}; + + ONNX_NAMESPACE::ValueInfoProto output; + CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_INT64, shape); + + TestShapeInference("Trilu", inputs, attributes, output); +} + +TEST(ShapeInferenceTests, tri_lower_float) { + std::vector shape = {4, 7}; + ONNX_NAMESPACE::ValueInfoProto input; + CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_FLOAT, shape); + std::vector inputs = {input}; + + ONNX_NAMESPACE::AttributeProto upper; + upper.set_name("upper"); + upper.set_type(ONNX_NAMESPACE::AttributeProto::INT); + upper.set_i(0); // lower + std::vector attributes = {upper}; + + ONNX_NAMESPACE::ValueInfoProto output; + CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_FLOAT, shape); + + TestShapeInference("Trilu", inputs, attributes, output); +} + +TEST(ShapeInferenceTests, tri_lower_4d_int) { + std::vector shape = {2, 3, 7, 11}; + ONNX_NAMESPACE::ValueInfoProto input; + CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_INT32, shape); + std::vector inputs = {input}; + + ONNX_NAMESPACE::AttributeProto upper; + upper.set_name("upper"); + upper.set_type(ONNX_NAMESPACE::AttributeProto::INT); + upper.set_i(0); // lower + std::vector attributes = {upper}; + + ONNX_NAMESPACE::ValueInfoProto output; + CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_INT32, shape); + + TestShapeInference("Trilu", inputs, attributes, output); +} + +TEST(ShapeInferenceTests, tri_lower_zero_dim_long) { + std::vector shape = {4, 7, 0}; + ONNX_NAMESPACE::ValueInfoProto input; + CreateValueInfo(input, "X", ONNX_NAMESPACE::TensorProto_DataType_INT64, shape); + std::vector inputs = {input}; + + ONNX_NAMESPACE::AttributeProto upper; + upper.set_name("upper"); + upper.set_type(ONNX_NAMESPACE::AttributeProto::INT); + upper.set_i(0); // lower + std::vector attributes = {upper}; + + ONNX_NAMESPACE::ValueInfoProto output; + CreateValueInfo(output, "Y", ONNX_NAMESPACE::TensorProto_DataType_INT64, shape); + + TestShapeInference("Trilu", inputs, attributes, output); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/trilu_test.cc b/onnxruntime/test/contrib_ops/trilu_test.cc new file mode 100644 index 0000000000..24d6499c62 --- /dev/null +++ b/onnxruntime/test/contrib_ops/trilu_test.cc @@ -0,0 +1,242 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "core/util/math.h" + +namespace onnxruntime { +namespace test { + +TEST(TriluContribOpTest, two_by_two_float_upper) { + OpTester test("Trilu", 1, kMSDomain); + int64_t up = 1; + test.AddAttribute("upper", up); + test.AddInput("X", {2, 2}, {4.f, 7.f, 2.f, 6.f}); + test.AddOutput("Y", {2, 2}, {4.f, 7.f, 0.f, 6.f}); + test.Run(); +} + +TEST(TriluContribOpTest, two_by_two_float_lower) { + OpTester test("Trilu", 1, kMSDomain); + int64_t up = 0; + test.AddAttribute("upper", up); + test.AddInput("X", {2, 2}, {4.f, 7.f, 2.f, 6.f}); + test.AddOutput("Y", {2, 2}, {4.f, 0.f, 2.f, 6.f}); + test.Run(); +} + +TEST(TriluContribOpTest, two_by_two_double_upper) { + OpTester test("Trilu", 1, kMSDomain); + test.AddInput("X", {2, 2}, {4, 7, 2, 6}); + test.AddInput("k", {1}, {1}); + test.AddOutput("Y", {2, 2}, {0, 7, 0, 0}); + test.Run(); +} + +TEST(TriluContribOpTest, two_by_two_double_lower) { + OpTester test("Trilu", 1, kMSDomain); + int64_t up = 0; + test.AddAttribute("upper", up); + test.AddInput("X", {2, 2}, {4, 7, 2, 6}); + test.AddInput("k", {1}, {1}); + test.AddOutput("Y", {2, 2}, {4, 7, 2, 6}); + test.Run(); +} + +TEST(TriluContribOpTest, two_by_two_long_upper) { + OpTester test("Trilu", 1, kMSDomain); + int64_t up = 1; + test.AddAttribute("upper", up); + test.AddInput("X", {2, 2}, {4, 7, 2, 6}); + test.AddOutput("Y", {2, 2}, {4, 7, 0, 6}); + test.Run(); +} + +TEST(TriluContribOpTest, two_by_two_long_lower) { + OpTester test("Trilu", 1, kMSDomain); + int64_t up = 0; + test.AddAttribute("upper", up); + test.AddInput("X", {2, 2}, {4, 7, 2, 6}); + test.AddOutput("Y", {2, 2}, {4, 0, 2, 6}); + test.Run(); +} + +TEST(TriluContribOpTest, three_dim_float_upper) { + OpTester test("Trilu", 1, kMSDomain); + test.AddInput("X", {2, 3, 4}, + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + }); + test.AddInput("k", {1}, {1}); + test.AddOutput("Y", {2, 3, 4}, + {0.f, 1.f, 5.f, 8.f, + 0.f, 0.f, 2.f, 4.f, + 0.f, 0.f, 0.f, 3.f, + 0.f, 6.f, 2.f, 1.f, + 0.f, 0.f, 5.f, 8.f, + 0.f, 0.f, 0.f, 4.f, + }); + test.Run(); +} + +TEST(TriluContribOpTest, three_dim_float_lower) { + OpTester test("Trilu", 1, kMSDomain); + int64_t up = 0; + test.AddAttribute("upper", up); + test.AddInput("X", {2, 3, 4}, + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + }); + test.AddInput("k", {1}, {1}); + test.AddOutput("Y", {2, 3, 4}, + {4.f, 1.f, 0.f, 0.f, + 4.f, 3.f, 2.f, 0.f, + 6.f, 1.f, 2.f, 3.f, + 1.f, 6.f, 0.f, 0.f, + 4.f, 1.f, 5.f, 0.f, + 4.f, 3.f, 2.f, 4.f, + }); + test.Run(); +} + +TEST(TriluContribOpTest, neg_k_float_upper) { + OpTester test("Trilu", 1, kMSDomain); + int64_t up = 1; + test.AddAttribute("upper", up); + test.AddInput("X", {2, 3, 4}, + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + }); + test.AddInput("k", {1}, {-1}); + test.AddOutput("Y", {2, 3, 4}, + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 0.f, 1.f, 2.f, 3.f, + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 0.f, 3.f, 2.f, 4.f, + }); + test.Run(); +} + +TEST(TriluContribOpTest, neg_k_float_lower) { + OpTester test("Trilu", 1, kMSDomain); + int64_t up = 0; + test.AddAttribute("upper", up); + test.AddInput("X", {2, 3, 4}, + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + }); + test.AddInput("k", {1}, {-1}); + test.AddOutput("Y", {2, 3, 4}, + {0.f, 0.f, 0.f, 0.f, + 4.f, 0.f, 0.f, 0.f, + 6.f, 1.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, + 4.f, 0.f, 0.f, 0.f, + 4.f, 3.f, 0.f, 0.f, + }); + test.Run(); +} + +TEST(TriluContribOpTest, small_k_float_upper) { + OpTester test("Trilu", 1, kMSDomain); + test.AddInput("X", {2, 3, 4}, + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + }); + test.AddInput("k", {1}, {-5}); + test.AddOutput("Y", {2, 3, 4}, + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + }); + test.Run(); +} + +TEST(TriluContribOpTest, small_k_float_lower) { + OpTester test("Trilu", 1, kMSDomain); + int64_t up = 0; + test.AddAttribute("upper", up); + test.AddInput("X", {2, 3, 4}, + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + }); + test.AddInput("k", {1}, {-5}); + test.AddOutput("Y", {2, 3, 4}, + {0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, + }); + test.Run(); +} + +TEST(TriluContribOpTest, zero_dim_upper) { + OpTester test("Trilu", 1, kMSDomain); + test.AddInput("X", {2, 3, 0}, {}); + test.AddInput("k", {1}, {0}); + test.AddOutput("Y", {2, 3, 0}, {}); + test.Run(); +} + +TEST(TriluContribOpTest, zero_dim_lower) { + OpTester test("Trilu", 1, kMSDomain); + int64_t up = 0; + test.AddAttribute("upper", up); + test.AddInput("X", {2, 3, 0}, {}); + test.AddInput("k", {1}, {0}); + test.AddOutput("Y", {2, 3, 0}, {}); + test.Run(); +} + +TEST(TriluContribOpTest, zero_dim_2_upper) { + OpTester test("Trilu", 1, kMSDomain); + test.AddInput("X", {2, 0, 0}, {}); + test.AddInput("k", {1}, {-5}); + test.AddOutput("Y", {2, 0, 0}, {}); + test.Run(); +} + +TEST(TriluContribOpTest, zero_dim_2_lower) { + OpTester test("Trilu", 1, kMSDomain); + int64_t up = 0; + test.AddAttribute("upper", up); + test.AddInput("X", {2, 0, 0}, {}); + test.AddInput("k", {1}, {-5}); + test.AddOutput("Y", {2, 0, 0}, {}); + test.Run(); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/python/contrib_ops/onnx_contrib_ops_helper.py b/onnxruntime/test/python/contrib_ops/onnx_contrib_ops_helper.py new file mode 100644 index 0000000000..25f53ce878 --- /dev/null +++ b/onnxruntime/test/python/contrib_ops/onnx_contrib_ops_helper.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# +# Helper functions for generating ONNX model and data to test ONNX Runtime contrib ops + +import onnx +import os +from onnx import numpy_helper +import subprocess +import shutil + +TOP_DIR = os.path.realpath(os.path.dirname(__file__)) +DATA_DIR = os.path.join(TOP_DIR, '..', 'testdata/') + + +def prepare_dir(path): + if os.path.exists(path): + shutil.rmtree(path) + os.makedirs(path) + + +def _extract_value_info(arr, name, ele_type=None): + return onnx.helper.make_tensor_value_info( + name=name, + elem_type=ele_type if ele_type else onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[arr.dtype], + shape=arr.shape) + + +def generate_data(graph, inputs, outputs, name): + output_dir = os.path.join(DATA_DIR, name) + prepare_dir(output_dir) + model = onnx.helper.make_model(graph) + with open(os.path.join(output_dir, 'model.onnx'), 'wb') as f: + f.write(model.SerializeToString()) + data_set = os.path.join(output_dir, 'test_data_set_0') + prepare_dir(data_set) + for j, input_np in enumerate(inputs): + tensor = numpy_helper.from_array( + input_np, model.graph.input[j].name) + with open(os.path.join( + data_set, 'input_{}.pb'.format(j)), 'wb') as f: + f.write(tensor.SerializeToString()) + for j, output_np in enumerate(outputs): + tensor = numpy_helper.from_array( + output_np, model.graph.output[j].name) + with open(os.path.join( + data_set, 'output_{}.pb'.format(j)), 'wb') as f: + f.write(tensor.SerializeToString()) + + +def expect(node, # type: onnx.NodeProto + inputs, + outputs, + name, + **kwargs + ): # type: (...) -> None + present_inputs = [x for x in node.input if (x != '')] + present_outputs = [x for x in node.output if (x != '')] + input_types = [None] * len(inputs) + if 'input_types' in kwargs: + input_types = kwargs[str('input_types')] + del kwargs[str('input_types')] + output_types = [None] * len(outputs) + if 'output_types' in kwargs: + output_types = kwargs[str('output_types')] + del kwargs[str('output_types')] + inputs_vi = [_extract_value_info(arr, arr_name, input_type) + for arr, arr_name, input_type in zip(inputs, present_inputs, input_types)] + outputs_vi = [_extract_value_info(arr, arr_name, output_type) + for arr, arr_name, output_type in zip(outputs, present_outputs, output_types)] + graph = onnx.helper.make_graph( + nodes=[node], + name=name, + inputs=inputs_vi, + outputs=outputs_vi) + + generate_data(graph, inputs, outputs, name) + + cwd = os.getcwd() + onnx_test_runner = os.path.join(cwd, 'onnx_test_runner') + subprocess.run([onnx_test_runner, DATA_DIR+name], check=True, cwd=cwd) diff --git a/onnxruntime/test/python/contrib_ops/onnx_test_trilu.py b/onnxruntime/test/python/contrib_ops/onnx_test_trilu.py new file mode 100644 index 0000000000..87e9a1a819 --- /dev/null +++ b/onnxruntime/test/python/contrib_ops/onnx_test_trilu.py @@ -0,0 +1,289 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# +# Test reference implementation and model for ONNX Runtime conrtib op trilu + +import onnx +import unittest +import numpy as np +from onnx_contrib_ops_helper import expect + + +def triu_reference_implementation(x, k=0): + return np.triu(x, k) + + +def tril_reference_implementation(x, k=0): + return np.tril(x, k) + + +class ONNXReferenceImplementationTest(unittest.TestCase): + def test_triu(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x'], + outputs=['y'], + domain="com.microsoft", + ) + + x = np.random.randn(3, 4, 5).astype(np.float32) + y = triu_reference_implementation(x) + expect(node, inputs=[x], outputs=[y], name='test_triu') + + def test_triu_neg(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + domain="com.microsoft", + ) + + x = np.random.randn(3, 4, 5).astype(np.float32) + k = np.array([-1]).astype(np.int64) + y = triu_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_triu_neg') + + def test_triu_out_neg(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + domain="com.microsoft", + ) + + x = np.random.randn(3, 4, 5).astype(np.float32) + k = np.array([-7]).astype(np.int64) + y = triu_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_triu_out_neg') + + def test_triu_pos(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + domain="com.microsoft", + ) + + x = np.random.randn(3, 4, 5).astype(np.float32) + k = np.array([2]).astype(np.int64) + y = triu_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_triu_pos') + + def test_triu_out_pos(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + domain="com.microsoft", + ) + + x = np.random.randn(3, 4, 5).astype(np.float32) + k = np.array([6]).astype(np.int64) + y = triu_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_triu_out_pos') + + def test_triu_square(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x'], + outputs=['y'], + domain="com.microsoft", + ) + + x = np.random.randn(3, 5, 5).astype(np.float32) + y = triu_reference_implementation(x) + expect(node, inputs=[x], outputs=[y], name='test_triu_square') + + def test_triu_square_neg(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + domain="com.microsoft", + ) + + x = np.random.randn(3, 5, 5).astype(np.float32) + k = np.array([-1]).astype(np.int64) + y = triu_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_triu_square_neg') + + def test_triu_one_row_neg(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + domain="com.microsoft", + ) + + x = np.random.randn(3, 1, 5).astype(np.float32) + k = np.array([-7]).astype(np.int64) + y = triu_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_triu_one_row_neg') + + def test_triu_square_pos(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + domain="com.microsoft", + ) + + x = np.random.randn(3, 5, 5).astype(np.float32) + k = np.array([2]).astype(np.int64) + y = triu_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_triu_square_pos') + + def test_triu_zero(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + domain="com.microsoft", + ) + + x = np.random.randn(3, 0, 5).astype(np.float32) + k = np.array([6]).astype(np.int64) + y = triu_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_triu_zero') + + def test_tril(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x'], + outputs=['y'], + upper=0, + domain="com.microsoft", + ) + + x = np.random.randn(3, 4, 5).astype(np.float32) + y = tril_reference_implementation(x) + expect(node, inputs=[x], outputs=[y], name='test_tril') + + def test_tril_neg(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + upper=0, + domain="com.microsoft", + ) + + x = np.random.randn(3, 4, 5).astype(np.float32) + k = np.array([-1]).astype(np.int64) + y = tril_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_tril_neg') + + def test_tril_out_neg(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + upper=0, + domain="com.microsoft", + ) + + x = np.random.randn(3, 4, 5).astype(np.float32) + k = np.array([-7]).astype(np.int64) + y = tril_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_tril_out_neg') + + def test_tril_pos(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + upper=0, + domain="com.microsoft", + ) + + x = np.random.randn(3, 4, 5).astype(np.float32) + k = np.array([2]).astype(np.int64) + y = tril_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_tril_pos') + + def test_tril_out_pos(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + upper=0, + domain="com.microsoft", + ) + + x = np.random.randn(3, 4, 5).astype(np.float32) + k = np.array([6]).astype(np.int64) + y = tril_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_tril_out_pos') + + def test_tril_square(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x'], + outputs=['y'], + upper=0, + domain="com.microsoft", + ) + + x = np.random.randn(3, 5, 5).astype(np.float32) + y = tril_reference_implementation(x) + expect(node, inputs=[x], outputs=[y], name='test_tril_square') + + def test_tril_square_neg(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + upper=0, + domain="com.microsoft", + ) + + x = np.random.randn(3, 5, 5).astype(np.float32) + k = np.array([-1]).astype(np.int64) + y = tril_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_tril_square_neg') + + def test_tril_one_row_neg(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + upper=0, + domain="com.microsoft", + ) + + x = np.random.randn(3, 1, 5).astype(np.float32) + k = np.array([-7]).astype(np.int64) + y = tril_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_tril_one_row_neg') + + def test_tril_square_pos(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + upper=0, + domain="com.microsoft", + ) + + x = np.random.randn(3, 5, 5).astype(np.float32) + k = np.array([2]).astype(np.int64) + y = tril_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_tril_square_pos') + + def test_tril_zero(self): + node = onnx.helper.make_node( + 'Trilu', + inputs=['x', 'k'], + outputs=['y'], + upper=0, + domain="com.microsoft", + ) + + x = np.random.randn(3, 0, 5).astype(np.float32) + k = np.array([6]).astype(np.int64) + y = tril_reference_implementation(x, k) + expect(node, inputs=[x, k], outputs=[y], name='test_tril_zero') + + +if __name__ == '__main__': + unittest.main(module=__name__, buffer=True) diff --git a/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh b/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh index 57a7c451a8..660baeb3af 100755 --- a/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh +++ b/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh @@ -1,3 +1,4 @@ +#!/bin/bash pip3 install --user --upgrade pip @@ -7,3 +8,9 @@ pip3 install --user /build/Release/dist/*.whl export PYTHONPATH=/onnxruntime_src/tools:/usr/local/lib/python3.6/site-packages:$PYTHONPATH python3 -m pytest -v /onnxruntime_src/tools/test/test_custom_ops_pytorch_exporter.py || exit 1 + +for filename in /onnxruntime_src/onnxruntime/test/python/contrib_ops/onnx_test_* ; do + cd /build/Release && python3 -m pytest -v $filename || exit 1 +done + +cd /build/Release && ./onnxruntime_test_all --gtest_filter=ShapeInferenceTests.* || exit 1 diff --git a/tools/python/register_custom_ops_pytorch_exporter.py b/tools/python/register_custom_ops_pytorch_exporter.py index 2be2404f53..0a95f38770 100644 --- a/tools/python/register_custom_ops_pytorch_exporter.py +++ b/tools/python/register_custom_ops_pytorch_exporter.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# +# Register pytorch symbolic for export using ONNX Runtime contrib ops + from torch.onnx import register_custom_op_symbolic @@ -17,6 +22,14 @@ def register_custom_op(): def gelu(g, self): return g.op("com.microsoft::Gelu", self) + def triu(g, self, diagonal): + return g.op("com.microsoft::Trilu", self, diagonal, upper_i=1) + + def tril(g, self, diagonal): + return g.op("com.microsoft::Trilu", self, diagonal, upper_i=0) + # Op Registration register_custom_op_symbolic('::inverse', inverse, _onnx_opset_version) register_custom_op_symbolic('::gelu', gelu, _onnx_opset_version) + register_custom_op_symbolic('::triu', triu, _onnx_opset_version) + register_custom_op_symbolic('::tril', tril, _onnx_opset_version) diff --git a/tools/test/test_custom_ops_pytorch_exporter.py b/tools/test/test_custom_ops_pytorch_exporter.py index 6b905c8594..035dc6c59a 100644 --- a/tools/test/test_custom_ops_pytorch_exporter.py +++ b/tools/test/test_custom_ops_pytorch_exporter.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# +# Test export of pytorch operators using ONNX Runtime contrib ops + import torch import onnxruntime import numpy as np @@ -32,7 +37,7 @@ def ort_test_with_input(ort_sess, input, output, rtol, atol): # These set of tests verify ONNX model export and compare onnxruntime outputs to pytorch. # To register custom ops and run the tests, you should set PYTHONPATH as: -# PYTHONPATH= pytest -v test_custom_ops_pytorch_exporter.py +# PYTHONPATH= python -m pytest -v test_custom_ops_pytorch_exporter.py class ONNXExporterTest(unittest.TestCase): from torch.onnx.symbolic_helper import _export_onnx_opset_version opset_version = _export_onnx_opset_version @@ -105,21 +110,68 @@ class ONNXExporterTest(unittest.TestCase): x = torch.randn(3, 3) self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + def test_triu(self): + for i in range(-5, 5): + class Module(torch.nn.Module): + def forward(self, input): + return input.triu(diagonal=i) -# opset 10 tests -TestONNXRuntime_opset10 = type(str("TestONNXRuntime_opset10"), - (unittest.TestCase,), - dict(ONNXExporterTest.__dict__, opset_version=10)) + model = Module() + x = torch.randn(5, 4, 7, dtype=torch.float32) + self.run_test(model, x, custom_opsets={'com.microsoft': 1}) -# opset 11 tests -ONNXExporterTest_opset11 = type(str("TestONNXRuntime_opset11"), - (unittest.TestCase,), - dict(ONNXExporterTest.__dict__, opset_version=11)) + x = torch.randn(5, 4, 0, dtype=torch.float32) + self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + + x = torch.randn(5, 0, 0, dtype=torch.float32) + self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + + for i in range(-5, 5): + class Module2D(torch.nn.Module): + def forward(self, input): + return input.triu(diagonal=i) + + model = Module2D() + x = torch.randn(4, 7, dtype=torch.float32) + self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + + x = torch.randn(0, 7, dtype=torch.float32) + self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + + x = torch.randn(0, 0, dtype=torch.float32) + self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + + def test_tril(self): + for i in range(-5, 5): + class Module(torch.nn.Module): + def forward(self, input): + return input.tril(diagonal=i) + + model = Module() + x = torch.randn(5, 4, 7, dtype=torch.float32) + self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + + x = torch.randn(5, 4, 0, dtype=torch.float32) + self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + + x = torch.randn(5, 0, 0, dtype=torch.float32) + self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + + for i in range(-5, 5): + class Module2D(torch.nn.Module): + def forward(self, input): + return input.tril(diagonal=i) + + model = Module2D() + x = torch.randn(4, 7, dtype=torch.float32) + self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + + x = torch.randn(0, 7, dtype=torch.float32) + self.run_test(model, x, custom_opsets={'com.microsoft': 1}) + + x = torch.randn(0, 0, dtype=torch.float32) + self.run_test(model, x, custom_opsets={'com.microsoft': 1}) -# opset 12 tests -ONNXExporterTest_opset12 = type(str("TestONNXRuntime_opset12"), - (unittest.TestCase,), - dict(ONNXExporterTest.__dict__, opset_version=12)) # opset 9 tests, with keep_initializers_as_inputs=False for # IR version 4 style export. @@ -128,27 +180,6 @@ ONNXExporterTest_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"), dict(ONNXExporterTest.__dict__, keep_initializers_as_inputs=False)) -# opset 10 tests, with keep_initializers_as_inputs=False for -# IR version 4 style export. -ONNXExporterTest_opset10_IRv4 = type(str("TestONNXRuntime_opset10_IRv4"), - (unittest.TestCase,), - dict(ONNXExporterTest.__dict__, opset_version=10, - keep_initializers_as_inputs=False)) - - -# opset 11 tests, with keep_initializers_as_inputs=False for -# IR version 4 style export. -ONNXExporterTest_opset11_IRv4 = type(str("TestONNXRuntime_opset11_IRv4"), - (unittest.TestCase,), - dict(ONNXExporterTest.__dict__, opset_version=11, - keep_initializers_as_inputs=False)) - -# opset 12 tests, with keep_initializers_as_inputs=False for -# IR version 4 style export. -ONNXExporterTest_opset12_IRv4 = type(str("TestONNXRuntime_opset12_IRv4"), - (unittest.TestCase,), - dict(ONNXExporterTest.__dict__, opset_version=12, - keep_initializers_as_inputs=False)) if __name__ == '__main__': unittest.main()