diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs
index 688ea4a1f8..0310eb7db0 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs
@@ -433,7 +433,7 @@ namespace Microsoft.ML.OnnxRuntime
}
if (valueType != OnnxValueType.ONNX_TYPE_TENSOR && valueType != OnnxValueType.ONNX_TYPE_SPARSETENSOR)
{
- return new NodeMetadata(valueType, new int[] { }, typeof(NamedOnnxValue));
+ return new NodeMetadata(valueType, new int[] { }, new string[] { }, typeof(NamedOnnxValue));
}
IntPtr tensorInfo;
@@ -453,14 +453,26 @@ namespace Microsoft.ML.OnnxRuntime
TensorElementTypeConverter.GetTypeAndWidth(type, out dotnetType, out width);
UIntPtr numDimensions;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(tensorInfo, out numDimensions));
+
long[] dimensions = new long[(int)numDimensions];
- NativeMethods.OrtGetDimensions(tensorInfo, dimensions, numDimensions);
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensions(tensorInfo, dimensions, numDimensions));
int[] intDimensions = new int[(int)numDimensions];
for (var i = 0; i < (long)numDimensions; i++)
{
intDimensions[i] = (int)dimensions[i];
}
- return new NodeMetadata(valueType, intDimensions, dotnetType);
+
+ IntPtr[] dimensionNamePtrs = new IntPtr[(int)numDimensions];
+ NativeApiStatus.VerifySuccess(
+ NativeMethods.OrtGetSymbolicDimensions(tensorInfo, dimensionNamePtrs, numDimensions));
+
+ string[] symbolicDimensions = new string[(int)numDimensions];
+ for (var i = 0; i < (int)numDimensions; i++)
+ {
+ symbolicDimensions[i] = Marshal.PtrToStringAnsi(dimensionNamePtrs[i]); //assumes charset = ANSI
+ }
+
+ return new NodeMetadata(valueType, intDimensions, symbolicDimensions, dotnetType);
}
#endregion
@@ -514,12 +526,14 @@ namespace Microsoft.ML.OnnxRuntime
{
private OnnxValueType _onnxValueType;
private int[] _dimensions;
+ private string[] _symbolicDimensions;
private Type _type;
- internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, Type type)
+ internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, string[] symbolicDimensions, Type type)
{
_onnxValueType = onnxValueType;
_dimensions = dimensions;
+ _symbolicDimensions = symbolicDimensions;
_type = type;
}
@@ -538,6 +552,15 @@ namespace Microsoft.ML.OnnxRuntime
return _dimensions;
}
}
+
+ public string[] SymbolicDimensions
+ {
+ get
+ {
+ return _symbolicDimensions;
+ }
+ }
+
public System.Type ElementType
{
get
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
index 03729247f7..2d0bad091c 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
@@ -86,6 +86,7 @@ namespace Microsoft.ML.OnnxRuntime
public IntPtr GetTensorElementType;
public IntPtr GetDimensionsCount;
public IntPtr GetDimensions;
+ public IntPtr GetSymbolicDimensions;
public IntPtr GetTensorShapeElementCount;
public IntPtr GetTensorTypeAndShape;
public IntPtr GetTypeInfo;
@@ -220,6 +221,7 @@ namespace Microsoft.ML.OnnxRuntime
OrtGetTensorElementType = (DOrtGetTensorElementType)Marshal.GetDelegateForFunctionPointer(api_.GetTensorElementType, typeof(DOrtGetTensorElementType));
OrtGetDimensionsCount = (DOrtGetDimensionsCount)Marshal.GetDelegateForFunctionPointer(api_.GetDimensionsCount, typeof(DOrtGetDimensionsCount));
OrtGetDimensions = (DOrtGetDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetDimensions, typeof(DOrtGetDimensions));
+ OrtGetSymbolicDimensions = (DOrtGetSymbolicDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetSymbolicDimensions, typeof(DOrtGetSymbolicDimensions));
OrtGetTensorShapeElementCount = (DOrtGetTensorShapeElementCount)Marshal.GetDelegateForFunctionPointer(api_.GetTensorShapeElementCount, typeof(DOrtGetTensorShapeElementCount));
OrtReleaseValue = (DOrtReleaseValue)Marshal.GetDelegateForFunctionPointer(api_.ReleaseValue, typeof(DOrtReleaseValue));
}
@@ -627,6 +629,23 @@ namespace Microsoft.ML.OnnxRuntime
UIntPtr dim_values_length);
public static DOrtGetDimensions OrtGetDimensions;
+ /**
+ * Get the symbolic dimension names for dimensions with a value of -1.
+ * Order and number of entries is the same as values returned by GetDimensions.
+ * The name may be empty for an unnamed symbolic dimension.
+ * e.g.
+ * If OrtGetDimensions returns [-1, -1, 2], OrtGetSymbolicDimensions would return an array with 3 entries.
+ * If the values returned were ['batch', '', ''] it would indicate that
+ * - the first dimension was a named symbolic dimension (-1 dim value and name in symbolic dimensions),
+ * - the second dimension was an unnamed symbolic dimension (-1 dim value and empty string),
+ * - the entry for the third dimension should be ignored as it is not a symbolic dimension (dim value >= 0).
+ */
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtGetSymbolicDimensions(
+ IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo,
+ IntPtr[] dim_params, /* const char* values, converted to string by caller */
+ UIntPtr dim_params_length);
+ public static DOrtGetSymbolicDimensions OrtGetSymbolicDimensions;
+
/**
* How many elements does this tensor have.
* May return a negative value
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
index ee0977d76e..d77fa59bd2 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
@@ -569,6 +569,39 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
+ [Fact]
+ private void TestSymbolicDimsMetadata()
+ {
+ string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "capi_symbolic_dims.onnx");
+ using (var session = new InferenceSession(modelPath))
+ {
+ var inputs = session.InputMetadata;
+ var outputs = session.OutputMetadata;
+
+ Assert.Equal(2, inputs.Count);
+ Assert.Equal(1, session.OutputMetadata.Count);
+ Assert.True(inputs.ContainsKey("A"));
+ Assert.True(inputs.ContainsKey("B"));
+ Assert.True(outputs.ContainsKey("C"));
+
+ var inputA = inputs["A"];
+ var inputB = inputs["B"];
+ var outputC = outputs["C"];
+
+ // dimension values and any symbolic dimension info should have the same length
+ Assert.Equal(inputA.Dimensions.Length, inputA.SymbolicDimensions.Length);
+ Assert.Equal(inputB.Dimensions.Length, inputB.SymbolicDimensions.Length);
+ Assert.Equal(outputC.Dimensions.Length, outputC.SymbolicDimensions.Length);
+
+ Assert.Equal(inputA.Dimensions, new int[] { -1, 2 });
+ Assert.Equal(inputA.SymbolicDimensions, new string[] { "n", "" });
+ Assert.Equal(inputB.Dimensions, new int[] { -1 });
+ Assert.Equal(inputB.SymbolicDimensions, new string[] { "m" });
+ Assert.Equal(outputC.Dimensions, new int[] { -1 });
+ Assert.Equal(outputC.SymbolicDimensions, new string[] { "" }); // unnamed symbolic dim
+ }
+ }
+
[Fact]
private void TestModelInputFloat()
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj
index 9ecd184d61..05bfb0cf80 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/Microsoft.ML.OnnxRuntime.Tests.csproj
@@ -75,7 +75,10 @@
Always
false
-
+
+ Always
+ false
+
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 2c632a6914..23acc15be8 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -468,6 +468,7 @@ struct OrtApi {
OrtStatus*(ORT_API_CALL* GetTensorElementType)(_In_ const OrtTensorTypeAndShapeInfo*, _Out_ enum ONNXTensorElementDataType* out)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* GetDimensionsCount)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* GetDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length)NO_EXCEPTION;
+ OrtStatus*(ORT_API_CALL* GetSymbolicDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ const char** dim_params, size_t dim_params_length)NO_EXCEPTION;
/**
* Return the number of elements specified by the tensor shape.
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index 5e57b928df..f07fc9c6af 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -200,6 +200,8 @@ struct TensorTypeAndShapeInfo : Base {
size_t GetDimensionsCount() const;
void GetDimensions(int64_t* values, size_t values_count) const;
+ void GetSymbolicDimensions(const char** values, size_t values_count) const;
+
std::vector GetShape() const;
};
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index e7569b5739..52514e6ad9 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -322,6 +322,10 @@ inline void TensorTypeAndShapeInfo::GetDimensions(int64_t* values, size_t values
ThrowOnError(g_api->GetDimensions(p_, values, values_count));
}
+inline void TensorTypeAndShapeInfo::GetSymbolicDimensions(const char** values, size_t values_count) const {
+ ThrowOnError(g_api->GetSymbolicDimensions(p_, values, values_count));
+}
+
inline std::vector TensorTypeAndShapeInfo::GetShape() const {
std::vector out(GetDimensionsCount(), 0);
GetDimensions(out.data(), out.size());
diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc
index 3e00d82f6c..3771e00e09 100644
--- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc
+++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc
@@ -41,72 +41,74 @@ ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) {
delete ptr;
}
-OrtStatus* GetTensorShapeAndType(const TensorShape* shape, const onnxruntime::DataTypeImpl* tensor_data_type, OrtTensorTypeAndShapeInfo** out);
-OrtStatus* GetTensorShapeAndType(const TensorShape* shape, const ONNX_NAMESPACE::TypeProto* type_proto, OrtTensorTypeAndShapeInfo** out);
+OrtStatus* GetTensorShapeAndType(const TensorShape& shape, const onnxruntime::DataTypeImpl& tensor_data_type,
+ OrtTensorTypeAndShapeInfo** out);
+OrtStatus* GetTensorShapeAndType(const TensorShape& shape, const std::vector* dim_params,
+ const ONNX_NAMESPACE::TypeProto& type_proto, OrtTensorTypeAndShapeInfo** out);
-OrtStatus* OrtTypeInfo::FromDataTypeImpl(const onnxruntime::DataTypeImpl* input, const TensorShape* shape, const onnxruntime::DataTypeImpl* tensor_data_type, OrtTypeInfo** out) {
- if (input == nullptr) {
+OrtStatus* OrtTypeInfo::FromOrtValue(const OrtValue& value, OrtTypeInfo** out) {
+ onnxruntime::MLDataType type = value.Type();
+ if (type == nullptr) {
*out = new OrtTypeInfo(ONNX_TYPE_UNKNOWN, nullptr);
return nullptr;
}
+
// GetType and GetType do not have TypeProto populated because they return a static
// TensorBase/SparseTensorBase instances, but other types are real MLDataTypes and they do have real protos
// unless they are primitive data types, in which case we as before return them not implemented
// however, this way we can support Opaque and we can avoid excessive calls to GetType()
- if (input->IsTensorType()) {
+ if (type->IsTensorType()) {
OrtTensorTypeAndShapeInfo* info = nullptr;
+ const Tensor& tensor = value.Get();
+ const auto* tensor_data_type = tensor.DataType();
if (tensor_data_type != nullptr) {
- OrtStatus* st = GetTensorShapeAndType(shape, tensor_data_type, &info);
- if (st != nullptr) return st;
+ OrtStatus* st = GetTensorShapeAndType(tensor.Shape(), *tensor_data_type, &info);
+ if (st != nullptr)
+ return st;
}
*out = new OrtTypeInfo(ONNX_TYPE_TENSOR, info);
return nullptr;
}
- if (input->IsSparseTensorType()) {
+
+ if (type->IsSparseTensorType()) {
OrtTensorTypeAndShapeInfo* info = nullptr;
+ const SparseTensor& tensor = value.Get();
+ const auto* tensor_data_type = tensor.Values().DataType();
if (tensor_data_type != nullptr) {
- OrtStatus* st = GetTensorShapeAndType(shape, tensor_data_type, &info);
+ OrtStatus* st = GetTensorShapeAndType(tensor.Shape(), *tensor_data_type, &info);
if (st != nullptr) return st;
}
*out = new OrtTypeInfo(ONNX_TYPE_SPARSETENSOR, info);
return nullptr;
}
- const auto* type_proto = input->GetTypeProto();
+
+ const auto* type_proto = type->GetTypeProto();
if (type_proto != nullptr) {
- // Place Opaque first as tensors will be
- // mostly handled above and maps and sequences
- // are not common
+ // Place Opaque first as tensors will be mostly handled above and maps and sequences are not common
switch (type_proto->value_case()) {
case on::TypeProto::kOpaqueType: {
*out = new OrtTypeInfo(ONNX_TYPE_OPAQUE, nullptr);
return nullptr;
- } break;
+ }
case on::TypeProto::kMapType: {
*out = new OrtTypeInfo(ONNX_TYPE_MAP, nullptr);
return nullptr;
- } break;
+ }
case on::TypeProto::kSequenceType: {
*out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, nullptr);
return nullptr;
- } break;
+ }
// Real Tensor support
case on::TypeProto::kTensorType:
case on::TypeProto::kSparseTensorType: {
- OrtTensorTypeAndShapeInfo* info = nullptr;
- OrtStatus* st = GetTensorShapeAndType(shape, type_proto, &info);
- if (st != nullptr) return st;
- if (type_proto->value_case() == on::TypeProto::kTensorType) {
- *out = new OrtTypeInfo(ONNX_TYPE_TENSOR, info);
- } else {
- *out = new OrtTypeInfo(ONNX_TYPE_SPARSETENSOR, nullptr);
- }
- return nullptr;
- } break;
+ return OrtApis::CreateStatus(ORT_FAIL, "Tensor types should have been handled already");
+ }
default:
// NOT_IMPLEMENTED
break;
}
}
+
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented");
}
@@ -146,7 +148,7 @@ const DataTypeImpl* OrtTypeInfo::ElementTypeFromProto(int type) {
}
}
-OrtStatus* OrtTypeInfo::FromDataTypeImpl(const ONNX_NAMESPACE::TypeProto* input, OrtTypeInfo** out) {
+OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, OrtTypeInfo** out) {
auto value_case = input->value_case();
switch (value_case) {
case on::TypeProto::kTensorType:
@@ -168,11 +170,13 @@ OrtStatus* OrtTypeInfo::FromDataTypeImpl(const ONNX_NAMESPACE::TypeProto* input,
sp = &sparse_type->shape();
}
}
+
OrtStatus* st = nullptr;
OrtTensorTypeAndShapeInfo* info = nullptr;
if (sp != nullptr) {
const on::TensorShapeProto& s = *sp;
std::vector dims(s.dim_size());
+ std::vector dim_params(s.dim_size());
TensorShape shape_data(std::move(dims));
for (int i = 0; i < s.dim_size(); ++i) {
auto& t = s.dim(i);
@@ -181,6 +185,8 @@ OrtStatus* OrtTypeInfo::FromDataTypeImpl(const ONNX_NAMESPACE::TypeProto* input,
shape_data[i] = t.dim_value();
break;
case on::TensorShapeProto::Dimension::kDimParam:
+ dim_params[i] = t.dim_param();
+ // fall through
case on::TensorShapeProto::Dimension::VALUE_NOT_SET:
shape_data[i] = -1;
break;
@@ -188,9 +194,9 @@ OrtStatus* OrtTypeInfo::FromDataTypeImpl(const ONNX_NAMESPACE::TypeProto* input,
assert(false);
}
}
- st = GetTensorShapeAndType(&shape_data, input, &info);
+ st = GetTensorShapeAndType(shape_data, &dim_params, *input, &info);
} else {
- st = GetTensorShapeAndType(nullptr, input, &info);
+ st = GetTensorShapeAndType(TensorShape(), nullptr, *input, &info);
}
if (st != nullptr) return st;
*out = new OrtTypeInfo(ten_type, info);
diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h
index ae05f366f4..d615840dcb 100644
--- a/onnxruntime/core/framework/onnxruntime_typeinfo.h
+++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h
@@ -29,9 +29,8 @@ struct OrtTypeInfo {
OrtTypeInfo(const OrtTypeInfo& other) = delete;
OrtTypeInfo& operator=(const OrtTypeInfo& other) = delete;
- static OrtStatus* FromDataTypeImpl(const onnxruntime::DataTypeImpl* input, const onnxruntime::TensorShape* shape,
- const onnxruntime::DataTypeImpl* tensor_data_type, OrtTypeInfo** out);
- static OrtStatus* FromDataTypeImpl(const ONNX_NAMESPACE::TypeProto*, OrtTypeInfo** out);
+ static OrtStatus* FromOrtValue(const OrtValue& value, OrtTypeInfo** out);
+ static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtTypeInfo** out);
static const onnxruntime::DataTypeImpl* ElementTypeFromProto(int type);
diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc
index 191f0d6741..c7c81a5275 100644
--- a/onnxruntime/core/framework/tensor_type_and_shape.cc
+++ b/onnxruntime/core/framework/tensor_type_and_shape.cc
@@ -61,6 +61,15 @@ ORT_API_STATUS_IMPL(OrtApis::GetDimensions, _In_ const struct OrtTensorTypeAndSh
return nullptr;
}
+ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, _In_ const struct OrtTensorTypeAndShapeInfo* info,
+ _Out_ const char** names, size_t dim_params_length) {
+ for (size_t idx = 0, end = std::min(info->dim_params.size(), dim_params_length); idx < end; ++idx) {
+ names[idx] = info->dim_params[idx].c_str();
+ }
+
+ return nullptr;
+}
+
ORT_API_STATUS_IMPL(OrtApis::GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* this_ptr, _Out_ size_t* out) {
*out = static_cast(this_ptr->shape.Size());
return nullptr;
@@ -159,7 +168,8 @@ ONNXTensorElementDataType TensorDataTypeToOnnxRuntimeTensorElementDataType(
return type;
}
-OrtStatus* GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, const onnxruntime::TensorShape* shape, OrtTensorTypeAndShapeInfo** out) {
+OrtStatus* GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, const onnxruntime::TensorShape shape,
+ const std::vector* dim_params, OrtTensorTypeAndShapeInfo** out) {
OrtTensorTypeAndShapeInfo* ret;
if (auto* status = OrtApis::CreateTensorTypeAndShapeInfo(&ret))
return status;
@@ -167,37 +177,47 @@ OrtStatus* GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, const onn
OrtApis::ReleaseTensorTypeAndShapeInfo(ret);
return status;
}
- if (shape != nullptr) {
- auto* status = OrtApis::SetDimensions(ret, shape->GetDims().data(), shape->GetDims().size());
- if (status != nullptr) {
- OrtApis::ReleaseTensorTypeAndShapeInfo(ret);
- return status;
- }
+
+ auto* status = OrtApis::SetDimensions(ret, shape.GetDims().data(), shape.GetDims().size());
+ if (status != nullptr) {
+ OrtApis::ReleaseTensorTypeAndShapeInfo(ret);
+ return status;
}
+
+ if (dim_params != nullptr) {
+ ret->dim_params = *dim_params;
+ } else {
+ // we expect to be being called with a concrete shape so validate that
+ assert(shape.Size() >= 0);
+ ret->dim_params.resize(shape.NumDimensions(), "");
+ }
+
*out = ret;
return nullptr;
}
-OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape* shape,
- const onnxruntime::DataTypeImpl* tensor_data_type, OrtTensorTypeAndShapeInfo** out) {
- ONNXTensorElementDataType type = MLDataTypeToOnnxRuntimeTensorElementDataType(tensor_data_type);
+OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape& shape,
+ const onnxruntime::DataTypeImpl& tensor_data_type, OrtTensorTypeAndShapeInfo** out) {
+ ONNXTensorElementDataType type = MLDataTypeToOnnxRuntimeTensorElementDataType(&tensor_data_type);
if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) {
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Not implemented");
}
- return GetTensorShapeAndTypeHelper(type, shape, out);
+ return GetTensorShapeAndTypeHelper(type, shape, nullptr, out);
}
-OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape* shape,
- const ONNX_NAMESPACE::TypeProto* type_proto, OrtTensorTypeAndShapeInfo** out) {
- assert(type_proto != nullptr);
- auto value_case = type_proto->value_case();
- assert(value_case == ONNX_NAMESPACE::TypeProto::kTensorType || value_case == ONNX_NAMESPACE::TypeProto::kSparseTensorType);
- auto dtype = (value_case == ONNX_NAMESPACE::TypeProto::kTensorType) ? type_proto->tensor_type().elem_type() : type_proto->sparse_tensor_type().elem_type();
+OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape& shape, const std::vector* dim_params,
+ const ONNX_NAMESPACE::TypeProto& type_proto, OrtTensorTypeAndShapeInfo** out) {
+ auto value_case = type_proto.value_case();
+ assert(value_case == ONNX_NAMESPACE::TypeProto::kTensorType ||
+ value_case == ONNX_NAMESPACE::TypeProto::kSparseTensorType);
+
+ auto dtype = (value_case == ONNX_NAMESPACE::TypeProto::kTensorType) ? type_proto.tensor_type().elem_type()
+ : type_proto.sparse_tensor_type().elem_type();
ONNXTensorElementDataType type = TensorDataTypeToOnnxRuntimeTensorElementDataType(dtype);
if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) {
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Not implemented");
}
- return GetTensorShapeAndTypeHelper(type, shape, out);
+ return GetTensorShapeAndTypeHelper(type, shape, dim_params, out);
}
ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out_ OrtTensorTypeAndShapeInfo** out) {
@@ -216,7 +236,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out
shape = &tensor.Shape();
data_type = tensor.Values().DataType();
}
- return GetTensorShapeAndType(shape, data_type, out);
+ return GetTensorShapeAndType(*shape, *data_type, out);
} else {
ORT_THROW("Argument is not a tensor");
}
@@ -225,10 +245,11 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out
ORT_API_STATUS_IMPL(OrtApis::GetValueType, _In_ const OrtValue* v, _Out_ ONNXType* out) {
API_IMPL_BEGIN
- onnxruntime::MLDataType type = v->Type();
OrtTypeInfo* type_info;
- if (auto status = OrtTypeInfo::FromDataTypeImpl(type, nullptr, nullptr, &type_info))
+ auto status = OrtTypeInfo::FromOrtValue(*v, &type_info);
+ if (status != nullptr)
return status;
+
*out = type_info->type;
OrtApis::ReleaseTypeInfo(type_info);
return nullptr;
@@ -236,29 +257,21 @@ ORT_API_STATUS_IMPL(OrtApis::GetValueType, _In_ const OrtValue* v, _Out_ ONNXTyp
}
/**
- * Get the type information of an OrtValue
- * \param value
- * \return The returned value should be freed by OrtReleaseTypeInfo after use
- */
+* Get the type information of an OrtValue
+* \param value
+* \return The returned value should be freed by OrtReleaseTypeInfo after use
+*/
ORT_API_STATUS_IMPL(OrtApis::GetTypeInfo, _In_ const OrtValue* v, struct OrtTypeInfo** out) {
- onnxruntime::MLDataType type = v->Type();
- if (type == nullptr) {
+ API_IMPL_BEGIN
+ // TODO: This is consistent with the previous implementation but inconsistent with GetValueType which returns
+ // ONNX_TYPE_UNKNOWN if v->Type() is null. Should we instead just call OrtTypeInfo::FromOrtValue and
+ // return an OrtTypeInfo value in 'out' with type set to ONNX_TYPE_UNKNOWN? Or is the inconsistency fine?
+ if (v->Type() == nullptr) {
*out = nullptr;
return nullptr;
}
- if (type->IsTensorType() || type->IsSparseTensorType()) {
- const onnxruntime::TensorShape* shape = nullptr;
- onnxruntime::MLDataType data_type = nullptr;
- if (type->IsTensorType()) {
- const Tensor& tensor = v->Get();
- shape = &tensor.Shape();
- data_type = tensor.DataType();
- } else {
- const SparseTensor& tensor = v->Get();
- shape = &tensor.Shape();
- data_type = tensor.Values().DataType();
- }
- return OrtTypeInfo::FromDataTypeImpl(type, shape, data_type, out);
- }
- return OrtTypeInfo::FromDataTypeImpl(type, nullptr, nullptr, out);
+
+ auto status = OrtTypeInfo::FromOrtValue(*v, out);
+ return status;
+ API_IMPL_END
}
diff --git a/onnxruntime/core/framework/tensor_type_and_shape.h b/onnxruntime/core/framework/tensor_type_and_shape.h
index 9c829215b9..28431a9d61 100644
--- a/onnxruntime/core/framework/tensor_type_and_shape.h
+++ b/onnxruntime/core/framework/tensor_type_and_shape.h
@@ -6,6 +6,9 @@ struct OrtTensorTypeAndShapeInfo {
public:
ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
onnxruntime::TensorShape shape;
+ // dim_param values. empty string if dim_value or no dim_param was specified.
+ // one entry per dimension in shape. only guaranteed to be populated for graph inputs and outputs
+ std::vector dim_params;
OrtTensorTypeAndShapeInfo() = default;
OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = delete;
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
index a52184fb9a..41456bbb6f 100644
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
@@ -599,7 +599,7 @@ static OrtStatus* GetNodeDefTypeInfoHelper(const OrtSession* sess, GetDefListFn
if (p.second->size() <= index)
return OrtApis::CreateStatus(ORT_FAIL, "out of index");
const ONNX_NAMESPACE::TypeProto* type_proto = (*p.second)[index]->TypeAsProto();
- return OrtTypeInfo::FromDataTypeImpl(type_proto, out);
+ return OrtTypeInfo::FromTypeProto(type_proto, out);
API_IMPL_END
}
@@ -1310,6 +1310,7 @@ static constexpr OrtApi ort_api_1 = {
&OrtApis::GetTensorElementType,
&OrtApis::GetDimensionsCount,
&OrtApis::GetDimensions,
+ &OrtApis::GetSymbolicDimensions,
&OrtApis::GetTensorShapeElementCount,
&OrtApis::GetTensorTypeAndShape,
&OrtApis::GetTypeInfo,
diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h
index 07e09fdcb7..e2fbcbde82 100644
--- a/onnxruntime/core/session/ort_apis.h
+++ b/onnxruntime/core/session/ort_apis.h
@@ -104,6 +104,7 @@ ORT_API_STATUS_IMPL(SetDimensions, OrtTensorTypeAndShapeInfo* info, _In_ const i
ORT_API_STATUS_IMPL(GetTensorElementType, _In_ const OrtTensorTypeAndShapeInfo*, _Out_ enum ONNXTensorElementDataType* out);
ORT_API_STATUS_IMPL(GetDimensionsCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out);
ORT_API_STATUS_IMPL(GetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
+ORT_API_STATUS_IMPL(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ const char* dim_params[], size_t dim_params_length);
ORT_API_STATUS_IMPL(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out);
ORT_API_STATUS_IMPL(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out);
ORT_API_STATUS_IMPL(GetTypeInfo, _In_ const OrtValue* value, _Outptr_ OrtTypeInfo** out);
diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc
index 318b6d4048..70d7f81157 100644
--- a/onnxruntime/test/shared_lib/test_inference.cc
+++ b/onnxruntime/test/shared_lib/test_inference.cc
@@ -120,6 +120,8 @@ void TestInference(Ort::Env& env, T model_uri,
static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/mul_1.onnx");
static constexpr PATH_TYPE CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_1.onnx");
static constexpr PATH_TYPE OVERRIDABLE_INITIALIZER_MODEL_URI = TSTR("testdata/overridable_initializer.onnx");
+static constexpr PATH_TYPE NAMED_AND_ANON_DIM_PARAM_URI = TSTR("testdata/capi_symbolic_dims.onnx");
+
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
static constexpr PATH_TYPE PYOP_FLOAT_MODEL_URI = TSTR("testdata/pyop_1.onnx");
#endif
@@ -145,6 +147,34 @@ TEST_P(CApiTestWithProvider, simple) {
TestInference(env_, MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, GetParam(), nullptr);
}
+TEST_F(CApiTest, dim_param) {
+ Ort::SessionOptions session_options;
+ Ort::Session session(env_, NAMED_AND_ANON_DIM_PARAM_URI, session_options);
+
+ auto in0 = session.GetInputTypeInfo(0);
+ auto in0_ttsi = in0.GetTensorTypeAndShapeInfo();
+
+ auto num_input_dims = in0_ttsi.GetDimensionsCount();
+ ASSERT_GE(num_input_dims, 1);
+ // reading 1st dimension only so don't need to malloc int64_t* or const char** values for the Get*Dimensions calls
+ int64_t dim_value = 0;
+ const char* dim_param = nullptr;
+ in0_ttsi.GetDimensions(&dim_value, 1);
+ in0_ttsi.GetSymbolicDimensions(&dim_param, 1);
+ ASSERT_EQ(dim_value, -1) << "symbolic dimension should be -1";
+ ASSERT_EQ(strcmp(dim_param, "n"), 0) << "Expected 'n'. Got: " << dim_param;
+
+ auto out0 = session.GetOutputTypeInfo(0);
+ auto out0_ttsi = out0.GetTensorTypeAndShapeInfo();
+ auto num_output_dims = out0_ttsi.GetDimensionsCount();
+ ASSERT_EQ(num_output_dims, 1);
+
+ out0_ttsi.GetDimensions(&dim_value, 1);
+ out0_ttsi.GetSymbolicDimensions(&dim_param, 1);
+ ASSERT_EQ(dim_value, -1) << "symbolic dimension should be -1";
+ ASSERT_EQ(strcmp(dim_param, ""), 0);
+}
+
INSTANTIATE_TEST_CASE_P(CApiTestWithProviders,
CApiTestWithProvider,
::testing::Values(0, 1, 2, 3, 4));
diff --git a/onnxruntime/test/testdata/capi_symbolic_dims.onnx b/onnxruntime/test/testdata/capi_symbolic_dims.onnx
new file mode 100644
index 0000000000..e9f8d1666b
Binary files /dev/null and b/onnxruntime/test/testdata/capi_symbolic_dims.onnx differ
diff --git a/onnxruntime/test/testdata/capi_symbolic_dims.py b/onnxruntime/test/testdata/capi_symbolic_dims.py
new file mode 100644
index 0000000000..d3af986534
--- /dev/null
+++ b/onnxruntime/test/testdata/capi_symbolic_dims.py
@@ -0,0 +1,40 @@
+import onnx
+from onnx import helper
+from onnx import TensorProto
+from onnx import shape_inference
+
+# create output with rank but unnamed symbolic dim
+output = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1])
+output.type.tensor_type.shape.Clear()
+dim = output.type.tensor_type.shape.dim.add()
+print(dim)
+
+graph_def = helper.make_graph(
+ nodes = [
+ helper.make_node(op_type = "Reshape", inputs = ['A', 'B'], outputs = ['C'], name = 'reshape'),
+ ],
+ name = 'test-model',
+ inputs = [
+ # create inputs with symbolic dims
+ helper.make_tensor_value_info("A", TensorProto.FLOAT, ['n', 2]),
+ helper.make_tensor_value_info("B", TensorProto.INT64, ['m']),
+ ],
+ outputs = [
+ output
+ ],
+ initializer = [
+ ]
+)
+
+model = helper.make_model(graph_def, opset_imports=[helper.make_operatorsetid("", 11)])
+onnx.checker.check_model(model)
+
+inferred_model = shape_inference.infer_shapes(model)
+onnx.checker.check_model(inferred_model)
+
+onnx.save_model(model, "capi_symbolic_dims.onnx")
+
+import onnxruntime as rt
+sess = rt.InferenceSession("capi_symbolic_dims.onnx")
+print([i.shape for i in sess.get_inputs()])
+print([i.shape for i in sess.get_outputs()])