#include "testPch.h" #include "ort_value_helper.h" using namespace winml; namespace OrtValueHelpers { template winml::ITensor CreateTensorFromShape(std::vector& shape) { using WinMLTensorKind = typename ONNXTensorElementDataTypeToWinMLTensorKind::Type; ITensor tensor = nullptr; WINML_EXPECT_NO_THROW(tensor = WinMLTensorKind::Create(shape)); return tensor; } // This function takes in an Ort::Value and returns a copy of winml::ITensor // TODO: String types still need to be implemented. winml::ITensor LoadTensorFromOrtValue(Ort::Value& val) { ITensor tensor = nullptr; auto tensorTypeAndShape = val.GetTensorTypeAndShapeInfo(); auto shape = tensorTypeAndShape.GetShape(); switch (tensorTypeAndShape.GetElementType()) { case (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT): { tensor = CreateTensorFromShape(shape); break; } case (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8): { tensor = CreateTensorFromShape(shape); break; } case (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8): { tensor = CreateTensorFromShape(shape); } case (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16): { tensor = CreateTensorFromShape(shape); break; } case (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16): { tensor = CreateTensorFromShape(shape); break; } case (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32): { tensor = CreateTensorFromShape(shape); break; } case (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64): { tensor = CreateTensorFromShape(shape); break; } case (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL): { tensor = CreateTensorFromShape(shape); break; } case (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16): { tensor = CreateTensorFromShape(shape); break; } case (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE): { tensor = CreateTensorFromShape(shape); break; } case (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32): { tensor = CreateTensorFromShape(shape); break; } case (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64): { tensor = CreateTensorFromShape(shape); break; } case (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16): { tensor = CreateTensorFromShape(shape); break; } default: throw winrt::hresult_invalid_argument(L"TensorType not implemented yet."); } BYTE* actualData = nullptr; uint32_t actualSizeInBytes = 0; WINML_EXPECT_NO_THROW(tensor.as()->GetBuffer(&actualData, &actualSizeInBytes)); void* ortValueTensorData = nullptr; WINML_EXPECT_NO_THROW(Ort::GetApi().GetTensorMutableData(val, &ortValueTensorData)); WINML_EXPECT_NO_THROW(memcpy(actualData, ortValueTensorData, actualSizeInBytes * sizeof(char))); return tensor; } } // namespace OrtValueHelpers