2020-11-13 01:57:08 +00:00
|
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
|
# Licensed under the MIT License.
|
|
|
|
|
|
|
|
|
|
import onnx
|
2022-04-26 16:35:16 +00:00
|
|
|
from onnx import TensorProto, helper
|
2020-11-13 01:57:08 +00:00
|
|
|
from onnx.helper import make_opsetid
|
|
|
|
|
|
2022-04-26 16:35:16 +00:00
|
|
|
input_info = helper.make_tensor_value_info("input", TensorProto.BFLOAT16, [1, 5])
|
|
|
|
|
output_info = helper.make_tensor_value_info("output", TensorProto.BFLOAT16, [1, 5])
|
2020-11-13 01:57:08 +00:00
|
|
|
|
|
|
|
|
# Create a node (NodeProto) - This is based on Pad-11
|
2022-04-26 16:35:16 +00:00
|
|
|
node_def = helper.make_node("Identity", ["input"], ["output"]) # node name # inputs # outputs
|
2020-11-13 01:57:08 +00:00
|
|
|
|
2022-04-26 16:35:16 +00:00
|
|
|
graph_def = helper.make_graph(nodes=[node_def], name="test_types_BLOAT16", inputs=[input_info], outputs=[output_info])
|
2020-11-13 01:57:08 +00:00
|
|
|
|
2022-04-26 16:35:16 +00:00
|
|
|
model_def = helper.make_model(graph_def, producer_name="AIInfra", opset_imports=[make_opsetid("", 13)])
|
2020-11-13 01:57:08 +00:00
|
|
|
|
2021-03-29 18:00:38 +00:00
|
|
|
onnx.checker.check_model(model_def)
|
|
|
|
|
onnx.helper.strip_doc_string(model_def)
|
|
|
|
|
final_model = onnx.shape_inference.infer_shapes(model_def)
|
|
|
|
|
onnx.checker.check_model(final_model)
|
2022-04-26 16:35:16 +00:00
|
|
|
onnx.save(final_model, "test_types_BFLOAT16.onnx")
|