From 1eb146f561fab04445f36ee7f18b967433ecccdb Mon Sep 17 00:00:00 2001 From: Ryan Lai Date: Thu, 10 Dec 2020 17:47:50 -0800 Subject: [PATCH] Implement conversion from ORT String to WinML Tensor String (#6097) * Implement conversion from ort string to winml string * NIT:comment --- winml/test/model/ort_value_helper.cpp | 54 +++++++++++++++++++++++++++ winml/test/model/skip_model_tests.h | 4 +- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/winml/test/model/ort_value_helper.cpp b/winml/test/model/ort_value_helper.cpp index d3db477f41..7e461a4cb1 100644 --- a/winml/test/model/ort_value_helper.cpp +++ b/winml/test/model/ort_value_helper.cpp @@ -1,5 +1,6 @@ #include "testPch.h" #include "ort_value_helper.h" +#include "StringHelpers.h" using namespace winml; namespace OrtValueHelpers { @@ -13,6 +14,56 @@ winml::ITensor CreateTensorFromShape(std::vector& shape) return tensor; } +static int64_t ShapeSize(const int64_t* shape, size_t count) { + // for each dim + int64_t size = 1; + for (size_t i = 0; i < count; i++) { + // find out it's total size + size *= shape[i]; + // make sure there are no invalid dimensions (-1 or any invalid shape) + THROW_HR_IF(E_INVALIDARG, shape[i] <= 0); + } + return size; +} + +winml::ITensor CreateStringTensor(Ort::Value& val) { + size_t dimensionCount = 0; + WINML_EXPECT_NO_THROW(dimensionCount = val.GetTensorTypeAndShapeInfo().GetDimensionsCount()); + std::vector shape; + if (dimensionCount > 0) { + WINML_EXPECT_NO_THROW(shape = val.GetTensorTypeAndShapeInfo().GetShape()); + } + auto length = ShapeSize(shape.data(), shape.size()); + + // make a big buffer to hold all the string data + size_t bufferLength = 0; + WINML_EXPECT_NO_THROW(bufferLength = val.GetStringTensorDataLength()); + + std::vector strings; + std::unique_ptr buffer(new uint8_t[bufferLength]); + std::vector offsets(static_cast(length)); + + WINML_EXPECT_NO_THROW(val.GetStringTensorContent(buffer.get(), bufferLength, offsets.data(), offsets.size())); + + // now go build all the strings + for (auto i = 0; i < length; ++i) { + size_t strLength = 0; + // are we on the last one? + if (i == (length - 1)) { + strLength = bufferLength - offsets[i]; + } else { + strLength = offsets[i+1] - offsets[i]; + } + auto strView = std::string_view(reinterpret_cast(buffer.get() + offsets[i]), strLength); + strings.push_back(_winml::Strings::HStringFromUTF8(strView.data(), strLength)); + } + + TensorString tensor = nullptr; + WINML_EXPECT_NO_THROW(tensor = TensorString::CreateFromShapeArrayAndDataArray(shape, strings)); + return tensor; +} + + // This function takes in an Ort::Value and returns a copy of winml::ITensor // TODO: String types still need to be implemented. winml::ITensor LoadTensorFromOrtValue(Ort::Value& val) { @@ -39,6 +90,9 @@ winml::ITensor LoadTensorFromOrtValue(Ort::Value& val) { tensor = CreateTensorFromShape(shape); break; } + case (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING): { + return CreateStringTensor(val); + } case (ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32): { tensor = CreateTensorFromShape(shape); break; diff --git a/winml/test/model/skip_model_tests.h b/winml/test/model/skip_model_tests.h index ccf12736da..29e3db71b8 100644 --- a/winml/test/model/skip_model_tests.h +++ b/winml/test/model/skip_model_tests.h @@ -7,9 +7,7 @@ static const std::string disabledGpuTestDefaultReason = "Model not working on GP // {"model test name", "reason for why it is happening and bug filed for it."} std::unordered_map disabledTests( - {// Onnx zoo models - {"test_bidaf_opset9", "Bug 31011100: Processing string tensors need to be implemented in WinML model tests https://microsoft.visualstudio.com/OS/_workitems/edit/31011100"}, - + { // Tier 2 models {"coreml_VGG16_ImageNet_opset8", "Bug 31011100: Processing string tensors need to be implemented in WinML model tests https://microsoft.visualstudio.com/OS/_workitems/edit/31011100"}, {"coreml_VGG16_ImageNet_opset9", "Bug 31011100: Processing string tensors need to be implemented in WinML model tests https://microsoft.visualstudio.com/OS/_workitems/edit/31011100"},