2022-01-27 19:31:13 +00:00
|
|
|
import numpy
|
|
|
|
|
from numpy.testing import assert_almost_equal
|
2022-04-26 16:35:16 +00:00
|
|
|
from onnx import TensorProto, numpy_helper
|
|
|
|
|
from onnx.helper import make_graph, make_model, make_node, make_tensor, make_tensor_value_info, set_model_props
|
|
|
|
|
|
2022-01-27 19:31:13 +00:00
|
|
|
import onnxruntime
|
|
|
|
|
|
2022-02-15 09:21:02 +00:00
|
|
|
if "TvmExecutionProvider" not in onnxruntime.get_available_providers():
|
2022-04-26 16:35:16 +00:00
|
|
|
raise AssertionError("Unable to find 'TvmExecutionProvider' in %r." % onnxruntime.get_available_providers())
|
|
|
|
|
|
|
|
|
|
X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
|
|
|
|
|
A = make_tensor_value_info("A", TensorProto.FLOAT, [None, None])
|
|
|
|
|
B = make_tensor_value_info("B", TensorProto.FLOAT, [None, None])
|
|
|
|
|
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None])
|
|
|
|
|
node1 = make_node("MatMul", ["X", "A"], ["XA"])
|
|
|
|
|
node2 = make_node("Add", ["XA", "B"], ["Y"])
|
|
|
|
|
graph = make_graph([node1, node2], "lr", [X, A, B], [Y])
|
2022-01-27 19:31:13 +00:00
|
|
|
onnx_model = make_model(graph)
|
|
|
|
|
|
|
|
|
|
a = numpy.random.randn(2, 2).astype(numpy.float32)
|
|
|
|
|
b = numpy.random.randn(1, 2).astype(numpy.float32)
|
|
|
|
|
x = numpy.random.randn(1, 2).astype(numpy.float32)
|
2022-04-26 16:35:16 +00:00
|
|
|
data = {"A": a, "B": b, "X": x}
|
2022-01-27 19:31:13 +00:00
|
|
|
|
2022-04-26 16:35:16 +00:00
|
|
|
sess = onnxruntime.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])
|
2022-01-27 19:31:13 +00:00
|
|
|
|
|
|
|
|
y = sess.run(None, data)[0]
|
|
|
|
|
|
|
|
|
|
provider_options = dict(
|
|
|
|
|
target="llvm -mcpu=core-avx2",
|
|
|
|
|
target_host="llvm -mcpu=core-avx2",
|
|
|
|
|
opt_level=3,
|
|
|
|
|
freeze_weights=True,
|
|
|
|
|
tuning_file_path="",
|
|
|
|
|
tuning_type="Ansor",
|
|
|
|
|
input_names=" ".join(i.name for i in sess.get_inputs()),
|
2022-04-26 16:35:16 +00:00
|
|
|
input_shapes=" ".join(str(numpy.array(data[i.name].shape)) for i in sess.get_inputs()),
|
|
|
|
|
)
|
2022-01-27 19:31:13 +00:00
|
|
|
|
|
|
|
|
so = onnxruntime.SessionOptions()
|
|
|
|
|
so.log_severity_level = 0
|
|
|
|
|
so.log_verbosity_level = 0
|
|
|
|
|
so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
|
|
|
|
|
|
|
|
|
|
sess = onnxruntime.InferenceSession(
|
2022-04-26 16:35:16 +00:00
|
|
|
onnx_model.SerializeToString(),
|
|
|
|
|
so,
|
2022-02-15 09:21:02 +00:00
|
|
|
providers=["TvmExecutionProvider"],
|
2022-04-26 16:35:16 +00:00
|
|
|
provider_options=[provider_options],
|
|
|
|
|
)
|
2022-01-27 19:31:13 +00:00
|
|
|
|
|
|
|
|
y_tvm = sess.run(None, data)[0]
|
|
|
|
|
|
|
|
|
|
assert_almost_equal(y, y_tvm)
|