mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
add new types to shape op (#362)
* add new types to shape op * add all fixed type support
This commit is contained in:
parent
61bbf4bfcc
commit
89f643f04b
3 changed files with 31 additions and 22 deletions
|
|
@ -5,20 +5,10 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
const std::vector<MLDataType> shapeOpTypeConstraints{
|
||||
DataTypeImpl::GetTensorType<bool>(),
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<int16_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<uint16_t>()};
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Shape,
|
||||
1,
|
||||
KernelDefBuilder().TypeConstraint("T", shapeOpTypeConstraints).TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()).TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Shape);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -7,16 +7,6 @@
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
const std::vector<MLDataType> shapeOpTypeConstraints{
|
||||
DataTypeImpl::GetTensorType<bool>(),
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<int16_t>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>(),
|
||||
DataTypeImpl::GetTensorType<uint8_t>(),
|
||||
DataTypeImpl::GetTensorType<uint16_t>()};
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Shape,
|
||||
kOnnxDomain,
|
||||
|
|
@ -24,7 +14,7 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.OutputMemoryType<OrtMemTypeCPUOutput>(0)
|
||||
.TypeConstraint("T", shapeOpTypeConstraints)
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
Shape);
|
||||
|
||||
|
|
|
|||
29
onnxruntime/test/providers/cpu/tensor/shape_op_test.cc
Normal file
29
onnxruntime/test/providers/cpu/tensor/shape_op_test.cc
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
template<typename T>
|
||||
void TestShape(const std::initializer_list<T>& data, const std::vector<int64_t>& shape)
|
||||
{
|
||||
OpTester test("Shape");
|
||||
test.AddInput<T>("data", shape, data);
|
||||
test.AddOutput<int64_t>("output", {static_cast<int64_t>(shape.size())}, shape);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ShapeOpTest, ShapeTestBool) { TestShape <bool> ({true, true, false, false, true, false}, {2, 3}); }
|
||||
TEST(ShapeOpTest, ShapeTestFloat) { TestShape <float> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 6}); }
|
||||
TEST(ShapeOpTest, ShapeTestDouble) { TestShape <double> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {6, 2}); }
|
||||
TEST(ShapeOpTest, ShapeTestInt8) { TestShape <int8_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); }
|
||||
TEST(ShapeOpTest, ShapeTestInt16) { TestShape <int16_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); }
|
||||
TEST(ShapeOpTest, ShapeTestInt32) { TestShape <int32_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {4, 3}); }
|
||||
TEST(ShapeOpTest, ShapeTestInt64) { TestShape <int64_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
|
||||
TEST(ShapeOpTest, ShapeTestUint8) { TestShape <uint8_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); }
|
||||
TEST(ShapeOpTest, ShapeTestUint16) { TestShape <uint16_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
|
||||
TEST(ShapeOpTest, ShapeTestUint32) { TestShape <uint32_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); }
|
||||
TEST(ShapeOpTest, ShapeTestUint64) { TestShape <uint64_t> ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); }
|
||||
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue