Allow creation of string tensor sequence (#3048)

This commit is contained in:
Pranav Sharma 2020-02-20 11:27:42 -08:00 committed by GitHub
parent cb24e2a214
commit 21f9a8bdc2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 1 deletions

View file

@ -1108,7 +1108,7 @@ static OrtStatus* OrtCreateValueImplSeqHelper(const OrtValue* const* in, size_t
}
OrtStatus* st{};
utils::MLTypeCallDispatcherRet<OrtStatus*, CallCreateValueImpl, bool, float, double,
utils::MLTypeCallDispatcherRet<OrtStatus*, CallCreateValueImpl, bool, float, double, std::string,
MLFloat16, BFloat16, int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, int64_t, uint64_t>
t_disp(one_tensor.GetElementType());

View file

@ -176,3 +176,42 @@ TEST(CApiTest, CreateGetSeqTensors) {
std::set<int64_t>(std::begin(vals), std::end(vals)));
}
}
TEST(CApiTest, CreateGetSeqStringTensors) {
// Creation
auto default_allocator = onnxruntime::make_unique<MockedOrtAllocator>();
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
std::vector<Ort::Value> in;
const char* string_input_data[] = {"abs", "def"};
const int N = 2;
for (int i = 0; i < N; ++i) {
// create tensor
std::vector<int64_t> shape{2};
auto value = Ort::Value::CreateTensor(Ort::AllocatorWithDefaultOptions(), shape.data(), shape.size(), ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
Ort::GetApi().FillStringTensor(value, string_input_data, 2);
in.push_back(std::move(value));
}
Ort::Value seq_ort = Ort::Value::CreateSequence(in);
// Fetch
std::set<std::string> string_set;
for (int idx = 0; idx < N; ++idx) {
Ort::Value out = seq_ort.GetValue(idx, default_allocator.get());
size_t data_len = out.GetStringTensorDataLength();
std::string result(data_len, '\0');
std::vector<size_t> offsets(N);
out.GetStringTensorContent((void*)result.data(), data_len, offsets.data(), offsets.size());
const char* s = result.data();
for (size_t i = 0; i < offsets.size(); ++i) {
size_t start = offsets[i];
size_t count = (i + 1) < offsets.size() ? offsets[i + 1] - start : data_len - start;
std::string stemp(s + start, count);
string_set.insert(stemp);
}
}
ASSERT_EQ(string_set, std::set<std::string>(std::begin(string_input_data), std::end(string_input_data)));
}