mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Allow creation of string tensor sequence (#3048)
This commit is contained in:
parent
cb24e2a214
commit
21f9a8bdc2
2 changed files with 40 additions and 1 deletions
|
|
@ -1108,7 +1108,7 @@ static OrtStatus* OrtCreateValueImplSeqHelper(const OrtValue* const* in, size_t
|
|||
}
|
||||
|
||||
OrtStatus* st{};
|
||||
utils::MLTypeCallDispatcherRet<OrtStatus*, CallCreateValueImpl, bool, float, double,
|
||||
utils::MLTypeCallDispatcherRet<OrtStatus*, CallCreateValueImpl, bool, float, double, std::string,
|
||||
MLFloat16, BFloat16, int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, int64_t, uint64_t>
|
||||
t_disp(one_tensor.GetElementType());
|
||||
|
||||
|
|
|
|||
|
|
@ -176,3 +176,42 @@ TEST(CApiTest, CreateGetSeqTensors) {
|
|||
std::set<int64_t>(std::begin(vals), std::end(vals)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CApiTest, CreateGetSeqStringTensors) {
|
||||
// Creation
|
||||
auto default_allocator = onnxruntime::make_unique<MockedOrtAllocator>();
|
||||
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
||||
|
||||
std::vector<Ort::Value> in;
|
||||
const char* string_input_data[] = {"abs", "def"};
|
||||
const int N = 2;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
// create tensor
|
||||
std::vector<int64_t> shape{2};
|
||||
auto value = Ort::Value::CreateTensor(Ort::AllocatorWithDefaultOptions(), shape.data(), shape.size(), ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
|
||||
|
||||
Ort::GetApi().FillStringTensor(value, string_input_data, 2);
|
||||
in.push_back(std::move(value));
|
||||
}
|
||||
|
||||
Ort::Value seq_ort = Ort::Value::CreateSequence(in);
|
||||
|
||||
// Fetch
|
||||
std::set<std::string> string_set;
|
||||
for (int idx = 0; idx < N; ++idx) {
|
||||
Ort::Value out = seq_ort.GetValue(idx, default_allocator.get());
|
||||
size_t data_len = out.GetStringTensorDataLength();
|
||||
std::string result(data_len, '\0');
|
||||
std::vector<size_t> offsets(N);
|
||||
out.GetStringTensorContent((void*)result.data(), data_len, offsets.data(), offsets.size());
|
||||
|
||||
const char* s = result.data();
|
||||
for (size_t i = 0; i < offsets.size(); ++i) {
|
||||
size_t start = offsets[i];
|
||||
size_t count = (i + 1) < offsets.size() ? offsets[i + 1] - start : data_len - start;
|
||||
std::string stemp(s + start, count);
|
||||
string_set.insert(stemp);
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(string_set, std::set<std::string>(std::begin(string_input_data), std::end(string_input_data)));
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue