diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 9b9dd81a74..d3a8cade4d 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2044,6 +2044,9 @@ inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_n int64_t i = {}; size_t out = {}; // first call to get the bytes needed + // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. + // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). + // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out); if (status) { size_t num_i = out / sizeof(int64_t); @@ -2051,6 +2054,9 @@ inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_n Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out)); return ints; } else { + if (out == 0u) { + return {}; + } return {i}; } } @@ -2068,6 +2074,9 @@ inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* at float f = {}; size_t out = {}; // first call to get the bytes needed + // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. + // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). + // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out); if (status) { size_t num_f = out / sizeof(float); @@ -2075,6 +2084,9 @@ inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* at Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out)); return floats; } else { + if (out == 0u) { + return {}; + } return {f}; } } @@ -2099,6 +2111,9 @@ inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* char c = {}; size_t out = {}; // first call to get the bytes needed + // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. + // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). + // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out); if (status) { std::vector chars(out, '\0'); @@ -2115,6 +2130,9 @@ inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* } return strings; } else { + if (out == 0u) { + return {}; + } return {std::string{c}}; } }