Fix a build break in tf_test_session.h (#1205)

This file need tensorflow C API package to build, so it is not part of the CI.
This commit is contained in:
Changming Sun 2019-06-11 14:05:24 -07:00 committed by GitHub
parent 6d5ea08936
commit 15bcde5053
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -116,12 +116,16 @@ class TensorflowTestSession : public TestSession {
ORT_THROW_ON_ERROR(OrtGetTensorTypeAndShape(value, &shape));
size_t buffer_length = 0;
std::vector<int64_t> dims;
size_t dim_count = OrtGetDimensionsCount(shape);
size_t dim_count;
ORT_THROW_ON_ERROR(OrtGetDimensionsCount(shape, &dim_count));
dims.resize(dim_count);
OrtGetDimensions(shape, dims.data(), dim_count);
int64_t ele_count = OrtGetTensorShapeElementCount(shape);
ORT_THROW_ON_ERROR(OrtGetDimensions(shape, dims.data(), dim_count));
size_t ele_count;
ORT_THROW_ON_ERROR(OrtGetTensorShapeElementCount(shape, &ele_count));
TF_DataType tf_datatype;
switch (OrtGetTensorElementType(shape)) {
ONNXTensorElementDataType element_type;
ORT_THROW_ON_ERROR(OrtGetTensorElementType(shape, &element_type));
switch (element_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: // maps to c type float
buffer_length = ele_count * sizeof(float);
tf_datatype = TF_FLOAT;