mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Add ability to get symbolic dimension info for graph inputs and outputs. (#2051)
* Add ability to get symbolic dimension info for graph inputs and outputs. WIP to get initial feedback. * Fix linxu build error. Update C# API and add unit test * Clarify the two different ways Tensor shape and type info is created. One is from concrete values and one is from a type proto where symbolic dimensions may exist. Doing so allows a change to default to empty strings for the symbolic dimensions if not provided.
This commit is contained in:
parent
20515363e5
commit
eb24617d2e
16 changed files with 258 additions and 80 deletions
|
|
@ -433,7 +433,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
}
|
||||
if (valueType != OnnxValueType.ONNX_TYPE_TENSOR && valueType != OnnxValueType.ONNX_TYPE_SPARSETENSOR)
|
||||
{
|
||||
return new NodeMetadata(valueType, new int[] { }, typeof(NamedOnnxValue));
|
||||
return new NodeMetadata(valueType, new int[] { }, new string[] { }, typeof(NamedOnnxValue));
|
||||
}
|
||||
|
||||
IntPtr tensorInfo;
|
||||
|
|
@ -453,14 +453,26 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
TensorElementTypeConverter.GetTypeAndWidth(type, out dotnetType, out width);
|
||||
UIntPtr numDimensions;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(tensorInfo, out numDimensions));
|
||||
|
||||
long[] dimensions = new long[(int)numDimensions];
|
||||
NativeMethods.OrtGetDimensions(tensorInfo, dimensions, numDimensions);
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensions(tensorInfo, dimensions, numDimensions));
|
||||
int[] intDimensions = new int[(int)numDimensions];
|
||||
for (var i = 0; i < (long)numDimensions; i++)
|
||||
{
|
||||
intDimensions[i] = (int)dimensions[i];
|
||||
}
|
||||
return new NodeMetadata(valueType, intDimensions, dotnetType);
|
||||
|
||||
IntPtr[] dimensionNamePtrs = new IntPtr[(int)numDimensions];
|
||||
NativeApiStatus.VerifySuccess(
|
||||
NativeMethods.OrtGetSymbolicDimensions(tensorInfo, dimensionNamePtrs, numDimensions));
|
||||
|
||||
string[] symbolicDimensions = new string[(int)numDimensions];
|
||||
for (var i = 0; i < (int)numDimensions; i++)
|
||||
{
|
||||
symbolicDimensions[i] = Marshal.PtrToStringAnsi(dimensionNamePtrs[i]); //assumes charset = ANSI
|
||||
}
|
||||
|
||||
return new NodeMetadata(valueType, intDimensions, symbolicDimensions, dotnetType);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
|
@ -514,12 +526,14 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
{
|
||||
private OnnxValueType _onnxValueType;
|
||||
private int[] _dimensions;
|
||||
private string[] _symbolicDimensions;
|
||||
private Type _type;
|
||||
|
||||
internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, Type type)
|
||||
internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, string[] symbolicDimensions, Type type)
|
||||
{
|
||||
_onnxValueType = onnxValueType;
|
||||
_dimensions = dimensions;
|
||||
_symbolicDimensions = symbolicDimensions;
|
||||
_type = type;
|
||||
}
|
||||
|
||||
|
|
@ -538,6 +552,15 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
return _dimensions;
|
||||
}
|
||||
}
|
||||
|
||||
public string[] SymbolicDimensions
|
||||
{
|
||||
get
|
||||
{
|
||||
return _symbolicDimensions;
|
||||
}
|
||||
}
|
||||
|
||||
public System.Type ElementType
|
||||
{
|
||||
get
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
public IntPtr GetTensorElementType;
|
||||
public IntPtr GetDimensionsCount;
|
||||
public IntPtr GetDimensions;
|
||||
public IntPtr GetSymbolicDimensions;
|
||||
public IntPtr GetTensorShapeElementCount;
|
||||
public IntPtr GetTensorTypeAndShape;
|
||||
public IntPtr GetTypeInfo;
|
||||
|
|
@ -220,6 +221,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
OrtGetTensorElementType = (DOrtGetTensorElementType)Marshal.GetDelegateForFunctionPointer(api_.GetTensorElementType, typeof(DOrtGetTensorElementType));
|
||||
OrtGetDimensionsCount = (DOrtGetDimensionsCount)Marshal.GetDelegateForFunctionPointer(api_.GetDimensionsCount, typeof(DOrtGetDimensionsCount));
|
||||
OrtGetDimensions = (DOrtGetDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetDimensions, typeof(DOrtGetDimensions));
|
||||
OrtGetSymbolicDimensions = (DOrtGetSymbolicDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetSymbolicDimensions, typeof(DOrtGetSymbolicDimensions));
|
||||
OrtGetTensorShapeElementCount = (DOrtGetTensorShapeElementCount)Marshal.GetDelegateForFunctionPointer(api_.GetTensorShapeElementCount, typeof(DOrtGetTensorShapeElementCount));
|
||||
OrtReleaseValue = (DOrtReleaseValue)Marshal.GetDelegateForFunctionPointer(api_.ReleaseValue, typeof(DOrtReleaseValue));
|
||||
}
|
||||
|
|
@ -627,6 +629,23 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
UIntPtr dim_values_length);
|
||||
public static DOrtGetDimensions OrtGetDimensions;
|
||||
|
||||
/**
|
||||
* Get the symbolic dimension names for dimensions with a value of -1.
|
||||
* Order and number of entries is the same as values returned by GetDimensions.
|
||||
* The name may be empty for an unnamed symbolic dimension.
|
||||
* e.g.
|
||||
* If OrtGetDimensions returns [-1, -1, 2], OrtGetSymbolicDimensions would return an array with 3 entries.
|
||||
* If the values returned were ['batch', '', ''] it would indicate that
|
||||
* - the first dimension was a named symbolic dimension (-1 dim value and name in symbolic dimensions),
|
||||
* - the second dimension was an unnamed symbolic dimension (-1 dim value and empty string),
|
||||
* - the entry for the third dimension should be ignored as it is not a symbolic dimension (dim value >= 0).
|
||||
*/
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetSymbolicDimensions(
|
||||
IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo,
|
||||
IntPtr[] dim_params, /* const char* values, converted to string by caller */
|
||||
UIntPtr dim_params_length);
|
||||
public static DOrtGetSymbolicDimensions OrtGetSymbolicDimensions;
|
||||
|
||||
/**
|
||||
* How many elements does this tensor have.
|
||||
* May return a negative value
|
||||
|
|
|
|||
|
|
@ -569,6 +569,39 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
private void TestSymbolicDimsMetadata()
|
||||
{
|
||||
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "capi_symbolic_dims.onnx");
|
||||
using (var session = new InferenceSession(modelPath))
|
||||
{
|
||||
var inputs = session.InputMetadata;
|
||||
var outputs = session.OutputMetadata;
|
||||
|
||||
Assert.Equal(2, inputs.Count);
|
||||
Assert.Equal(1, session.OutputMetadata.Count);
|
||||
Assert.True(inputs.ContainsKey("A"));
|
||||
Assert.True(inputs.ContainsKey("B"));
|
||||
Assert.True(outputs.ContainsKey("C"));
|
||||
|
||||
var inputA = inputs["A"];
|
||||
var inputB = inputs["B"];
|
||||
var outputC = outputs["C"];
|
||||
|
||||
// dimension values and any symbolic dimension info should have the same length
|
||||
Assert.Equal(inputA.Dimensions.Length, inputA.SymbolicDimensions.Length);
|
||||
Assert.Equal(inputB.Dimensions.Length, inputB.SymbolicDimensions.Length);
|
||||
Assert.Equal(outputC.Dimensions.Length, outputC.SymbolicDimensions.Length);
|
||||
|
||||
Assert.Equal(inputA.Dimensions, new int[] { -1, 2 });
|
||||
Assert.Equal(inputA.SymbolicDimensions, new string[] { "n", "" });
|
||||
Assert.Equal(inputB.Dimensions, new int[] { -1 });
|
||||
Assert.Equal(inputB.SymbolicDimensions, new string[] { "m" });
|
||||
Assert.Equal(outputC.Dimensions, new int[] { -1 });
|
||||
Assert.Equal(outputC.SymbolicDimensions, new string[] { "" }); // unnamed symbolic dim
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
[Fact]
|
||||
private void TestModelInputFloat()
|
||||
|
|
|
|||
|
|
@ -75,7 +75,10 @@
|
|||
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
|
||||
<Visible>false</Visible>
|
||||
</None>
|
||||
|
||||
<None Include="$(OnnxRuntimeCSharpRoot)\..\onnxruntime\test\testdata\capi_symbolic_dims.onnx">
|
||||
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
|
||||
<Visible>false</Visible>
|
||||
</None>
|
||||
<BuildEnvVars Include="OnnxRuntimeBuildDirectory=$(OnnxRuntimeBuildDirectory)" />
|
||||
</ItemGroup>
|
||||
|
||||
|
|
|
|||
|
|
@ -468,6 +468,7 @@ struct OrtApi {
|
|||
OrtStatus*(ORT_API_CALL* GetTensorElementType)(_In_ const OrtTensorTypeAndShapeInfo*, _Out_ enum ONNXTensorElementDataType* out)NO_EXCEPTION;
|
||||
OrtStatus*(ORT_API_CALL* GetDimensionsCount)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out)NO_EXCEPTION;
|
||||
OrtStatus*(ORT_API_CALL* GetDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length)NO_EXCEPTION;
|
||||
OrtStatus*(ORT_API_CALL* GetSymbolicDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ const char** dim_params, size_t dim_params_length)NO_EXCEPTION;
|
||||
|
||||
/**
|
||||
* Return the number of elements specified by the tensor shape.
|
||||
|
|
|
|||
|
|
@ -200,6 +200,8 @@ struct TensorTypeAndShapeInfo : Base<OrtTensorTypeAndShapeInfo> {
|
|||
|
||||
size_t GetDimensionsCount() const;
|
||||
void GetDimensions(int64_t* values, size_t values_count) const;
|
||||
void GetSymbolicDimensions(const char** values, size_t values_count) const;
|
||||
|
||||
std::vector<int64_t> GetShape() const;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -322,6 +322,10 @@ inline void TensorTypeAndShapeInfo::GetDimensions(int64_t* values, size_t values
|
|||
ThrowOnError(g_api->GetDimensions(p_, values, values_count));
|
||||
}
|
||||
|
||||
inline void TensorTypeAndShapeInfo::GetSymbolicDimensions(const char** values, size_t values_count) const {
|
||||
ThrowOnError(g_api->GetSymbolicDimensions(p_, values, values_count));
|
||||
}
|
||||
|
||||
inline std::vector<int64_t> TensorTypeAndShapeInfo::GetShape() const {
|
||||
std::vector<int64_t> out(GetDimensionsCount(), 0);
|
||||
GetDimensions(out.data(), out.size());
|
||||
|
|
|
|||
|
|
@ -41,72 +41,74 @@ ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) {
|
|||
delete ptr;
|
||||
}
|
||||
|
||||
OrtStatus* GetTensorShapeAndType(const TensorShape* shape, const onnxruntime::DataTypeImpl* tensor_data_type, OrtTensorTypeAndShapeInfo** out);
|
||||
OrtStatus* GetTensorShapeAndType(const TensorShape* shape, const ONNX_NAMESPACE::TypeProto* type_proto, OrtTensorTypeAndShapeInfo** out);
|
||||
OrtStatus* GetTensorShapeAndType(const TensorShape& shape, const onnxruntime::DataTypeImpl& tensor_data_type,
|
||||
OrtTensorTypeAndShapeInfo** out);
|
||||
OrtStatus* GetTensorShapeAndType(const TensorShape& shape, const std::vector<std::string>* dim_params,
|
||||
const ONNX_NAMESPACE::TypeProto& type_proto, OrtTensorTypeAndShapeInfo** out);
|
||||
|
||||
OrtStatus* OrtTypeInfo::FromDataTypeImpl(const onnxruntime::DataTypeImpl* input, const TensorShape* shape, const onnxruntime::DataTypeImpl* tensor_data_type, OrtTypeInfo** out) {
|
||||
if (input == nullptr) {
|
||||
OrtStatus* OrtTypeInfo::FromOrtValue(const OrtValue& value, OrtTypeInfo** out) {
|
||||
onnxruntime::MLDataType type = value.Type();
|
||||
if (type == nullptr) {
|
||||
*out = new OrtTypeInfo(ONNX_TYPE_UNKNOWN, nullptr);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// GetType<Tensor> and GetType<SparseTensor> do not have TypeProto populated because they return a static
|
||||
// TensorBase/SparseTensorBase instances, but other types are real MLDataTypes and they do have real protos
|
||||
// unless they are primitive data types, in which case we as before return them not implemented
|
||||
// however, this way we can support Opaque and we can avoid excessive calls to GetType()
|
||||
if (input->IsTensorType()) {
|
||||
if (type->IsTensorType()) {
|
||||
OrtTensorTypeAndShapeInfo* info = nullptr;
|
||||
const Tensor& tensor = value.Get<onnxruntime::Tensor>();
|
||||
const auto* tensor_data_type = tensor.DataType();
|
||||
if (tensor_data_type != nullptr) {
|
||||
OrtStatus* st = GetTensorShapeAndType(shape, tensor_data_type, &info);
|
||||
if (st != nullptr) return st;
|
||||
OrtStatus* st = GetTensorShapeAndType(tensor.Shape(), *tensor_data_type, &info);
|
||||
if (st != nullptr)
|
||||
return st;
|
||||
}
|
||||
*out = new OrtTypeInfo(ONNX_TYPE_TENSOR, info);
|
||||
return nullptr;
|
||||
}
|
||||
if (input->IsSparseTensorType()) {
|
||||
|
||||
if (type->IsSparseTensorType()) {
|
||||
OrtTensorTypeAndShapeInfo* info = nullptr;
|
||||
const SparseTensor& tensor = value.Get<onnxruntime::SparseTensor>();
|
||||
const auto* tensor_data_type = tensor.Values().DataType();
|
||||
if (tensor_data_type != nullptr) {
|
||||
OrtStatus* st = GetTensorShapeAndType(shape, tensor_data_type, &info);
|
||||
OrtStatus* st = GetTensorShapeAndType(tensor.Shape(), *tensor_data_type, &info);
|
||||
if (st != nullptr) return st;
|
||||
}
|
||||
*out = new OrtTypeInfo(ONNX_TYPE_SPARSETENSOR, info);
|
||||
return nullptr;
|
||||
}
|
||||
const auto* type_proto = input->GetTypeProto();
|
||||
|
||||
const auto* type_proto = type->GetTypeProto();
|
||||
if (type_proto != nullptr) {
|
||||
// Place Opaque first as tensors will be
|
||||
// mostly handled above and maps and sequences
|
||||
// are not common
|
||||
// Place Opaque first as tensors will be mostly handled above and maps and sequences are not common
|
||||
switch (type_proto->value_case()) {
|
||||
case on::TypeProto::kOpaqueType: {
|
||||
*out = new OrtTypeInfo(ONNX_TYPE_OPAQUE, nullptr);
|
||||
return nullptr;
|
||||
} break;
|
||||
}
|
||||
case on::TypeProto::kMapType: {
|
||||
*out = new OrtTypeInfo(ONNX_TYPE_MAP, nullptr);
|
||||
return nullptr;
|
||||
} break;
|
||||
}
|
||||
case on::TypeProto::kSequenceType: {
|
||||
*out = new OrtTypeInfo(ONNX_TYPE_SEQUENCE, nullptr);
|
||||
return nullptr;
|
||||
} break;
|
||||
}
|
||||
// Real Tensor support
|
||||
case on::TypeProto::kTensorType:
|
||||
case on::TypeProto::kSparseTensorType: {
|
||||
OrtTensorTypeAndShapeInfo* info = nullptr;
|
||||
OrtStatus* st = GetTensorShapeAndType(shape, type_proto, &info);
|
||||
if (st != nullptr) return st;
|
||||
if (type_proto->value_case() == on::TypeProto::kTensorType) {
|
||||
*out = new OrtTypeInfo(ONNX_TYPE_TENSOR, info);
|
||||
} else {
|
||||
*out = new OrtTypeInfo(ONNX_TYPE_SPARSETENSOR, nullptr);
|
||||
}
|
||||
return nullptr;
|
||||
} break;
|
||||
return OrtApis::CreateStatus(ORT_FAIL, "Tensor types should have been handled already");
|
||||
}
|
||||
default:
|
||||
// NOT_IMPLEMENTED
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "not implemented");
|
||||
}
|
||||
|
||||
|
|
@ -146,7 +148,7 @@ const DataTypeImpl* OrtTypeInfo::ElementTypeFromProto(int type) {
|
|||
}
|
||||
}
|
||||
|
||||
OrtStatus* OrtTypeInfo::FromDataTypeImpl(const ONNX_NAMESPACE::TypeProto* input, OrtTypeInfo** out) {
|
||||
OrtStatus* OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::TypeProto* input, OrtTypeInfo** out) {
|
||||
auto value_case = input->value_case();
|
||||
switch (value_case) {
|
||||
case on::TypeProto::kTensorType:
|
||||
|
|
@ -168,11 +170,13 @@ OrtStatus* OrtTypeInfo::FromDataTypeImpl(const ONNX_NAMESPACE::TypeProto* input,
|
|||
sp = &sparse_type->shape();
|
||||
}
|
||||
}
|
||||
|
||||
OrtStatus* st = nullptr;
|
||||
OrtTensorTypeAndShapeInfo* info = nullptr;
|
||||
if (sp != nullptr) {
|
||||
const on::TensorShapeProto& s = *sp;
|
||||
std::vector<int64_t> dims(s.dim_size());
|
||||
std::vector<std::string> dim_params(s.dim_size());
|
||||
TensorShape shape_data(std::move(dims));
|
||||
for (int i = 0; i < s.dim_size(); ++i) {
|
||||
auto& t = s.dim(i);
|
||||
|
|
@ -181,6 +185,8 @@ OrtStatus* OrtTypeInfo::FromDataTypeImpl(const ONNX_NAMESPACE::TypeProto* input,
|
|||
shape_data[i] = t.dim_value();
|
||||
break;
|
||||
case on::TensorShapeProto::Dimension::kDimParam:
|
||||
dim_params[i] = t.dim_param();
|
||||
// fall through
|
||||
case on::TensorShapeProto::Dimension::VALUE_NOT_SET:
|
||||
shape_data[i] = -1;
|
||||
break;
|
||||
|
|
@ -188,9 +194,9 @@ OrtStatus* OrtTypeInfo::FromDataTypeImpl(const ONNX_NAMESPACE::TypeProto* input,
|
|||
assert(false);
|
||||
}
|
||||
}
|
||||
st = GetTensorShapeAndType(&shape_data, input, &info);
|
||||
st = GetTensorShapeAndType(shape_data, &dim_params, *input, &info);
|
||||
} else {
|
||||
st = GetTensorShapeAndType(nullptr, input, &info);
|
||||
st = GetTensorShapeAndType(TensorShape(), nullptr, *input, &info);
|
||||
}
|
||||
if (st != nullptr) return st;
|
||||
*out = new OrtTypeInfo(ten_type, info);
|
||||
|
|
|
|||
|
|
@ -29,9 +29,8 @@ struct OrtTypeInfo {
|
|||
OrtTypeInfo(const OrtTypeInfo& other) = delete;
|
||||
OrtTypeInfo& operator=(const OrtTypeInfo& other) = delete;
|
||||
|
||||
static OrtStatus* FromDataTypeImpl(const onnxruntime::DataTypeImpl* input, const onnxruntime::TensorShape* shape,
|
||||
const onnxruntime::DataTypeImpl* tensor_data_type, OrtTypeInfo** out);
|
||||
static OrtStatus* FromDataTypeImpl(const ONNX_NAMESPACE::TypeProto*, OrtTypeInfo** out);
|
||||
static OrtStatus* FromOrtValue(const OrtValue& value, OrtTypeInfo** out);
|
||||
static OrtStatus* FromTypeProto(const ONNX_NAMESPACE::TypeProto*, OrtTypeInfo** out);
|
||||
|
||||
static const onnxruntime::DataTypeImpl* ElementTypeFromProto(int type);
|
||||
|
||||
|
|
|
|||
|
|
@ -61,6 +61,15 @@ ORT_API_STATUS_IMPL(OrtApis::GetDimensions, _In_ const struct OrtTensorTypeAndSh
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, _In_ const struct OrtTensorTypeAndShapeInfo* info,
|
||||
_Out_ const char** names, size_t dim_params_length) {
|
||||
for (size_t idx = 0, end = std::min(info->dim_params.size(), dim_params_length); idx < end; ++idx) {
|
||||
names[idx] = info->dim_params[idx].c_str();
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* this_ptr, _Out_ size_t* out) {
|
||||
*out = static_cast<size_t>(this_ptr->shape.Size());
|
||||
return nullptr;
|
||||
|
|
@ -159,7 +168,8 @@ ONNXTensorElementDataType TensorDataTypeToOnnxRuntimeTensorElementDataType(
|
|||
return type;
|
||||
}
|
||||
|
||||
OrtStatus* GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, const onnxruntime::TensorShape* shape, OrtTensorTypeAndShapeInfo** out) {
|
||||
OrtStatus* GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, const onnxruntime::TensorShape shape,
|
||||
const std::vector<std::string>* dim_params, OrtTensorTypeAndShapeInfo** out) {
|
||||
OrtTensorTypeAndShapeInfo* ret;
|
||||
if (auto* status = OrtApis::CreateTensorTypeAndShapeInfo(&ret))
|
||||
return status;
|
||||
|
|
@ -167,37 +177,47 @@ OrtStatus* GetTensorShapeAndTypeHelper(ONNXTensorElementDataType type, const onn
|
|||
OrtApis::ReleaseTensorTypeAndShapeInfo(ret);
|
||||
return status;
|
||||
}
|
||||
if (shape != nullptr) {
|
||||
auto* status = OrtApis::SetDimensions(ret, shape->GetDims().data(), shape->GetDims().size());
|
||||
if (status != nullptr) {
|
||||
OrtApis::ReleaseTensorTypeAndShapeInfo(ret);
|
||||
return status;
|
||||
}
|
||||
|
||||
auto* status = OrtApis::SetDimensions(ret, shape.GetDims().data(), shape.GetDims().size());
|
||||
if (status != nullptr) {
|
||||
OrtApis::ReleaseTensorTypeAndShapeInfo(ret);
|
||||
return status;
|
||||
}
|
||||
|
||||
if (dim_params != nullptr) {
|
||||
ret->dim_params = *dim_params;
|
||||
} else {
|
||||
// we expect to be being called with a concrete shape so validate that
|
||||
assert(shape.Size() >= 0);
|
||||
ret->dim_params.resize(shape.NumDimensions(), "");
|
||||
}
|
||||
|
||||
*out = ret;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape* shape,
|
||||
const onnxruntime::DataTypeImpl* tensor_data_type, OrtTensorTypeAndShapeInfo** out) {
|
||||
ONNXTensorElementDataType type = MLDataTypeToOnnxRuntimeTensorElementDataType(tensor_data_type);
|
||||
OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape& shape,
|
||||
const onnxruntime::DataTypeImpl& tensor_data_type, OrtTensorTypeAndShapeInfo** out) {
|
||||
ONNXTensorElementDataType type = MLDataTypeToOnnxRuntimeTensorElementDataType(&tensor_data_type);
|
||||
if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) {
|
||||
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Not implemented");
|
||||
}
|
||||
return GetTensorShapeAndTypeHelper(type, shape, out);
|
||||
return GetTensorShapeAndTypeHelper(type, shape, nullptr, out);
|
||||
}
|
||||
|
||||
OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape* shape,
|
||||
const ONNX_NAMESPACE::TypeProto* type_proto, OrtTensorTypeAndShapeInfo** out) {
|
||||
assert(type_proto != nullptr);
|
||||
auto value_case = type_proto->value_case();
|
||||
assert(value_case == ONNX_NAMESPACE::TypeProto::kTensorType || value_case == ONNX_NAMESPACE::TypeProto::kSparseTensorType);
|
||||
auto dtype = (value_case == ONNX_NAMESPACE::TypeProto::kTensorType) ? type_proto->tensor_type().elem_type() : type_proto->sparse_tensor_type().elem_type();
|
||||
OrtStatus* GetTensorShapeAndType(const onnxruntime::TensorShape& shape, const std::vector<std::string>* dim_params,
|
||||
const ONNX_NAMESPACE::TypeProto& type_proto, OrtTensorTypeAndShapeInfo** out) {
|
||||
auto value_case = type_proto.value_case();
|
||||
assert(value_case == ONNX_NAMESPACE::TypeProto::kTensorType ||
|
||||
value_case == ONNX_NAMESPACE::TypeProto::kSparseTensorType);
|
||||
|
||||
auto dtype = (value_case == ONNX_NAMESPACE::TypeProto::kTensorType) ? type_proto.tensor_type().elem_type()
|
||||
: type_proto.sparse_tensor_type().elem_type();
|
||||
ONNXTensorElementDataType type = TensorDataTypeToOnnxRuntimeTensorElementDataType(dtype);
|
||||
if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) {
|
||||
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Not implemented");
|
||||
}
|
||||
return GetTensorShapeAndTypeHelper(type, shape, out);
|
||||
return GetTensorShapeAndTypeHelper(type, shape, dim_params, out);
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out_ OrtTensorTypeAndShapeInfo** out) {
|
||||
|
|
@ -216,7 +236,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out
|
|||
shape = &tensor.Shape();
|
||||
data_type = tensor.Values().DataType();
|
||||
}
|
||||
return GetTensorShapeAndType(shape, data_type, out);
|
||||
return GetTensorShapeAndType(*shape, *data_type, out);
|
||||
} else {
|
||||
ORT_THROW("Argument is not a tensor");
|
||||
}
|
||||
|
|
@ -225,10 +245,11 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out
|
|||
|
||||
ORT_API_STATUS_IMPL(OrtApis::GetValueType, _In_ const OrtValue* v, _Out_ ONNXType* out) {
|
||||
API_IMPL_BEGIN
|
||||
onnxruntime::MLDataType type = v->Type();
|
||||
OrtTypeInfo* type_info;
|
||||
if (auto status = OrtTypeInfo::FromDataTypeImpl(type, nullptr, nullptr, &type_info))
|
||||
auto status = OrtTypeInfo::FromOrtValue(*v, &type_info);
|
||||
if (status != nullptr)
|
||||
return status;
|
||||
|
||||
*out = type_info->type;
|
||||
OrtApis::ReleaseTypeInfo(type_info);
|
||||
return nullptr;
|
||||
|
|
@ -236,29 +257,21 @@ ORT_API_STATUS_IMPL(OrtApis::GetValueType, _In_ const OrtValue* v, _Out_ ONNXTyp
|
|||
}
|
||||
|
||||
/**
|
||||
* Get the type information of an OrtValue
|
||||
* \param value
|
||||
* \return The returned value should be freed by OrtReleaseTypeInfo after use
|
||||
*/
|
||||
* Get the type information of an OrtValue
|
||||
* \param value
|
||||
* \return The returned value should be freed by OrtReleaseTypeInfo after use
|
||||
*/
|
||||
ORT_API_STATUS_IMPL(OrtApis::GetTypeInfo, _In_ const OrtValue* v, struct OrtTypeInfo** out) {
|
||||
onnxruntime::MLDataType type = v->Type();
|
||||
if (type == nullptr) {
|
||||
API_IMPL_BEGIN
|
||||
// TODO: This is consistent with the previous implementation but inconsistent with GetValueType which returns
|
||||
// ONNX_TYPE_UNKNOWN if v->Type() is null. Should we instead just call OrtTypeInfo::FromOrtValue and
|
||||
// return an OrtTypeInfo value in 'out' with type set to ONNX_TYPE_UNKNOWN? Or is the inconsistency fine?
|
||||
if (v->Type() == nullptr) {
|
||||
*out = nullptr;
|
||||
return nullptr;
|
||||
}
|
||||
if (type->IsTensorType() || type->IsSparseTensorType()) {
|
||||
const onnxruntime::TensorShape* shape = nullptr;
|
||||
onnxruntime::MLDataType data_type = nullptr;
|
||||
if (type->IsTensorType()) {
|
||||
const Tensor& tensor = v->Get<onnxruntime::Tensor>();
|
||||
shape = &tensor.Shape();
|
||||
data_type = tensor.DataType();
|
||||
} else {
|
||||
const SparseTensor& tensor = v->Get<onnxruntime::SparseTensor>();
|
||||
shape = &tensor.Shape();
|
||||
data_type = tensor.Values().DataType();
|
||||
}
|
||||
return OrtTypeInfo::FromDataTypeImpl(type, shape, data_type, out);
|
||||
}
|
||||
return OrtTypeInfo::FromDataTypeImpl(type, nullptr, nullptr, out);
|
||||
|
||||
auto status = OrtTypeInfo::FromOrtValue(*v, out);
|
||||
return status;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@ struct OrtTensorTypeAndShapeInfo {
|
|||
public:
|
||||
ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
onnxruntime::TensorShape shape;
|
||||
// dim_param values. empty string if dim_value or no dim_param was specified.
|
||||
// one entry per dimension in shape. only guaranteed to be populated for graph inputs and outputs
|
||||
std::vector<std::string> dim_params;
|
||||
|
||||
OrtTensorTypeAndShapeInfo() = default;
|
||||
OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other) = delete;
|
||||
|
|
|
|||
|
|
@ -599,7 +599,7 @@ static OrtStatus* GetNodeDefTypeInfoHelper(const OrtSession* sess, GetDefListFn
|
|||
if (p.second->size() <= index)
|
||||
return OrtApis::CreateStatus(ORT_FAIL, "out of index");
|
||||
const ONNX_NAMESPACE::TypeProto* type_proto = (*p.second)[index]->TypeAsProto();
|
||||
return OrtTypeInfo::FromDataTypeImpl(type_proto, out);
|
||||
return OrtTypeInfo::FromTypeProto(type_proto, out);
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
|
|
@ -1310,6 +1310,7 @@ static constexpr OrtApi ort_api_1 = {
|
|||
&OrtApis::GetTensorElementType,
|
||||
&OrtApis::GetDimensionsCount,
|
||||
&OrtApis::GetDimensions,
|
||||
&OrtApis::GetSymbolicDimensions,
|
||||
&OrtApis::GetTensorShapeElementCount,
|
||||
&OrtApis::GetTensorTypeAndShape,
|
||||
&OrtApis::GetTypeInfo,
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ ORT_API_STATUS_IMPL(SetDimensions, OrtTensorTypeAndShapeInfo* info, _In_ const i
|
|||
ORT_API_STATUS_IMPL(GetTensorElementType, _In_ const OrtTensorTypeAndShapeInfo*, _Out_ enum ONNXTensorElementDataType* out);
|
||||
ORT_API_STATUS_IMPL(GetDimensionsCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out);
|
||||
ORT_API_STATUS_IMPL(GetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
|
||||
ORT_API_STATUS_IMPL(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ const char* dim_params[], size_t dim_params_length);
|
||||
ORT_API_STATUS_IMPL(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out);
|
||||
ORT_API_STATUS_IMPL(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out);
|
||||
ORT_API_STATUS_IMPL(GetTypeInfo, _In_ const OrtValue* value, _Outptr_ OrtTypeInfo** out);
|
||||
|
|
|
|||
|
|
@ -120,6 +120,8 @@ void TestInference(Ort::Env& env, T model_uri,
|
|||
static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/mul_1.onnx");
|
||||
static constexpr PATH_TYPE CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_1.onnx");
|
||||
static constexpr PATH_TYPE OVERRIDABLE_INITIALIZER_MODEL_URI = TSTR("testdata/overridable_initializer.onnx");
|
||||
static constexpr PATH_TYPE NAMED_AND_ANON_DIM_PARAM_URI = TSTR("testdata/capi_symbolic_dims.onnx");
|
||||
|
||||
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
||||
static constexpr PATH_TYPE PYOP_FLOAT_MODEL_URI = TSTR("testdata/pyop_1.onnx");
|
||||
#endif
|
||||
|
|
@ -145,6 +147,34 @@ TEST_P(CApiTestWithProvider, simple) {
|
|||
TestInference<PATH_TYPE>(env_, MODEL_URI, inputs, "Y", expected_dims_y, expected_values_y, GetParam(), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(CApiTest, dim_param) {
|
||||
Ort::SessionOptions session_options;
|
||||
Ort::Session session(env_, NAMED_AND_ANON_DIM_PARAM_URI, session_options);
|
||||
|
||||
auto in0 = session.GetInputTypeInfo(0);
|
||||
auto in0_ttsi = in0.GetTensorTypeAndShapeInfo();
|
||||
|
||||
auto num_input_dims = in0_ttsi.GetDimensionsCount();
|
||||
ASSERT_GE(num_input_dims, 1);
|
||||
// reading 1st dimension only so don't need to malloc int64_t* or const char** values for the Get*Dimensions calls
|
||||
int64_t dim_value = 0;
|
||||
const char* dim_param = nullptr;
|
||||
in0_ttsi.GetDimensions(&dim_value, 1);
|
||||
in0_ttsi.GetSymbolicDimensions(&dim_param, 1);
|
||||
ASSERT_EQ(dim_value, -1) << "symbolic dimension should be -1";
|
||||
ASSERT_EQ(strcmp(dim_param, "n"), 0) << "Expected 'n'. Got: " << dim_param;
|
||||
|
||||
auto out0 = session.GetOutputTypeInfo(0);
|
||||
auto out0_ttsi = out0.GetTensorTypeAndShapeInfo();
|
||||
auto num_output_dims = out0_ttsi.GetDimensionsCount();
|
||||
ASSERT_EQ(num_output_dims, 1);
|
||||
|
||||
out0_ttsi.GetDimensions(&dim_value, 1);
|
||||
out0_ttsi.GetSymbolicDimensions(&dim_param, 1);
|
||||
ASSERT_EQ(dim_value, -1) << "symbolic dimension should be -1";
|
||||
ASSERT_EQ(strcmp(dim_param, ""), 0);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(CApiTestWithProviders,
|
||||
CApiTestWithProvider,
|
||||
::testing::Values(0, 1, 2, 3, 4));
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/capi_symbolic_dims.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/capi_symbolic_dims.onnx
vendored
Normal file
Binary file not shown.
40
onnxruntime/test/testdata/capi_symbolic_dims.py
vendored
Normal file
40
onnxruntime/test/testdata/capi_symbolic_dims.py
vendored
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
from onnx import shape_inference
|
||||
|
||||
# create output with rank but unnamed symbolic dim
|
||||
output = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1])
|
||||
output.type.tensor_type.shape.Clear()
|
||||
dim = output.type.tensor_type.shape.dim.add()
|
||||
print(dim)
|
||||
|
||||
graph_def = helper.make_graph(
|
||||
nodes = [
|
||||
helper.make_node(op_type = "Reshape", inputs = ['A', 'B'], outputs = ['C'], name = 'reshape'),
|
||||
],
|
||||
name = 'test-model',
|
||||
inputs = [
|
||||
# create inputs with symbolic dims
|
||||
helper.make_tensor_value_info("A", TensorProto.FLOAT, ['n', 2]),
|
||||
helper.make_tensor_value_info("B", TensorProto.INT64, ['m']),
|
||||
],
|
||||
outputs = [
|
||||
output
|
||||
],
|
||||
initializer = [
|
||||
]
|
||||
)
|
||||
|
||||
model = helper.make_model(graph_def, opset_imports=[helper.make_operatorsetid("", 11)])
|
||||
onnx.checker.check_model(model)
|
||||
|
||||
inferred_model = shape_inference.infer_shapes(model)
|
||||
onnx.checker.check_model(inferred_model)
|
||||
|
||||
onnx.save_model(model, "capi_symbolic_dims.onnx")
|
||||
|
||||
import onnxruntime as rt
|
||||
sess = rt.InferenceSession("capi_symbolic_dims.onnx")
|
||||
print([i.shape for i in sess.get_inputs()])
|
||||
print([i.shape for i in sess.get_outputs()])
|
||||
Loading…
Reference in a new issue