diff --git a/include/onnxruntime/core/framework/sparse_tensor.h b/include/onnxruntime/core/framework/sparse_tensor.h
index af66a80874..6a0596882a 100644
--- a/include/onnxruntime/core/framework/sparse_tensor.h
+++ b/include/onnxruntime/core/framework/sparse_tensor.h
@@ -7,13 +7,15 @@
#include "core/framework/tensor_shape.h"
#include "core/framework/tensor.h"
+struct OrtValue;
+
namespace onnxruntime {
class IDataTransfer;
class DataTransferManager;
/**
- * @brief This is a Sparse Format enumeration representing bitflags
+ * @brief This is a Sparse Format enumeration
*
*
*/
@@ -59,8 +61,8 @@ class SparseTensor final {
///
/// MlDataType
/// a shape of original tensor in dense form
- /// shape for user supplied values
- /// a pointer to values
+ /// shape for user supplied values. Use {0} shape for fully sparse tensors.
+ /// a pointer to values. Use nullptr for fully sparse tensors.
/// description of the user allocated memory
SparseTensor(MLDataType elt_type,
const TensorShape& dense_shape,
@@ -70,7 +72,7 @@ class SparseTensor final {
///
/// Use this constructor to hold sparse data in the buffer
- /// allocated with the specificed allocator. Use Make*() methods
+ /// allocated with the specified allocator. Use Make*() methods
/// to populate the instance with data which will be copied into the
/// allocated buffer.
///
@@ -87,6 +89,57 @@ class SparseTensor final {
ORT_DISALLOW_COPY_AND_ASSIGNMENT(SparseTensor);
+ ///
+ /// The factory function creates an instance of SparseTensor on the heap
+ /// using appropriate constructor and initializes OrtValue instance wit it.
+ ///
+ /// element data type
+ /// dense shape of the sparse tensor
+ /// values shape. Use {0} for fully sparse tensors.
+ /// pointer to a user allocated buffer. Use nullptr for fully sparse tensors.
+ /// description of the user allocated buffer
+ /// default constructed input/output ort_value
+ static void InitOrtValue(MLDataType elt_type,
+ const TensorShape& dense_shape,
+ const TensorShape& values_shape,
+ void* values_data,
+ const OrtMemoryInfo& location,
+ OrtValue& ort_value);
+
+ ///
+ /// The factory function creates an instance of SparseTensor on the heap
+ /// using appropriate constructor and initializes OrtValue instance wit it.
+ ///
+ /// element data type
+ /// dense shape of the sparse tensor
+ /// allocator to use
+ /// default constructed input/output ort_value
+ static void InitOrtValue(MLDataType elt_type,
+ const TensorShape& dense_shape,
+ std::shared_ptr allocator,
+ OrtValue& ort_value);
+
+ ///
+ /// The function will check if the OrtValue is allocated
+ /// fetch the containing SparseTensor instance or throw if it
+ /// does not contain one. It will check that the SparseTensor has
+ /// sparse format set (i.e. fully constructed).
+ ///
+ /// OrtValue instance
+ /// const SparseTensor Reference
+ static const SparseTensor& GetSparseTensorFromOrtValue(const OrtValue& v);
+
+ ///
+ /// /// The function will check if the OrtValue is allocated
+ /// fetch the containing SparseTensor instance or throw if it
+ /// does not contain one. It will check that the SparseTensor does not
+ /// have sparse format set and will return non-const ref to so indices
+ /// can be added to it.
+ ///
+ /// OrtValue
+ /// non-const reference to SparseTensor
+ static SparseTensor& GetSparseTensorFromOrtValue(OrtValue& v);
+
///
// Returns the number of non-zero values (aka "NNZ")
// For block sparse formats this may include some zeros in the blocks
@@ -195,7 +248,7 @@ class SparseTensor final {
/// index shape would be 1-D (values_count) or it must be twice the number of values
/// in which case its shape would be 2-D (values_count, 2)
///
- /// user allocated buffer span
+ /// user allocated buffer span. Use empty span for fully sparse tensors.
/// Status
Status UseCooIndices(gsl::span indices);
@@ -209,13 +262,25 @@ class SparseTensor final {
///
/// Values shape is supplied at construction time and its Size() must match values_count.
///
- ///
- ///
+ /// Use 0 for fully sparse tensors.
+ /// pointer to a buffer to be copied. Use nullptr for fully sparse tensors.
///
///
Status MakeCooData(const IDataTransfer& data_transfer, const OrtMemoryInfo& data_location,
size_t values_count, const void* values_data, gsl::span indices);
+ ///
+ /// The method allocates a single contiguous buffer and creates instances of std::strings in it, with
+ /// copies of the supplied zero-terminated strings followed by COO indices.
+ /// All data is assumed to be on CPU and the allocator supplied must be
+ /// a CPU based allocator.
+ ///
+ /// use 0 for fully sparse tensors
+ /// array of char* pointers. use nullptr for fully sparse tensors
+ /// span of indices. Use empty span for fully sparse tensors.
+ /// Status
+ Status MakeCooStrings(size_t string_count, const char* const* strings, gsl::span indices);
+
///
/// Gives mutable access to Coo buffers so they can be populated
///
@@ -234,8 +299,8 @@ class SparseTensor final {
/// Allocates memory for values and index and returns a mutator so
/// data can be copied into the buffer.
///
- ///
- ///
+ /// use 0 for fully sparse tensors
+ /// use 0 for fully sparse tensors
///
CooMutator MakeCooData(size_t values_count, size_t index_count);
@@ -255,17 +320,17 @@ class SparseTensor final {
};
///
- /// Returns Csr indices readonly view
+ /// Returns Csr indices read only view
///
///
CsrView AsCsr() const;
///
/// This function will use Csr indices contained within the user allocated buffers.
- /// The lifespan of the buffers must exclipse the lifespan of sparse tensor instance.
+ /// The lifespan of the buffers must eclipse the lifespan of sparse tensor instance.
///
- ///
- ///
+ /// User allocated buffer span. use empty span for fully sparse tensors
+ /// User allocated buffer span. Use empty span for fully sparse tensors
///
Status UseCsrIndices(gsl::span inner_index, gsl::span outer_index);
@@ -275,10 +340,10 @@ class SparseTensor final {
///
///
///
- ///
- ///
- ///
- ///
+ /// use 0 for fully sparse tensors
+ /// pointer to data to be copied. Use nullptr for fully sparse tensors.
+ /// inner index to be copied. Use empty span for fully sparse tensors.
+ /// outer index to be copied. Use empty span for fully sparse tensors.
///
Status MakeCsrData(const IDataTransfer& data_transfer,
const OrtMemoryInfo& data_location,
@@ -286,6 +351,21 @@ class SparseTensor final {
gsl::span inner_index,
gsl::span outer_index);
+ ///
+ /// The method allocates a single contiguous buffer and creates instances of std::strings in it, with
+ /// copies of the supplied zero-terminated strings followed by COO indices.
+ /// All data is assumed to be on CPU and the allocator supplied must be
+ /// a CPU based allocator
+ ///
+ ///
+ /// array of char* pointers
+ /// inner index to be copied. Use empty span for fully sparse tensors.
+ /// outer index to be copied. Use empty span for fully sparse tensors.
+ ///
+ Status MakeCsrStrings(size_t string_count, const char* const* strings,
+ gsl::span inner_index,
+ gsl::span outer_index);
+
///
/// Give writable access to Csr values and indices
///
@@ -307,9 +387,9 @@ class SparseTensor final {
/// Allocates memory for values and index and returns mutator so
/// data can be populated.
///
- ///
- ///
- ///
+ /// Use 0 for fully sparse tensors.
+ /// Use 0 for fully sparse tensors.
+ /// Use 0 for fully sparse tensors.
///
CsrMutator MakeCsrData(size_t values_count, size_t inner_index_count, size_t outer_index_count);
@@ -338,8 +418,8 @@ class SparseTensor final {
/// were supplied to the constructor. The supplied buffer lifespan must eclipse the life
/// of sparse tensor instance.
///
- ///
- ///
+ /// Use {0} for fully sparse tensors.
+ /// Ptr to user allocated buffer. Use nullptr for fully spare tensors.
///
Status UseBlockSparseIndices(const TensorShape& indices_shape, int32_t* indices_data);
@@ -350,20 +430,35 @@ class SparseTensor final {
///
// The shape of the index is must be at least 2-D and must contain one tuple per each of
// the value blocks that were supplied to the constructor. Each index tuple is a
- // (row, col) coordindate of the values block in a dense matrix.
+ // (row, col) coordinates of the values block in a dense matrix.
///
///
///
- ///
- ///
- ///
- ///
+ /// The shape is expected to be at least 3-D. However, use {0} for fully sparse tensors.
+ /// Pointer to a data to be copied. Use nullptr for fully sparse tensors.
+ /// The shape is expected to be 2-D. However, you can use {0} for fully sparse tensors.
+ /// Pointer to index data to be copied. Use nullptr for fully sparse tensors.
///
Status MakeBlockSparseData(const IDataTransfer& data_transfer,
const OrtMemoryInfo& data_location,
const TensorShape& values_shape, const void* values_data,
const TensorShape& indices_shape, const int32_t* indices_data);
+
+ ///
+ /// The method allocates a single contiguous buffer and creates instances of std::strings in it, with
+ /// copies of the supplied zero-terminated strings followed by COO indices.
+ /// All data is assumed to be on CPU and the allocator supplied must be
+ /// a CPU based allocator.
+ ///
+ /// Use {0} shape for fully sparse tensors
+ /// array of char* ptrs, use nullptr for fully sparse tensor
+ /// Use {0} for fully sparse tensors
+ /// use nullptr for fully sparse tensors
+ ///
+ Status MakeBlockSparseStrings(const TensorShape& values_shape, const char* const* strings,
+ const TensorShape& indices_shape, const int32_t* indices_data);
+
///
/// Mutable data access
///
@@ -383,8 +478,8 @@ class SparseTensor final {
/// Allocates memory for values and index and returns mutator so
/// data can be populated
///
- ///
- ///
+ /// Shape is expected to be 3-D, use {0} for fully sparse tensors
+ /// Shape is expected to be 2-D, use {0} for fully sparse tensors
///
BlockSparseMutator MakeBlockSparseData(const TensorShape& values_shape, const TensorShape& indices_shape);
@@ -416,6 +511,7 @@ class SparseTensor final {
Status ValidateCsrIndices(size_t values_count, size_t inner_size, size_t outer_size) const;
void InitCsrIndices(size_t inner_size, const int64_t* inner, size_t outer_size, const int64_t* outer);
+ void InitBlockSparseIndices(const TensorShape& indices_shape, int32_t* indices_data);
SparseFormat format_; // sparse format enum value
TensorShape dense_shape_; // a shape of a corresponding dense tensor
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 868695c7f7..dda996234e 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -120,7 +120,6 @@ typedef enum ONNXTensorElementDataType {
ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 // Non-IEEE floating-point format based on IEEE754 single-precision
} ONNXTensorElementDataType;
-
// Synced with onnx TypeProto oneof
typedef enum ONNXType {
ONNX_TYPE_UNKNOWN,
@@ -132,8 +131,7 @@ typedef enum ONNXType {
} ONNXType;
// These types are synced with internal
-// SparseFormatFlags but are not exposed
-// as flags
+// SparseFormatFlags
typedef enum OrtSparseFormat {
ORT_SPARSE_UNDEFINED = 0,
ORT_SPARSE_COO = 0x1,
@@ -141,6 +139,13 @@ typedef enum OrtSparseFormat {
ORT_SPARSE_BLOCK_SPARSE = 0x4
} OrtSparseFormat;
+// Enum allows to query sparse tensor indices
+enum OrtSparseIndicesFormat {
+ ORT_SPARSE_COO_INDICES,
+ ORT_SPARSE_CSR_INNER_INDICES,
+ ORT_SPARSE_CSR_OUTER_INDICES,
+ ORT_SPARSE_BLOCK_SPARSE_INDICES
+};
typedef enum OrtLoggingLevel {
ORT_LOGGING_LEVEL_VERBOSE,
@@ -589,23 +594,36 @@ struct OrtApi {
ORT_API2_STATUS(FillStringTensor, _Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len);
/**
- * \param value A tensor created from OrtCreateTensor... function.
- * \param len total data length, not including the trailing '\0' chars.
+ * Obtain a total length of strings contained within a tensor.
+ * For sparse tensors it returns the total length of values (nnz) strings.
+ * \param[in] value A tensor created from OrtCreateTensor... function.
+ * \param[out] len total data length, not including the trailing '\0' chars.
*/
ORT_API2_STATUS(GetStringTensorDataLength, _In_ const OrtValue* value, _Out_ size_t* len);
/**
- * \param s string contents. Each string is NOT null-terminated.
- * \param value A tensor created from OrtCreateTensor... function.
- * \param s_len total data length, get it from OrtGetStringTensorDataLength
+ * This API returns all of of UTF-8 encoded strings that are contained within a tensor
+ * or in non-empty values of a sparse tensor in one single buffer. Use offsets to calculate
+ * the length of each string such as len[i] = offsets[i + 1] - offsets[i] except the last
+ * string for which the length is calculated as total_len - offset[i].
+ *
+ * \param[in] value A tensor created from OrtCreateTensor... API or a sparse tensor
+ * created with OrtCreateSparseTensor... API.
+ * \param[in,out] s string contents. Each string is NOT null-terminated.
+ * \param[in] s_len total data length, get it from OrtGetStringTensorDataLength
+ * \param[in,out] offsets pointer to a preallocated buffer where offsets for each of the string
+ * element are returned. The number of offsets must match the number of string elements.
+ * \param[in] offsets_len number of offsets expected in the buffer.
*/
ORT_API2_STATUS(GetStringTensorContent, _In_ const OrtValue* value, _Out_writes_bytes_all_(s_len) void* s,
size_t s_len, _Out_writes_all_(offsets_len) size_t* offsets, size_t offsets_len);
- /**
- * Don't free the 'out' value
- */
- ORT_API2_STATUS(CastTypeInfoToTensorInfo, _In_ const OrtTypeInfo*,
+ /** Retrieves OrtTensorTypeAndShapeInfo part of the OrtTypeInfo
+ *
+ * \param[in] type_info
+ * \param[out] out a returned ptr. Don't free the 'out' value, it is owned by type_info
+ */
+ ORT_API2_STATUS(CastTypeInfoToTensorInfo, _In_ const OrtTypeInfo* type_info,
_Outptr_result_maybenull_ const OrtTensorTypeAndShapeInfo** out);
/**
@@ -647,25 +665,39 @@ struct OrtApi {
ORT_API2_STATUS(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out);
/**
- * \param out Should be freed by ReleaseTensorTypeAndShapeInfo after use
- */
+ * Returns data type and shape iff OrtValue contains a Tensor or a SparseTensor.
+ * For sparse tensors it returns a dense shape of the tensor.
+ *
+ * \param[in] value OrtValue that contains tensor or a sparse tensor
+ * \param[out] out Should be freed by ReleaseTensorTypeAndShapeInfo after use
+ */
ORT_API2_STATUS(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out);
/**
- * Get the type information of an OrtValue
- * \param value
- * \param out The returned value should be freed by ReleaseTypeInfo after use
- */
+ * Get the type information of an OrtValue. API works for tensors and sparse tensors.
+ *
+ * \param[in] value
+ * \param[in,out] out The returned value should be freed by ReleaseTypeInfo after use
+ */
ORT_API2_STATUS(GetTypeInfo, _In_ const OrtValue* value, _Outptr_result_maybenull_ OrtTypeInfo** out);
ORT_API2_STATUS(GetValueType, _In_ const OrtValue* value, _Out_ enum ONNXType* out);
- ORT_API2_STATUS(CreateMemoryInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1,
- enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out);
+ /**
+ * Creates an instance of OrtMemoryInfo. It must be freed by ReleaseMemoryInfo after use.
+ * This may describe one of the existing ORT allocator types OR a custom allocator.
+ *
+ * \param[in] name such as "cpu", "gpu"
+ * \param[in] type one of the enum values
+ * \param[in] device ID. For GPU gpu id.
+ * \param[in] mem_type. Memory type enum value.
+ */
+ ORT_API2_STATUS(CreateMemoryInfo, _In_ const char* name, enum OrtAllocatorType type, int id,
+ enum OrtMemType mem_type, _Outptr_ OrtMemoryInfo** out);
/**
- * Convenience function for special case of CreateMemoryInfo, for the CPU allocator. Uses name = "Cpu" and id = 0.
- */
+ * Convenience function for special case of CreateMemoryInfo, for the CPU allocator. Uses name = "Cpu" and id = 0.
+ */
ORT_API2_STATUS(CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type1,
_Outptr_ OrtMemoryInfo** out);
@@ -990,13 +1022,21 @@ struct OrtApi {
_In_ int providers_length);
/**
- * \param value - A tensor created from OrtCreateTensor... function.
- * \param index - index of string tensor element, length of element at index will be returned.
- * \param out - number of UTF-8 bytes that the string contains
+ * This API returns a length of string element at [index]. For sparse tensors
+ * it will return a string element of sparse values. It is an error to request
+ * an out of bounds element.
+ *
+ * \param[in] value - A tensor created from OrtCreateTensor... function.
+ * \param[in] index - flat index of string tensor element, length of element at index will be returned.
+ * \param[out] out - number of UTF-8 bytes that the string contains
*/
ORT_API2_STATUS(GetStringTensorElementLength, _In_ const OrtValue* value, size_t index, _Out_ size_t* out);
/**
+ * This API will return a copy UTF-8 data contained with a string element at the specified index.
+ * For sparse tensors it would return a string element of sparse values. It is an error to request an out
+ * of bounds element.
+ *
* \param s string element contents in UTF-8 encoding. The string is NOT null-terminated.
* \param value A tensor created from OrtCreateTensor... function.
* \param s_len element length, get it from OrtGetStringTensorElementLength.
@@ -1472,12 +1512,15 @@ struct OrtApi {
* Registers a custom allocator instance with the env to enable
* sharing between multiple sessions that use the same env instance.
* Returns an error if an allocator with the same OrtMemoryInfo is already registered.
- * \param env OrtEnv instance (must be non-null).
- * \param allocator user provided allocator (must be non-null).
+ *
* The behavior of this API is exactly the same as CreateAndRegisterAllocator() except
* instead of ORT creating an allocator based on provided info, in this case
* ORT uses the user-provided custom allocator.
* See docs/C_API.md for details.
+ *
+ * \param[in,out] env OrtEnv instance (must be non-null).
+ * \param[in] allocator user provided allocator (must be non-null).
+ *
*/
ORT_API2_STATUS(RegisterAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocator* allocator);
@@ -1489,6 +1532,212 @@ struct OrtApi {
*/
ORT_API2_STATUS(UnregisterAllocator, _Inout_ OrtEnv* env,
_In_ const OrtMemoryInfo* mem_info);
+
+ /**
+ * Sets *out to 1 iff an OrtValue is a SparseTensor, and 0 otherwise
+ *
+ * \param[in] value existing OrtValue
+ * \param[out] out unless an error occurs, contains 1 iff the value contains an instance
+ * of sparse tensor or 0 otherwise.
+ */
+ ORT_API2_STATUS(IsSparseTensor, _In_ const OrtValue* value, _Out_ int* out);
+
+ /**
+ * Create an OrtValue with a sparse tensor that is empty.
+ * Use FillSparseTensor() functions to populate sparse tensor with non-zero values and
+ * format specific indices data.
+ * Use ReleaseValue to destroy the sparse tensor, this will also release the buffer inside the output value
+ * if any was allocated.
+ * \param[in,out] allocator allocator to use when performing an allocation. Allocation will be performed
+ * by FillSparseTensor() APIs. The lifespan of the allocator instance must eclipse the lifespan
+ * this sparse tensor instance as the same allocator will be used to free memory.
+ * \param[in] dense_shape shape of the original dense tensor
+ * \param[in] dense_shape_len number of shape dimensions being passed
+ * \param[in] type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx
+ * \param[out] out Should be freed by calling ReleaseValue
+ * \return OrtStatus*
+ */
+ ORT_API2_STATUS(CreateSparseTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* dense_shape,
+ size_t dense_shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out);
+
+ /**
+ * This API fills populates an empty tensor that was created using CreateSparseTensorAsOrtValue API.
+ * The API will allocate required memory and copy the supplied NNZ values and COO indices into that memory allocation.
+ * Memory allocation is performed using the allocator that was specified with CreateSparseTensorAsOrtValue.
+ *
+ * \param[in,out] ort_value OrtValue to populate with data
+ * \param[in] mem_info serves to identify the location of the data to be copied. If the allocator specified
+ * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed.
+ * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer.
+ * \param[in] values_shape pointer to values shape array
+ * \param[in] values_shape_len length of the values_shape
+ * \param[in] values pointer to an array of values. For strings, pass const char**.
+ * \param[in] indices_data pointer to a location of COO indices
+ * \param[in] indices_num number of COO indices
+ */
+ ORT_API2_STATUS(FillSparseTensorCoo, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info,
+ _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values,
+ _In_ const int64_t* indices_data, size_t indices_num);
+
+ /**
+ * This API fills populates an empty tensor that was created using CreateSparseTensorAsOrtValue API.
+ * The API will allocate required memory and copy the supplied NNZ values and CSR indices into that memory allocation.
+ * Memory allocation is performed using the allocator that was specified with CreateSparseTensorAsOrtValue.
+ *
+ * \param[in,out] ort_value OrtValue to populate with data
+ * \param[in] mem_info serves to identify the location of the data to be copied. If the allocator specified
+ * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed.
+ * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer.
+ * \param[in] values_shape pointer to values shape array
+ * \param[in] values_shape_len length of the values_shape
+ * \param[in] values - pointer to an array of values. For strings, pass const char**.
+ * \param[in] inner_indices_data pointer to a location of CSR inner indices
+ * \param[in] inner_indices_num number of CSR inner indices
+ * \param[in] outer_indices_data pointer to a location of CSR outer indices
+ * \param[in] outer_indices_num number of CSR outer indices
+ */
+ ORT_API2_STATUS(FillSparseTensorCsr, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info,
+ _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values,
+ _In_ const int64_t* inner_indices_data, size_t inner_indices_num,
+ _In_ const int64_t* outer_indices_data, size_t outer_indices_num);
+
+ /**
+ * This API fills populates an empty tensor that was created using CreateSparseTensorAsOrtValue API.
+ * The API will allocate required memory and copy the supplied NNZ values and BlockSparse indices into that memory allocation.
+ * Memory allocation is performed using the allocator that was specified with CreateSparseTensorAsOrtValue.
+ *
+ * \param[in,out] ort_value OrtValue to populate with data
+ * \param[in] mem_info serves to identify the location of the data to be copied. If the allocator specified
+ * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed.
+ * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer.
+ * \param[in] values structure with values information
+ * \param[in] indices_shape_data pointer to a location of indices shape
+ * \param[in] indices_shape_len length of the block sparse indices shape
+ * \param[in] indices_data pointer to a location of indices data. Shape will determine the length of the indices data.
+ */
+ ORT_API2_STATUS(FillSparseTensorBlockSparse, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info,
+ _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values,
+ _In_ const int64_t* indices_shape_data, size_t indices_shape_len,
+ _In_ const int32_t* indices_data);
+
+ /**
+ * Create an OrtValue with a sparse tensor. This is the first step.
+ * Next, use UseIndices() functions to supply sparse tensor with
+ * format specific indices data and set its sparse format to a specific enum value.
+ * This API will not perform memory allocations. It will
+ * use supplied user buffer which should outlive the created sparse tensor.
+ * Use ReleaseValue to destroy the sparse tensor. It would not release the supplied values buffer.
+ * This API can not be used to map strings from the user allocated memory. Strings must always be copied
+ * and have UTF-8 encoding. Therefore, use CreateSparseTensorAsOrtValue() API above and then fill it with data
+ * using appropriate Make*() function.
+ *
+ * \param[in] info memory info where sparse values reside.
+ * \param[in,out] p_data pointer to a user allocated buffer with values. To create a full sparse tensor with no non-zero
+ * values, pass nullptr
+ * \param[in] dense_shape shape of the original dense tensor
+ * \param[in] dense_shape_len number of shape dimensions being passed
+ * \param[in] values_shape shape of the values data. To create a fully sparse tensor with no non-zero values,
+ * pass {0} shape.
+ * \param[in] values_shape_len number of values shape dimensions
+ * \param[in] type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx
+ * \param[out] out Should be freed by calling ReleaseValue
+ * \return OrtStatus*
+ */
+ ORT_API2_STATUS(CreateSparseTensorWithValuesAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data,
+ _In_ const int64_t* dense_shape, size_t dense_shape_len,
+ _In_ const int64_t* values_shape, size_t values_shape_len,
+ ONNXTensorElementDataType type, _Outptr_ OrtValue** out);
+
+ /**
+ * The API assigns Coo format indices to the SparseTensor that was created by
+ * CreateSparseTensorWithValuesAsOrtValue API above. It also sets OrtSparseFormat to
+ * ORT_SPARSE_COO. The API will not allocate any additional memory for data. The life span of
+ * indices_data buffer should eclipse the life span of this OrtValue.
+ *
+ * \param[in,out] ort_value OrtValue instance constructed with CreateSparseTensorWithValuesAsOrtValue
+ * \param[in,out] indices_data pointer to a user pre-allocated buffer or nullptr for fully sparse tensors.
+ * \param[in] indices_num number of COO indices. Should either be 0 for fully sparse tensors, be equal
+ * to the number of nnz values specified to CreateSparseTensorWithValuesAsOrtValue for 1-D {nnz} indices or
+ * be twice as number of nnz values for a 2-D indices {nnz, 2}
+ */
+ ORT_API2_STATUS(UseCooIndices, _Inout_ OrtValue* ort_value, _Inout_ int64_t* indices_data, size_t indices_num);
+
+ /**
+ * The API assigns CSR format indices to the SparseTensor that was created by
+ * CreateSparseTensorWithValuesAsOrtValue API above. It also sets OrtSparseFormat to
+ * ORT_SPARSE_CSRC. The API will not allocate any additional memory for data. The life spans of
+ * indner_data and outer_data buffers should eclipse the life span of this OrtValue.
+ *
+ * \param[in,out] ort_value OrtValue instance constructed with CreateSparseTensorWithValuesAsOrtValue
+ * \param[in,out] inner_data pointer to a user pre-allocated buffer or nullptr for fully sparse tensors.
+ * \param[in] inner_num number of inner CSR indices. Should either be 0 for fully sparse tensors or be equal
+ * to the number of nnz values specified to CreateSparseTensorWithValuesAsOrtValue.
+ * \param[in,out] outer_data pointer to user pre-allocated buffer or nullptr for fully sparse tensors.
+ * \param[in] outer_num number of CSR outer indices. Should either be 0 for fully sparse tensors or
+ * equal to rows + 1 of the dense shape.
+ */
+ ORT_API2_STATUS(UseCsrIndices, _Inout_ OrtValue* ort_value, _Inout_ int64_t* inner_data, size_t inner_num,
+ _Inout_ int64_t* outer_data, size_t outer_num);
+
+ /**
+ * The API assigns BlockSparse format indices to the SparseTensor that was created by
+ * CreateSparseTensorWithValuesAsOrtValue API above. It also sets OrtSparseFormat to
+ * ORT_SPARSE_BLOCK_SPARSE. The API will not allocate any additional memory for data. The life span of
+ * indices_data buffer must eclipse the lifespan of this OrtValue.
+ *
+ * \param[in,out] ort_value OrtValue instance constructed with CreateSparseTensorWithValuesAsOrtValue
+ * \param[in] indices_shape pointer to indices shape. Use {0} for fully sparse tensors
+ * \param[in] indices_shape_len length of the indices shape
+ * \param[in,out] indices_data pointer to user pre-allocated buffer or nullptr for fully sparse tensors.
+ */
+ ORT_API2_STATUS(UseBlockSparseIndices, _Inout_ OrtValue* ort_value, const int64_t* indices_shape, size_t indices_shape_len, _Inout_ int32_t* indices_data);
+
+ /**
+ * The API returns sparse tensor format enum iff a given ort value contains an instance of sparse tensor.
+ *
+ * \param[in] ort_value OrtValue that contains an instance of sparse tensor
+ * \param[out] out pointer to out parameter
+ */
+ ORT_API2_STATUS(GetSparseTensorFormat, _In_ const OrtValue* ort_value, _Out_ enum OrtSparseFormat* out);
+
+ /**
+ * The API Returns data type and shape of sparse tensor values (nnz) iff OrtValue contains a SparseTensor.
+ *
+ * \param[in] ort_value an OrtValue that contains a fully constructed sparse tensor
+ * \param[out] out Should be freed by ReleaseTensorTypeAndShapeInfo after use
+ */
+ ORT_API2_STATUS(GetSparseTensorValuesTypeAndShape, _In_ const OrtValue* ort_value, _Outptr_ OrtTensorTypeAndShapeInfo** out);
+
+ /**
+ * The API returns numeric data for sparse tensor values (nnz). For string values use GetStringTensor*() API.
+ *
+ * \param[in] ort_value an instance of OrtValue containing sparse tensor
+ * \param[out] out returns a pointer to values data. Do not attempt to free this ptr.
+ */
+ ORT_API2_STATUS(GetSparseTensorValues, _In_ const OrtValue* ort_value, _Outptr_ const void** out);
+
+ /**
+ * The API returns data type, shape for the type of indices specified by
+ * indices_format.
+ *
+ * \param[in] ort_value OrtValue containing sparse tensor.
+ * \param[in] indices_format - one of the indices formats. It is an error to request a format that the sparse
+ * tensor does not contain.
+ * \param[out] an instance of OrtTensorTypeAndShapeInfo. Must be freed by the ReleaseTensorTypeAndShapeInfo.
+ */
+ ORT_API2_STATUS(GetSparseTensorIndicesTypeShape, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Outptr_ OrtTensorTypeAndShapeInfo** out);
+
+ /**
+ * The API returns indices data for the type of the indices specified by indices_format.
+ * Do not free the returned ptr as it points directly to the internal sparse tensor buffer.
+ *
+ * \param[in] ort_value OrtValue containing sparse tensor.
+ * \param[in] indices_format - one of the indices formats. It is an error to request a format that the sparse
+ * tensor does not contain.
+ * \param[out] num_indices ptr where the number of indices entries is returned
+ * \param[out] indices out param where the pointer to the internal buffer is returned. Do not free this buffer.
+ */
+ ORT_API2_STATUS(GetSparseTensorIndices, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Out_ size_t* num_indices, _Outptr_ const void** indices);
};
/*
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index 63f2202357..0ae27590c6 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -420,14 +420,208 @@ struct TypeInfo : Base {
};
struct Value : Base {
+ // This structure is used to feed sparse tensor values
+ // information for use with FillSparseTensor() API
+ // if the data type for the sparse tensor values is numeric
+ // use data.p_data, otherwise, use data.str pointer to feed
+ // values. data.str is an array of const char* that are zero terminated.
+ // number of strings in the array must match shape size.
+ // For fully sparse tensors use shape {0} and set p_data/str
+ // to nullptr.
+ struct OrtSparseValuesParam {
+ const int64_t* values_shape;
+ size_t values_shape_len;
+ union {
+ const void* p_data;
+ const char** str;
+ } data;
+ };
+
+ // Provides a way to pass shape in a single
+ // argument
+ struct Shape {
+ const int64_t* shape;
+ size_t shape_len;
+ };
+
template
static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
ONNXTensorElementDataType type);
+
+
+ ///
+ /// This is a simple forwarding method to the other overload that helps deducing
+ /// data type enum value from the type of the buffer.
+ ///
+ /// numeric datatype. This API is not suitable for strings.
+ /// Memory description where the user buffers reside (CPU vs GPU etc)
+ /// pointer to the user supplied buffer, use nullptr for fully sparse tensors
+ /// a would be dense shape of the tensor
+ /// non zero values shape. Use a single 0 shape for fully sparse tensors.
+ ///
+ template
+ static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
+ const Shape& values_shape);
+
+ ///
+ /// Creates an OrtValue instance containing SparseTensor. This constructs
+ /// a sparse tensor that makes use of user allocated buffers. It does not make copies
+ /// of the user provided data and does not modify it. The lifespan of user provided buffers should
+ /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain
+ /// a pointer to non-zero values. To fully populate the sparse tensor call UseIndices() API below
+ /// to supply a sparse format specific indices.
+ /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings
+ /// can be properly copied into the allocated buffer.
+ ///
+ /// Memory description where the user buffers reside (CPU vs GPU etc)
+ /// pointer to the user supplied buffer, use nullptr for fully sparse tensors
+ /// a would be dense shape of the tensor
+ /// non zero values shape. Use a single 0 shape for fully sparse tensors.
+ /// data type
+ /// Ort::Value instance containing SparseTensor
+ static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
+ const Shape& values_shape, ONNXTensorElementDataType type);
+
+ ///
+ /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor.
+ /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
+ /// allocated buffers lifespan must eclipse that of the OrtValue.
+ /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
+ ///
+ /// pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors.
+ /// number of indices entries. Use 0 for fully sparse tensors
+ void UseCooIndices(int64_t* indices_data, size_t indices_num);
+
+ ///
+ /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor.
+ /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
+ /// allocated buffers lifespan must eclipse that of the OrtValue.
+ /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
+ ///
+ /// pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors
+ /// number of csr inner indices or 0 for fully sparse tensors
+ /// pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors
+ /// number of csr outer indices or 0 for fully sparse tensors
+ void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
+
+ ///
+ /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor.
+ /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
+ /// allocated buffers lifespan must eclipse that of the OrtValue.
+ /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
+ ///
+ /// indices shape or a {0} for fully sparse
+ /// user allocated buffer with indices or nullptr for fully spare tensors
+ void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
+
template
static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
+ ///
+ /// This is a simple forwarding method the below CreateSparseTensor.
+ /// This helps to specify data type enum in terms of C++ data type.
+ /// Use CreateSparseTensor
+ ///
+ /// numeric data type only. String data enum must be specified explicitly.
+ /// allocator to use
+ /// a would be dense shape of the tensor
+ /// Ort::Value
+ template
+ static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
+
+ ///
+ /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data.
+ /// The data must be supplied by on of the FillSparseTensor() methods that take both non-zero values
+ /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator.
+ /// Use this API to create OrtValues that contain sparse tensors with all supported data types including
+ /// strings.
+ ///
+ /// allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue
+ /// a would be dense shape of the tensor
+ /// data type
+ /// an instance of Ort::Value
+ static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
+
+ ///
+ /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
+ /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located
+ /// at difference device than the allocator, a X-device copy will be performed if possible.
+ ///
+ /// specified buffer memory description
+ /// values buffer information.
+ /// coo indices buffer or nullptr for fully sparse data
+ /// number of COO indices or 0 for fully sparse data
+ void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
+ const int64_t* indices_data, size_t indices_num);
+
+ ///
+ /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
+ /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located
+ /// at difference device than the allocator, a X-device copy will be performed if possible.
+ ///
+ /// specified buffer memory description
+ /// values buffer information
+ /// csr inner indices pointer or nullptr for fully sparse tensors
+ /// number of csr inner indices or 0 for fully sparse tensors
+ /// pointer to csr indices data or nullptr for fully sparse tensors
+ /// number of csr outer indices or 0
+ void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
+ const OrtSparseValuesParam& values,
+ const int64_t* inner_indices_data, size_t inner_indices_num,
+ const int64_t* outer_indices_data, size_t outer_indices_num);
+
+ ///
+ /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
+ /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located
+ /// at difference device than the allocator, a X-device copy will be performed if possible.
+ ///
+ /// specified buffer memory description
+ /// values buffer information
+ /// indices shape. use {0} for fully sparse tensors
+ /// pointer to indices data or nullptr for fully sparse tensors
+ void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
+ const OrtSparseValuesParam& values,
+ const Shape& indices_shape,
+ const int32_t* indices_data);
+
+ ///
+ /// The API returns the sparse data format this OrtValue holds in a sparse tensor.
+ /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used
+ /// the value returned is ORT_SPARSE_UNDEFINED.
+ ///
+ /// Format enum
+ OrtSparseFormat GetSparseFormat() const;
+
+ ///
+ /// The API returns type and shape information for stored non-zero values of the
+ /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer.
+ ///
+ /// TensorTypeAndShapeInfo values information
+ TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const;
+
+ ///
+ /// The API returns type and shape information for the specified indices. Each supported
+ /// indices have their own enum values even if a give format has more than one kind of indices.
+ /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer.
+ ///
+ /// enum requested
+ /// type and shape information
+ TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat) const;
+
+ ///
+ /// The API retrieves a pointer to the internal indices buffer. The API merely performs
+ /// a convenience data type casting on the return type pointer. Make sure you are requesting
+ /// the right type, use GetSparseTensorIndicesTypeShapeInfo();
+ ///
+ /// type to cast to
+ /// requested indices kind
+ /// number of indices entries
+ /// Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer.
+ template
+ const T* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
+
static Value CreateMap(Value& keys, Value& values);
static Value CreateSequence(std::vector& values);
@@ -443,10 +637,38 @@ struct Value : Base {
Value& operator=(Value&&) = default;
bool IsTensor() const;
+
+ ///
+ /// Returns true if the OrtValue contains a sparse tensor
+ ///
+ ///
+ bool IsSparseTensor() const;
+
size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
Value GetValue(int index, OrtAllocator* allocator) const;
+ ///
+ /// This API returns a full length of string data contained within either a tensor or a sparse Tensor.
+ /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful
+ /// for allocating necessary memory and calling GetStringTensorContent().
+ ///
+ /// total length of UTF-8 encoded bytes contained. No zero terminators counted.
size_t GetStringTensorDataLength() const;
+
+ ///
+ /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor
+ /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate.
+ /// The user must also allocate offsets buffer with the number of entries equal to that of the contained
+ /// strings.
+ ///
+ /// Strings are always assumed to be on CPU, no X-device copy.
+ ///
+ /// user allocated buffer
+ /// length in bytes of the allocated buffer
+ /// a pointer to the offsets user allocated buffer
+ /// count of offsets, must be equal to the number of strings contained.
+ /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo()
+ /// for sparse tensors
void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
template
@@ -455,13 +677,52 @@ struct Value : Base {
template
const T* GetTensorData() const;
+ ///
+ /// The API returns a pointer to an internal buffer of the sparse tensor
+ /// containing non-zero values. The API merely does casting. Make sure you
+ /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo()
+ /// first.
+ ///
+ /// numeric data types only. Use GetStringTensor*() to retrieve strings.
+ /// a pointer to the internal values buffer. Do not free this pointer.
+ template
+ const T* GetSparseTensorValues() const;
+
template
T& At(const std::vector& location);
+ ///
+ /// The API returns type information for data contained in a tensor. For sparse
+ /// tensors it returns type information for contained non-zero values.
+ /// It returns dense shape for sparse tensors.
+ ///
+ /// TypeInfo
TypeInfo GetTypeInfo() const;
+
+ ///
+ /// The API returns type information for data contained in a tensor. For sparse
+ /// tensors it returns type information for contained non-zero values.
+ /// It returns dense shape for sparse tensors.
+ ///
+ /// TensorTypeAndShapeInfo
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
+ ///
+ /// The API returns a byte length of UTF-8 encoded string element
+ /// contained in either a tensor or a spare tensor values.
+ ///
+ ///
+ /// byte length for the specified string element
size_t GetStringTensorElementLength(size_t element_index) const;
+
+ ///
+ /// The API copies UTF-8 encoded bytes for the requested string element
+ /// contained within a tensor or a sparse tensor into a provided buffer.
+ /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate.
+ ///
+ ///
+ ///
+ ///
void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
void FillStringTensor(const char* const* s, size_t s_len);
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index a4596fc205..684c8fbaa4 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -755,6 +755,82 @@ inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t
return Value{out};
}
+template
+inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
+ const Shape& values_shape) {
+ return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType::type);
+}
+
+inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
+ const Shape& values_shape, ONNXTensorElementDataType type) {
+ OrtValue* out;
+ ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
+ values_shape.shape, values_shape.shape_len, type, &out));
+ return Value{out};
+}
+
+inline void Value::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
+ const int64_t* indices_data, size_t indices_num) {
+ ThrowOnError(GetApi().FillSparseTensorCoo(p_, mem_info, values_param.values_shape,
+ values_param.values_shape_len, values_param.data.p_data,
+ indices_data, indices_num));
+}
+
+inline void Value::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
+ const OrtSparseValuesParam& values,
+ const int64_t* inner_indices_data, size_t inner_indices_num,
+ const int64_t* outer_indices_data, size_t outer_indices_num) {
+ ThrowOnError(GetApi().FillSparseTensorCsr(p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
+ inner_indices_data, inner_indices_num,
+ outer_indices_data, outer_indices_num));
+}
+
+inline void Value::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
+ const OrtSparseValuesParam& values,
+ const Shape& indices_shape,
+ const int32_t* indices_data) {
+ ThrowOnError(GetApi().FillSparseTensorBlockSparse(p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
+ indices_shape.shape, indices_shape.shape_len,
+ indices_data));
+}
+
+inline void Value::UseCooIndices(int64_t* indices_data, size_t indices_num) {
+ ThrowOnError(GetApi().UseCooIndices(p_, indices_data, indices_num));
+}
+
+inline void Value::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
+ ThrowOnError(GetApi().UseCsrIndices(p_, inner_data, inner_num, outer_data, outer_num));
+}
+
+inline void Value::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
+ ThrowOnError(GetApi().UseBlockSparseIndices(p_, indices_shape.shape, indices_shape.shape_len, indices_data));
+}
+
+inline OrtSparseFormat Value::GetSparseFormat() const {
+ OrtSparseFormat format;
+ ThrowOnError(GetApi().GetSparseTensorFormat(p_, &format));
+ return format;
+}
+
+inline TensorTypeAndShapeInfo Value::GetSparseTensorValuesTypeAndShapeInfo() const {
+ OrtTensorTypeAndShapeInfo* output;
+ ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(p_, &output));
+ return TensorTypeAndShapeInfo{output};
+}
+
+inline TensorTypeAndShapeInfo Value::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
+ OrtTensorTypeAndShapeInfo* output;
+ ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(p_, indices_format, &output));
+ return TensorTypeAndShapeInfo{output};
+}
+
+template
+inline const T* Value::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
+ const void* out;
+ ThrowOnError(GetApi().GetSparseTensorIndices(p_, indices_format, &num_indices, &out));
+ return reinterpret_cast(out);
+}
+
template
inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
return CreateTensor(allocator, shape, shape_len, TypeToTensorType::type);
@@ -766,6 +842,18 @@ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape,
return Value{out};
}
+template
+inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
+ return CreateSparseTensor(allocator, dense_shape, TypeToTensorType::type);
+}
+
+inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
+ ONNXTensorElementDataType type) {
+ OrtValue* out;
+ ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
+ return Value{out};
+}
+
inline Value Value::CreateMap(Value& keys, Value& values) {
OrtValue* out;
OrtValue* inputs[2] = {keys, values};
@@ -798,6 +886,12 @@ inline bool Value::IsTensor() const {
return out != 0;
}
+inline bool Value::IsSparseTensor() const {
+ int out;
+ ThrowOnError(GetApi().IsSparseTensor(p_, &out));
+ return out != 0;
+}
+
inline size_t Value::GetCount() const {
size_t out;
ThrowOnError(GetApi().GetValueCount(p_, &out));
@@ -852,6 +946,13 @@ const T* Value::GetTensorData() const {
return out;
}
+template
+inline const T* Value::GetSparseTensorValues() const {
+ const void* out;
+ ThrowOnError(GetApi().GetSparseTensorValues(p_, &out));
+ return reinterpret_cast(out);
+}
+
template
inline T& Value::At(const std::vector& location) {
static_assert(!std::is_same::value, "this api does not support std::string");
diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc
index caaa78f63f..7a0798782a 100644
--- a/onnxruntime/core/framework/execution_frame.cc
+++ b/onnxruntime/core/framework/execution_frame.cc
@@ -618,9 +618,7 @@ static Status AllocateSparseTensor(OrtValue& mlvalue, const DataTypeImpl& ml_typ
const TensorShape& shape, bool create_fence,
const SessionState& session_state) {
auto element_type = ml_type.AsSparseTensorType()->GetElementType();
- auto sparse = std::make_unique(element_type, shape, allocator);
- auto deleter = DataTypeImpl::GetType()->GetDeleteFunc();
- mlvalue.Init(sparse.release(), DataTypeImpl::GetType(), deleter);
+ SparseTensor::InitOrtValue(element_type, shape, std::move(allocator), mlvalue);
// create fence if needed
if (create_fence) {
diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc
index 9b2e5db7e8..bcfc36d131 100644
--- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc
+++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc
@@ -56,7 +56,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetOnnxTypeFromTypeInfo, _In_ const struct OrtTypeI
ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtTypeInfo* input,
_Outptr_result_maybenull_ const struct OrtTensorTypeAndShapeInfo** out) {
- *out = input->type == ONNX_TYPE_TENSOR ? input->data : nullptr;
+ *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) ? input->data : nullptr;
return nullptr;
}
diff --git a/onnxruntime/core/framework/sparse_tensor.cc b/onnxruntime/core/framework/sparse_tensor.cc
index 640b62b168..817ee8b7af 100644
--- a/onnxruntime/core/framework/sparse_tensor.cc
+++ b/onnxruntime/core/framework/sparse_tensor.cc
@@ -4,6 +4,7 @@
#include "core/framework/data_types.h"
#include "core/framework/sparse_tensor.h"
#include "core/framework/data_transfer_manager.h"
+#include "core/framework/ort_value.h"
#include "core/framework/utils.h"
#include
@@ -42,13 +43,13 @@ inline std::vector> MakeListConst(const T&.
return std::vector{std::cref(t)...};
}
-void CopyStrings(const Tensor& src, Tensor& dst) {
- auto src_span = src.DataAsSpan();
- auto* dst_iter = dst.MutableData();
- std::copy(src_span.cbegin(), src_span.cend(), dst_iter);
+void CopyStrings(const Tensor& src_t, Tensor& dst_t) {
+ auto src_span = src_t.DataAsSpan();
+ std::string* dst = dst_t.MutableData();
+ std::copy(src_span.cbegin(), src_span.cend(), dst);
}
-Status CopyData(const IDataTransfer& data_transfer,
+Status CopyData(const IDataTransfer* data_transfer,
const std::vector>& src,
const std::vector>& dst) {
ORT_RETURN_IF_NOT(src.size() == dst.size(), "Must have the same size. Got src_size: ",
@@ -59,12 +60,26 @@ Status CopyData(const IDataTransfer& data_transfer,
if (src_t.IsDataTypeString()) {
CopyStrings(src_t, dst_t);
} else {
- ORT_RETURN_IF_ERROR(data_transfer.CopyTensor(src_t, dst_t));
+ if (data_transfer != nullptr) {
+ ORT_RETURN_IF_ERROR(data_transfer->CopyTensor(src_t, dst_t));
+ } else {
+ memcpy(dst_t.MutableDataRaw(), src_t.DataRaw(), src_t.SizeInBytes());
+ }
}
}
return Status::OK();
}
+Status CopyStringsAndIndices(size_t string_count, const char* const strings[], Tensor& values,
+ const std::vector>& src_ind,
+ const std::vector>& dst_ind) {
+ auto* str_dest = values.MutableData();
+ for (size_t i = 0; i < string_count; ++i) {
+ str_dest[i] = strings[i];
+ }
+
+ return CopyData(nullptr, src_ind, dst_ind);
+}
} // namespace
const void* SparseTensor::IndicesStart(int64_t values_bytes) const {
@@ -149,12 +164,58 @@ SparseTensor::~SparseTensor() {
ReleaseBuffer();
}
+void SparseTensor::InitOrtValue(MLDataType elt_type,
+ const TensorShape& dense_shape,
+ const TensorShape& values_shape,
+ void* values_data,
+ const OrtMemoryInfo& location,
+ OrtValue& ort_value) {
+ auto sparse_tensor = std::make_unique(elt_type, dense_shape, values_shape, values_data, location);
+ auto ml_tensor = DataTypeImpl::GetType();
+ ort_value.Init(sparse_tensor.release(),
+ ml_tensor,
+ ml_tensor->GetDeleteFunc());
+}
+
+void SparseTensor::InitOrtValue(MLDataType elt_type,
+ const TensorShape& dense_shape,
+ std::shared_ptr allocator,
+ OrtValue& ort_value) {
+ auto sparse_tensor = std::make_unique(elt_type, dense_shape, std::move(allocator));
+ auto ml_tensor = DataTypeImpl::GetType();
+ ort_value.Init(sparse_tensor.release(),
+ ml_tensor,
+ ml_tensor->GetDeleteFunc());
+}
+
+const SparseTensor& SparseTensor::GetSparseTensorFromOrtValue(const OrtValue& v) {
+ if (!v.IsAllocated()) {
+ ORT_THROW("the ort_value must contain a constructed sparse tensor");
+ }
+ const auto& sparse_tensor = v.Get();
+ if (sparse_tensor.Format() == onnxruntime::SparseFormat::kUndefined) {
+ ORT_THROW("Sparse Tensor does not contain sparse data");
+ }
+ return sparse_tensor;
+}
+
+SparseTensor& SparseTensor::GetSparseTensorFromOrtValue(OrtValue& v) {
+ if (!v.IsAllocated()) {
+ ORT_THROW("the ort_value must contain a constructed sparse tensor");
+ }
+ auto& sparse_tensor = *v.GetMutable();
+ if (sparse_tensor.Format() != SparseFormat::kUndefined) {
+ ORT_THROW("this tensor already has populated sparse_indices");
+ }
+ return sparse_tensor;
+}
+
Status SparseTensor::AllocateBuffer(int64_t buffer_size, size_t num_values) {
if (buffer_size > 0) {
SafeInt buffer_size_t(buffer_size);
const auto values_bytes = SafeInt(num_values) * ml_data_type_->Size();
ORT_RETURN_IF_NOT(buffer_size_t > values_bytes,
- "Values size ", static_cast(values_bytes), " must be less than total buffer size: ", buffer_size);
+ "Values size ", static_cast(values_bytes), " must be less than total buffer size: ", buffer_size);
auto data_ptr = IAllocator::MakeUniquePtr(allocator_, buffer_size_t);
ORT_RETURN_IF(data_ptr == nullptr, "SparseTensor Allocation failed for size: ", buffer_size);
if (IsDataTypeString()) {
@@ -206,6 +267,7 @@ void SparseTensor::InitCooIndex(const TensorShape& index_shape, int64_t* index_d
}
Status SparseTensor::UseCooIndices(gsl::span indices) {
+ ORT_RETURN_IF_NOT(Format() == SparseFormat::kUndefined, "Sparse format must not be set. Already contains format: ", Format());
ORT_RETURN_IF_NOT(allocator_ == nullptr, "Not expecting an allocator set");
TensorShape index_shape(GetCooIndexDims(NumValues(), indices.size()));
InitCooIndex(index_shape, indices.data());
@@ -216,6 +278,7 @@ Status SparseTensor::MakeCooData(const IDataTransfer& data_transfer,
const OrtMemoryInfo& data_location,
size_t values_count, const void* values_data,
gsl::span indices) {
+ ORT_RETURN_IF(IsDataTypeString(), "Use MakeCooStrings");
auto mutator = MakeCooData(values_count, indices.size());
if (values_count > 0) {
auto& dst_values = mutator.Values();
@@ -223,12 +286,26 @@ Status SparseTensor::MakeCooData(const IDataTransfer& data_transfer,
Tensor src_values(dst_values.DataType(), dst_values.Shape(), const_cast(values_data), data_location);
Tensor src_index(dst_index.DataType(), dst_index.Shape(), const_cast(indices.data()), data_location);
- ORT_RETURN_IF_ERROR(CopyData(data_transfer, MakeListConst(src_values, src_index), MakeListNonConst(dst_values, dst_index)));
+ ORT_RETURN_IF_ERROR(CopyData(&data_transfer, MakeListConst(src_values, src_index), MakeListNonConst(dst_values, dst_index)));
+ }
+ return Status::OK();
+}
+
+Status SparseTensor::MakeCooStrings(size_t string_count, const char* const* strings,
+ gsl::span indices) {
+ ORT_RETURN_IF_NOT(IsDataTypeString(), "Expecting data type to be set as string");
+ auto mutator = MakeCooData(string_count, indices.size());
+ if (string_count > 0) {
+ auto& dst_values = mutator.Values();
+ auto& dst_indices = mutator.Indices();
+ Tensor src_indices(dst_indices.DataType(), dst_indices.Shape(), const_cast(indices.data()), Location());
+ ORT_RETURN_IF_ERROR(CopyStringsAndIndices(string_count, strings, dst_values, {std::cref(src_indices)}, {std::ref(dst_indices)}));
}
return Status::OK();
}
SparseTensor::CooMutator SparseTensor::MakeCooData(size_t values_count, size_t index_count) {
+ ORT_ENFORCE(Format() == SparseFormat::kUndefined, "Sparse format must not be set. Already contains format: ", Format());
ORT_ENFORCE(allocator_ != nullptr, "This method should follow a call to constructor that supplies the allocator");
const auto num_values = gsl::narrow(values_count);
TensorShape values_shape{num_values};
@@ -253,11 +330,13 @@ SparseTensor::CsrView SparseTensor::AsCsr() const {
Status SparseTensor::ValidateCsrIndices(size_t values_count, size_t inner_size, size_t outer_size) const {
ORT_RETURN_IF_NOT(dense_shape_.NumDimensions() == 2U, "dense shape must 2-D. Got: ", dense_shape_.NumDimensions());
+ ORT_RETURN_IF_NOT((inner_size == 0 && outer_size == 0) || (inner_size > 0 && outer_size > 0),
+ "Inner and Outer indices must either be both zero or non-zero");
ORT_RETURN_IF_NOT(inner_size == values_count,
- "Expecting inner index size: ", inner_size, " the same as values size: ", values_count);
+ "Expecting inner index size: ", inner_size, " the same as values size: ", values_count);
const auto rows = dense_shape_.GetDims()[0];
ORT_RETURN_IF_NOT(outer_size == 0 || outer_size == static_cast(rows + 1),
- "Outer index count must be rows + 1 or zero. Got: ", outer_size, " rows: ", rows);
+ "Outer index count must be rows + 1 or zero. Got: ", outer_size, " rows: ", rows);
return Status::OK();
}
@@ -274,6 +353,7 @@ void SparseTensor::InitCsrIndices(size_t inner_size, const int64_t* inner, size_
Status SparseTensor::UseCsrIndices(gsl::span inner_index, gsl::span outer_index) {
ORT_RETURN_IF_NOT(allocator_ == nullptr, "This method does not expect allocator to be set");
+ ORT_RETURN_IF_NOT(Format() == SparseFormat::kUndefined, "Sparse format must not be set. Already contains format: ", Format());
ORT_RETURN_IF_ERROR(ValidateCsrIndices(NumValues(), inner_index.size(), outer_index.size()));
InitCsrIndices(inner_index.size(), inner_index.data(), outer_index.size(), outer_index.data());
return Status::OK();
@@ -282,6 +362,7 @@ Status SparseTensor::UseCsrIndices(gsl::span inner_index, gsl::span inner_index, gsl::span outer_index) {
+ ORT_RETURN_IF(IsDataTypeString(), "Use MakeCsrStrings");
auto mutator = MakeCsrData(values_count, inner_index.size(), outer_index.size());
if (values_count > 0) {
auto& dst_values = mutator.Values();
@@ -291,16 +372,34 @@ Status SparseTensor::MakeCsrData(const IDataTransfer& data_transfer, const OrtMe
Tensor src_values(dst_values.DataType(), dst_values.Shape(), const_cast(values_data), data_location);
Tensor src_inner(dst_inner.DataType(), dst_inner.Shape(), const_cast(inner_index.data()), data_location);
Tensor src_outer(dst_outer.DataType(), dst_outer.Shape(), const_cast(outer_index.data()), data_location);
- ORT_RETURN_IF_ERROR(CopyData(data_transfer, MakeListConst(src_values, src_inner, src_outer),
+ ORT_RETURN_IF_ERROR(CopyData(&data_transfer, MakeListConst(src_values, src_inner, src_outer),
MakeListNonConst(dst_values, dst_inner, dst_outer)));
}
return Status::OK();
}
+Status SparseTensor::MakeCsrStrings(size_t string_count, const char* const* strings,
+ gsl::span inner_index, gsl::span outer_index) {
+ ORT_RETURN_IF_NOT(IsDataTypeString(), "Expecting data type to be set as string");
+ auto mutator = MakeCsrData(string_count, inner_index.size(), outer_index.size());
+ if (string_count > 0) {
+ auto& dst_values = mutator.Values();
+ auto& dst_inner = mutator.Inner();
+ auto& dst_outer = mutator.Outer();
+ Tensor src_inner(dst_inner.DataType(), dst_inner.Shape(), const_cast(inner_index.data()), Location());
+ Tensor src_outer(dst_outer.DataType(), dst_outer.Shape(), const_cast(outer_index.data()), Location());
+ ORT_RETURN_IF_ERROR(CopyStringsAndIndices(string_count, strings, dst_values,
+ MakeListConst(src_inner, src_outer),
+ MakeListNonConst(dst_inner, dst_outer)));
+ }
+ return Status::OK();
+}
+
SparseTensor::CsrMutator SparseTensor::MakeCsrData(size_t values_count,
size_t inner_index_count,
size_t outer_index_count) {
ORT_ENFORCE(allocator_ != nullptr, "This method should follow a call to constructor that supplies the allocator");
+ ORT_ENFORCE(Format() == SparseFormat::kUndefined, "Sparse format must not be set. Already contains format: ", Format());
ORT_THROW_IF_ERROR(ValidateCsrIndices(values_count, inner_index_count, outer_index_count));
if (values_count > 0) {
@@ -326,44 +425,70 @@ SparseTensor::BlockSparseView SparseTensor::AsBlockSparse() const {
}
Status SparseTensor::ValidateBlockSparseShapes(const TensorShape& values_shape, const TensorShape& indices_shape) const {
- ORT_RETURN_IF_NOT(values_shape.NumDimensions() >= 3,
- "Expecting values dimensions to be at least 3. Got:", values_shape.NumDimensions());
- ORT_RETURN_IF_NOT(indices_shape.NumDimensions() == 2,
- "Expecting index dimensions to be 2. Got: ", indices_shape.NumDimensions());
- const auto values_blocks = values_shape.SizeFromDimension(2);
- const auto index_blocks = indices_shape.Size() / 2; // Two integers per block
- ORT_RETURN_IF_NOT(values_blocks == index_blocks,
- "Expecting index blocks: ", index_blocks, " to be equal to values blocks: ", values_blocks);
+ if (values_shape.Size() > 0) {
+ ORT_RETURN_IF_NOT(values_shape.NumDimensions() >= 3,
+ "Expecting to have at lest 3-D shape. Got:", values_shape.NumDimensions());
+ ORT_RETURN_IF_NOT(indices_shape.NumDimensions() == 2,
+ "Expecting indices to have 2-D shape . Got: ", indices_shape.NumDimensions());
+ ORT_RETURN_IF_NOT(indices_shape.GetDims()[0] == 2, "Indices shape must have dim[0] == 2");
+ const auto values_blocks = values_shape.SizeFromDimension(2);
+ const auto index_blocks = indices_shape.Size() / 2; // Two integers per block
+ ORT_RETURN_IF_NOT(values_blocks == index_blocks,
+ "Expecting index blocks: ", index_blocks, " to be equal to values blocks: ", values_blocks);
+ } else {
+ ORT_RETURN_IF_NOT(values_shape.GetDims().size() == 1, "Expecting fully sparse tensors to have value shape {0}");
+ ORT_RETURN_IF_NOT(indices_shape.GetDims().size() == 1, "Expecting fully sparse tensors to have indices shape {0}");
+ }
return Status::OK();
}
-Status SparseTensor::UseBlockSparseIndices(const TensorShape& index_shape, int32_t* indices_data) {
- ORT_RETURN_IF_NOT(allocator_ == nullptr, "Not expecting an allocator set");
- ORT_RETURN_IF_ERROR(ValidateBlockSparseShapes(Values().Shape(), index_shape));
-
+void SparseTensor::InitBlockSparseIndices(const TensorShape& indices_shape, int32_t* indices_data) {
format_data_.resize(1);
- format_data_[0] = Tensor(DataTypeImpl::GetType(), index_shape,
+ format_data_[0] = Tensor(DataTypeImpl::GetType(), indices_shape,
indices_data, Location());
format_ = SparseFormat::kBlockSparse;
+}
+
+Status SparseTensor::UseBlockSparseIndices(const TensorShape& indices_shape, int32_t* indices_data) {
+ ORT_RETURN_IF_NOT(allocator_ == nullptr, "Not expecting an allocator set");
+ ORT_RETURN_IF_NOT(Format() == SparseFormat::kUndefined, "Sparse format must not be set. Already contains format: ", Format());
+ ORT_RETURN_IF_ERROR(ValidateBlockSparseShapes(Values().Shape(), indices_shape));
+ InitBlockSparseIndices(indices_shape, indices_data);
return Status::OK();
}
Status SparseTensor::MakeBlockSparseData(const IDataTransfer& data_transfer, const OrtMemoryInfo& data_location,
const TensorShape& values_shape, const void* values_data,
const TensorShape& indices_shape, const int32_t* indices_data) {
+ ORT_RETURN_IF(IsDataTypeString(), "Use MakeBlockSparseStrings");
auto mutator = MakeBlockSparseData(values_shape, indices_shape);
if (values_shape.Size() > 0) {
auto& dst_values = mutator.Values();
auto& dst_indices = mutator.Indices();
Tensor src_values(dst_values.DataType(), dst_values.Shape(), const_cast(values_data), data_location);
Tensor src_index(dst_indices.DataType(), dst_indices.Shape(), const_cast(indices_data), data_location);
- ORT_RETURN_IF_ERROR(CopyData(data_transfer, MakeListConst(src_values, src_index), MakeListNonConst(dst_values, dst_indices)));
+ ORT_RETURN_IF_ERROR(CopyData(&data_transfer, MakeListConst(src_values, src_index), MakeListNonConst(dst_values, dst_indices)));
+ }
+ return Status::OK();
+}
+
+Status SparseTensor::MakeBlockSparseStrings(const TensorShape& values_shape, const char* const* strings,
+ const TensorShape& indices_shape, const int32_t* indices_data) {
+ ORT_RETURN_IF_NOT(IsDataTypeString(), "Expecting data type to be set as string");
+ auto mutator = MakeBlockSparseData(values_shape, indices_shape);
+ auto string_count = gsl::narrow(values_shape.Size());
+ if (string_count > 0) {
+ auto& dst_values = mutator.Values();
+ auto& dst_indices = mutator.Indices();
+ Tensor src_indices(dst_indices.DataType(), dst_indices.Shape(), const_cast(indices_data), Location());
+ ORT_RETURN_IF_ERROR(CopyStringsAndIndices(string_count, strings, dst_values, {std::cref(src_indices)}, {std::ref(dst_indices)}));
}
return Status::OK();
}
SparseTensor::BlockSparseMutator SparseTensor::MakeBlockSparseData(const TensorShape& values_shape, const TensorShape& indices_shape) {
ORT_ENFORCE(allocator_ != nullptr, "This method should follow a call to constructor that supplies the allocator");
+ ORT_ENFORCE(Format() == SparseFormat::kUndefined, "Sparse format must not be set. Already contains format: ", Format());
ORT_THROW_IF_ERROR(ValidateBlockSparseShapes(values_shape, indices_shape));
if (values_shape.Size() > 0) {
const auto data_size = SafeInt(values_shape.Size()) * ml_data_type_->Size();
@@ -372,10 +497,9 @@ SparseTensor::BlockSparseMutator SparseTensor::MakeBlockSparseData(const TensorS
gsl::narrow(index_size));
ORT_THROW_IF_ERROR(AllocateBuffer(required_buffer_size, static_cast(data_size / ml_data_type_->Size())));
}
+
values_ = Tensor(DataType(), values_shape, p_data_, Location());
- format_data_.resize(1);
- format_data_[0] = Tensor(DataTypeImpl::GetType(), indices_shape, IndicesStart(values_.SizeInBytes()), Location());
- format_ = SparseFormat::kBlockSparse;
+ InitBlockSparseIndices(indices_shape, reinterpret_cast(IndicesStart(values_.SizeInBytes())));
return BlockSparseMutator(values_, format_data_[0]);
}
diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc
index fa512f629b..f818ef7a26 100644
--- a/onnxruntime/core/framework/tensor_type_and_shape.cc
+++ b/onnxruntime/core/framework/tensor_type_and_shape.cc
@@ -203,12 +203,13 @@ OrtStatus* OrtTensorTypeAndShapeInfo::Clone(OrtTensorTypeAndShapeInfo** out) {
ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Outptr_ OrtTensorTypeAndShapeInfo** out) {
API_IMPL_BEGIN
- onnxruntime::MLDataType type = v->Type();
- ORT_ENFORCE(type != nullptr, "OrtValue is not a Tensor");
- if (type->IsTensorType() || type->IsSparseTensorType()) {
+ if (!v->IsAllocated()) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "the ort_value must contain a constructed tensor or sparse tensor");
+ }
+ if (v->IsTensor() || v->IsSparseTensor()) {
const onnxruntime::TensorShape* shape = nullptr;
onnxruntime::MLDataType data_type = nullptr;
- if (type->IsTensorType()) {
+ if (v->IsTensor()) {
const Tensor& tensor = v->Get();
shape = &tensor.Shape();
data_type = tensor.DataType();
@@ -224,6 +225,57 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out
API_IMPL_END
}
+ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorValuesTypeAndShape, _In_ const OrtValue* v,
+ _Outptr_ OrtTensorTypeAndShapeInfo** out) {
+ API_IMPL_BEGIN
+ const auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*v);
+ const auto& values = sparse_tensor.Values();
+ return GetTensorShapeAndType(values.Shape(), *values.DataType(), out);
+ API_IMPL_END
+}
+
+namespace {
+const Tensor& GetIndicesTensor(const OrtValue& v, OrtSparseIndicesFormat indices_format) {
+ const auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(v);
+ const Tensor* indices_tensor = nullptr;
+ switch (indices_format) {
+ case OrtSparseIndicesFormat::ORT_SPARSE_COO_INDICES:
+ indices_tensor = &sparse_tensor.AsCoo().Indices();
+ break;
+ case OrtSparseIndicesFormat::ORT_SPARSE_CSR_INNER_INDICES:
+ indices_tensor = &sparse_tensor.AsCsr().Inner();
+ break;
+ case OrtSparseIndicesFormat::ORT_SPARSE_CSR_OUTER_INDICES:
+ indices_tensor = &sparse_tensor.AsCsr().Outer();
+ break;
+ case OrtSparseIndicesFormat::ORT_SPARSE_BLOCK_SPARSE_INDICES:
+ indices_tensor = &sparse_tensor.AsBlockSparse().Indices();
+ break;
+ default:
+ ORT_THROW(ORT_INVALID_ARGUMENT, "Unsupported indices_format passed");
+ }
+ return *indices_tensor;
+}
+} // namespace
+
+ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorIndicesTypeShape, _In_ const OrtValue* v,
+ OrtSparseIndicesFormat indices_format, _Outptr_ OrtTensorTypeAndShapeInfo** out) {
+ API_IMPL_BEGIN
+ const Tensor& indices_tensor = GetIndicesTensor(*v, indices_format);
+ return GetTensorShapeAndType(indices_tensor.Shape(), *indices_tensor.DataType(), out);
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorIndices, _In_ const OrtValue* v,
+ enum OrtSparseIndicesFormat indices_format, _Out_ size_t* num_indices, _Outptr_ const void** indices) {
+ API_IMPL_BEGIN
+ const Tensor& indices_tensor = GetIndicesTensor(*v, indices_format);
+ *num_indices = gsl::narrow(indices_tensor.Shape().Size());
+ *indices = indices_tensor.DataRaw();
+ return nullptr;
+ API_IMPL_END
+}
+
ORT_API_STATUS_IMPL(OrtApis::GetValueType, _In_ const OrtValue* v, _Out_ ONNXType* out) {
API_IMPL_BEGIN
OrtTypeInfo* type_info;
diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc
index 2766573e1d..022182062e 100644
--- a/onnxruntime/core/framework/utils.cc
+++ b/onnxruntime/core/framework/utils.cc
@@ -140,13 +140,7 @@ static common::Status AllocateHelper(const AllocatorPtr& allocator,
allocator, target_mlvalue);
} else if (source_mlvalue.IsSparseTensor()) {
const SparseTensor& source_tensor = source_mlvalue.Get();
- auto p_tensor = std::make_unique(source_tensor.DataType(),
- source_tensor.DenseShape(),
- allocator);
- auto ml_tensor = DataTypeImpl::GetType();
- target_mlvalue.Init(p_tensor.release(),
- ml_tensor,
- ml_tensor->GetDeleteFunc());
+ SparseTensor::InitOrtValue(source_tensor.DataType(), source_tensor.DenseShape(), allocator, target_mlvalue);
} else if (source_mlvalue.IsTensorSequence()) {
const TensorSeq& source_tensor_seq = source_mlvalue.Get();
auto target_tensor_seq = std::make_unique(source_tensor_seq.DataType());
diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc
index 3b5ee4f98b..8bc84056b6 100644
--- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc
+++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc
@@ -146,9 +146,7 @@ Status OptimizerExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value,
"Tried to allocate without valid type information, ort_value index=" + std::to_string(ort_value_idx));
if (ml_type->IsSparseTensorType()) {
auto element_type = ml_type->AsSparseTensorType()->GetElementType();
- auto container_type = DataTypeImpl::GetType();
- auto sparse = std::make_unique(element_type, *shape, info_.GetAllocator());
- ort_value.Init(sparse.release(), container_type, container_type->GetDeleteFunc());
+ SparseTensor::InitOrtValue(element_type, *shape, info_.GetAllocator(), ort_value);
return Status::OK();
}
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
index 6a027e403f..7fffb25682 100644
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
@@ -37,6 +37,14 @@
#include "core/framework/TensorSeq.h"
#include "core/platform/ort_mutex.h"
+#ifdef USE_CUDA
+#include "core/providers/cuda/cuda_provider_factory.h"
+#include "core/providers/cuda/cuda_execution_provider_info.h"
+namespace onnxruntime {
+ProviderInfo_CUDA* TryGetProviderInfo_CUDA();
+}
+#endif
+
#ifdef ENABLE_EXTENSION_CUSTOM_OPS
#include "ortcustomops.h"
#endif
@@ -220,6 +228,224 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorAsOrtValue, _Inout_ OrtAllocator* alloc
API_IMPL_END
}
+ORT_API_STATUS_IMPL(OrtApis::CreateSparseTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* dense_shape,
+ size_t dense_shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out) {
+ API_IMPL_BEGIN
+ auto sparse_tensor_type = DataTypeImpl::SparseTensorTypeFromONNXEnum(type);
+ auto element_type = sparse_tensor_type->GetElementType();
+ assert(element_type->AsPrimitiveDataType() != nullptr);
+ TensorShape shape(dense_shape, dense_shape_len);
+ if (std::any_of(shape.GetDims().cbegin(), shape.GetDims().cend(),
+ [](int64_t v) { return v < 0; })) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape");
+ }
+
+ auto alloc_ptr = std::make_shared(allocator);
+ auto value = std::make_unique();
+ SparseTensor::InitOrtValue(element_type, shape, std::move(alloc_ptr), *value);
+ *out = value.release();
+ return nullptr;
+ API_IMPL_END
+}
+
+namespace {
+std::unique_ptr GetDataTransfer(const OrtDevice& src_device, const OrtDevice& dst_device) {
+ if (src_device.Type() == OrtDevice::CPU && dst_device.Type() == OrtDevice::CPU) {
+ return std::make_unique();
+ }
+#ifdef USE_CUDA
+ if (src_device.Type() == OrtDevice::GPU || dst_device.Type() == OrtDevice::GPU) {
+ if (auto* provider_info = TryGetProviderInfo_CUDA()) {
+ return provider_info->CreateGPUDataTransfer(nullptr);
+ }
+ }
+#endif
+ ORT_THROW("Not able to find appropriate IDataTransfer to copy sparse data");
+}
+
+SparseTensor& ValidateFillInputArgs(OrtValue* v, const TensorShape& values_shape, const OrtMemoryInfo* data_mem_info) {
+ auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*v);
+ if (sparse_tensor.IsDataTypeString()) {
+ if ((data_mem_info->device.Type() != OrtDevice::CPU) || sparse_tensor.Location().device.Type() != OrtDevice::CPU) {
+ ORT_THROW("Strings can only reside in CPU memory");
+ }
+ }
+ if (std::any_of(values_shape.GetDims().cbegin(), values_shape.GetDims().cend(),
+ [](int64_t v) { return v < 0; })) {
+ ORT_THROW("tried Filling sparse tensor with negative value in values shape");
+ }
+
+ return sparse_tensor;
+}
+
+union PtrConvert {
+ explicit PtrConvert(const void* p_p) : p(p_p) {}
+ const void* p;
+ const char** strings;
+};
+
+} // namespace
+
+ORT_API_STATUS_IMPL(OrtApis::FillSparseTensorCoo, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info,
+ _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values,
+ _In_ const int64_t* indices_data, size_t indices_num) {
+ API_IMPL_BEGIN
+ TensorShape values_t_shape(values_shape, values_shape_len);
+ auto& sparse_tensor = ValidateFillInputArgs(ort_value, values_t_shape, data_mem_info);
+
+ auto values_size = gsl::narrow(values_t_shape.Size());
+ auto indices_span = gsl::make_span(indices_data, indices_num);
+
+ if (sparse_tensor.IsDataTypeString()) {
+ PtrConvert conv(values);
+ ORT_THROW_IF_ERROR(sparse_tensor.MakeCooStrings(values_size, conv.strings, indices_span));
+ } else {
+ auto data_transfer = GetDataTransfer(data_mem_info->device, sparse_tensor.Location().device);
+ ORT_THROW_IF_ERROR(sparse_tensor.MakeCooData(*data_transfer, *data_mem_info, values_size,
+ values, indices_span));
+ }
+ return nullptr;
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::FillSparseTensorCsr, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info,
+ _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values,
+ _In_ const int64_t* inner_indices_data, size_t inner_indices_num,
+ _In_ const int64_t* outer_indices_data, size_t outer_indices_num) {
+ API_IMPL_BEGIN
+ TensorShape values_t_shape(values_shape, values_shape_len);
+ auto& sparse_tensor = ValidateFillInputArgs(ort_value, values_t_shape, data_mem_info);
+ auto values_size = gsl::narrow(values_t_shape.Size());
+
+ auto inner_indices_span = gsl::make_span(inner_indices_data, inner_indices_num);
+ auto outer_indices_span = gsl::make_span(outer_indices_data, outer_indices_num);
+ if (sparse_tensor.IsDataTypeString()) {
+ PtrConvert conv(values);
+ ORT_THROW_IF_ERROR(sparse_tensor.MakeCsrStrings(values_size, conv.strings, inner_indices_span, outer_indices_span));
+ } else {
+ auto data_transfer = GetDataTransfer(data_mem_info->device, sparse_tensor.Location().device);
+ ORT_THROW_IF_ERROR(sparse_tensor.MakeCsrData(*data_transfer, *data_mem_info, values_size,
+ values, inner_indices_span, outer_indices_span));
+ }
+ return nullptr;
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::FillSparseTensorBlockSparse, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info,
+ _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values,
+ _In_ const int64_t* indices_shape_data, size_t indices_shape_len,
+ _In_ const int32_t* indices_data) {
+ API_IMPL_BEGIN
+ TensorShape values_t_shape(values_shape, values_shape_len);
+ auto& sparse_tensor = ValidateFillInputArgs(ort_value, values_t_shape, data_mem_info);
+
+ TensorShape indices_t_shape(indices_shape_data, indices_shape_len);
+ if (std::any_of(indices_t_shape.GetDims().cbegin(), indices_t_shape.GetDims().cend(),
+ [](int64_t v) { return v < 0; })) {
+ ORT_THROW("tried Filling sparse tensor with negative value in block sparse indices shape");
+ }
+
+ if (sparse_tensor.IsDataTypeString()) {
+ PtrConvert conv(values);
+ ORT_THROW_IF_ERROR(sparse_tensor.MakeBlockSparseStrings(values_t_shape, conv.strings, indices_t_shape, indices_data));
+ } else {
+ auto data_transfer = GetDataTransfer(data_mem_info->device, sparse_tensor.Location().device);
+ ORT_THROW_IF_ERROR(sparse_tensor.MakeBlockSparseData(*data_transfer, *data_mem_info, values_t_shape,
+ values, indices_t_shape, indices_data));
+ }
+ return nullptr;
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::CreateSparseTensorWithValuesAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data,
+ _In_ const int64_t* dense_shape, size_t dense_shape_len,
+ _In_ const int64_t* values_shape, size_t values_shape_len,
+ ONNXTensorElementDataType type, _Outptr_ OrtValue** out) {
+ API_IMPL_BEGIN
+ auto sparse_tensor_type = DataTypeImpl::SparseTensorTypeFromONNXEnum(type);
+ auto element_type = sparse_tensor_type->GetElementType();
+ assert(element_type->AsPrimitiveDataType() != nullptr);
+ if (utils::IsDataTypeString(element_type)) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
+ "Can not use strings in pre-allocated memory."
+ " Use CreateSparseTensorAsOrtValue() to allocate memory inside and copy");
+ }
+ TensorShape tensor_dense_shape(dense_shape, dense_shape_len);
+ TensorShape tensor_values_shape(values_shape, values_shape_len);
+ if (std::any_of(tensor_values_shape.GetDims().cbegin(), tensor_values_shape.GetDims().cend(),
+ [](int64_t v) { return v < 0; })) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape");
+ }
+ auto value = std::make_unique();
+ SparseTensor::InitOrtValue(element_type, tensor_dense_shape, tensor_values_shape, p_data, *info, *value);
+ *out = value.release();
+ return nullptr;
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::UseCooIndices, _Inout_ OrtValue* ort_value, _Inout_ int64_t* indices_data, size_t indices_num) {
+ API_IMPL_BEGIN
+ auto v = reinterpret_cast<::OrtValue*>(ort_value);
+ auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*v);
+ auto indices_span = (indices_num == 0 || indices_data == nullptr)
+ ? gsl::span()
+ : gsl::make_span(indices_data, indices_num);
+
+ ORT_THROW_IF_ERROR(sparse_tensor.UseCooIndices(indices_span));
+ return nullptr;
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::UseCsrIndices, _Inout_ OrtValue* ort_value,
+ _Inout_ int64_t* inner_data, size_t inner_num,
+ _Inout_ int64_t* outer_data, size_t outer_num) {
+ API_IMPL_BEGIN
+ auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*ort_value);
+ auto inner_span = (inner_num == 0 || inner_data == nullptr)
+ ? gsl::span()
+ : gsl::make_span(inner_data, inner_num);
+ auto outer_span = (outer_num == 0 || outer_data == nullptr)
+ ? gsl::span()
+ : gsl::make_span(outer_data, outer_num);
+ ORT_THROW_IF_ERROR(sparse_tensor.UseCsrIndices(inner_span, outer_span));
+ return nullptr;
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::UseBlockSparseIndices, _Inout_ OrtValue* ort_value, const int64_t* indices_shape, size_t indices_shape_len,
+ _Inout_ int32_t* indices_data) {
+ API_IMPL_BEGIN
+ auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*ort_value);
+ TensorShape ind_shape(indices_shape, indices_shape_len);
+ ORT_THROW_IF_ERROR(sparse_tensor.UseBlockSparseIndices(ind_shape, indices_data));
+ return nullptr;
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorFormat, _In_ const OrtValue* ort_value, _Out_ enum OrtSparseFormat* out) {
+ API_IMPL_BEGIN
+ auto v = reinterpret_cast(ort_value);
+ if (!v->IsAllocated()) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "the ort_value must contain a constructed tensor");
+ }
+ const auto& sparse_tensor = v->Get();
+ *out = static_cast(sparse_tensor.Format());
+ return nullptr;
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorValues, _In_ const OrtValue* ort_value, _Outptr_ const void** out) {
+ API_IMPL_BEGIN
+ const auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*ort_value);
+ if (sparse_tensor.IsDataTypeString()) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Use GetStringTensor*() API to retrieve strings");
+ }
+ const auto& values = sparse_tensor.Values();
+ *out = values.DataRaw();
+ return nullptr;
+ API_IMPL_END
+}
+
ORT_API_STATUS_IMPL(OrtApis::CreateCustomOpDomain, _In_ const char* domain, _Outptr_ OrtCustomOpDomain** out) {
API_IMPL_BEGIN
auto custom_op_domain = std::make_unique();
@@ -656,9 +882,18 @@ ORT_API_STATUS_IMPL(OrtApis::IsTensor, _In_ const OrtValue* value, _Out_ int* ou
return nullptr;
}
+ORT_API_STATUS_IMPL(OrtApis::IsSparseTensor, _In_ const OrtValue* value, _Out_ int* out) {
+ auto v = reinterpret_cast(value);
+ *out = v->IsSparseTensor() ? 1 : 0;
+ return nullptr;
+}
+
ORT_API_STATUS_IMPL(OrtApis::GetTensorMutableData, _Inout_ OrtValue* value, _Outptr_ void** output) {
TENSOR_READWRITE_API_BEGIN
- //TODO: test if it's a string tensor
+ // Uncomment when WinML fixed their code
+ //if (tensor->IsDataTypeString()) {
+ // return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "this API does not support strings");
+ //}
*output = tensor->MutableDataRaw();
return nullptr;
API_IMPL_END
@@ -693,79 +928,127 @@ ORT_API_STATUS_IMPL(OrtApis::FillStringTensorElement, _Inout_ OrtValue* value, _
API_IMPL_END
}
-ORT_API_STATUS_IMPL(OrtApis::GetStringTensorDataLength, _In_ const OrtValue* value, _Out_ size_t* out) {
- TENSOR_READ_API_BEGIN
- const auto* src = tensor.Data();
- int64_t len = tensor.Shape().Size();
- if (len >= 0) {
- size_t ret = 0;
- for (int64_t i = 0; i != len; ++i) {
- ret += src[i].size();
+namespace {
+
+OrtStatusPtr GetTensorStringSpan(const ::OrtValue& v, gsl::span& span) {
+ if (!v.IsAllocated()) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtValue should contain a Tensor or a Sparse Tensor");
+ }
+ gsl::span str_span;
+ int64_t items = 0;
+ // Data type will be enforced on DataAsSpan() call.
+ if (v.IsTensor()) {
+ const auto& tensor = v.Get();
+ items = tensor.Shape().Size();
+ if (items >= 0) {
+ str_span = tensor.DataAsSpan();
}
- *out = ret;
- } else
+ } else if (v.IsSparseTensor()) {
+ const auto& sparse_tensor = v.Get();
+ if (sparse_tensor.Format() == onnxruntime::SparseFormat::kUndefined) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Sparse Tensor does not contain sparse data");
+ }
+ items = sparse_tensor.Values().Shape().Size();
+ if (items >= 0) {
+ str_span = sparse_tensor.Values().DataAsSpan();
+ }
+ } else {
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API supports Tensors or SparseTensors");
+ }
+
+ if (items < 0) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "shape is invalid");
+ }
+ span = str_span;
+ return nullptr;
+}
+} // namespace
+
+ORT_API_STATUS_IMPL(OrtApis::GetStringTensorDataLength, _In_ const OrtValue* value, _Out_ size_t* out) {
+ API_IMPL_BEGIN
+ gsl::span str_span;
+ if (auto* status = GetTensorStringSpan(*value, str_span)) {
+ return status;
+ }
+
+ size_t ret = 0;
+ for (const auto& s : str_span) {
+ ret += s.size();
+ }
+
+ *out = ret;
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::GetStringTensorElementLength, _In_ const OrtValue* value, size_t index, _Out_ size_t* out) {
- TENSOR_READ_API_BEGIN
- const auto* src = tensor.Data();
- auto len = static_cast(tensor.Shape().Size());
- if (index < len) {
- *out = src[index].size();
- } else
- return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "shape is invalid");
+ API_IMPL_BEGIN
+ gsl::span str_span;
+ if (auto* status = GetTensorStringSpan(*value, str_span)) {
+ return status;
+ }
+
+ if (index < str_span.size()) {
+ *out = str_span[index].size();
+ } else {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "index is out of bounds");
+ }
+
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::GetStringTensorContent, _In_ const OrtValue* value, _Out_writes_bytes_all_(s_len) void* s,
size_t s_len, _Out_writes_all_(offsets_len) size_t* offsets, size_t offsets_len) {
- TENSOR_READ_API_BEGIN
- const auto* input = tensor.Data();
- auto len = static_cast(tensor.Shape().Size());
- if (offsets_len != len) {
+ API_IMPL_BEGIN
+
+ gsl::span str_span;
+ if (auto* status = GetTensorStringSpan(*value, str_span)) {
+ return status;
+ }
+
+ if (offsets_len != str_span.size()) {
return OrtApis::CreateStatus(ORT_FAIL, "offsets buffer is not equal to tensor size");
}
- {
- size_t ret = 0;
- for (size_t i = 0; i != len; ++i) {
- ret += input[i].size();
- }
- if (s_len < ret) {
- return OrtApis::CreateStatus(ORT_FAIL, "output buffer is too small");
- }
+
+ size_t total_size = 0;
+ for (const auto& str : str_span) {
+ total_size += str.size();
}
+
+ if (s_len < total_size) {
+ return OrtApis::CreateStatus(ORT_FAIL, "output buffer is too small. Use GetStringTensorDataLength.");
+ }
+
size_t f = 0;
char* p = static_cast(s);
- for (size_t i = 0; i != len; ++i, ++offsets) {
- memcpy(p, input[i].data(), input[i].size());
- p += input[i].size();
- *offsets = f;
- f += input[i].size();
+ for (const auto& str : str_span) {
+ memcpy(p, str.data(), str.size());
+ p += str.size();
+ *offsets++ = f;
+ f += str.size();
}
return nullptr;
API_IMPL_END
}
-ORT_API_STATUS_IMPL(OrtApis::GetStringTensorElement, _In_ const OrtValue* value, size_t s_len, size_t index, _Out_writes_bytes_all_(s_len) void* s) {
- TENSOR_READ_API_BEGIN
- const auto* input = tensor.Data();
- auto len = static_cast(tensor.Shape().Size());
+ORT_API_STATUS_IMPL(OrtApis::GetStringTensorElement, _In_ const OrtValue* value,
+ size_t s_len, size_t index, _Out_writes_bytes_all_(s_len) void* s) {
+ API_IMPL_BEGIN
+ gsl::span str_span;
+ if (auto* status = GetTensorStringSpan(*value, str_span)) {
+ return status;
+ }
- if (index >= len) {
+ if (index < str_span.size()) {
+ const auto& str = str_span[index];
+ if (s_len < str.size()) {
+ return OrtApis::CreateStatus(ORT_FAIL, "buffer size is too small for string element");
+ }
+ memcpy(s, str.data(), str.size());
+ } else {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "element index is out of bounds");
}
-
- size_t ret = input[index].size();
- if (s_len < ret) {
- return OrtApis::CreateStatus(ORT_FAIL, "buffer size is too small for string");
- }
-
- memcpy(s, input[index].data(), input[index].size());
-
return nullptr;
API_IMPL_END
}
@@ -2097,6 +2380,20 @@ static constexpr OrtApi ort_api_1_to_9 = {
&OrtApis::EnableOrtCustomOps,
&OrtApis::RegisterAllocator,
&OrtApis::UnregisterAllocator,
+ &OrtApis::IsSparseTensor,
+ &OrtApis::CreateSparseTensorAsOrtValue,
+ &OrtApis::FillSparseTensorCoo,
+ &OrtApis::FillSparseTensorCsr,
+ &OrtApis::FillSparseTensorBlockSparse,
+ &OrtApis::CreateSparseTensorWithValuesAsOrtValue,
+ &OrtApis::UseCooIndices,
+ &OrtApis::UseCsrIndices,
+ &OrtApis::UseBlockSparseIndices,
+ &OrtApis::GetSparseTensorFormat,
+ &OrtApis::GetSparseTensorValuesTypeAndShape,
+ &OrtApis::GetSparseTensorValues,
+ &OrtApis::GetSparseTensorIndicesTypeShape,
+ &OrtApis::GetSparseTensorIndices,
};
// Asserts to do a some checks to ensure older Versions of the OrtApi never change (will detect an addition or deletion but not if they cancel out each other)
diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h
index 080431028f..6a6b3fa817 100644
--- a/onnxruntime/core/session/ort_apis.h
+++ b/onnxruntime/core/session/ort_apis.h
@@ -288,4 +288,31 @@ ORT_API(void, ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensorRTProvide
ORT_API_STATUS_IMPL(EnableOrtCustomOps, _Inout_ OrtSessionOptions* options);
ORT_API_STATUS_IMPL(RegisterAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocator* allocator);
ORT_API_STATUS_IMPL(UnregisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info);
+// SparseTensor related API
+ORT_API_STATUS_IMPL(IsSparseTensor, _In_ const OrtValue* value, _Out_ int* out);
+ORT_API_STATUS_IMPL(CreateSparseTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* dense_shape,
+ size_t dense_shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out);
+ORT_API_STATUS_IMPL(FillSparseTensorCoo, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* mem_info,
+ _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values,
+ _In_ const int64_t* indices_data, size_t indices_num);
+ORT_API_STATUS_IMPL(FillSparseTensorCsr, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info,
+ _In_ const int64_t* values_shape, size_t values_shape_len, const void* values,
+ _In_ const int64_t* inner_indices_data, size_t inner_indices_num,
+ _In_ const int64_t* outer_indices_data, size_t outer_indices_num);
+ORT_API_STATUS_IMPL(FillSparseTensorBlockSparse, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info,
+ _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values,
+ _In_ const int64_t* indices_shape_data, size_t indices_shape_len,
+ _In_ const int32_t* indices_data);
+ORT_API_STATUS_IMPL(CreateSparseTensorWithValuesAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data,
+ _In_ const int64_t* dense_shape, size_t dense_shape_len,
+ _In_ const int64_t* values_shape, size_t values_shape_len,
+ ONNXTensorElementDataType type, _Outptr_ OrtValue** out);
+ORT_API_STATUS_IMPL(UseCooIndices, _Inout_ OrtValue* ort_value, _Inout_ int64_t* indices_data, size_t indices_num);
+ORT_API_STATUS_IMPL(UseCsrIndices, _Inout_ OrtValue*, _Inout_ int64_t* inner_data, size_t inner_num, _Inout_ int64_t* outer_data, size_t outer_num);
+ORT_API_STATUS_IMPL(UseBlockSparseIndices, _Inout_ OrtValue* ort_value, const int64_t* indices_shape, size_t indices_shape_len, _Inout_ int32_t* indices_data);
+ORT_API_STATUS_IMPL(GetSparseTensorFormat, _In_ const OrtValue* ort_value, _Out_ enum OrtSparseFormat* out);
+ORT_API_STATUS_IMPL(GetSparseTensorValuesTypeAndShape, _In_ const OrtValue* ort_value, _Outptr_ OrtTensorTypeAndShapeInfo** out);
+ORT_API_STATUS_IMPL(GetSparseTensorValues, _In_ const OrtValue* ort_value, _Outptr_ const void** out);
+ORT_API_STATUS_IMPL(GetSparseTensorIndicesTypeShape, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Outptr_ OrtTensorTypeAndShapeInfo** out);
+ORT_API_STATUS_IMPL(GetSparseTensorIndices, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Out_ size_t* num_indices, _Outptr_ const void** indices);
} // namespace OrtApis
diff --git a/onnxruntime/test/framework/sparse_kernels_test.cc b/onnxruntime/test/framework/sparse_kernels_test.cc
index c2d0f86090..01a5adc387 100644
--- a/onnxruntime/test/framework/sparse_kernels_test.cc
+++ b/onnxruntime/test/framework/sparse_kernels_test.cc
@@ -36,7 +36,6 @@ inline int64_t vector_len(const std::vector& v) {
return static_cast(v.size());
}
-
// This file contains sample implementations of several ops with sparse-tensor inputs/outputs.
// Each op is implemented as a struct with the following signature:
// struct SparseOp {
@@ -1209,9 +1208,6 @@ TEST(SparseTensorConversionTests, TestDenseToSparseConversion) {
RawSparseDataChecker);
}
-template
-using SparseMatrixRowMajor = Eigen::SparseMatrix;
-
TEST(SparseTensorConversionTests, CsrConversion) {
auto* cpu_provider = TestCPUExecutionProvider();
auto cpu_allocator = cpu_provider->GetAllocator(0, OrtMemTypeDefault);
@@ -1234,6 +1230,7 @@ TEST(SparseTensorConversionTests, CsrConversion) {
const std::vector expected_values = {1, 1, 1};
const std::vector expected_values_str = {"1", "1", "1"};
+ const char* const strings[] = {"1", "1", "1"};
const std::vector expected_inner = {2, 0, 2};
const std::vector expected_outer = {0, 1, 3, 3};
@@ -1242,6 +1239,49 @@ TEST(SparseTensorConversionTests, CsrConversion) {
auto cpu_transfer = cpu_provider->GetDataTransfer();
dtm.RegisterDataTransfer(std::move(cpu_transfer));
}
+ {
+ {
+ // Test CSR initialization of 100% sparse tensor, passing 0 in the shape
+ SparseTensor fully_sparse(DataTypeImpl::GetType(), TensorShape{3, 3}, cpu_allocator);
+ ASSERT_STATUS_OK(fully_sparse.MakeCsrData(*cpu_provider->GetDataTransfer(), cpu_allocator->Info(),
+ 0U, nullptr, gsl::span(), gsl::span()));
+ ASSERT_EQ(fully_sparse.Format(), SparseFormat::kCsrc);
+ ASSERT_EQ(0, fully_sparse.RequiredAllocationSize());
+ ASSERT_EQ(0U, fully_sparse.NumValues());
+ ASSERT_EQ(1U, fully_sparse.Values().Shape().GetDims().size());
+ ASSERT_EQ(0, fully_sparse.Values().Shape().Size());
+ ASSERT_TRUE(fully_sparse.Values().DataAsSpan().empty());
+ auto csr_view = fully_sparse.AsCsr();
+ const auto& inner = csr_view.Inner();
+ ASSERT_EQ(0, inner.Shape().Size());
+ ASSERT_EQ(1U, inner.Shape().GetDims().size());
+ ASSERT_TRUE(inner.DataAsSpan().empty());
+ const auto& outer = csr_view.Outer();
+ ASSERT_EQ(0, outer.Shape().Size());
+ ASSERT_EQ(1U, outer.Shape().GetDims().size());
+ ASSERT_TRUE(outer.DataAsSpan().empty());
+ }
+ {
+ // Test CSR initialization of 100% sparse tensor, passing 0 in the shape
+ SparseTensor fully_sparse(DataTypeImpl::GetType(), TensorShape{3, 3}, TensorShape{0}, nullptr, cpu_allocator->Info());
+ ASSERT_STATUS_OK(fully_sparse.UseCsrIndices(gsl::span(), gsl::span()));
+ ASSERT_EQ(fully_sparse.Format(), SparseFormat::kCsrc);
+ ASSERT_EQ(0, fully_sparse.RequiredAllocationSize());
+ ASSERT_EQ(0U, fully_sparse.NumValues());
+ ASSERT_EQ(1U, fully_sparse.Values().Shape().GetDims().size());
+ ASSERT_EQ(0, fully_sparse.Values().Shape().Size());
+ ASSERT_TRUE(fully_sparse.Values().DataAsSpan().empty());
+ auto csr_view = fully_sparse.AsCsr();
+ const auto& inner = csr_view.Inner();
+ ASSERT_EQ(0, inner.Shape().Size());
+ ASSERT_EQ(1U, inner.Shape().GetDims().size());
+ ASSERT_TRUE(inner.DataAsSpan().empty());
+ const auto& outer = csr_view.Outer();
+ ASSERT_EQ(0, outer.Shape().Size());
+ ASSERT_EQ(1U, outer.Shape().GetDims().size());
+ ASSERT_TRUE(outer.DataAsSpan().empty());
+ }
+ }
Tensor dense_cpu_src(DataTypeImpl::GetType(), dense_shape, dense_data.data(), cpu_allocator->Info());
{
@@ -1309,6 +1349,28 @@ TEST(SparseTensorConversionTests, CsrConversion) {
ASSERT_TRUE(std::equal(dense_values_dst.cbegin(), dense_values_dst.cend(), dense_data_str.cbegin(), dense_data_str.cend()));
}
+ {
+ // Use MakeCsrStrings()
+ SparseTensor str_cpu_src(DataTypeImpl::GetType(), dense_shape, cpu_allocator);
+ ASSERT_STATUS_OK(str_cpu_src.MakeCsrStrings(expected_values_str.size(), strings,
+ gsl::make_span(expected_inner), gsl::make_span(expected_outer)));
+ ASSERT_EQ(str_cpu_src.Format(), SparseFormat::kCsrc);
+ ASSERT_TRUE(str_cpu_src.IsDataTypeString());
+ ASSERT_EQ(str_cpu_src.DenseShape().GetDims(), dense_shape);
+ ASSERT_EQ(str_cpu_src.NumValues(), expected_values_str.size());
+ auto values = str_cpu_src.Values().DataAsSpan();
+ ASSERT_TRUE(std::equal(expected_values_str.cbegin(), expected_values_str.cend(), values.cbegin(), values.cend()));
+
+ auto csr_view = str_cpu_src.AsCsr();
+ auto inner = csr_view.Inner().DataAsSpan();
+ ASSERT_EQ(expected_inner.size(), inner.size());
+ ASSERT_TRUE(std::equal(expected_inner.cbegin(), expected_inner.cend(), inner.cbegin(), inner.cend()));
+
+ auto outer = csr_view.Outer().DataAsSpan();
+ ASSERT_EQ(expected_outer.size(), outer.size());
+ ASSERT_TRUE(std::equal(expected_outer.cbegin(), expected_outer.cend(), outer.cbegin(), outer.cend()));
+ }
+
#ifdef USE_CUDA
auto cuda_provider = DefaultCudaExecutionProvider();
auto cuda_allocator = cuda_provider->GetAllocator(0, OrtMemTypeDefault);
@@ -1387,6 +1449,7 @@ TEST(SparseTensorConversionTests, CooConversion) {
const std::vector expected_values = {1, 1, 1};
const std::vector expected_values_str = {"1", "1", "1"};
+ const char* const strings[] = {"1", "1", "1"};
const std::vector expected_linear_indices = {2, 3, 5};
const std::vector expected_2d_indices = {0, 2, 1, 0, 1, 2};
@@ -1395,6 +1458,43 @@ TEST(SparseTensorConversionTests, CooConversion) {
auto cpu_transfer = cpu_provider->GetDataTransfer();
dtm.RegisterDataTransfer(std::move(cpu_transfer));
}
+
+ {
+ // Test COO initialization of 100% sparse tensor, passing 0 in the shape
+ SparseTensor fully_sparse(DataTypeImpl::GetType(), TensorShape{3, 3}, cpu_allocator);
+ ASSERT_STATUS_OK(fully_sparse.MakeCooData(*cpu_provider->GetDataTransfer(), cpu_allocator->Info(), 0, nullptr, gsl::span()));
+ ASSERT_EQ(fully_sparse.Format(), SparseFormat::kCoo);
+ ASSERT_EQ(0, fully_sparse.RequiredAllocationSize());
+ ASSERT_EQ(0U, fully_sparse.NumValues());
+ ASSERT_EQ(1U, fully_sparse.Values().Shape().GetDims().size());
+ ASSERT_EQ(0, fully_sparse.Values().Shape().Size());
+ ASSERT_TRUE(fully_sparse.Values().DataAsSpan().empty());
+ auto coo_view = fully_sparse.AsCoo();
+ const auto& indices = coo_view.Indices();
+ ASSERT_EQ(0, indices.Shape().Size());
+ // For fully sparse we assume a 2-D indices.
+ ASSERT_EQ(2U, indices.Shape().GetDims().size());
+ ASSERT_TRUE(indices.DataAsSpan().empty());
+ }
+
+ {
+ // Test COO initialization of 100% sparse tensor, passing 0 in the shape
+ SparseTensor fully_sparse(DataTypeImpl::GetType(), TensorShape{3, 3}, TensorShape{0}, nullptr, cpu_allocator->Info());
+ ASSERT_STATUS_OK(fully_sparse.UseCooIndices(gsl::span()));
+ ASSERT_EQ(fully_sparse.Format(), SparseFormat::kCoo);
+ ASSERT_EQ(0, fully_sparse.RequiredAllocationSize());
+ ASSERT_EQ(0U, fully_sparse.NumValues());
+ ASSERT_EQ(1U, fully_sparse.Values().Shape().GetDims().size());
+ ASSERT_EQ(0, fully_sparse.Values().Shape().Size());
+ ASSERT_TRUE(fully_sparse.Values().DataAsSpan().empty());
+ auto coo_view = fully_sparse.AsCoo();
+ const auto& indices = coo_view.Indices();
+ ASSERT_EQ(0, indices.Shape().Size());
+ // For fully sparse we assume a 2-D indices.
+ ASSERT_EQ(2U, indices.Shape().GetDims().size());
+ ASSERT_TRUE(indices.DataAsSpan().empty());
+ }
+
Tensor dense_cpu_src(DataTypeImpl::GetType(), dense_shape, dense_data.data(), cpu_allocator->Info());
{
// test where both src and destination are on CPU. Linear index.
@@ -1452,6 +1552,25 @@ TEST(SparseTensorConversionTests, CooConversion) {
ASSERT_TRUE(std::equal(dense_values_dst.cbegin(), dense_values_dst.cend(), dense_data_str.cbegin(), dense_data_str.cend()));
}
+ {
+ // Use MakeCooStrings()
+ SparseTensor str_cpu_src(DataTypeImpl::GetType(), dense_shape, cpu_allocator);
+ ASSERT_STATUS_OK(str_cpu_src.MakeCooStrings(expected_values_str.size(), strings,
+ gsl::make_span(expected_linear_indices)));
+ ASSERT_EQ(str_cpu_src.Format(), SparseFormat::kCoo);
+ ASSERT_TRUE(str_cpu_src.IsDataTypeString());
+ ASSERT_EQ(str_cpu_src.DenseShape().GetDims(), dense_shape);
+ ASSERT_EQ(str_cpu_src.NumValues(), expected_values_str.size());
+ auto values = str_cpu_src.Values().DataAsSpan();
+ ASSERT_TRUE(std::equal(expected_values_str.cbegin(), expected_values_str.cend(), values.cbegin(), values.cend()));
+
+ auto coo_view = str_cpu_src.AsCoo();
+ auto indices = coo_view.Indices().DataAsSpan();
+ ASSERT_EQ(expected_linear_indices.size(), indices.size());
+ ASSERT_TRUE(std::equal(expected_linear_indices.cbegin(), expected_linear_indices.cend(), indices.cbegin(), indices.cend()));
+ }
+
+
{
// test where both src and destination are on CPU. 2-D index
SparseTensor dst;
@@ -1539,5 +1658,133 @@ TEST(SparseTensorConversionTests, CooConversion) {
#endif
}
#endif // !ORT_MINIMAL_BUILD
+
+TEST(SparseTensorConversionTests, BlockSparse) {
+ auto* cpu_provider = TestCPUExecutionProvider();
+ auto cpu_allocator = cpu_provider->GetAllocator(0, OrtMemTypeDefault);
+
+ DataTransferManager dtm;
+ {
+ auto cpu_transfer = cpu_provider->GetDataTransfer();
+ dtm.RegisterDataTransfer(std::move(cpu_transfer));
+ }
+
+ {
+ // Fully sparse
+ SparseTensor fully_sparse(DataTypeImpl::GetType(), TensorShape{3, 3}, cpu_allocator);
+ ASSERT_STATUS_OK(fully_sparse.MakeBlockSparseData(*cpu_provider->GetDataTransfer(), cpu_allocator->Info(),
+ TensorShape{0}, nullptr, TensorShape{0}, nullptr));
+ ASSERT_EQ(fully_sparse.Format(), SparseFormat::kBlockSparse);
+ ASSERT_EQ(0, fully_sparse.RequiredAllocationSize());
+ ASSERT_EQ(0U, fully_sparse.NumValues());
+ ASSERT_EQ(1U, fully_sparse.Values().Shape().GetDims().size());
+ ASSERT_EQ(0, fully_sparse.Values().Shape().Size());
+ ASSERT_TRUE(fully_sparse.Values().DataAsSpan().empty());
+ auto blocksparse_view = fully_sparse.AsBlockSparse();
+ const auto& indices = blocksparse_view.Indices();
+ ASSERT_EQ(0, indices.Shape().Size());
+ ASSERT_EQ(1U, indices.Shape().GetDims().size());
+ ASSERT_TRUE(indices.DataAsSpan().empty());
+ }
+
+ {
+ // Fully sparse
+ SparseTensor fully_sparse(DataTypeImpl::GetType(), TensorShape{3, 3},
+ TensorShape{0}, nullptr, cpu_allocator->Info());
+ ASSERT_STATUS_OK(fully_sparse.UseBlockSparseIndices(TensorShape{0}, nullptr));
+ ASSERT_EQ(fully_sparse.Format(), SparseFormat::kBlockSparse);
+ ASSERT_EQ(0, fully_sparse.RequiredAllocationSize());
+ ASSERT_EQ(0U, fully_sparse.NumValues());
+ ASSERT_EQ(1U, fully_sparse.Values().Shape().GetDims().size());
+ ASSERT_EQ(0, fully_sparse.Values().Shape().Size());
+ ASSERT_TRUE(fully_sparse.Values().DataAsSpan().empty());
+ auto blocksparse_view = fully_sparse.AsBlockSparse();
+ const auto& indices = blocksparse_view.Indices();
+ ASSERT_EQ(0, indices.Shape().Size());
+ ASSERT_EQ(1U, indices.Shape().GetDims().size());
+ ASSERT_TRUE(indices.DataAsSpan().empty());
+ }
+
+ const TensorShape dense_shape{8, 8};
+ constexpr int64_t block_size = 2;
+ const TensorShape values_shape{2, block_size, block_size};
+ // Two dense blocks
+ std::vector data_blocks{
+ 1, 2, 3, 4, 5, 6, 7, 8};
+
+ const char* const strings[] = {
+ "1", "2", "3", "4", "5", "6", "7", "8"};
+
+ const std::string expected_strings[] = {
+ "1", "2", "3", "4", "5", "6", "7", "8"};
+
+
+ const TensorShape indices_shape{2, 2}; // two blocks by two coordinates
+ // (0, 0), (0,1)
+ std::vector blocksparse_indices = {
+ 0, 0, 0, 1};
+
+ {
+ // Test instantiation only
+ SparseTensor own_buffer_tensor(DataTypeImpl::GetType(), dense_shape, cpu_allocator);
+ ASSERT_STATUS_OK(own_buffer_tensor.MakeBlockSparseData(*cpu_provider->GetDataTransfer(), cpu_allocator->Info(),
+ values_shape, data_blocks.data(),
+ indices_shape, blocksparse_indices.data()));
+ ASSERT_EQ(own_buffer_tensor.Format(), SparseFormat::kBlockSparse);
+ ASSERT_EQ(dense_shape, own_buffer_tensor.DenseShape());
+ ASSERT_EQ(data_blocks.size(), own_buffer_tensor.NumValues());
+ ASSERT_EQ(values_shape, own_buffer_tensor.Values().Shape());
+ auto data_span = own_buffer_tensor.Values().DataAsSpan();
+ ASSERT_EQ(data_blocks.size(), data_span.size());
+ ASSERT_TRUE(std::equal(data_blocks.cbegin(), data_blocks.cend(), data_span.cbegin(), data_span.cend()));
+
+ const auto& indices = own_buffer_tensor.AsBlockSparse().Indices();
+ ASSERT_EQ(indices_shape, indices.Shape());
+ auto indices_span = indices.DataAsSpan();
+ ASSERT_TRUE(std::equal(blocksparse_indices.cbegin(), blocksparse_indices.cend(),
+ indices_span.cbegin(), indices_span.cend()));
+ }
+
+ {
+ // Test instantiation only
+ SparseTensor user_buffer_tensor(DataTypeImpl::GetType(), dense_shape, values_shape, data_blocks.data(), cpu_allocator->Info());
+ ASSERT_STATUS_OK(user_buffer_tensor.UseBlockSparseIndices(indices_shape, blocksparse_indices.data()));
+ ASSERT_EQ(user_buffer_tensor.Format(), SparseFormat::kBlockSparse);
+ ASSERT_EQ(dense_shape, user_buffer_tensor.DenseShape());
+ ASSERT_EQ(data_blocks.size(), user_buffer_tensor.NumValues());
+ ASSERT_EQ(values_shape, user_buffer_tensor.Values().Shape());
+ auto data_span = user_buffer_tensor.Values().DataAsSpan();
+ ASSERT_EQ(data_blocks.size(), data_span.size());
+ ASSERT_TRUE(std::equal(data_blocks.cbegin(), data_blocks.cend(), data_span.cbegin(), data_span.cend()));
+
+ const auto& indices = user_buffer_tensor.AsBlockSparse().Indices();
+ ASSERT_EQ(indices_shape, indices.Shape());
+ auto indices_span = indices.DataAsSpan();
+ ASSERT_TRUE(std::equal(blocksparse_indices.cbegin(), blocksparse_indices.cend(),
+ indices_span.cbegin(), indices_span.cend()));
+ }
+
+ {
+ // Use MakeBlockSparseStrings()
+ SparseTensor own_buffer_tensor(DataTypeImpl::GetType(), dense_shape, cpu_allocator);
+ ASSERT_STATUS_OK(own_buffer_tensor.MakeBlockSparseStrings(values_shape, strings, indices_shape, blocksparse_indices.data()));
+ ASSERT_TRUE(own_buffer_tensor.IsDataTypeString());
+ ASSERT_EQ(own_buffer_tensor.Format(), SparseFormat::kBlockSparse);
+ ASSERT_EQ(dense_shape, own_buffer_tensor.DenseShape());
+ ASSERT_EQ(data_blocks.size(), own_buffer_tensor.NumValues());
+ ASSERT_EQ(values_shape, own_buffer_tensor.Values().Shape());
+ auto data_span = own_buffer_tensor.Values().DataAsSpan();
+ auto expected_span = gsl::make_span(expected_strings);
+ ASSERT_EQ(expected_span.size(), data_span.size());
+ ASSERT_TRUE(std::equal(expected_span.cbegin(), expected_span.cend(), data_span.cbegin(), data_span.cend()));
+
+ const auto& indices = own_buffer_tensor.AsBlockSparse().Indices();
+ ASSERT_EQ(indices_shape, indices.Shape());
+ auto indices_span = indices.DataAsSpan();
+ ASSERT_TRUE(std::equal(blocksparse_indices.cbegin(), blocksparse_indices.cend(),
+ indices_span.cbegin(), indices_span.cend()));
+
+ }
+}
} // namespace test
} // namespace onnxruntime
diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc
index 4c43a575b2..4c9316e6dd 100644
--- a/onnxruntime/test/shared_lib/test_inference.cc
+++ b/onnxruntime/test/shared_lib/test_inference.cc
@@ -23,6 +23,7 @@
#include "test_fixture.h"
#include "utils.h"
#include "custom_op_utils.h"
+#include
#ifdef _WIN32
#include
@@ -175,6 +176,10 @@ static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/f
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx");
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx");
static constexpr PATH_TYPE CUSTOM_OP_MODEL_WITH_ATTRIBUTES_URI = TSTR("testdata/foo_bar_3.onnx");
+static constexpr PATH_TYPE SPARSE_OUTPUT_MODEL_URI = TSTR("testdata/sparse_initializer_as_output.onnx");
+#ifndef DISABLE_CONTRIB_OPS
+static constexpr PATH_TYPE SPARSE_INPUT_MATMUL_MODEL_URI = TSTR("testdata/sparse_to_dense_matmul.onnx");
+#endif
#ifdef ENABLE_EXTENSION_CUSTOM_OPS
static constexpr PATH_TYPE ORT_CUSTOM_OPS_MODEL_URI = TSTR("testdata/custom_op_string_lower.onnx");
@@ -239,6 +244,121 @@ INSTANTIATE_TEST_SUITE_P(CApiTestWithProviders,
CApiTestWithProvider,
::testing::Values(0, 1, 2, 3, 4));
+TEST(CApiTest, SparseOutputModel) {
+ std::vector dense_shape{3, 3};
+ std::vector values{1.764052391052246, 0.40015721321105957, 0.978738009929657};
+ std::vector values_shape{3};
+ std::vector coo_indices{2, 3, 5};
+ std::vector indices_shape{3};
+
+ std::vector ort_inputs;
+ std::vector input_names;
+ const char* const output_names[] = {"values"};
+ Ort::Session session(*ort_env, SPARSE_OUTPUT_MODEL_URI, Ort::SessionOptions{});
+ auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
+ output_names, 1);
+ ASSERT_EQ(ort_outputs.size(), 1U);
+ const auto& sparse_output = ort_outputs[0];
+ auto ti = sparse_output.GetTypeInfo();
+ ASSERT_EQ(ONNX_TYPE_SPARSETENSOR, ti.GetONNXType());
+ auto tensor_type_shape = ti.GetTensorTypeAndShapeInfo();
+ ASSERT_EQ(dense_shape, tensor_type_shape.GetShape());
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_shape.GetElementType());
+
+ ASSERT_EQ(ORT_SPARSE_COO, sparse_output.GetSparseFormat());
+ auto values_ts = sparse_output.GetSparseTensorValuesTypeAndShapeInfo();
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, values_ts.GetElementType());
+ ASSERT_EQ(values_shape, values_ts.GetShape());
+
+ const auto* values_fetch = sparse_output.GetSparseTensorValues();
+ auto val_span = gsl::make_span(values_fetch, values.size());
+ ASSERT_TRUE(std::equal(values.cbegin(), values.cend(), val_span.cbegin(), val_span.cend()));
+
+ auto indices_ts = sparse_output.GetSparseTensorIndicesTypeShapeInfo(ORT_SPARSE_COO_INDICES);
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, indices_ts.GetElementType());
+ ASSERT_EQ(indices_shape, indices_ts.GetShape());
+
+ size_t num_indices = 0;
+ const int64_t* indices = sparse_output.GetSparseTensorIndicesData(ORT_SPARSE_COO_INDICES, num_indices);
+ ASSERT_EQ(num_indices, static_cast(indices_shape[0]));
+ auto ind_span = gsl::make_span(indices, num_indices);
+ ASSERT_TRUE(std::equal(coo_indices.cbegin(), coo_indices.cend(), ind_span.cbegin(), ind_span.cend()));
+}
+
+#ifndef DISABLE_CONTRIB_OPS
+TEST(CApiTest, SparseInputModel) {
+
+ std::vector common_shape{9, 9}; // inputs and outputs same shape
+ std::vector A_values{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,
+ 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0,
+ 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
+ 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0,
+ 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0,
+ 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0,
+ 50.0, 51.0, 52.0, 53.0};
+
+ // 2 - D index
+ std::vector indices_shape{gsl::narrow(A_values.size()), 2};
+ std::vector A_indices{0, 1, 0, 2, 0, 6, 0, 7, 0, 8, 1, 0, 1,
+ 1, 1, 2, 1, 6, 1, 7, 1, 8, 2, 0, 2, 1,
+ 2, 2, 2, 6, 2, 7, 2, 8, 3, 3, 3, 4, 3,
+ 5, 3, 6, 3, 7, 3, 8, 4, 3, 4, 4, 4, 5,
+ 4, 6, 4, 7, 4, 8, 5, 3, 5, 4, 5, 5, 5,
+ 6, 5, 7, 5, 8, 6, 0, 6, 1, 6, 2, 6, 3,
+ 6, 4, 6, 5, 7, 0, 7, 1, 7, 2, 7, 3, 7,
+ 4, 7, 5, 8, 0, 8, 1, 8, 2, 8, 3, 8, 4,
+ 8, 5};
+
+ std::vector B_data{0, 1, 2, 0, 0, 0, 3, 4, 5,
+ 6, 7, 8, 0, 0, 0, 9, 10, 11,
+ 12, 13, 14, 0, 0, 0, 15, 16, 17,
+ 0, 0, 0, 18, 19, 20, 21, 22, 23,
+ 0, 0, 0, 24, 25, 26, 27, 28, 29,
+ 0, 0, 0, 30, 31, 32, 33, 34, 35,
+ 36, 37, 38, 39, 40, 41, 0, 0, 0,
+ 42, 43, 44, 45, 46, 47, 0, 0, 0,
+ 48, 49, 50, 51, 52, 53, 0, 0, 0};
+
+ std::vector Y_result{546, 561, 576, 552, 564, 576, 39, 42, 45,
+ 1410, 1461, 1512, 1362, 1392, 1422, 201, 222, 243,
+ 2274, 2361, 2448, 2172, 2220, 2268, 363, 402, 441,
+ 2784, 2850, 2916, 4362, 4485, 4608, 1551, 1608, 1665,
+ 3540, 3624, 3708, 5604, 5763, 5922, 2037, 2112, 2187,
+ 4296, 4398, 4500, 6846, 7041, 7236, 2523, 2616, 2709,
+ 678, 789, 900, 2892, 3012, 3132, 4263, 4494, 4725,
+ 786, 915, 1044, 3324, 3462, 3600, 4911, 5178, 5445,
+ 894, 1041, 1188, 3756, 3912, 4068, 5559, 5862, 6165};
+
+ Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
+ Ort::Value::Shape ort_dense_shape{common_shape.data(), common_shape.size()};
+ Ort::Value::Shape ort_values_shape{&indices_shape[0], 1U};
+ auto a_st = Ort::Value::CreateSparseTensor(info, A_values.data(), ort_dense_shape, ort_values_shape);
+ a_st.UseCooIndices(A_indices.data(), A_indices.size());
+
+ auto b_tensor = Ort::Value::CreateTensor(info, B_data.data(), B_data.size(), common_shape.data(), common_shape.size());
+
+ std::vector ort_inputs;
+ ort_inputs.push_back(std::move(a_st));
+ ort_inputs.push_back(std::move(b_tensor));
+ const char* input_names[] = {"sparse_A", "dense_B"};
+ const char* const output_names[] = {"dense_Y"};
+ Ort::Session session(*ort_env, SPARSE_INPUT_MATMUL_MODEL_URI, Ort::SessionOptions{});
+ auto ort_outputs = session.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(),
+ output_names, 1);
+ ASSERT_EQ(ort_outputs.size(), 1U);
+ const auto& dense_Y = ort_outputs[0];
+ ASSERT_TRUE(dense_Y.IsTensor());
+
+ auto result_ts = dense_Y.GetTensorTypeAndShapeInfo();
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, result_ts.GetElementType());
+ ASSERT_EQ(common_shape, result_ts.GetShape());
+
+ const auto* result_vals = dense_Y.GetTensorData();
+ auto result_span = gsl::make_span(result_vals, Y_result.size());
+ ASSERT_TRUE(std::equal(Y_result.cbegin(), Y_result.cend(), result_span.cbegin(), result_span.cend()));
+}
+#endif // DISABLE_CONTRIB_OPS
+
TEST(CApiTest, custom_op_handler) {
std::cout << "Running custom op inference" << std::endl;
diff --git a/onnxruntime/test/shared_lib/test_nontensor_types.cc b/onnxruntime/test/shared_lib/test_nontensor_types.cc
index 8110074638..232b1d8a62 100644
--- a/onnxruntime/test/shared_lib/test_nontensor_types.cc
+++ b/onnxruntime/test/shared_lib/test_nontensor_types.cc
@@ -9,6 +9,8 @@
#include "core/session/onnxruntime_cxx_api.h"
#include "test_allocator.h"
+#include
+
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -306,3 +308,617 @@ TEST(CApiTest, TypeInfoSequence) {
ASSERT_EQ(seq_type_info.GetSequenceElementType().GetTensorTypeAndShapeInfo().GetElementType(),
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
}
+
+TEST(CApiTest, SparseTensorUsingAPI) {
+ Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
+
+ {
+ // COO
+ const std::vector dense_shape{3, 3};
+ const std::vector values_shape{3};
+ std::vector expected_values = {1, 1, 1};
+ constexpr int64_t values_len = 3;
+ std::vector expected_linear_indices = {2, 3, 5};
+ const std::vector indices_shape{3};
+
+ Ort::Value::Shape ort_dense_shape{dense_shape.data(), dense_shape.size()};
+ Ort::Value::Shape ort_values_shape{&values_len, 1U};
+ auto coo_st = Ort::Value::CreateSparseTensor(info, expected_values.data(), ort_dense_shape, ort_values_shape);
+ coo_st.UseCooIndices(expected_linear_indices.data(), expected_linear_indices.size());
+
+ {
+ auto ti = coo_st.GetTypeInfo();
+ ASSERT_EQ(ONNX_TYPE_SPARSETENSOR, ti.GetONNXType());
+ auto tensor_type_shape = ti.GetTensorTypeAndShapeInfo();
+ ASSERT_EQ(dense_shape, tensor_type_shape.GetShape());
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, tensor_type_shape.GetElementType());
+ ASSERT_EQ(dense_shape.size(), tensor_type_shape.GetDimensionsCount());
+ }
+
+ {
+ auto t_type_shape = coo_st.GetTensorTypeAndShapeInfo();
+ ASSERT_EQ(dense_shape, t_type_shape.GetShape());
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, t_type_shape.GetElementType());
+ ASSERT_EQ(dense_shape.size(), t_type_shape.GetDimensionsCount());
+ }
+
+ ASSERT_EQ(ORT_SPARSE_COO, coo_st.GetSparseFormat());
+
+ {
+ auto values_ts = coo_st.GetSparseTensorValuesTypeAndShapeInfo();
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, values_ts.GetElementType());
+ ASSERT_EQ(values_shape, values_ts.GetShape());
+ }
+
+ {
+ const auto* values = coo_st.GetSparseTensorValues();
+ auto val_span = gsl::make_span(values, values_shape[0]);
+ ASSERT_TRUE(std::equal(expected_values.cbegin(), expected_values.cend(), val_span.cbegin(), val_span.cend()));
+ }
+
+ {
+ auto indices_ts = coo_st.GetSparseTensorIndicesTypeShapeInfo(ORT_SPARSE_COO_INDICES);
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, indices_ts.GetElementType());
+ ASSERT_EQ(indices_shape, indices_ts.GetShape());
+
+ size_t num_indices = 0;
+ const int64_t* indices = coo_st.GetSparseTensorIndicesData(ORT_SPARSE_COO_INDICES, num_indices);
+ ASSERT_EQ(num_indices, static_cast(indices_shape[0]));
+ auto ind_span = gsl::make_span(indices, num_indices);
+ ASSERT_TRUE(std::equal(expected_linear_indices.cbegin(), expected_linear_indices.cend(), ind_span.cbegin(), ind_span.cend()));
+ }
+ }
+
+ {
+ // CSR test
+ const std::vector dense_shape{3, 3};
+ const std::vector values_shape{3};
+ const std::vector inner_shape{3};
+ const std::vector outer_shape{4};
+ std::vector expected_values = {1, 1, 1};
+ const std::vector expected_values_str = {"1", "1", "1"};
+ std::vector expected_inner = {2, 0, 2};
+ std::vector expected_outer = {0, 1, 3, 3};
+
+ Ort::Value::Shape ort_dense_shape{dense_shape.data(), dense_shape.size()};
+ constexpr int64_t values_len = 3;
+ Ort::Value::Shape ort_values_shape{&values_len, 1U};
+ auto csr_st = Ort::Value::CreateSparseTensor(info, expected_values.data(), ort_dense_shape, ort_values_shape);
+ csr_st.UseCsrIndices(expected_inner.data(), expected_inner.size(), expected_outer.data(), expected_outer.size());
+ {
+ auto ti = csr_st.GetTypeInfo();
+ ASSERT_EQ(ONNX_TYPE_SPARSETENSOR, ti.GetONNXType());
+ auto tensor_type_shape = ti.GetTensorTypeAndShapeInfo();
+ ASSERT_EQ(dense_shape, tensor_type_shape.GetShape());
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, tensor_type_shape.GetElementType());
+ ASSERT_EQ(dense_shape.size(), tensor_type_shape.GetDimensionsCount());
+ }
+
+ {
+ auto t_type_shape = csr_st.GetTensorTypeAndShapeInfo();
+ ASSERT_EQ(dense_shape, t_type_shape.GetShape());
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, t_type_shape.GetElementType());
+ ASSERT_EQ(dense_shape.size(), t_type_shape.GetDimensionsCount());
+ }
+
+ ASSERT_EQ(ORT_SPARSE_CSRC, csr_st.GetSparseFormat());
+
+ {
+ auto values_ts = csr_st.GetSparseTensorValuesTypeAndShapeInfo();
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, values_ts.GetElementType());
+ ASSERT_EQ(values_shape, values_ts.GetShape());
+ }
+
+ {
+ const auto* values = csr_st.GetSparseTensorValues();
+ auto val_span = gsl::make_span(values, expected_values.size());
+ ASSERT_TRUE(std::equal(expected_values.cbegin(), expected_values.cend(), val_span.cbegin(), val_span.cend()));
+ }
+
+ {
+ auto indices_ts = csr_st.GetSparseTensorIndicesTypeShapeInfo(ORT_SPARSE_CSR_INNER_INDICES);
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, indices_ts.GetElementType());
+ ASSERT_EQ(inner_shape, indices_ts.GetShape());
+
+ size_t num_indices = 0;
+ const int64_t* indices = csr_st.GetSparseTensorIndicesData(ORT_SPARSE_CSR_INNER_INDICES, num_indices);
+ ASSERT_EQ(num_indices, expected_inner.size());
+ auto ind_span = gsl::make_span(indices, num_indices);
+ ASSERT_TRUE(std::equal(expected_inner.cbegin(), expected_inner.cend(), ind_span.cbegin(), ind_span.cend()));
+ }
+
+ {
+ auto indices_ts = csr_st.GetSparseTensorIndicesTypeShapeInfo(ORT_SPARSE_CSR_OUTER_INDICES);
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, indices_ts.GetElementType());
+ ASSERT_EQ(outer_shape, indices_ts.GetShape());
+
+ size_t num_indices = 0;
+ const int64_t* indices = csr_st.GetSparseTensorIndicesData(ORT_SPARSE_CSR_OUTER_INDICES, num_indices);
+ ASSERT_EQ(num_indices, expected_outer.size());
+ auto ind_span = gsl::make_span(indices, num_indices);
+ ASSERT_TRUE(std::equal(expected_outer.cbegin(), expected_outer.cend(), ind_span.cbegin(), ind_span.cend()));
+ }
+ }
+
+ {
+ // BlockSparse test
+ const std::vector dense_shape{8, 8};
+ constexpr int64_t block_size = 2;
+ const std::vector values_shape{2, block_size, block_size};
+ // Two dense blocks
+ std::vector data_blocks{
+ 1, 2, 3, 4, 5, 6, 7, 8};
+ const std::vector indices_shape{2, 2}; // two blocks by two coordinates
+ // (0, 0), (0,1)
+ std::vector blocksparse_indices = {
+ 0, 0, 0, 1};
+
+ Ort::Value::Shape ort_dense_shape{dense_shape.data(), dense_shape.size()};
+ Ort::Value::Shape ort_values_shape{values_shape.data(), values_shape.size()};
+ auto bsp_st = Ort::Value::CreateSparseTensor(info, data_blocks.data(), ort_dense_shape, ort_values_shape);
+ bsp_st.UseBlockSparseIndices({indices_shape.data(), indices_shape.size()}, blocksparse_indices.data());
+ {
+ auto ti = bsp_st.GetTypeInfo();
+ ASSERT_EQ(ONNX_TYPE_SPARSETENSOR, ti.GetONNXType());
+ auto tensor_type_shape = ti.GetTensorTypeAndShapeInfo();
+ ASSERT_EQ(dense_shape, tensor_type_shape.GetShape());
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, tensor_type_shape.GetElementType());
+ ASSERT_EQ(dense_shape.size(), tensor_type_shape.GetDimensionsCount());
+ }
+ {
+ auto t_type_shape = bsp_st.GetTensorTypeAndShapeInfo();
+ ASSERT_EQ(dense_shape, t_type_shape.GetShape());
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, t_type_shape.GetElementType());
+ ASSERT_EQ(dense_shape.size(), t_type_shape.GetDimensionsCount());
+ }
+ ASSERT_EQ(ORT_SPARSE_BLOCK_SPARSE, bsp_st.GetSparseFormat());
+ {
+ auto values_ts = bsp_st.GetSparseTensorValuesTypeAndShapeInfo();
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, values_ts.GetElementType());
+ ASSERT_EQ(values_shape, values_ts.GetShape());
+ }
+ {
+ const auto* values = bsp_st.GetSparseTensorValues();
+ auto val_span = gsl::make_span(values, data_blocks.size());
+ ASSERT_TRUE(std::equal(data_blocks.cbegin(), data_blocks.cend(), val_span.cbegin(), val_span.cend()));
+ }
+ {
+ auto indices_ts = bsp_st.GetSparseTensorIndicesTypeShapeInfo(ORT_SPARSE_BLOCK_SPARSE_INDICES);
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, indices_ts.GetElementType());
+ ASSERT_EQ(indices_shape, indices_ts.GetShape());
+
+ size_t num_indices = 0;
+ const int32_t* indices = bsp_st.GetSparseTensorIndicesData(ORT_SPARSE_BLOCK_SPARSE_INDICES, num_indices);
+ ASSERT_EQ(num_indices, blocksparse_indices.size());
+ auto ind_span = gsl::make_span(indices, num_indices);
+ ASSERT_TRUE(std::equal(blocksparse_indices.cbegin(), blocksparse_indices.cend(), ind_span.cbegin(), ind_span.cend()));
+ }
+ }
+}
+
+TEST(CApiTest, SparseTensorFillSparseTensorFormatAPI) {
+ auto allocator = Ort::AllocatorWithDefaultOptions();
+ Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
+ {
+ // COO
+ const std::vector dense_shape{3, 3};
+ const std::vector values_shape{3};
+ std::vector expected_values = {1, 1, 1};
+ constexpr int64_t values_len = 3;
+ std::vector expected_linear_indices = {2, 3, 5};
+ const std::vector indices_shape{3};
+
+ Ort::Value::Shape ort_dense_shape{dense_shape.data(), dense_shape.size()};
+ auto coo_st = Ort::Value::CreateSparseTensor(allocator, ort_dense_shape);
+ coo_st.FillSparseTensorCoo(info, {&values_len, 1U, {expected_values.data()}},
+ expected_linear_indices.data(), expected_linear_indices.size());
+ {
+ auto ti = coo_st.GetTypeInfo();
+ ASSERT_EQ(ONNX_TYPE_SPARSETENSOR, ti.GetONNXType());
+ auto tensor_type_shape = ti.GetTensorTypeAndShapeInfo();
+ ASSERT_EQ(dense_shape, tensor_type_shape.GetShape());
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, tensor_type_shape.GetElementType());
+ ASSERT_EQ(dense_shape.size(), tensor_type_shape.GetDimensionsCount());
+ }
+
+ {
+ auto t_type_shape = coo_st.GetTensorTypeAndShapeInfo();
+ ASSERT_EQ(dense_shape, t_type_shape.GetShape());
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, t_type_shape.GetElementType());
+ ASSERT_EQ(dense_shape.size(), t_type_shape.GetDimensionsCount());
+ }
+
+ ASSERT_EQ(ORT_SPARSE_COO, coo_st.GetSparseFormat());
+
+ {
+ auto values_ts = coo_st.GetSparseTensorValuesTypeAndShapeInfo();
+ ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, values_ts.GetElementType());
+ ASSERT_EQ(values_shape, values_ts.GetShape());
+ }
+
+ {
+ const auto* values = coo_st.GetSparseTensorValues