Add support to serialize qtensor in JIT. (#23356)

Summary:
Adds qtensor specific fields to the proto file so that they get serialized into the model.json

Pull Request resolved: https://github.com/pytorch/pytorch/pull/23356
ghstack-source-id: 87263428

Differential Revision: D16473237

fbshipit-source-id: bf5b51d0863d036d30a1644a3c3b74516468224b
This commit is contained in:
Supriya Rao 2019-07-26 15:45:39 -07:00 committed by Facebook Github Bot
parent 9dad13e1f0
commit 9223fa1c46
5 changed files with 70 additions and 24 deletions

View file

@ -9,28 +9,31 @@
namespace caffe2 {
TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta) {
static_assert(sizeof(int) == 4,
"int in this compiler does not equal to 4 bytes.");
static std::map<TypeIdentifier, TensorProto::DataType> data_type_map {
{TypeMeta::Id<float>(), TensorProto_DataType_FLOAT},
{TypeMeta::Id<int>(), TensorProto_DataType_INT32},
// BYTE does not have a type meta to proto mapping: we should
// always use uint8_t when serializing. BYTE is kept for backward
// compatibility.
// {TypeMeta::Id<>(), TensorProto_DataType_BYTE},
{TypeMeta::Id<string>(), TensorProto_DataType_STRING},
{TypeMeta::Id<bool>(), TensorProto_DataType_BOOL},
{TypeMeta::Id<uint8_t>(), TensorProto_DataType_UINT8},
{TypeMeta::Id<int8_t>(), TensorProto_DataType_INT8},
{TypeMeta::Id<uint16_t>(), TensorProto_DataType_UINT16},
{TypeMeta::Id<int16_t>(), TensorProto_DataType_INT16},
{TypeMeta::Id<int64_t>(), TensorProto_DataType_INT64},
{TypeMeta::Id<at::Half>(), TensorProto_DataType_FLOAT16},
{TypeMeta::Id<double>(), TensorProto_DataType_DOUBLE},
static_assert(
sizeof(int) == 4, "int in this compiler does not equal to 4 bytes.");
static std::map<TypeIdentifier, TensorProto::DataType> data_type_map{
{TypeMeta::Id<float>(), TensorProto_DataType_FLOAT},
{TypeMeta::Id<int>(), TensorProto_DataType_INT32},
// BYTE does not have a type meta to proto mapping: we should
// always use uint8_t when serializing. BYTE is kept for backward
// compatibility.
// {TypeMeta::Id<>(), TensorProto_DataType_BYTE},
{TypeMeta::Id<string>(), TensorProto_DataType_STRING},
{TypeMeta::Id<bool>(), TensorProto_DataType_BOOL},
{TypeMeta::Id<uint8_t>(), TensorProto_DataType_UINT8},
{TypeMeta::Id<int8_t>(), TensorProto_DataType_INT8},
{TypeMeta::Id<uint16_t>(), TensorProto_DataType_UINT16},
{TypeMeta::Id<int16_t>(), TensorProto_DataType_INT16},
{TypeMeta::Id<int64_t>(), TensorProto_DataType_INT64},
{TypeMeta::Id<at::Half>(), TensorProto_DataType_FLOAT16},
{TypeMeta::Id<double>(), TensorProto_DataType_DOUBLE},
{TypeMeta::Id<c10::qint8>(), TensorProto_DataType_INT8},
{TypeMeta::Id<c10::quint8>(), TensorProto_DataType_UINT8},
{TypeMeta::Id<c10::qint32>(), TensorProto_DataType_INT32},
};
const auto it = data_type_map.find(meta.id());
return (it == data_type_map.end()
? TensorProto_DataType_UNDEFINED : it->second);
return (
it == data_type_map.end() ? TensorProto_DataType_UNDEFINED : it->second);
}
const TypeMeta& DataTypeToTypeMeta(const TensorProto::DataType& dt) {
@ -55,4 +58,4 @@ const TypeMeta& DataTypeToTypeMeta(const TensorProto::DataType& dt) {
return it->second;
}
} // namespace caffe2
} // namespace caffe2

View file

@ -21,6 +21,10 @@ message TensorDef {
// device field stores the canonical device string, and it follows the
// format below: `(cpu|cuda)[:<device-index>]`, e.g., 'cuda:0'
optional string device = 7;
optional bool is_quantized = 8;
optional double scale = 9;
optional int64 zero_point = 10;
}
message AttributeDef {

View file

@ -2847,6 +2847,25 @@ graph(%Ra, %Rb):
s = str(torch.ops)
self.assertRegex(s, r'ops')
def test_serialize_qtensor(self):
class SimpleQTensor(torch.jit.ScriptModule):
def __init__(self):
super(SimpleQTensor, self).__init__()
x = torch.rand(5, 5).float()
x_q = torch.quantize_linear(x, 0.2, 10, torch.quint8)
self.register_buffer('x', x_q)
@torch.jit.script_method
def forward(self):
return self.x
model = SimpleQTensor()
buffer = io.BytesIO()
torch.jit.save(model, buffer)
buffer.seek(0)
model_loaded = torch.jit.load(buffer)
self.assertEqual(model_loaded(), model())
class TestScript(JitTestCase):
def test_sequence_parsing(self):

View file

@ -746,6 +746,12 @@ void ScriptModuleSerializer::convertAndWriteTensor(
tensor_proto->set_requires_grad(tensor.requires_grad());
tensor_proto->set_is_quantized(tensor.is_quantized());
if (tensor.is_quantized()) {
tensor_proto->set_scale(tensor.q_scale());
tensor_proto->set_zero_point(tensor.q_zero_point());
}
auto* key = tensor.storage().unsafeGetStorageImpl();
auto storage_it = storageMap.find(key);
if (storage_it == storageMap.end()) {

View file

@ -181,6 +181,9 @@ at::Tensor ScriptModuleDeserializer::loadTensor(
tensor_proto.strides().begin(), tensor_proto.strides().end());
auto type = at::typeMetaToScalarType(
caffe2::DataTypeToTypeMeta(tensor_proto.data_type()));
if (tensor_proto.is_quantized()) {
type = toQIntType(type);
}
const std::string& record_key = tensor_proto.data().key();
AT_ASSERT(tensor_proto.has_device() && !tensor_proto.device().empty());
at::Device device(tensor_proto.device());
@ -227,10 +230,21 @@ at::Tensor ScriptModuleDeserializer::loadTensor(
}
at::Tensor result;
if (device.type() == at::DeviceType::CPU) {
result =
at::empty({0}, at::CPU(type).options())
.set_(storage_it->second, tensor_proto.offset(), dims, strides);
if (tensor_proto.is_quantized()) {
result = at::_empty_affine_quantized(
{0},
type,
tensor_proto.scale(),
tensor_proto.zero_point())
.set_(storage_it->second, tensor_proto.offset(), dims, strides);
}
else {
result =
at::empty({0}, at::CPU(type).options())
.set_(storage_it->second, tensor_proto.offset(), dims, strides);
}
} else if (device.type() == at::DeviceType::CUDA) {
result =
at::empty(