Expand: add additional supported types. (#364)

This commit is contained in:
Bowen Bao 2019-01-22 19:07:36 -08:00 committed by Changming Sun
parent ea816615eb
commit d040b452cb
3 changed files with 142 additions and 6 deletions

View file

@ -210,6 +210,17 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Uns
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Upsample);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, int32_t, Upsample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int8_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int16_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int32_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int64_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint8_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint16_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint32_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint64_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, bool, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, MLFloat16, Expand);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 8, Scan);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, If);
@ -440,6 +451,17 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, float, Upsample)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, int32_t, Upsample)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, float, Expand)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, double, Expand)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int8_t, Expand)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int16_t, Expand)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int32_t, Expand)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, int64_t, Expand)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint8_t, Expand)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint16_t, Expand)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint32_t, Expand)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, uint64_t, Expand)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, bool, Expand)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, MLFloat16, Expand)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 8, 8, Scan)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, If)>());

View file

@ -1001,12 +1001,26 @@ Status Expand_8<T>::Compute(OpKernelContext* context) const {
return Status::OK();
}
ONNX_CPU_OPERATOR_TYPED_KERNEL(
Expand,
8,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Expand_8<float>);
#define REG_EXPAND_KERNEL(TYPE) \
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
Expand, \
8, \
TYPE, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<TYPE>()), \
Expand_8<TYPE>);
REG_EXPAND_KERNEL(float)
REG_EXPAND_KERNEL(double)
REG_EXPAND_KERNEL(int8_t)
REG_EXPAND_KERNEL(int16_t)
REG_EXPAND_KERNEL(int32_t)
REG_EXPAND_KERNEL(int64_t)
REG_EXPAND_KERNEL(uint8_t)
REG_EXPAND_KERNEL(uint16_t)
REG_EXPAND_KERNEL(uint32_t)
REG_EXPAND_KERNEL(uint64_t)
REG_EXPAND_KERNEL(bool)
REG_EXPAND_KERNEL(MLFloat16)
template <>
Status Scale<float>::Compute(OpKernelContext* ctx) const {

View file

@ -3,6 +3,7 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "core/util/math.h"
namespace onnxruntime {
namespace test {
@ -877,6 +878,105 @@ TEST(MathOpTest, Expand_8_1x3) {
test.Run();
}
TEST(MathOpTest, Expand_8_3x3_int32) {
OpTester test("Expand", 8);
test.AddInput<int32_t>("data_0", {1}, {1});
test.AddInput<int64_t>("data_1", {2}, {3, 3});
test.AddOutput<int32_t>("result", {3, 3},
{1, 1, 1,
1, 1, 1,
1, 1, 1});
test.Run();
}
TEST(MathOpTest, Expand_8_3x1_int32) {
OpTester test("Expand", 8);
test.AddInput<int32_t>("data_0", {3}, {1, 2, 3});
test.AddInput<int64_t>("data_1", {2}, {3, 1});
test.AddOutput<int32_t>("result", {3, 3},
{1, 2, 3,
1, 2, 3,
1, 2, 3});
test.Run();
}
TEST(MathOpTest, Expand_8_1x3_int32) {
OpTester test("Expand", 8);
test.AddInput<int32_t>("data_0", {3, 1}, {1, 2, 3});
test.AddInput<int64_t>("data_1", {2}, {1, 3});
test.AddOutput<int32_t>("result", {3, 3},
{1, 1, 1,
2, 2, 2,
3, 3, 3});
test.Run();
}
TEST(MathOpTest, Expand_8_3x3_int64) {
OpTester test("Expand", 8);
test.AddInput<int64_t>("data_0", {1}, {1});
test.AddInput<int64_t>("data_1", {2}, {3, 3});
test.AddOutput<int64_t>("result", {3, 3},
{1, 1, 1,
1, 1, 1,
1, 1, 1});
test.Run();
}
TEST(MathOpTest, Expand_8_3x1_int64) {
OpTester test("Expand", 8);
test.AddInput<int64_t>("data_0", {3}, {1, 2, 3});
test.AddInput<int64_t>("data_1", {2}, {3, 1});
test.AddOutput<int64_t>("result", {3, 3},
{1, 2, 3,
1, 2, 3,
1, 2, 3});
test.Run();
}
TEST(MathOpTest, Expand_8_1x3_int64) {
OpTester test("Expand", 8);
test.AddInput<int64_t>("data_0", {3, 1}, {1, 2, 3});
test.AddInput<int64_t>("data_1", {2}, {1, 3});
test.AddOutput<int64_t>("result", {3, 3},
{1, 1, 1,
2, 2, 2,
3, 3, 3});
test.Run();
}
TEST(MathOpTest, Expand_8_3x3_float16) {
OpTester test("Expand", 8);
test.AddInput<MLFloat16>("data_0", {1}, {MLFloat16(math::floatToHalf(1.0f))});
test.AddInput<int64_t>("data_1", {2}, {3, 3});
test.AddOutput<MLFloat16>("result", {3, 3},
{MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)),
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)),
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f))});
test.Run();
}
TEST(MathOpTest, Expand_8_3x1_float16) {
OpTester test("Expand", 8);
test.AddInput<MLFloat16>("data_0", {3}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))});
test.AddInput<int64_t>("data_1", {2}, {3, 1});
test.AddOutput<MLFloat16>("result", {3, 3},
{MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f)),
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f)),
MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))});
test.Run();
}
TEST(MathOpTest, Expand_8_1x3_float16) {
OpTester test("Expand", 8);
test.AddInput<MLFloat16>("data_0", {3, 1}, {MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(3.0f))});
test.AddInput<int64_t>("data_1", {2}, {1, 3});
test.AddOutput<MLFloat16>("result", {3, 3},
{MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)), MLFloat16(math::floatToHalf(1.0f)),
MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(2.0f)), MLFloat16(math::floatToHalf(2.0f)),
MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f)), MLFloat16(math::floatToHalf(3.0f))});
test.Run();
}
TEST(MathOpTest, Scale) {
OpTester test("Scale");
std::vector<int64_t> dims{2, 2};