mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-20 21:40:57 +00:00
Fix conversion of TensorData, TensorsData to json (#22166)
### Description Fix write_calibration_table to support TensorData, TensorsData
This commit is contained in:
parent
280c013d67
commit
407c1ab2e2
3 changed files with 56 additions and 7 deletions
|
|
@ -69,6 +69,7 @@ class TensorData:
|
|||
_floats = frozenset(["avg", "std", "lowest", "highest", "hist_edges"])
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._attrs = list(kwargs.keys())
|
||||
for k, v in kwargs.items():
|
||||
if k not in TensorData._allowed:
|
||||
raise ValueError(f"Unexpected value {k!r} not in {TensorData._allowed}.")
|
||||
|
|
@ -91,6 +92,12 @@ class TensorData:
|
|||
raise AttributeError(f"Attributes 'avg' and/or 'std' missing in {dir(self)}.")
|
||||
return (self.avg, self.std)
|
||||
|
||||
def to_dict(self):
|
||||
# This is needed to serialize the data into JSON.
|
||||
data = {k: getattr(self, k) for k in self._attrs}
|
||||
data["CLS"] = self.__class__.__name__
|
||||
return data
|
||||
|
||||
|
||||
class TensorsData:
|
||||
def __init__(self, calibration_method, data: Dict[str, Union[TensorData, Tuple]]):
|
||||
|
|
@ -125,12 +132,24 @@ class TensorsData:
|
|||
raise RuntimeError(f"Only an existing tensor can be modified, {key!r} is not.")
|
||||
self.data[key] = value
|
||||
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
def values(self):
|
||||
return self.data.values()
|
||||
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
def to_dict(self):
|
||||
# This is needed to serialize the data into JSON.
|
||||
data = {
|
||||
"CLS": self.__class__.__name__,
|
||||
"data": self.data,
|
||||
"calibration_method": self.calibration_method,
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
class CalibrationMethod(Enum):
|
||||
MinMax = 0
|
||||
|
|
|
|||
|
|
@ -671,21 +671,41 @@ def write_calibration_table(calibration_cache, dir="."):
|
|||
import json
|
||||
|
||||
import flatbuffers
|
||||
import numpy as np
|
||||
|
||||
import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue
|
||||
import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable
|
||||
from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData
|
||||
|
||||
logging.info(f"calibration cache: {calibration_cache}")
|
||||
|
||||
class MyEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, (TensorData, TensorsData)):
|
||||
return obj.to_dict()
|
||||
if isinstance(obj, np.ndarray):
|
||||
return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"}
|
||||
if isinstance(obj, CalibrationMethod):
|
||||
return {"CLS": obj.__class__.__name__, "value": str(obj)}
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
json_data = json.dumps(calibration_cache, cls=MyEncoder)
|
||||
|
||||
with open(os.path.join(dir, "calibration.json"), "w") as file:
|
||||
file.write(json.dumps(calibration_cache)) # use `json.loads` to do the reverse
|
||||
file.write(json_data) # use `json.loads` to do the reverse
|
||||
|
||||
# Serialize data using FlatBuffers
|
||||
zero = np.array(0)
|
||||
builder = flatbuffers.Builder(1024)
|
||||
key_value_list = []
|
||||
for key in sorted(calibration_cache.keys()):
|
||||
values = calibration_cache[key]
|
||||
value = str(max(abs(values[0]), abs(values[1])))
|
||||
d_values = values.to_dict()
|
||||
floats = [
|
||||
float(d_values.get("highest", zero).item()),
|
||||
float(d_values.get("lowest", zero).item()),
|
||||
]
|
||||
value = str(max(floats))
|
||||
|
||||
flat_key = builder.CreateString(key)
|
||||
flat_value = builder.CreateString(value)
|
||||
|
|
@ -724,9 +744,14 @@ def write_calibration_table(calibration_cache, dir="."):
|
|||
# write plain text
|
||||
with open(os.path.join(dir, "calibration.cache"), "w") as file:
|
||||
for key in sorted(calibration_cache.keys()):
|
||||
value = calibration_cache[key]
|
||||
s = key + " " + str(max(abs(value[0]), abs(value[1])))
|
||||
file.write(s)
|
||||
values = calibration_cache[key]
|
||||
d_values = values.to_dict()
|
||||
floats = [
|
||||
float(d_values.get("highest", zero).item()),
|
||||
float(d_values.get("lowest", zero).item()),
|
||||
]
|
||||
value = key + " " + str(max(floats))
|
||||
file.write(value)
|
||||
file.write("\n")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ from op_test_utils import (
|
|||
create_clip_node,
|
||||
)
|
||||
|
||||
from onnxruntime.quantization import QDQQuantizer, QuantFormat, QuantType, quantize_static
|
||||
from onnxruntime.quantization.calibrate import TensorData
|
||||
from onnxruntime.quantization import QDQQuantizer, QuantFormat, QuantType, quantize_static, write_calibration_table
|
||||
from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData
|
||||
|
||||
|
||||
class TestQDQFormat(unittest.TestCase):
|
||||
|
|
@ -1720,6 +1720,11 @@ class TestQDQ4bit(TestQDQFormat):
|
|||
size_ratio = weight_quant_init.ByteSize() / unpacked_size
|
||||
self.assertLess(size_ratio, 0.55)
|
||||
|
||||
def test_json_serialization(self):
|
||||
td = TensorData(lowest=np.array([0.1], dtype=np.float32), highest=np.array([1.1], dtype=np.float32))
|
||||
new_calibrate_tensors_range = TensorsData(CalibrationMethod.MinMax, {"td": td})
|
||||
write_calibration_table(new_calibrate_tensors_range)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in a new issue