mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Fix a build break in tf_test_session.h (#1205)
This file need tensorflow C API package to build, so it is not part of the CI.
This commit is contained in:
parent
6d5ea08936
commit
15bcde5053
1 changed files with 8 additions and 4 deletions
|
|
@ -116,12 +116,16 @@ class TensorflowTestSession : public TestSession {
|
|||
ORT_THROW_ON_ERROR(OrtGetTensorTypeAndShape(value, &shape));
|
||||
size_t buffer_length = 0;
|
||||
std::vector<int64_t> dims;
|
||||
size_t dim_count = OrtGetDimensionsCount(shape);
|
||||
size_t dim_count;
|
||||
ORT_THROW_ON_ERROR(OrtGetDimensionsCount(shape, &dim_count));
|
||||
dims.resize(dim_count);
|
||||
OrtGetDimensions(shape, dims.data(), dim_count);
|
||||
int64_t ele_count = OrtGetTensorShapeElementCount(shape);
|
||||
ORT_THROW_ON_ERROR(OrtGetDimensions(shape, dims.data(), dim_count));
|
||||
size_t ele_count;
|
||||
ORT_THROW_ON_ERROR(OrtGetTensorShapeElementCount(shape, &ele_count));
|
||||
TF_DataType tf_datatype;
|
||||
switch (OrtGetTensorElementType(shape)) {
|
||||
ONNXTensorElementDataType element_type;
|
||||
ORT_THROW_ON_ERROR(OrtGetTensorElementType(shape, &element_type));
|
||||
switch (element_type) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: // maps to c type float
|
||||
buffer_length = ele_count * sizeof(float);
|
||||
tf_datatype = TF_FLOAT;
|
||||
|
|
|
|||
Loading…
Reference in a new issue