From 89f643f04ba990cb0dac045183dc46396fd1280b Mon Sep 17 00:00:00 2001 From: Randy <45701928+RandyShuai@users.noreply.github.com> Date: Thu, 24 Jan 2019 09:51:09 -0800 Subject: [PATCH] add new types to shape op (#362) * add new types to shape op * add all fixed type support --- .../core/providers/cpu/tensor/shape_op.cc | 12 +------- .../core/providers/cuda/tensor/shape_op.cc | 12 +------- .../providers/cpu/tensor/shape_op_test.cc | 29 +++++++++++++++++++ 3 files changed, 31 insertions(+), 22 deletions(-) create mode 100644 onnxruntime/test/providers/cpu/tensor/shape_op_test.cc diff --git a/onnxruntime/core/providers/cpu/tensor/shape_op.cc b/onnxruntime/core/providers/cpu/tensor/shape_op.cc index aba5b45c85..4433333ee7 100644 --- a/onnxruntime/core/providers/cpu/tensor/shape_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/shape_op.cc @@ -5,20 +5,10 @@ namespace onnxruntime { -const std::vector shapeOpTypeConstraints{ - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}; - ONNX_CPU_OPERATOR_KERNEL( Shape, 1, - KernelDefBuilder().TypeConstraint("T", shapeOpTypeConstraints).TypeConstraint("T1", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()).TypeConstraint("T1", DataTypeImpl::GetTensorType()), Shape); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/shape_op.cc b/onnxruntime/core/providers/cuda/tensor/shape_op.cc index e9ccd80764..d1969785b4 100644 --- a/onnxruntime/core/providers/cuda/tensor/shape_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/shape_op.cc @@ -7,16 +7,6 @@ namespace onnxruntime { namespace cuda { -const std::vector shapeOpTypeConstraints{ - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}; - ONNX_OPERATOR_KERNEL_EX( Shape, kOnnxDomain, @@ -24,7 +14,7 @@ ONNX_OPERATOR_KERNEL_EX( kCudaExecutionProvider, KernelDefBuilder() .OutputMemoryType(0) - .TypeConstraint("T", shapeOpTypeConstraints) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Shape); diff --git a/onnxruntime/test/providers/cpu/tensor/shape_op_test.cc b/onnxruntime/test/providers/cpu/tensor/shape_op_test.cc new file mode 100644 index 0000000000..afd73b9c9f --- /dev/null +++ b/onnxruntime/test/providers/cpu/tensor/shape_op_test.cc @@ -0,0 +1,29 @@ +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +template +void TestShape(const std::initializer_list& data, const std::vector& shape) +{ + OpTester test("Shape"); + test.AddInput("data", shape, data); + test.AddOutput("output", {static_cast(shape.size())}, shape); + test.Run(); +} + +TEST(ShapeOpTest, ShapeTestBool) { TestShape ({true, true, false, false, true, false}, {2, 3}); } +TEST(ShapeOpTest, ShapeTestFloat) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 6}); } +TEST(ShapeOpTest, ShapeTestDouble) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {6, 2}); } +TEST(ShapeOpTest, ShapeTestInt8) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); } +TEST(ShapeOpTest, ShapeTestInt16) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); } +TEST(ShapeOpTest, ShapeTestInt32) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {4, 3}); } +TEST(ShapeOpTest, ShapeTestInt64) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); } +TEST(ShapeOpTest, ShapeTestUint8) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); } +TEST(ShapeOpTest, ShapeTestUint16) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); } +TEST(ShapeOpTest, ShapeTestUint32) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); } +TEST(ShapeOpTest, ShapeTestUint64) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); } + +} +} \ No newline at end of file