diff --git a/onnxruntime/core/providers/cpu/tensor/shape_op.cc b/onnxruntime/core/providers/cpu/tensor/shape_op.cc index aba5b45c85..4433333ee7 100644 --- a/onnxruntime/core/providers/cpu/tensor/shape_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/shape_op.cc @@ -5,20 +5,10 @@ namespace onnxruntime { -const std::vector shapeOpTypeConstraints{ - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}; - ONNX_CPU_OPERATOR_KERNEL( Shape, 1, - KernelDefBuilder().TypeConstraint("T", shapeOpTypeConstraints).TypeConstraint("T1", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()).TypeConstraint("T1", DataTypeImpl::GetTensorType()), Shape); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/shape_op.cc b/onnxruntime/core/providers/cuda/tensor/shape_op.cc index e9ccd80764..d1969785b4 100644 --- a/onnxruntime/core/providers/cuda/tensor/shape_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/shape_op.cc @@ -7,16 +7,6 @@ namespace onnxruntime { namespace cuda { -const std::vector shapeOpTypeConstraints{ - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}; - ONNX_OPERATOR_KERNEL_EX( Shape, kOnnxDomain, @@ -24,7 +14,7 @@ ONNX_OPERATOR_KERNEL_EX( kCudaExecutionProvider, KernelDefBuilder() .OutputMemoryType(0) - .TypeConstraint("T", shapeOpTypeConstraints) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Shape); diff --git a/onnxruntime/test/providers/cpu/tensor/shape_op_test.cc b/onnxruntime/test/providers/cpu/tensor/shape_op_test.cc new file mode 100644 index 0000000000..afd73b9c9f --- /dev/null +++ b/onnxruntime/test/providers/cpu/tensor/shape_op_test.cc @@ -0,0 +1,29 @@ +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +template +void TestShape(const std::initializer_list& data, const std::vector& shape) +{ + OpTester test("Shape"); + test.AddInput("data", shape, data); + test.AddOutput("output", {static_cast(shape.size())}, shape); + test.Run(); +} + +TEST(ShapeOpTest, ShapeTestBool) { TestShape ({true, true, false, false, true, false}, {2, 3}); } +TEST(ShapeOpTest, ShapeTestFloat) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 6}); } +TEST(ShapeOpTest, ShapeTestDouble) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {6, 2}); } +TEST(ShapeOpTest, ShapeTestInt8) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); } +TEST(ShapeOpTest, ShapeTestInt16) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); } +TEST(ShapeOpTest, ShapeTestInt32) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {4, 3}); } +TEST(ShapeOpTest, ShapeTestInt64) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); } +TEST(ShapeOpTest, ShapeTestUint8) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); } +TEST(ShapeOpTest, ShapeTestUint16) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); } +TEST(ShapeOpTest, ShapeTestUint32) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); } +TEST(ShapeOpTest, ShapeTestUint64) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); } + +} +} \ No newline at end of file