mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add support to serialize qtensor in JIT. (#23356)
Summary: Adds qtensor specific fields to the proto file so that they get serialized into the model.json Pull Request resolved: https://github.com/pytorch/pytorch/pull/23356 ghstack-source-id: 87263428 Differential Revision: D16473237 fbshipit-source-id: bf5b51d0863d036d30a1644a3c3b74516468224b
This commit is contained in:
parent
9dad13e1f0
commit
9223fa1c46
5 changed files with 70 additions and 24 deletions
|
|
@ -9,28 +9,31 @@
|
|||
namespace caffe2 {
|
||||
|
||||
TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta) {
|
||||
static_assert(sizeof(int) == 4,
|
||||
"int in this compiler does not equal to 4 bytes.");
|
||||
static std::map<TypeIdentifier, TensorProto::DataType> data_type_map {
|
||||
{TypeMeta::Id<float>(), TensorProto_DataType_FLOAT},
|
||||
{TypeMeta::Id<int>(), TensorProto_DataType_INT32},
|
||||
// BYTE does not have a type meta to proto mapping: we should
|
||||
// always use uint8_t when serializing. BYTE is kept for backward
|
||||
// compatibility.
|
||||
// {TypeMeta::Id<>(), TensorProto_DataType_BYTE},
|
||||
{TypeMeta::Id<string>(), TensorProto_DataType_STRING},
|
||||
{TypeMeta::Id<bool>(), TensorProto_DataType_BOOL},
|
||||
{TypeMeta::Id<uint8_t>(), TensorProto_DataType_UINT8},
|
||||
{TypeMeta::Id<int8_t>(), TensorProto_DataType_INT8},
|
||||
{TypeMeta::Id<uint16_t>(), TensorProto_DataType_UINT16},
|
||||
{TypeMeta::Id<int16_t>(), TensorProto_DataType_INT16},
|
||||
{TypeMeta::Id<int64_t>(), TensorProto_DataType_INT64},
|
||||
{TypeMeta::Id<at::Half>(), TensorProto_DataType_FLOAT16},
|
||||
{TypeMeta::Id<double>(), TensorProto_DataType_DOUBLE},
|
||||
static_assert(
|
||||
sizeof(int) == 4, "int in this compiler does not equal to 4 bytes.");
|
||||
static std::map<TypeIdentifier, TensorProto::DataType> data_type_map{
|
||||
{TypeMeta::Id<float>(), TensorProto_DataType_FLOAT},
|
||||
{TypeMeta::Id<int>(), TensorProto_DataType_INT32},
|
||||
// BYTE does not have a type meta to proto mapping: we should
|
||||
// always use uint8_t when serializing. BYTE is kept for backward
|
||||
// compatibility.
|
||||
// {TypeMeta::Id<>(), TensorProto_DataType_BYTE},
|
||||
{TypeMeta::Id<string>(), TensorProto_DataType_STRING},
|
||||
{TypeMeta::Id<bool>(), TensorProto_DataType_BOOL},
|
||||
{TypeMeta::Id<uint8_t>(), TensorProto_DataType_UINT8},
|
||||
{TypeMeta::Id<int8_t>(), TensorProto_DataType_INT8},
|
||||
{TypeMeta::Id<uint16_t>(), TensorProto_DataType_UINT16},
|
||||
{TypeMeta::Id<int16_t>(), TensorProto_DataType_INT16},
|
||||
{TypeMeta::Id<int64_t>(), TensorProto_DataType_INT64},
|
||||
{TypeMeta::Id<at::Half>(), TensorProto_DataType_FLOAT16},
|
||||
{TypeMeta::Id<double>(), TensorProto_DataType_DOUBLE},
|
||||
{TypeMeta::Id<c10::qint8>(), TensorProto_DataType_INT8},
|
||||
{TypeMeta::Id<c10::quint8>(), TensorProto_DataType_UINT8},
|
||||
{TypeMeta::Id<c10::qint32>(), TensorProto_DataType_INT32},
|
||||
};
|
||||
const auto it = data_type_map.find(meta.id());
|
||||
return (it == data_type_map.end()
|
||||
? TensorProto_DataType_UNDEFINED : it->second);
|
||||
return (
|
||||
it == data_type_map.end() ? TensorProto_DataType_UNDEFINED : it->second);
|
||||
}
|
||||
|
||||
const TypeMeta& DataTypeToTypeMeta(const TensorProto::DataType& dt) {
|
||||
|
|
@ -55,4 +58,4 @@ const TypeMeta& DataTypeToTypeMeta(const TensorProto::DataType& dt) {
|
|||
return it->second;
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -21,6 +21,10 @@ message TensorDef {
|
|||
// device field stores the canonical device string, and it follows the
|
||||
// format below: `(cpu|cuda)[:<device-index>]`, e.g., 'cuda:0'
|
||||
optional string device = 7;
|
||||
|
||||
optional bool is_quantized = 8;
|
||||
optional double scale = 9;
|
||||
optional int64 zero_point = 10;
|
||||
}
|
||||
|
||||
message AttributeDef {
|
||||
|
|
|
|||
|
|
@ -2847,6 +2847,25 @@ graph(%Ra, %Rb):
|
|||
s = str(torch.ops)
|
||||
self.assertRegex(s, r'ops')
|
||||
|
||||
def test_serialize_qtensor(self):
|
||||
class SimpleQTensor(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(SimpleQTensor, self).__init__()
|
||||
x = torch.rand(5, 5).float()
|
||||
x_q = torch.quantize_linear(x, 0.2, 10, torch.quint8)
|
||||
|
||||
self.register_buffer('x', x_q)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self):
|
||||
return self.x
|
||||
|
||||
model = SimpleQTensor()
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(model, buffer)
|
||||
buffer.seek(0)
|
||||
model_loaded = torch.jit.load(buffer)
|
||||
self.assertEqual(model_loaded(), model())
|
||||
|
||||
class TestScript(JitTestCase):
|
||||
def test_sequence_parsing(self):
|
||||
|
|
|
|||
|
|
@ -746,6 +746,12 @@ void ScriptModuleSerializer::convertAndWriteTensor(
|
|||
|
||||
tensor_proto->set_requires_grad(tensor.requires_grad());
|
||||
|
||||
tensor_proto->set_is_quantized(tensor.is_quantized());
|
||||
if (tensor.is_quantized()) {
|
||||
tensor_proto->set_scale(tensor.q_scale());
|
||||
tensor_proto->set_zero_point(tensor.q_zero_point());
|
||||
}
|
||||
|
||||
auto* key = tensor.storage().unsafeGetStorageImpl();
|
||||
auto storage_it = storageMap.find(key);
|
||||
if (storage_it == storageMap.end()) {
|
||||
|
|
|
|||
|
|
@ -181,6 +181,9 @@ at::Tensor ScriptModuleDeserializer::loadTensor(
|
|||
tensor_proto.strides().begin(), tensor_proto.strides().end());
|
||||
auto type = at::typeMetaToScalarType(
|
||||
caffe2::DataTypeToTypeMeta(tensor_proto.data_type()));
|
||||
if (tensor_proto.is_quantized()) {
|
||||
type = toQIntType(type);
|
||||
}
|
||||
const std::string& record_key = tensor_proto.data().key();
|
||||
AT_ASSERT(tensor_proto.has_device() && !tensor_proto.device().empty());
|
||||
at::Device device(tensor_proto.device());
|
||||
|
|
@ -227,10 +230,21 @@ at::Tensor ScriptModuleDeserializer::loadTensor(
|
|||
}
|
||||
|
||||
at::Tensor result;
|
||||
|
||||
if (device.type() == at::DeviceType::CPU) {
|
||||
result =
|
||||
at::empty({0}, at::CPU(type).options())
|
||||
.set_(storage_it->second, tensor_proto.offset(), dims, strides);
|
||||
if (tensor_proto.is_quantized()) {
|
||||
result = at::_empty_affine_quantized(
|
||||
{0},
|
||||
type,
|
||||
tensor_proto.scale(),
|
||||
tensor_proto.zero_point())
|
||||
.set_(storage_it->second, tensor_proto.offset(), dims, strides);
|
||||
}
|
||||
else {
|
||||
result =
|
||||
at::empty({0}, at::CPU(type).options())
|
||||
.set_(storage_it->second, tensor_proto.offset(), dims, strides);
|
||||
}
|
||||
} else if (device.type() == at::DeviceType::CUDA) {
|
||||
result =
|
||||
at::empty(
|
||||
|
|
|
|||
Loading…
Reference in a new issue