From eb24617d2e195f3e1d4c419065b0bdfca6baa354 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sun, 13 Oct 2019 08:46:28 +1000 Subject: [PATCH] Add ability to get symbolic dimension info for graph inputs and outputs. (#2051) * Add ability to get symbolic dimension info for graph inputs and outputs. WIP to get initial feedback. * Fix linxu build error. Update C# API and add unit test * Clarify the two different ways Tensor shape and type info is created. One is from concrete values and one is from a type proto where symbolic dimensions may exist. Doing so allows a change to default to empty strings for the symbolic dimensions if not provided. --- .../InferenceSession.cs | 31 +++++- .../Microsoft.ML.OnnxRuntime/NativeMethods.cs | 19 ++++ .../InferenceTest.cs | 33 ++++++ .../Microsoft.ML.OnnxRuntime.Tests.csproj | 5 +- .../core/session/onnxruntime_c_api.h | 1 + .../core/session/onnxruntime_cxx_api.h | 2 + .../core/session/onnxruntime_cxx_inline.h | 4 + .../core/framework/onnxruntime_typeinfo.cc | 64 ++++++------ .../core/framework/onnxruntime_typeinfo.h | 5 +- .../core/framework/tensor_type_and_shape.cc | 97 ++++++++++-------- .../core/framework/tensor_type_and_shape.h | 3 + onnxruntime/core/session/onnxruntime_c_api.cc | 3 +- onnxruntime/core/session/ort_apis.h | 1 + onnxruntime/test/shared_lib/test_inference.cc | 30 ++++++ .../test/testdata/capi_symbolic_dims.onnx | Bin 0 -> 106 bytes .../test/testdata/capi_symbolic_dims.py | 40 ++++++++ 16 files changed, 258 insertions(+), 80 deletions(-) create mode 100644 onnxruntime/test/testdata/capi_symbolic_dims.onnx create mode 100644 onnxruntime/test/testdata/capi_symbolic_dims.py diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs index 688ea4a1f8..0310eb7db0 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs @@ -433,7 +433,7 @@ namespace Microsoft.ML.OnnxRuntime } if (valueType != OnnxValueType.ONNX_TYPE_TENSOR && valueType != OnnxValueType.ONNX_TYPE_SPARSETENSOR) { - return new NodeMetadata(valueType, new int[] { }, typeof(NamedOnnxValue)); + return new NodeMetadata(valueType, new int[] { }, new string[] { }, typeof(NamedOnnxValue)); } IntPtr tensorInfo; @@ -453,14 +453,26 @@ namespace Microsoft.ML.OnnxRuntime TensorElementTypeConverter.GetTypeAndWidth(type, out dotnetType, out width); UIntPtr numDimensions; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(tensorInfo, out numDimensions)); + long[] dimensions = new long[(int)numDimensions]; - NativeMethods.OrtGetDimensions(tensorInfo, dimensions, numDimensions); + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensions(tensorInfo, dimensions, numDimensions)); int[] intDimensions = new int[(int)numDimensions]; for (var i = 0; i < (long)numDimensions; i++) { intDimensions[i] = (int)dimensions[i]; } - return new NodeMetadata(valueType, intDimensions, dotnetType); + + IntPtr[] dimensionNamePtrs = new IntPtr[(int)numDimensions]; + NativeApiStatus.VerifySuccess( + NativeMethods.OrtGetSymbolicDimensions(tensorInfo, dimensionNamePtrs, numDimensions)); + + string[] symbolicDimensions = new string[(int)numDimensions]; + for (var i = 0; i < (int)numDimensions; i++) + { + symbolicDimensions[i] = Marshal.PtrToStringAnsi(dimensionNamePtrs[i]); //assumes charset = ANSI + } + + return new NodeMetadata(valueType, intDimensions, symbolicDimensions, dotnetType); } #endregion @@ -514,12 +526,14 @@ namespace Microsoft.ML.OnnxRuntime { private OnnxValueType _onnxValueType; private int[] _dimensions; + private string[] _symbolicDimensions; private Type _type; - internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, Type type) + internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, string[] symbolicDimensions, Type type) { _onnxValueType = onnxValueType; _dimensions = dimensions; + _symbolicDimensions = symbolicDimensions; _type = type; } @@ -538,6 +552,15 @@ namespace Microsoft.ML.OnnxRuntime return _dimensions; } } + + public string[] SymbolicDimensions + { + get + { + return _symbolicDimensions; + } + } + public System.Type ElementType { get diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs index 03729247f7..2d0bad091c 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs @@ -86,6 +86,7 @@ namespace Microsoft.ML.OnnxRuntime public IntPtr GetTensorElementType; public IntPtr GetDimensionsCount; public IntPtr GetDimensions; + public IntPtr GetSymbolicDimensions; public IntPtr GetTensorShapeElementCount; public IntPtr GetTensorTypeAndShape; public IntPtr GetTypeInfo; @@ -220,6 +221,7 @@ namespace Microsoft.ML.OnnxRuntime OrtGetTensorElementType = (DOrtGetTensorElementType)Marshal.GetDelegateForFunctionPointer(api_.GetTensorElementType, typeof(DOrtGetTensorElementType)); OrtGetDimensionsCount = (DOrtGetDimensionsCount)Marshal.GetDelegateForFunctionPointer(api_.GetDimensionsCount, typeof(DOrtGetDimensionsCount)); OrtGetDimensions = (DOrtGetDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetDimensions, typeof(DOrtGetDimensions)); + OrtGetSymbolicDimensions = (DOrtGetSymbolicDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetSymbolicDimensions, typeof(DOrtGetSymbolicDimensions)); OrtGetTensorShapeElementCount = (DOrtGetTensorShapeElementCount)Marshal.GetDelegateForFunctionPointer(api_.GetTensorShapeElementCount, typeof(DOrtGetTensorShapeElementCount)); OrtReleaseValue = (DOrtReleaseValue)Marshal.GetDelegateForFunctionPointer(api_.ReleaseValue, typeof(DOrtReleaseValue)); } @@ -627,6 +629,23 @@ namespace Microsoft.ML.OnnxRuntime UIntPtr dim_values_length); public static DOrtGetDimensions OrtGetDimensions; + /** + * Get the symbolic dimension names for dimensions with a value of -1. + * Order and number of entries is the same as values returned by GetDimensions. + * The name may be empty for an unnamed symbolic dimension. + * e.g. + * If OrtGetDimensions returns [-1, -1, 2], OrtGetSymbolicDimensions would return an array with 3 entries. + * If the values returned were ['batch', '', ''] it would indicate that + * - the first dimension was a named symbolic dimension (-1 dim value and name in symbolic dimensions), + * - the second dimension was an unnamed symbolic dimension (-1 dim value and empty string), + * - the entry for the third dimension should be ignored as it is not a symbolic dimension (dim value >= 0). + */ + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetSymbolicDimensions( + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + IntPtr[] dim_params, /* const char* values, converted to string by caller */ + UIntPtr dim_params_length); + public static DOrtGetSymbolicDimensions OrtGetSymbolicDimensions; + /** * How many elements does this tensor have. * May return a negative value diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index ee0977d76e..d77fa59bd2 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -569,6 +569,39 @@ namespace Microsoft.ML.OnnxRuntime.Tests } } + [Fact] + private void TestSymbolicDimsMetadata() + { + string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "capi_symbolic_dims.onnx"); + using (var session = new InferenceSession(modelPath)) + { + var inputs = session.InputMetadata; + var outputs = session.OutputMetadata; + + Assert.Equal(2, inputs.Count); + Assert.Equal(1, session.OutputMetadata.Count); + Assert.True(inputs.ContainsKey("A")); + Assert.True(inputs.ContainsKey("B")); + Assert.True(outputs.ContainsKey("C")); + + var inputA = inputs["A"]; + var inputB = inputs["B"]; + var outputC = outputs["C"]; + + // dimension values and any symbolic dimension info should have the same length + Assert.Equal(inputA.Dimensions.Length, inputA.SymbolicDimensions.Length); + Assert.Equal(inputB.Dimensions.Length, inputB.SymbolicDimensions.Length); + Assert.Equal(outputC.Dimensions.Length, outputC.SymbolicDimensions.Length); + + Assert.Equal(inputA.Dimensions, new int[] { -1, 2 }); + Assert.Equal(inputA.SymbolicDimensions, new string[] { "n", "" }); + Assert.Equal(inputB.Dimensions, new int[] { -1 }); + Assert.Equal(inputB.SymbolicDimensions, new string[] { "m" }); + Assert.Equal(outputC.Dimensions, new int[] { -1 }); + Assert.Equal(outputC.SymbolicDimensions, new string[] { "" }); // unnamed symbolic dim + } + } + [Fact] private void TestModelInputFloat() diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj index 9ecd184d61..05bfb0cf80 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj @@ -75,7 +75,10 @@ Always false - + + Always + false + diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2c632a6914..23acc15be8 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -468,6 +468,7 @@ struct OrtApi { OrtStatus*(ORT_API_CALL* GetTensorElementType)(_In_ const OrtTensorTypeAndShapeInfo*, _Out_ enum ONNXTensorElementDataType* out)NO_EXCEPTION; OrtStatus*(ORT_API_CALL* GetDimensionsCount)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out)NO_EXCEPTION; OrtStatus*(ORT_API_CALL* GetDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length)NO_EXCEPTION; + OrtStatus*(ORT_API_CALL* GetSymbolicDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ const char** dim_params, size_t dim_params_length)NO_EXCEPTION; /** * Return the number of elements specified by the tensor shape. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 5e57b928df..f07fc9c6af 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -200,6 +200,8 @@ struct TensorTypeAndShapeInfo : Base { size_t GetDimensionsCount() const; void GetDimensions(int64_t* values, size_t values_count) const; + void GetSymbolicDimensions(const char** values, size_t values_count) const; + std::vector GetShape() const; }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index e7569b5739..52514e6ad9 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -322,6 +322,10 @@ inline void TensorTypeAndShapeInfo::GetDimensions(int64_t* values, size_t values ThrowOnError(g_api->GetDimensions(p_, values, values_count)); } +inline void TensorTypeAndShapeInfo::GetSymbolicDimensions(const char** values, size_t values_count) const { + ThrowOnError(g_api->GetSymbolicDimensions(p_, values, values_count)); +} + inline std::vector TensorTypeAndShapeInfo::GetShape() const { std::vector out(GetDimensionsCount(), 0); GetDimensions(out.data(), out.size()); diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index 3e00d82f6c..3771e00e09 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -41,72 +41,74 @@ ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) { delete ptr; } -OrtStatus* GetTensorShapeAndType(const TensorShape* shape, const onnxruntime::DataTypeImpl* tensor_data_type, OrtTensorTypeAndShapeInfo** out); -OrtStatus* GetTensorShapeAndType(const TensorShape* shape, const ONNX_NAMESPACE::TypeProto* type_proto, OrtTensorTypeAndShapeInfo** out); +OrtStatus* GetTensorShapeAndType(const TensorShape& shape, const onnxruntime::DataTypeImpl& tensor_data_type, + OrtTensorTypeAndShapeInfo** out); +OrtStatus* GetTensorShapeAndType(const TensorShape& shape, const std::vector* dim_params, + const ONNX_NAMESPACE::TypeProto& type_proto, OrtTensorTypeAndShapeInfo** out); -OrtStatus* OrtTypeInfo::FromDataTypeImpl(const onnxruntime::DataTypeImpl* input, const TensorShape* shape, const onnxruntime::DataTypeImpl* tensor_data_type, OrtTypeInfo** out) { - if (input == nullptr) { +OrtStatus* OrtTypeInfo::FromOrtValue(const OrtValue& value, OrtTypeInfo** out) { + onnxruntime::MLDataType type = value.Type(); + if (type == nullptr) { *out = new OrtTypeInfo(ONNX_TYPE_UNKNOWN, nullptr); return nullptr; } + // GetType and GetType do not have TypeProto populated because they return a static // TensorBase/SparseTensorBase instances, but other types are real MLDataTypes and they do have real protos // unless they are primitive data types, in which case we as before return them not implemented // however, this way we can support Opaque and we can avoid excessive calls to GetType() - if (input->IsTensorType()) { + if (type->IsTensorType()) { OrtTensorTypeAndShapeInfo* info = nullptr; + const Tensor& tensor = value.Get(); + const auto* tensor_data_type = tensor.DataType(); if (tensor_data_type != nullptr) { - OrtStatus* st = GetTensorShapeAndType(shape, tensor_data_type, &info); - if (st != nullptr) return st; + OrtStatus* st = GetTensorShapeAndType(tensor.Shape(), *tensor_data_type, &info); + if (st != nullptr) + return st; } *out = new OrtTypeInfo(ONNX_TYPE_TENSOR, info); return nullptr; } - if (input->IsSparseTensorType()) { + + if (type->IsSparseTensorType()) { OrtTensorTypeAndShapeInfo* info = nullptr; + const SparseTensor& tensor = value.Get(); + const auto* tensor_data_type = tensor.Values().DataType(); if (tensor_data_type != nullptr) { - OrtStatus* st = GetTensorShapeAndType(shape, tensor_data_type, &info); + OrtStatus* st = GetTensorShapeAndType(tensor.Shape(), *tensor_data_type, &info); if (st != nullptr) return st; } *out = new OrtTypeInfo(ONNX_TYPE_SPARSETENSOR, info); return nullptr; } - const auto* type_proto = input->GetTypeProto(); + + const auto* type_proto = type->GetTypeProto(); if (type_proto != nullptr) { - // Place Opaque first as tensors will be - // mostly handled above and maps and sequences - // are not common + // Place Opaque first as tensors will be mostly handled above and maps and sequences are not common switch (type_proto->value_case()) { case on::TypeProto::kOpaqueType: { *out = new OrtTypeInfo(ONNX_TYPE_OPAQUE, nullptr); return nullptr; - } break; + } case on::TypeProto::kMapType: { *out = new OrtTypeInfo(ONNX_TYPE_MAP, nullptr); return nullptr; - } break; + } case on::TypeProto::kSequenceType: { *out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, nullptr); return nullptr; - } break; + } // Real Tensor support case on::TypeProto::kTensorType: case on::TypeProto::kSparseTensorType: { - OrtTensorTypeAndShapeInfo* info = nullptr; - OrtStatus* st = GetTensorShapeAndType(shape, type_proto, &info); - if (st != nullptr) return st; - if (type_proto->value_case() == on::TypeProto::kTensorType) { - *out = new OrtTypeInfo(ONNX_TYPE_TENSOR, info); - } else { - *out = new OrtTypeInfo(ONNX_TYPE_SPARSETENSOR, nullptr); - } - return nullptr; - } break; + return OrtApis::CreateStatus(ORT_FAIL, "Tensor types should have been handled already"); + } default: // NOT_IMPLEMENTED break; } } + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented"); } @@ -146,7 +148,7 @@ const DataTypeImpl* OrtTypeInfo::ElementTypeFromProto(int type) { } } -OrtStatus* OrtTypeInfo::FromDataTypeImpl(const ONNX_NAMESPACE::TypeProto* input, OrtTypeInfo** out) { +OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, OrtTypeInfo** out) { auto value_case = input->value_case(); switch (value_case) { case on::TypeProto::kTensorType: @@ -168,11 +170,13 @@ OrtStatus* OrtTypeInfo::FromDataTypeImpl(const ONNX_NAMESPACE::TypeProto* input, sp = &sparse_type->shape(); } } + OrtStatus* st = nullptr; OrtTensorTypeAndShapeInfo* info = nullptr; if (sp != nullptr) { const on::TensorShapeProto& s = *sp; std::vector dims(s.dim_size()); + std::vector dim_params(s.dim_size()); TensorShape shape_data(std::move(dims)); for (int i = 0; i < s.dim_size(); ++i) { auto& t = s.dim(i); @@ -181,6 +185,8 @@ OrtStatus* OrtTypeInfo::FromDataTypeImpl(const ONNX_NAMESPACE::TypeProto* input, shape_data[i] = t.dim_value(); break; case on::TensorShapeProto::Dimension::kDimParam: + dim_params[i] = t.dim_param(); + // fall through case on::TensorShapeProto::Dimension::VALUE_NOT_SET: shape_data[i] = -1; break; @@ -188,9 +194,9 @@ OrtStatus* OrtTypeInfo::FromDataTypeImpl(const ONNX_NAMESPACE::TypeProto* input, assert(false); } } - st = GetTensorShapeAndType(&shape_data, input, &info); + st = GetTensorShapeAndType(shape_data, &dim_params, *input, &info); } else { - st = GetTensorShapeAndType(nullptr, input, &info); + st = GetTensorShapeAndType(TensorShape(), nullptr, *input, &info); } if (st != nullptr) return st; *out = new OrtTypeInfo(ten_type, info); diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h index ae05f366f4..d615840dcb 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.h +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h @@ -29,9 +29,8 @@ struct OrtTypeInfo { OrtTypeInfo(const OrtTypeInfo& other) = delete; OrtTypeInfo& operator=(const OrtTypeInfo& other) = delete; - static OrtStatus* FromDataTypeImpl(const onnxruntime::DataTypeImpl* input, const onnxruntime::TensorShape* shape, - const onnxruntime::DataTypeImpl* tensor_data_type, OrtTypeInfo** out); - static OrtStatus* FromDataTypeImpl(const ONNX_NAMESPACE::TypeProto*, OrtTypeInfo** out); + static OrtStatus* FromOrtValue(const OrtValue& value, OrtTypeInfo** out); + static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtTypeInfo** out); static const onnxruntime::DataTypeImpl* ElementTypeFromProto(int type); diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index 191f0d6741..c7c81a5275 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -61,6 +61,15 @@ ORT_API_STATUS_IMPL(OrtApis::GetDimensions, _In_ const struct OrtTensorTypeAndSh return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, _In_ const struct OrtTensorTypeAndShapeInfo* info, + _Out_ const char** names, size_t dim_params_length) { + for (size_t idx = 0, end = std::min(info->dim_params.size(), dim_params_length); idx < end; ++idx) { + names[idx] = info->dim_params[idx].c_str(); + } + + return nullptr; +} + ORT_API_STATUS_IMPL(OrtApis::GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* this_ptr, _Out_ size_t* out) { *out = static_cast(this_ptr->shape.Size()); return nullptr; @@ -159,7 +168,8 @@ ONNXTensorElementDataType TensorDataTypeToOnnxRuntimeTensorElementDataType( return type; } -OrtStatus* GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, const onnxruntime::TensorShape* shape, OrtTensorTypeAndShapeInfo** out) { +OrtStatus* GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, const onnxruntime::TensorShape shape, + const std::vector* dim_params, OrtTensorTypeAndShapeInfo** out) { OrtTensorTypeAndShapeInfo* ret; if (auto* status = OrtApis::CreateTensorTypeAndShapeInfo(&ret)) return status; @@ -167,37 +177,47 @@ OrtStatus* GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, const onn OrtApis::ReleaseTensorTypeAndShapeInfo(ret); return status; } - if (shape != nullptr) { - auto* status = OrtApis::SetDimensions(ret, shape->GetDims().data(), shape->GetDims().size()); - if (status != nullptr) { - OrtApis::ReleaseTensorTypeAndShapeInfo(ret); - return status; - } + + auto* status = OrtApis::SetDimensions(ret, shape.GetDims().data(), shape.GetDims().size()); + if (status != nullptr) { + OrtApis::ReleaseTensorTypeAndShapeInfo(ret); + return status; } + + if (dim_params != nullptr) { + ret->dim_params = *dim_params; + } else { + // we expect to be being called with a concrete shape so validate that + assert(shape.Size() >= 0); + ret->dim_params.resize(shape.NumDimensions(), ""); + } + *out = ret; return nullptr; } -OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape* shape, - const onnxruntime::DataTypeImpl* tensor_data_type, OrtTensorTypeAndShapeInfo** out) { - ONNXTensorElementDataType type = MLDataTypeToOnnxRuntimeTensorElementDataType(tensor_data_type); +OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape& shape, + const onnxruntime::DataTypeImpl& tensor_data_type, OrtTensorTypeAndShapeInfo** out) { + ONNXTensorElementDataType type = MLDataTypeToOnnxRuntimeTensorElementDataType(&tensor_data_type); if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) { return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Not implemented"); } - return GetTensorShapeAndTypeHelper(type, shape, out); + return GetTensorShapeAndTypeHelper(type, shape, nullptr, out); } -OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape* shape, - const ONNX_NAMESPACE::TypeProto* type_proto, OrtTensorTypeAndShapeInfo** out) { - assert(type_proto != nullptr); - auto value_case = type_proto->value_case(); - assert(value_case == ONNX_NAMESPACE::TypeProto::kTensorType || value_case == ONNX_NAMESPACE::TypeProto::kSparseTensorType); - auto dtype = (value_case == ONNX_NAMESPACE::TypeProto::kTensorType) ? type_proto->tensor_type().elem_type() : type_proto->sparse_tensor_type().elem_type(); +OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape& shape, const std::vector* dim_params, + const ONNX_NAMESPACE::TypeProto& type_proto, OrtTensorTypeAndShapeInfo** out) { + auto value_case = type_proto.value_case(); + assert(value_case == ONNX_NAMESPACE::TypeProto::kTensorType || + value_case == ONNX_NAMESPACE::TypeProto::kSparseTensorType); + + auto dtype = (value_case == ONNX_NAMESPACE::TypeProto::kTensorType) ? type_proto.tensor_type().elem_type() + : type_proto.sparse_tensor_type().elem_type(); ONNXTensorElementDataType type = TensorDataTypeToOnnxRuntimeTensorElementDataType(dtype); if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) { return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Not implemented"); } - return GetTensorShapeAndTypeHelper(type, shape, out); + return GetTensorShapeAndTypeHelper(type, shape, dim_params, out); } ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out_ OrtTensorTypeAndShapeInfo** out) { @@ -216,7 +236,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out shape = &tensor.Shape(); data_type = tensor.Values().DataType(); } - return GetTensorShapeAndType(shape, data_type, out); + return GetTensorShapeAndType(*shape, *data_type, out); } else { ORT_THROW("Argument is not a tensor"); } @@ -225,10 +245,11 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out ORT_API_STATUS_IMPL(OrtApis::GetValueType, _In_ const OrtValue* v, _Out_ ONNXType* out) { API_IMPL_BEGIN - onnxruntime::MLDataType type = v->Type(); OrtTypeInfo* type_info; - if (auto status = OrtTypeInfo::FromDataTypeImpl(type, nullptr, nullptr, &type_info)) + auto status = OrtTypeInfo::FromOrtValue(*v, &type_info); + if (status != nullptr) return status; + *out = type_info->type; OrtApis::ReleaseTypeInfo(type_info); return nullptr; @@ -236,29 +257,21 @@ ORT_API_STATUS_IMPL(OrtApis::GetValueType, _In_ const OrtValue* v, _Out_ ONNXTyp } /** - * Get the type information of an OrtValue - * \param value - * \return The returned value should be freed by OrtReleaseTypeInfo after use - */ +* Get the type information of an OrtValue +* \param value +* \return The returned value should be freed by OrtReleaseTypeInfo after use +*/ ORT_API_STATUS_IMPL(OrtApis::GetTypeInfo, _In_ const OrtValue* v, struct OrtTypeInfo** out) { - onnxruntime::MLDataType type = v->Type(); - if (type == nullptr) { + API_IMPL_BEGIN + // TODO: This is consistent with the previous implementation but inconsistent with GetValueType which returns + // ONNX_TYPE_UNKNOWN if v->Type() is null. Should we instead just call OrtTypeInfo::FromOrtValue and + // return an OrtTypeInfo value in 'out' with type set to ONNX_TYPE_UNKNOWN? Or is the inconsistency fine? + if (v->Type() == nullptr) { *out = nullptr; return nullptr; } - if (type->IsTensorType() || type->IsSparseTensorType()) { - const onnxruntime::TensorShape* shape = nullptr; - onnxruntime::MLDataType data_type = nullptr; - if (type->IsTensorType()) { - const Tensor& tensor = v->Get(); - shape = &tensor.Shape(); - data_type = tensor.DataType(); - } else { - const SparseTensor& tensor = v->Get(); - shape = &tensor.Shape(); - data_type = tensor.Values().DataType(); - } - return OrtTypeInfo::FromDataTypeImpl(type, shape, data_type, out); - } - return OrtTypeInfo::FromDataTypeImpl(type, nullptr, nullptr, out); + + auto status = OrtTypeInfo::FromOrtValue(*v, out); + return status; + API_IMPL_END } diff --git a/onnxruntime/core/framework/tensor_type_and_shape.h b/onnxruntime/core/framework/tensor_type_and_shape.h index 9c829215b9..28431a9d61 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.h +++ b/onnxruntime/core/framework/tensor_type_and_shape.h @@ -6,6 +6,9 @@ struct OrtTensorTypeAndShapeInfo { public: ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; onnxruntime::TensorShape shape; + // dim_param values. empty string if dim_value or no dim_param was specified. + // one entry per dimension in shape. only guaranteed to be populated for graph inputs and outputs + std::vector dim_params; OrtTensorTypeAndShapeInfo() = default; OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = delete; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index a52184fb9a..41456bbb6f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -599,7 +599,7 @@ static OrtStatus* GetNodeDefTypeInfoHelper(const OrtSession* sess, GetDefListFn if (p.second->size() <= index) return OrtApis::CreateStatus(ORT_FAIL, "out of index"); const ONNX_NAMESPACE::TypeProto* type_proto = (*p.second)[index]->TypeAsProto(); - return OrtTypeInfo::FromDataTypeImpl(type_proto, out); + return OrtTypeInfo::FromTypeProto(type_proto, out); API_IMPL_END } @@ -1310,6 +1310,7 @@ static constexpr OrtApi ort_api_1 = { &OrtApis::GetTensorElementType, &OrtApis::GetDimensionsCount, &OrtApis::GetDimensions, + &OrtApis::GetSymbolicDimensions, &OrtApis::GetTensorShapeElementCount, &OrtApis::GetTensorTypeAndShape, &OrtApis::GetTypeInfo, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 07e09fdcb7..e2fbcbde82 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -104,6 +104,7 @@ ORT_API_STATUS_IMPL(SetDimensions, OrtTensorTypeAndShapeInfo* info, _In_ const i ORT_API_STATUS_IMPL(GetTensorElementType, _In_ const OrtTensorTypeAndShapeInfo*, _Out_ enum ONNXTensorElementDataType* out); ORT_API_STATUS_IMPL(GetDimensionsCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); ORT_API_STATUS_IMPL(GetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length); +ORT_API_STATUS_IMPL(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ const char* dim_params[], size_t dim_params_length); ORT_API_STATUS_IMPL(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); ORT_API_STATUS_IMPL(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out); ORT_API_STATUS_IMPL(GetTypeInfo, _In_ const OrtValue* value, _Outptr_ OrtTypeInfo** out); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 318b6d4048..70d7f81157 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -120,6 +120,8 @@ void TestInference(Ort::Env& env, T model_uri, static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/mul_1.onnx"); static constexpr PATH_TYPE CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_1.onnx"); static constexpr PATH_TYPE OVERRIDABLE_INITIALIZER_MODEL_URI = TSTR("testdata/overridable_initializer.onnx"); +static constexpr PATH_TYPE NAMED_AND_ANON_DIM_PARAM_URI = TSTR("testdata/capi_symbolic_dims.onnx"); + #ifdef ENABLE_LANGUAGE_INTEROP_OPS static constexpr PATH_TYPE PYOP_FLOAT_MODEL_URI = TSTR("testdata/pyop_1.onnx"); #endif @@ -145,6 +147,34 @@ TEST_P(CApiTestWithProvider, simple) { TestInference(env_, MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, GetParam(), nullptr); } +TEST_F(CApiTest, dim_param) { + Ort::SessionOptions session_options; + Ort::Session session(env_, NAMED_AND_ANON_DIM_PARAM_URI, session_options); + + auto in0 = session.GetInputTypeInfo(0); + auto in0_ttsi = in0.GetTensorTypeAndShapeInfo(); + + auto num_input_dims = in0_ttsi.GetDimensionsCount(); + ASSERT_GE(num_input_dims, 1); + // reading 1st dimension only so don't need to malloc int64_t* or const char** values for the Get*Dimensions calls + int64_t dim_value = 0; + const char* dim_param = nullptr; + in0_ttsi.GetDimensions(&dim_value, 1); + in0_ttsi.GetSymbolicDimensions(&dim_param, 1); + ASSERT_EQ(dim_value, -1) << "symbolic dimension should be -1"; + ASSERT_EQ(strcmp(dim_param, "n"), 0) << "Expected 'n'. Got: " << dim_param; + + auto out0 = session.GetOutputTypeInfo(0); + auto out0_ttsi = out0.GetTensorTypeAndShapeInfo(); + auto num_output_dims = out0_ttsi.GetDimensionsCount(); + ASSERT_EQ(num_output_dims, 1); + + out0_ttsi.GetDimensions(&dim_value, 1); + out0_ttsi.GetSymbolicDimensions(&dim_param, 1); + ASSERT_EQ(dim_value, -1) << "symbolic dimension should be -1"; + ASSERT_EQ(strcmp(dim_param, ""), 0); +} + INSTANTIATE_TEST_CASE_P(CApiTestWithProviders, CApiTestWithProvider, ::testing::Values(0, 1, 2, 3, 4)); diff --git a/onnxruntime/test/testdata/capi_symbolic_dims.onnx b/onnxruntime/test/testdata/capi_symbolic_dims.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e9f8d1666b880f8feb8d20d688ea74dcb9986c8b GIT binary patch literal 106 zcmd;Jvr6ES=3;c@VssK>be3W-N-fSvEJ#&i4}vg+xJpusOLTMdQ&MxHM1Tr~__=sF y7=<{wn1vYgxR^MYq69#Cxw$ww*o9a@qPa=DT#U{_99(Qbbxd3gPApsu0^9&b$`GCa literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/capi_symbolic_dims.py b/onnxruntime/test/testdata/capi_symbolic_dims.py new file mode 100644 index 0000000000..d3af986534 --- /dev/null +++ b/onnxruntime/test/testdata/capi_symbolic_dims.py @@ -0,0 +1,40 @@ +import onnx +from onnx import helper +from onnx import TensorProto +from onnx import shape_inference + +# create output with rank but unnamed symbolic dim +output = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1]) +output.type.tensor_type.shape.Clear() +dim = output.type.tensor_type.shape.dim.add() +print(dim) + +graph_def = helper.make_graph( + nodes = [ + helper.make_node(op_type = "Reshape", inputs = ['A', 'B'], outputs = ['C'], name = 'reshape'), + ], + name = 'test-model', + inputs = [ + # create inputs with symbolic dims + helper.make_tensor_value_info("A", TensorProto.FLOAT, ['n', 2]), + helper.make_tensor_value_info("B", TensorProto.INT64, ['m']), + ], + outputs = [ + output + ], + initializer = [ + ] +) + +model = helper.make_model(graph_def, opset_imports=[helper.make_operatorsetid("", 11)]) +onnx.checker.check_model(model) + +inferred_model = shape_inference.infer_shapes(model) +onnx.checker.check_model(inferred_model) + +onnx.save_model(model, "capi_symbolic_dims.onnx") + +import onnxruntime as rt +sess = rt.InferenceSession("capi_symbolic_dims.onnx") +print([i.shape for i in sess.get_inputs()]) +print([i.shape for i in sess.get_outputs()])