mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
Implement Inverse(12) for CPU and CUDA (#3485)
This commit is contained in:
parent
38a18023c7
commit
db9566f70d
7 changed files with 407 additions and 1 deletions
90
onnxruntime/contrib_ops/cpu/inverse.cc
Normal file
90
onnxruntime/contrib_ops/cpu/inverse.cc
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "Eigen/src/Core/Map.h"
|
||||
#include "Eigen/LU"
|
||||
#include <functional>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
class Inverse final : public OpKernel {
|
||||
public:
|
||||
explicit Inverse(const OpKernelInfo& info) : OpKernel(info) {}
|
||||
Status Compute(OpKernelContext* ctx) const override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
struct ComputeImpl;
|
||||
};
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Inverse,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", BuildKernelDefConstraints<float, double, MLFloat16>()),
|
||||
Inverse);
|
||||
|
||||
template <typename T>
|
||||
using MatrixT = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
|
||||
template <typename T>
|
||||
struct Inverse::ComputeImpl {
|
||||
void operator()(const Tensor* input, Tensor* output,
|
||||
int64_t batch_num, int64_t rows, int64_t cols) const {
|
||||
auto batch_offset = batch_num * rows * cols;
|
||||
const auto* input_data = input->Data<T>() + batch_offset;
|
||||
auto* output_data = output->MutableData<T>() + batch_offset;
|
||||
|
||||
Eigen::Map<const MatrixT<T>> input_matrix(input_data, rows, cols);
|
||||
Eigen::Map<MatrixT<T>> output_matrix(output_data, rows, cols);
|
||||
output_matrix = input_matrix.inverse();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Inverse::ComputeImpl<MLFloat16> {
|
||||
void operator()(const Tensor* input, Tensor* output,
|
||||
int64_t batch_num, int64_t rows, int64_t cols) const {
|
||||
auto batch_offset = batch_num * rows * cols;
|
||||
// Direct cast to half as it just as MLFloat16 containes only uint16_t
|
||||
const auto* input_data = reinterpret_cast<const Eigen::half*>(input->Data<MLFloat16>() + batch_offset);
|
||||
auto* output_data = reinterpret_cast<Eigen::half*>(output->MutableData<MLFloat16>() + batch_offset);
|
||||
|
||||
Eigen::Map<const MatrixT<Eigen::half>> input_matrix(input_data, rows, cols);
|
||||
Eigen::Map<MatrixT<Eigen::half>> output_matrix(output_data, rows, cols);
|
||||
output_matrix = input_matrix.inverse();
|
||||
}
|
||||
};
|
||||
|
||||
Status Inverse::Compute(OpKernelContext* ctx) const {
|
||||
const auto& input = ctx->Input<Tensor>(0);
|
||||
const auto elem_type = input->GetElementType();
|
||||
const auto& input_shape = input->Shape();
|
||||
const auto num_dim = input_shape.NumDimensions();
|
||||
auto* output = ctx->Output(0, input_shape);
|
||||
|
||||
int64_t num_batches = 1;
|
||||
const int64_t rows = input_shape.GetDims()[num_dim - 2];
|
||||
const int64_t cols = input_shape.GetDims()[num_dim - 1];
|
||||
if (num_dim > 2) {
|
||||
num_batches = input_shape.SizeToDimension(num_dim - 2);
|
||||
}
|
||||
|
||||
std::function<void(ptrdiff_t)> fn = [elem_type, input, output, rows, cols](ptrdiff_t batch_num) {
|
||||
utils::MLTypeCallDispatcher<ComputeImpl, float, double, MLFloat16> t_disp(elem_type);
|
||||
t_disp.Invoke(input, output, batch_num, rows, cols);
|
||||
};
|
||||
|
||||
concurrency::ThreadPool::TryBatchParallelFor(ctx->GetOperatorThreadPool(), num_batches, std::move(fn), 0);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -62,6 +62,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, LayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse);
|
||||
|
||||
Status RegisterNchwcKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
|
|
@ -127,6 +128,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
160
onnxruntime/contrib_ops/cuda/inverse.cc
Normal file
160
onnxruntime/contrib_ops/cuda/inverse.cc
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
class Inverse final : public ::onnxruntime::cuda::CudaKernel {
|
||||
public:
|
||||
explicit Inverse(const OpKernelInfo& info) : CudaKernel{info} {
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
using Base = CudaKernel;
|
||||
using CublasHandle = cublasHandle_t;
|
||||
|
||||
template <typename T>
|
||||
struct ComputeImpl;
|
||||
};
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Inverse,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", BuildKernelDefConstraints<float, double, MLFloat16>()),
|
||||
Inverse);
|
||||
|
||||
namespace inverse_internal {
|
||||
|
||||
template <typename T>
|
||||
Status ComputeMatrixOffsets(T* workspace_data, size_t num_batches, size_t rows, IAllocatorUniquePtr<T*>& matrix_ptrs) {
|
||||
std::vector<T*> cuda_ptrs;
|
||||
const size_t matrix_size = rows * rows;
|
||||
for (size_t i = 0; i < num_batches; ++i) {
|
||||
cuda_ptrs.push_back(workspace_data);
|
||||
workspace_data += matrix_size;
|
||||
}
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpy(matrix_ptrs.get(), cuda_ptrs.data(), sizeof(T*) * num_batches,
|
||||
cudaMemcpyHostToDevice));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckForSingularity(const IAllocatorUniquePtr<int>& info, const std::unique_ptr<int[]>& info_cpu, size_t num_batches) {
|
||||
// Let's check if any of the info values is non-zero
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpy(info_cpu.get(), info.get(), sizeof(int) * num_batches,
|
||||
cudaMemcpyDeviceToHost));
|
||||
for (size_t i = 0; i < num_batches; ++i) {
|
||||
if (info_cpu[i] != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Matrix is singular at batch:", i);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace inverse_internal
|
||||
|
||||
template <typename T>
|
||||
struct Inverse::ComputeImpl {
|
||||
Status operator()(Inverse::CublasHandle cublas_h, const Inverse* inst, const Tensor& input, Tensor& output,
|
||||
const IAllocatorUniquePtr<int>& info, const IAllocatorUniquePtr<int>& pivots,
|
||||
size_t num_batches, size_t rows) const {
|
||||
using namespace onnxruntime::cuda;
|
||||
using namespace inverse_internal;
|
||||
using CudaT = typename ToCudaType<T>::MappedType;
|
||||
const size_t input_count = static_cast<size_t>(input.Shape().Size());
|
||||
auto info_cpu = onnxruntime::make_unique<int[]>(num_batches);
|
||||
const auto dim = static_cast<int>(rows);
|
||||
const auto n_batches = static_cast<int>(num_batches);
|
||||
|
||||
// Make a copy of the input which will serve as a workspace as well.
|
||||
if (std::is_same<T, float>::value || std::is_same<T, MLFloat16>::value) {
|
||||
IAllocatorUniquePtr<float> input_workspace = inst->GetScratchBuffer<float>(input_count);
|
||||
if (std::is_same<T, MLFloat16>::value) {
|
||||
// Convert from MLFloat16(half) to float
|
||||
Impl_Cast<CudaT, float>(reinterpret_cast<const CudaT*>(input.Data<MLFloat16>()), input_workspace.get(), input_count);
|
||||
} else {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpy(input_workspace.get(), input.Data<float>(), sizeof(float) * input_count,
|
||||
cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
IAllocatorUniquePtr<float*> matrix_ptrs = inst->GetScratchBuffer<float*>(n_batches);
|
||||
ORT_RETURN_IF_ERROR(ComputeMatrixOffsets<float>(input_workspace.get(), num_batches, rows, matrix_ptrs));
|
||||
// Do LU factorization
|
||||
CUBLAS_RETURN_IF_ERROR(cublasSgetrfBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), info.get(), n_batches));
|
||||
ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches));
|
||||
|
||||
// Need to compute ptrs for output buffers
|
||||
// Output for MLFloat
|
||||
IAllocatorUniquePtr<float*> output_ptrs = inst->GetScratchBuffer<float*>(n_batches);
|
||||
if (std::is_same<T, MLFloat16>::value) {
|
||||
IAllocatorUniquePtr<float> ml_float_output = inst->GetScratchBuffer<float>(input_count);
|
||||
ORT_RETURN_IF_ERROR(ComputeMatrixOffsets<float>(ml_float_output.get(), num_batches, rows, output_ptrs));
|
||||
// Do the inverse
|
||||
CUBLAS_RETURN_IF_ERROR(cublasSgetriBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), output_ptrs.get(), dim, info.get(), n_batches));
|
||||
ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches));
|
||||
// Copy the result to output with casting
|
||||
Impl_Cast<float, CudaT>(ml_float_output.get(), reinterpret_cast<CudaT*>(output.MutableData<MLFloat16>()), input_count);
|
||||
// We are done here
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(ComputeMatrixOffsets<float>(output.MutableData<float>(), num_batches, rows, output_ptrs));
|
||||
// Do the inverse
|
||||
CUBLAS_RETURN_IF_ERROR(cublasSgetriBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), output_ptrs.get(), dim, info.get(), n_batches));
|
||||
ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches));
|
||||
// We are done here
|
||||
}
|
||||
} else if (std::is_same<T, double>::value) {
|
||||
IAllocatorUniquePtr<double> input_workspace = inst->GetScratchBuffer<double>(static_cast<int>(input_count));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpy(input_workspace.get(), input.Data<double>(), sizeof(double) * input_count,
|
||||
cudaMemcpyDeviceToDevice));
|
||||
|
||||
IAllocatorUniquePtr<double*> matrix_ptrs = inst->GetScratchBuffer<double*>(n_batches);
|
||||
ORT_RETURN_IF_ERROR(ComputeMatrixOffsets<double>(input_workspace.get(), num_batches, rows, matrix_ptrs));
|
||||
// Do LU factorization
|
||||
CUBLAS_RETURN_IF_ERROR(cublasDgetrfBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), info.get(), n_batches));
|
||||
ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches));
|
||||
|
||||
// Need to compute ptrs for output buffers
|
||||
IAllocatorUniquePtr<double*> output_ptrs = inst->GetScratchBuffer<double*>(n_batches);
|
||||
ORT_RETURN_IF_ERROR(ComputeMatrixOffsets<double>(output.MutableData<double>(), num_batches, rows, output_ptrs));
|
||||
CUBLAS_RETURN_IF_ERROR(cublasDgetriBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), output_ptrs.get(), dim, info.get(), n_batches));
|
||||
ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches));
|
||||
// We are done here
|
||||
} else {
|
||||
ORT_THROW("Type is not supported");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
Status Inverse::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const auto* input = ctx->Input<Tensor>(0);
|
||||
const auto& input_shape = input->Shape();
|
||||
const auto num_dim = input_shape.NumDimensions();
|
||||
auto* output = ctx->Output(0, input_shape);
|
||||
|
||||
size_t num_batches = 1;
|
||||
const size_t rows = static_cast<size_t>(input_shape.GetDims()[num_dim - 2]);
|
||||
const size_t cols = static_cast<size_t>(input_shape.GetDims()[num_dim - 1]);
|
||||
ORT_ENFORCE(rows == cols, "Expecting square matrices");
|
||||
if (num_dim > 2) {
|
||||
num_batches = static_cast<size_t>(input_shape.SizeToDimension(num_dim - 2));
|
||||
}
|
||||
|
||||
IAllocatorUniquePtr<int> info = GetScratchBuffer<int>(num_batches);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemset(info.get(), 0, num_batches));
|
||||
IAllocatorUniquePtr<int> pivots = GetScratchBuffer<int>(rows * num_batches);
|
||||
|
||||
utils::MLTypeCallDispatcherRet<Status, ComputeImpl, float, double, MLFloat16> t_disp(input->GetElementType());
|
||||
return t_disp.Invoke(Base::CublasHandle(), this, *input, *output, info, pivots, num_batches, rows);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -56,6 +56,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float, LayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_float, LayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float, LayerNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear);
|
||||
|
|
@ -110,6 +111,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_float, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear)>,
|
||||
|
|
|
|||
|
|
@ -2357,6 +2357,62 @@ It's an extension of Gelu. It takes the sum of input A and bias input B as the i
|
|||
"Constrain input and output types to float tensors.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
|
||||
|
||||
// Used to be ONNX 1.7 Inverse(12)
|
||||
// Comment out docs not to increase the binary size
|
||||
//
|
||||
// static const char* Inverse_ver1_doc = R"DOC(
|
||||
//Calculates inverse of a square matrix or batches of square matrices.
|
||||
//Inverse takes one input tensor of shape `[*, M, M]`, where `*` is zero or more batch dimensions,
|
||||
//and the inner-most 2 dimensions form square matrices. These matrices must be invertible (full-rank).
|
||||
//The behavior where one of the matrices is not invertible is undefined. The implementation can choose
|
||||
//to throw an error or output (garbage) results as is. The output is a tensor of shape `[*, M, M]`,
|
||||
//containing the individual inverses of all input submatrices.
|
||||
//)DOC";
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(Inverse)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
|
||||
.Input(0, "X", "Input tensor. Every matrix in the batch must be invertible.", "T")
|
||||
.Output(0, "Y", "Output tensor of the same type and shape as the input tensor.", "T")
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(float16)",
|
||||
"tensor(float)",
|
||||
"tensor(double)"},
|
||||
"Constrain input and output types to float tensors.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
// Type inference
|
||||
using namespace ONNX_NAMESPACE;
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
|
||||
// Shape inference
|
||||
if (hasInputShape(ctx, 0)) {
|
||||
const TensorShapeProto& input_shape =
|
||||
ctx.getInputType(0)->tensor_type().shape();
|
||||
const int rank = static_cast<int>(input_shape.dim_size());
|
||||
|
||||
if (rank < 2) {
|
||||
fail_shape_inference("Input rank must be >= 2.")
|
||||
}
|
||||
|
||||
const auto mat_w = input_shape.dim(rank - 1);
|
||||
const auto mat_h = input_shape.dim(rank - 2);
|
||||
if (mat_w.has_dim_value() && mat_h.has_dim_value() &&
|
||||
(mat_w.dim_value() != mat_h.dim_value())) {
|
||||
fail_shape_inference(
|
||||
"The inner-most 2 dimensions must have the same size (mat_w:",
|
||||
mat_w.dim_value(),
|
||||
" != mat_h:",
|
||||
mat_h.dim_value(),
|
||||
").");
|
||||
}
|
||||
|
||||
// Shape inference
|
||||
propagateShapeFromInputToOutput(ctx, 0, 0);
|
||||
}
|
||||
});
|
||||
|
||||
RegisterBertSchemas();
|
||||
}
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -1052,7 +1052,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12,Clip)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Min)>,
|
||||
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Max)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, MaxPool)>,
|
||||
|
|
|
|||
95
onnxruntime/test/contrib_ops/inverse_test.cc
Normal file
95
onnxruntime/test/contrib_ops/inverse_test.cc
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "core/util/math.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(InverseContribOpTest, two_by_two_float) {
|
||||
OpTester test("Inverse", 1, kMSDomain);
|
||||
test.AddInput<float>("X", {2, 2}, {4, 7, 2, 6});
|
||||
test.AddOutput<float>("Y", {2, 2}, {0.6f, -0.7f, -0.2f, 0.4f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(InverseContribOpTest, two_by_two_double) {
|
||||
OpTester test("Inverse", 1, kMSDomain);
|
||||
test.AddInput<double>("X", {2, 2}, {4, 7, 2, 6});
|
||||
test.AddOutput<double>("Y", {2, 2}, {0.6f, -0.7f, -0.2f, 0.4f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(InverseContribOpTest, two_by_two_float16) {
|
||||
OpTester test("Inverse", 1, kMSDomain);
|
||||
|
||||
auto input_float = {4.f, 7.f, 2.f, 6.f};
|
||||
std::vector<MLFloat16> input;
|
||||
std::transform(
|
||||
input_float.begin(), input_float.end(), std::back_inserter(input),
|
||||
[](float v) {
|
||||
return MLFloat16(math::floatToHalf(v));
|
||||
});
|
||||
|
||||
auto output_float = {0.6f, -0.7f, -0.2f, 0.4f};
|
||||
std::vector<MLFloat16> output;
|
||||
std::transform(
|
||||
output_float.begin(), output_float.end(), std::back_inserter(output), [](float v) {
|
||||
return MLFloat16(math::floatToHalf(v));
|
||||
});
|
||||
|
||||
test.AddInput<MLFloat16>("X", {2, 2}, input);
|
||||
test.AddOutput<MLFloat16>("Y", {2, 2}, output);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(InverseContribOpTest, four_by_four_float) {
|
||||
OpTester test("Inverse", 1, kMSDomain);
|
||||
test.AddInput<float>("X", {4, 4},
|
||||
{4.f, 0.f, 0.f, 0.f,
|
||||
0.f, 0.f, 2.f, 0.f,
|
||||
0.f, 1.f, 2.f, 0.f,
|
||||
1.f, 0.f, 0.f, 1.f
|
||||
});
|
||||
test.AddOutput<float>("Y", {4, 4}, {
|
||||
0.25f, 0.f, 0.f, 0.f,
|
||||
0.f, -1.f, 1.f, 0.f,
|
||||
0.f, 0.5f, 0.f, 0.f,
|
||||
-0.25f, 0.f, 0.f, 1.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(InverseContribOpTest, four_by_four_batches_float) {
|
||||
OpTester test("Inverse", 1, kMSDomain);
|
||||
|
||||
auto one_input_matrix_4x4 = {
|
||||
4.f, 0.f, 0.f, 0.f,
|
||||
0.f, 0.f, 2.f, 0.f,
|
||||
0.f, 1.f, 2.f, 0.f,
|
||||
1.f, 0.f, 0.f, 1.f};
|
||||
|
||||
// batches 3x4 i.e. 12 batches so the full shape is 3x4x4x4
|
||||
std::vector<float> input;
|
||||
for (int64_t i = 0; i < 3 * 4; ++i) {
|
||||
std::copy(one_input_matrix_4x4.begin(), one_input_matrix_4x4.end(), std::back_inserter(input));
|
||||
}
|
||||
|
||||
auto one_output_matrix_4x4 = {
|
||||
0.25f, 0.f, 0.f, 0.f,
|
||||
0.f, -1.f, 1.f, 0.f,
|
||||
0.f, 0.5f, 0.f, 0.f,
|
||||
-0.25f, 0.f, 0.f, 1.f};
|
||||
|
||||
std::vector<float> output;
|
||||
for (int64_t i = 0; i < 3 * 4; ++i) {
|
||||
std::copy(one_output_matrix_4x4.begin(), one_output_matrix_4x4.end(), std::back_inserter(output));
|
||||
}
|
||||
|
||||
test.AddInput<float>("Input", {3, 4, 4, 4}, input);
|
||||
test.AddOutput<float>("Output", {3, 4, 4, 4}, output);
|
||||
test.Run();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue