pytorch/caffe2/core/qtensor_serialization.h
Michael Antonov a6949abb15 Guard all Caffe2 protobuf string serializations with CAFFE_ENFORCE (fixed reverted bug) (#12848)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12848

Updated all non-test uses of protobuf::MessageLite::SerializeAsString to call
SerializeAsString_EnforceCheck so that the return value is checked and can
throw an exception if failing.

Most of the affected code was called from classes derived from  BlobSerializeBase.
Didn't touch most tests and ENFORCE calls because they usually do checks
anyway.

Original commit changeset: c0760e73ecc7

Reviewed By: dzhulgakov

Differential Revision: D10453456

fbshipit-source-id: d2f2b7b4578e721924354149f08f627c7e3bf070
2018-10-23 16:21:26 -07:00

89 lines
2.6 KiB
C++

#ifndef CAFFE2_CORE_QTENSOR_SERIALIZATION_H_
#define CAFFE2_CORE_QTENSOR_SERIALIZATION_H_
#include "caffe2/core/blob_serialization.h"
#include "caffe2/core/qtensor.h"
namespace caffe2 {
constexpr auto kQTensorBlobQType = "QTensor";
template <class Context>
class QTensorSerializer : public BlobSerializerBase {
public:
QTensorSerializer() : context_() {}
~QTensorSerializer() {}
/**
* Serializes a Blob. Note that this blob has to contain QTensor<Context>.
*/
void Serialize(
const void* pointer,
TypeMeta typeMeta,
const string& name,
SerializationAcceptor acceptor) override;
private:
Context context_;
};
template <class Context>
class QTensorDeserializer : public BlobDeserializerBase {
public:
void Deserialize(const BlobProto& proto, Blob* blob) override;
void Deserialize(const QTensorProto& proto, QTensor<Context>* tensor);
};
template <class Context>
void QTensorSerializer<Context>::Serialize(
const void* pointer,
TypeMeta typeMeta,
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor) {
CAFFE_ENFORCE(typeMeta.Match<QTensor<Context>>());
const auto& qtensor = *static_cast<const QTensor<Context>*>(pointer);
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type(kQTensorBlobQType);
QTensorProto& proto = *blob_proto.mutable_qtensor();
proto.set_name(name);
for (int i = 0; i < qtensor.ndim(); ++i) {
proto.add_dims(qtensor.dim32(i));
}
proto.set_precision(qtensor.precision());
proto.set_scale(qtensor.scale());
proto.set_bias(qtensor.bias());
proto.set_is_signed(qtensor.is_signed());
detail::CopyToProtoWithCast(
qtensor.nbytes(), qtensor.data(), proto.mutable_data(), &this->context_);
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
template <class Context>
void QTensorDeserializer<Context>::Deserialize(
const BlobProto& blob_proto,
Blob* blob) {
Deserialize(blob_proto.qtensor(), blob->GetMutable<QTensor<Context>>());
}
template <class Context>
void QTensorDeserializer<Context>::Deserialize(
const QTensorProto& proto,
QTensor<Context>* qtensor) {
Context context{};
vector<int> dims;
for (const int d : proto.dims()) {
dims.push_back(d);
}
qtensor->Resize(dims);
qtensor->SetPrecision(proto.precision());
qtensor->SetScale(proto.scale());
qtensor->SetBias(proto.bias());
qtensor->SetSigned(proto.is_signed());
detail::CopyFromProtoWithCast(
qtensor->nbytes(), proto.data(), qtensor->mutable_data(), &context);
}
} // namespace caffe2
#endif // CAFFE2_CORE_QTENSOR_SERIALIZATION_H_