mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
Add compatibility with Protobuf 3.12 (#4291)
In Protobuf 3.12, classes generated from protobuf files are declared as `final`, so use those classes as members rather than base classes. Ref: https://github.com/protocolbuffers/protobuf/releases/tag/v3.12.0
This commit is contained in:
parent
5db67ec000
commit
a37e2e33b4
2 changed files with 70 additions and 62 deletions
|
|
@ -149,44 +149,48 @@ struct DimSetter<d, dims...> {
|
|||
};
|
||||
|
||||
template <int... dims>
|
||||
struct TensorShapeTypeProto : public TensorShapeProto {
|
||||
struct TensorShapeTypeProto {
|
||||
TensorShapeTypeProto() {
|
||||
DimSetter<dims...>::set(*this);
|
||||
DimSetter<dims...>::set(proto);
|
||||
}
|
||||
TensorShapeProto proto;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TensorShapeTypeProto<> : public TensorShapeProto {};
|
||||
struct TensorShapeTypeProto<> { TensorShapeProto proto; };
|
||||
|
||||
template <TensorProto_DataType T>
|
||||
struct TensorTypeProto : public TypeProto {
|
||||
struct TensorTypeProto {
|
||||
TensorTypeProto() {
|
||||
mutable_tensor_type()->set_elem_type(T);
|
||||
proto.mutable_tensor_type()->set_elem_type(T);
|
||||
}
|
||||
TypeProto proto;
|
||||
};
|
||||
|
||||
template <TensorProto_DataType T>
|
||||
struct SparseTensorTypeProto : public TypeProto {
|
||||
struct SparseTensorTypeProto {
|
||||
SparseTensorTypeProto() {
|
||||
mutable_sparse_tensor_type()->set_elem_type(T);
|
||||
proto.mutable_sparse_tensor_type()->set_elem_type(T);
|
||||
}
|
||||
void SetShape(const TensorShapeProto& shape) {
|
||||
mutable_sparse_tensor_type()->mutable_shape()->CopyFrom(shape);
|
||||
proto.mutable_sparse_tensor_type()->mutable_shape()->CopyFrom(shape);
|
||||
}
|
||||
void SetShape(TensorShapeProto&& shape) {
|
||||
*mutable_sparse_tensor_type()->mutable_shape() = std::move(shape);
|
||||
*proto.mutable_sparse_tensor_type()->mutable_shape() = std::move(shape);
|
||||
}
|
||||
void ClearShape() {
|
||||
mutable_sparse_tensor_type()->clear_shape();
|
||||
proto.mutable_sparse_tensor_type()->clear_shape();
|
||||
}
|
||||
TypeProto proto;
|
||||
};
|
||||
|
||||
template <TensorProto_DataType key, TensorProto_DataType value>
|
||||
struct MapTypeProto : public TypeProto {
|
||||
struct MapTypeProto {
|
||||
MapTypeProto() {
|
||||
mutable_map_type()->set_key_type(key);
|
||||
mutable_map_type()->mutable_value_type()->mutable_tensor_type()->set_elem_type(value);
|
||||
proto.mutable_map_type()->set_key_type(key);
|
||||
proto.mutable_map_type()->mutable_value_type()->mutable_tensor_type()->set_elem_type(value);
|
||||
}
|
||||
TypeProto proto;
|
||||
};
|
||||
|
||||
class DataTypeTest : public testing::Test {
|
||||
|
|
@ -241,9 +245,9 @@ TEST_F(DataTypeTest, OpaqueRegistrationTest) {
|
|||
TEST_F(DataTypeTest, MapStringStringTest) {
|
||||
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
|
||||
auto ml_str_str = DataTypeImpl::GetType<MapStringToString>();
|
||||
EXPECT_TRUE(DataTypeImpl::GetTensorType<float>()->IsCompatible(tensor_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetTensorType<uint64_t>()->IsCompatible(tensor_type));
|
||||
EXPECT_FALSE(ml_str_str->IsCompatible(tensor_type));
|
||||
EXPECT_TRUE(DataTypeImpl::GetTensorType<float>()->IsCompatible(tensor_type.proto));
|
||||
EXPECT_FALSE(DataTypeImpl::GetTensorType<uint64_t>()->IsCompatible(tensor_type.proto));
|
||||
EXPECT_FALSE(ml_str_str->IsCompatible(tensor_type.proto));
|
||||
utils::ContainerChecker c_checker(ml_str_str);
|
||||
bool result = c_checker.IsMapOf<std::string, std::string>();
|
||||
EXPECT_TRUE(result);
|
||||
|
|
@ -257,17 +261,17 @@ TEST_F(DataTypeTest, MapStringStringTest) {
|
|||
|
||||
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_STRING> maps2s_type;
|
||||
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_INT64> maps2i_type;
|
||||
EXPECT_TRUE(ml_str_str->IsCompatible(maps2s_type));
|
||||
EXPECT_FALSE(ml_str_str->IsCompatible(maps2i_type));
|
||||
EXPECT_TRUE(ml_str_str->IsCompatible(maps2s_type.proto));
|
||||
EXPECT_FALSE(ml_str_str->IsCompatible(maps2i_type.proto));
|
||||
}
|
||||
|
||||
TEST_F(DataTypeTest, MapStringInt64Test) {
|
||||
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_STRING> maps2s_type;
|
||||
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_INT64> maps2i_type;
|
||||
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToInt64>()->IsCompatible(maps2s_type));
|
||||
EXPECT_TRUE(DataTypeImpl::GetType<MapStringToInt64>()->IsCompatible(maps2i_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToInt64>()->IsCompatible(tensor_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToInt64>()->IsCompatible(maps2s_type.proto));
|
||||
EXPECT_TRUE(DataTypeImpl::GetType<MapStringToInt64>()->IsCompatible(maps2i_type.proto));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToInt64>()->IsCompatible(tensor_type.proto));
|
||||
|
||||
utils::ContainerChecker c_checker(DataTypeImpl::GetType<MapStringToInt64>());
|
||||
bool result = c_checker.IsMapOf<std::string, int64_t>();
|
||||
|
|
@ -278,36 +282,36 @@ TEST_F(DataTypeTest, MapStringFloatTest) {
|
|||
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_FLOAT> maps2f_type;
|
||||
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_INT64> maps2i_type;
|
||||
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
|
||||
EXPECT_TRUE(DataTypeImpl::GetType<MapStringToFloat>()->IsCompatible(maps2f_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToFloat>()->IsCompatible(maps2i_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToFloat>()->IsCompatible(tensor_type));
|
||||
EXPECT_TRUE(DataTypeImpl::GetType<MapStringToFloat>()->IsCompatible(maps2f_type.proto));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToFloat>()->IsCompatible(maps2i_type.proto));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToFloat>()->IsCompatible(tensor_type.proto));
|
||||
}
|
||||
|
||||
TEST_F(DataTypeTest, MapStringDoubleTest) {
|
||||
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_DOUBLE> maps2d_type;
|
||||
MapTypeProto<TensorProto_DataType_STRING, TensorProto_DataType_INT64> maps2i_type;
|
||||
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
|
||||
EXPECT_TRUE(DataTypeImpl::GetType<MapStringToDouble>()->IsCompatible(maps2d_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToDouble>()->IsCompatible(maps2i_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToDouble>()->IsCompatible(tensor_type));
|
||||
EXPECT_TRUE(DataTypeImpl::GetType<MapStringToDouble>()->IsCompatible(maps2d_type.proto));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToDouble>()->IsCompatible(maps2i_type.proto));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapStringToDouble>()->IsCompatible(tensor_type.proto));
|
||||
}
|
||||
|
||||
TEST_F(DataTypeTest, MapInt64StringTest) {
|
||||
MapTypeProto<TensorProto_DataType_INT64, TensorProto_DataType_STRING> mapi2s_type;
|
||||
MapTypeProto<TensorProto_DataType_INT64, TensorProto_DataType_INT64> mapi2i_type;
|
||||
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
|
||||
EXPECT_TRUE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(mapi2s_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(mapi2i_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(tensor_type));
|
||||
EXPECT_TRUE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(mapi2s_type.proto));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(mapi2i_type.proto));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(tensor_type.proto));
|
||||
}
|
||||
|
||||
TEST_F(DataTypeTest, MapInt64DoubleTest) {
|
||||
MapTypeProto<TensorProto_DataType_INT64, TensorProto_DataType_DOUBLE> mapi2d_type;
|
||||
MapTypeProto<TensorProto_DataType_INT64, TensorProto_DataType_INT64> mapi2i_type;
|
||||
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
|
||||
EXPECT_TRUE(DataTypeImpl::GetType<MapInt64ToDouble>()->IsCompatible(mapi2d_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(mapi2i_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(tensor_type));
|
||||
EXPECT_TRUE(DataTypeImpl::GetType<MapInt64ToDouble>()->IsCompatible(mapi2d_type.proto));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(mapi2i_type.proto));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<MapInt64ToString>()->IsCompatible(tensor_type.proto));
|
||||
}
|
||||
|
||||
TEST_F(DataTypeTest, RecursiveMapTest) {
|
||||
|
|
@ -366,9 +370,9 @@ TEST_F(DataTypeTest, VectorMapStringToFloatTest) {
|
|||
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
|
||||
|
||||
EXPECT_TRUE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(vector_map_string_to_float));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(mapi2d_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(mapi2i_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(tensor_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(mapi2d_type.proto));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(mapi2i_type.proto));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapStringToFloat>()->IsCompatible(tensor_type.proto));
|
||||
utils::ContainerChecker c_check(DataTypeImpl::GetType<VectorMapStringToFloat>());
|
||||
bool result = c_check.IsSequenceOf<MapStringToFloat>();
|
||||
EXPECT_TRUE(result);
|
||||
|
|
@ -384,9 +388,9 @@ TEST_F(DataTypeTest, VectorMapInt64ToFloatTest) {
|
|||
TensorTypeProto<TensorProto_DataType_FLOAT> tensor_type;
|
||||
|
||||
EXPECT_TRUE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(type_proto));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(mapi2d_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(mapi2i_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(tensor_type));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(mapi2d_type.proto));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(mapi2i_type.proto));
|
||||
EXPECT_FALSE(DataTypeImpl::GetType<VectorMapInt64ToFloat>()->IsCompatible(tensor_type.proto));
|
||||
}
|
||||
|
||||
TEST_F(DataTypeTest, BFloat16Test) {
|
||||
|
|
@ -476,7 +480,7 @@ TEST_F(DataTypeTest, DataUtilsTest) {
|
|||
// We expect that the above string will be matched in both cases
|
||||
// where we have shape and where we don't
|
||||
SparseTensorTypeProto<TensorProto_DataType_UINT64> sparse_proto;
|
||||
DataType ten_dt = DataTypeUtils::ToType(sparse_proto);
|
||||
DataType ten_dt = DataTypeUtils::ToType(sparse_proto.proto);
|
||||
EXPECT_NE(ten_dt, nullptr);
|
||||
EXPECT_EQ(tensor_uint64, *ten_dt);
|
||||
DataType ten_from_str = DataTypeUtils::ToType(*ten_dt);
|
||||
|
|
@ -485,8 +489,8 @@ TEST_F(DataTypeTest, DataUtilsTest) {
|
|||
|
||||
// Now add empty shape, we expect the same string
|
||||
TensorShapeTypeProto<> shape_no_dims;
|
||||
sparse_proto.SetShape(shape_no_dims);
|
||||
ten_dt = DataTypeUtils::ToType(sparse_proto);
|
||||
sparse_proto.SetShape(shape_no_dims.proto);
|
||||
ten_dt = DataTypeUtils::ToType(sparse_proto.proto);
|
||||
EXPECT_NE(ten_dt, nullptr);
|
||||
EXPECT_EQ(tensor_uint64, *ten_dt);
|
||||
ten_from_str = DataTypeUtils::ToType(*ten_dt);
|
||||
|
|
@ -496,8 +500,8 @@ TEST_F(DataTypeTest, DataUtilsTest) {
|
|||
// Now add shape with dimensions, we expect no difference
|
||||
sparse_proto.ClearShape();
|
||||
TensorShapeTypeProto<10, 12> shape_with_dim;
|
||||
sparse_proto.SetShape(shape_with_dim);
|
||||
ten_dt = DataTypeUtils::ToType(sparse_proto);
|
||||
sparse_proto.SetShape(shape_with_dim.proto);
|
||||
ten_dt = DataTypeUtils::ToType(sparse_proto.proto);
|
||||
EXPECT_NE(ten_dt, nullptr);
|
||||
EXPECT_EQ(tensor_uint64, *ten_dt);
|
||||
ten_from_str = DataTypeUtils::ToType(*ten_dt);
|
||||
|
|
|
|||
|
|
@ -120,12 +120,12 @@ constexpr ONNX_NAMESPACE::TensorProto_DataType TypeToDataType<BFloat16>() {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
struct TTypeProto : ONNX_NAMESPACE::TypeProto {
|
||||
struct TTypeProto {
|
||||
TTypeProto(const std::vector<int64_t>* shape = nullptr) {
|
||||
mutable_tensor_type()->set_elem_type(TypeToDataType<T>());
|
||||
proto.mutable_tensor_type()->set_elem_type(TypeToDataType<T>());
|
||||
|
||||
if (shape) {
|
||||
auto mutable_shape = mutable_tensor_type()->mutable_shape();
|
||||
auto mutable_shape = proto.mutable_tensor_type()->mutable_shape();
|
||||
for (auto i : *shape) {
|
||||
auto* mutable_dim = mutable_shape->add_dim();
|
||||
if (i != -1)
|
||||
|
|
@ -135,6 +135,7 @@ struct TTypeProto : ONNX_NAMESPACE::TypeProto {
|
|||
}
|
||||
}
|
||||
}
|
||||
ONNX_NAMESPACE::TypeProto proto;
|
||||
};
|
||||
|
||||
// Variable template for ONNX_NAMESPACE::TensorProto_DataTypes, s_type_proto<float>, etc..
|
||||
|
|
@ -148,12 +149,13 @@ const TTypeProto<T> TTensorType<T>::s_type_proto;
|
|||
|
||||
// TypeProto for map<TKey, TVal>
|
||||
template <typename TKey, typename TVal>
|
||||
struct MTypeProto : ONNX_NAMESPACE::TypeProto {
|
||||
struct MTypeProto {
|
||||
MTypeProto() {
|
||||
mutable_map_type()->set_key_type(TypeToDataType<TKey>());
|
||||
mutable_map_type()->mutable_value_type()->mutable_tensor_type()->set_elem_type(TypeToDataType<TVal>());
|
||||
mutable_map_type()->mutable_value_type()->mutable_tensor_type()->mutable_shape()->clear_dim();
|
||||
proto.mutable_map_type()->set_key_type(TypeToDataType<TKey>());
|
||||
proto.mutable_map_type()->mutable_value_type()->mutable_tensor_type()->set_elem_type(TypeToDataType<TVal>());
|
||||
proto.mutable_map_type()->mutable_value_type()->mutable_tensor_type()->mutable_shape()->clear_dim();
|
||||
}
|
||||
ONNX_NAMESPACE::TypeProto proto;
|
||||
};
|
||||
|
||||
template <typename TKey, typename TVal>
|
||||
|
|
@ -166,13 +168,14 @@ const MTypeProto<TKey, TVal> MMapType<TKey, TVal>::s_map_type_proto;
|
|||
|
||||
// TypeProto for vector<map<TKey, TVal>>
|
||||
template <typename TKey, typename TVal>
|
||||
struct VectorOfMapTypeProto : ONNX_NAMESPACE::TypeProto {
|
||||
struct VectorOfMapTypeProto {
|
||||
VectorOfMapTypeProto() {
|
||||
auto* map_type = mutable_sequence_type()->mutable_elem_type()->mutable_map_type();
|
||||
auto* map_type = proto.mutable_sequence_type()->mutable_elem_type()->mutable_map_type();
|
||||
map_type->set_key_type(TypeToDataType<TKey>());
|
||||
map_type->mutable_value_type()->mutable_tensor_type()->set_elem_type(TypeToDataType<TVal>());
|
||||
map_type->mutable_value_type()->mutable_tensor_type()->mutable_shape()->clear_dim();
|
||||
}
|
||||
ONNX_NAMESPACE::TypeProto proto;
|
||||
};
|
||||
|
||||
template <typename TKey, typename TVal>
|
||||
|
|
@ -184,14 +187,15 @@ template <typename TKey, typename TVal>
|
|||
const VectorOfMapTypeProto<TKey, TVal> VectorOfMapType<TKey, TVal>::s_vec_map_type_proto;
|
||||
|
||||
template <typename ElemType>
|
||||
struct SequenceTensorTypeProto : ONNX_NAMESPACE::TypeProto {
|
||||
struct SequenceTensorTypeProto {
|
||||
SequenceTensorTypeProto() {
|
||||
MLDataType dt = DataTypeImpl::GetTensorType<ElemType>();
|
||||
const auto* elem_proto = dt->GetTypeProto();
|
||||
mutable_sequence_type()->mutable_elem_type()->CopyFrom(*elem_proto);
|
||||
auto* tensor_type = mutable_sequence_type()->mutable_elem_type()->mutable_tensor_type();
|
||||
proto.mutable_sequence_type()->mutable_elem_type()->CopyFrom(*elem_proto);
|
||||
auto* tensor_type = proto.mutable_sequence_type()->mutable_elem_type()->mutable_tensor_type();
|
||||
tensor_type->set_elem_type(TypeToDataType<ElemType>());
|
||||
}
|
||||
ONNX_NAMESPACE::TypeProto proto;
|
||||
};
|
||||
|
||||
template <typename ElemType>
|
||||
|
|
@ -306,14 +310,14 @@ class OpTester {
|
|||
OrtValue value;
|
||||
value.Init(ptr.release(), DataTypeImpl::GetType<std::map<TKey, TVal>>(),
|
||||
DataTypeImpl::GetType<std::map<TKey, TVal>>()->GetDeleteFunc());
|
||||
input_data_.push_back(Data(NodeArg(name, &MMapType<TKey, TVal>::s_map_type_proto), std::move(value),
|
||||
input_data_.push_back(Data(NodeArg(name, &MMapType<TKey, TVal>::s_map_type_proto.proto), std::move(value),
|
||||
optional<float>(), optional<float>()));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AddMissingOptionalInput() {
|
||||
std::string name; // empty == input doesn't exist
|
||||
input_data_.push_back(Data(NodeArg(name, &TTensorType<T>::s_type_proto), OrtValue(), optional<float>(),
|
||||
input_data_.push_back(Data(NodeArg(name, &TTensorType<T>::s_type_proto.proto), OrtValue(), optional<float>(),
|
||||
optional<float>()));
|
||||
}
|
||||
|
||||
|
|
@ -338,7 +342,7 @@ class OpTester {
|
|||
template <typename T>
|
||||
void AddMissingOptionalOutput() {
|
||||
std::string name; // empty == input doesn't exist
|
||||
output_data_.push_back(Data(NodeArg(name, &TTensorType<T>::s_type_proto), OrtValue(), optional<float>(),
|
||||
output_data_.push_back(Data(NodeArg(name, &TTensorType<T>::s_type_proto.proto), OrtValue(), optional<float>(),
|
||||
optional<float>()));
|
||||
}
|
||||
|
||||
|
|
@ -374,7 +378,7 @@ class OpTester {
|
|||
OrtValue ml_value;
|
||||
ml_value.Init(ptr.release(), DataTypeImpl::GetType<std::vector<std::map<TKey, TVal>>>(),
|
||||
DataTypeImpl::GetType<std::vector<std::map<TKey, TVal>>>()->GetDeleteFunc());
|
||||
output_data_.push_back(Data(NodeArg(name, &VectorOfMapType<TKey, TVal>::s_vec_map_type_proto), std::move(ml_value),
|
||||
output_data_.push_back(Data(NodeArg(name, &VectorOfMapType<TKey, TVal>::s_vec_map_type_proto.proto), std::move(ml_value),
|
||||
optional<float>(), optional<float>()));
|
||||
}
|
||||
|
||||
|
|
@ -530,7 +534,7 @@ class OpTester {
|
|||
OrtValue value;
|
||||
value.Init(p_tensor.release(), DataTypeImpl::GetType<Tensor>(),
|
||||
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
|
||||
auto node_arg = NodeArg(name, &type_proto);
|
||||
auto node_arg = NodeArg(name, &type_proto.proto);
|
||||
if (dim_params && !(dim_params->empty())) {
|
||||
// If dim_params presents, configure node_arg's dim value based on dim_params, which supports symbolic dim and dim broadcast.
|
||||
auto& dim_params_data = *dim_params;
|
||||
|
|
@ -590,7 +594,7 @@ class OpTester {
|
|||
ptr->SetElements(std::move(tensors));
|
||||
value.Init(ptr.get(), mltype, mltype->GetDeleteFunc());
|
||||
ptr.release();
|
||||
data.push_back(Data(NodeArg(name, &SequenceTensorType<T>::s_sequence_tensor_type_proto), std::move(value),
|
||||
data.push_back(Data(NodeArg(name, &SequenceTensorType<T>::s_sequence_tensor_type_proto.proto), std::move(value),
|
||||
optional<float>(), optional<float>()));
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue