add new types to shape op (#362)

* add new types to shape op

* add all fixed type support
This commit is contained in:
Randy 2019-01-24 09:51:09 -08:00 committed by GitHub
parent 61bbf4bfcc
commit 89f643f04b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 22 deletions

View file

@ -5,20 +5,10 @@
namespace onnxruntime {
const std::vector<MLDataType> shapeOpTypeConstraints{
DataTypeImpl::GetTensorType<bool>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int16_t>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<uint16_t>()};
ONNX_CPU_OPERATOR_KERNEL(
Shape,
1,
KernelDefBuilder().TypeConstraint("T", shapeOpTypeConstraints).TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()).TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
Shape);
} // namespace onnxruntime

View file

@ -7,16 +7,6 @@
namespace onnxruntime {
namespace cuda {
const std::vector<MLDataType> shapeOpTypeConstraints{
DataTypeImpl::GetTensorType<bool>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int16_t>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<uint16_t>()};
ONNX_OPERATOR_KERNEL_EX(
Shape,
kOnnxDomain,
@ -24,7 +14,7 @@ ONNX_OPERATOR_KERNEL_EX(
kCudaExecutionProvider,
KernelDefBuilder()
.OutputMemoryType<OrtMemTypeCPUOutput>(0)
.TypeConstraint("T", shapeOpTypeConstraints)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
Shape);

View file

@ -0,0 +1,29 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
template<typename T>
void TestShape(const std::initializer_list<T>& data, const std::vector<int64_t>& shape)
{
OpTester test("Shape");
test.AddInput<T>("data", shape, data);
test.AddOutput<int64_t>("output", {static_cast<int64_t>(shape.size())}, shape);
test.Run();
}
TEST(ShapeOpTest, ShapeTestBool) { TestShape <bool> ({true, true, false, false, true, false}, {2, 3}); }
TEST(ShapeOpTest, ShapeTestFloat) { TestShape <float> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 6}); }
TEST(ShapeOpTest, ShapeTestDouble) { TestShape <double> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {6, 2}); }
TEST(ShapeOpTest, ShapeTestInt8) { TestShape <int8_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); }
TEST(ShapeOpTest, ShapeTestInt16) { TestShape <int16_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); }
TEST(ShapeOpTest, ShapeTestInt32) { TestShape <int32_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {4, 3}); }
TEST(ShapeOpTest, ShapeTestInt64) { TestShape <int64_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
TEST(ShapeOpTest, ShapeTestUint8) { TestShape <uint8_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); }
TEST(ShapeOpTest, ShapeTestUint16) { TestShape <uint16_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
TEST(ShapeOpTest, ShapeTestUint32) { TestShape <uint32_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); }
TEST(ShapeOpTest, ShapeTestUint64) { TestShape <uint64_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
}
}