mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
🛠️ __Changes in this pull request:__ This pull request introduces two significant changes to the project: - Changing on device training checkpoint format: The current implementation stores the on device training checkpoint as a sequence of tensors in multiple files inside a checkpoint folder, which can be inefficient in terms of storage and performance. In this PR, I have modified the checkpoint format to utilize the flatbuffer table to save the checkpoint to a single file, providing a more compact and efficient representation. The changes around this are twofold: - Add the checkpoint flatbuffer schema that will generate the necessary checkpoint source files. - Update the checkpoint saving and loading functionality to use the new format. - Adding support for onnxruntime minimal build: To support scenarios where binary size is a constraint, I made changes to ensure that the training build can work well with the minimal build. 🔍 __Open Issues:__ - In order to extract the optimizer type, the existing implementation re-loaded the onnx optimizer model and parsed it. This is no longer possible, since the model format can either be onnx or ort. One idea is to do the same for ort format optimizer model. This needs some investigation. - Changes to the offline tooling to generate ort format training artifacts. - End-to-end training example showcasing the use of the minimal training build. - Add support for export model for inferencing in a minimal build.
320 lines
5.8 KiB
Text
320 lines
5.8 KiB
Text
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
namespace onnxruntime.fbs;
|
|
|
|
// Attribute
|
|
enum AttributeType : int32 {
|
|
UNDEFINED = 0,
|
|
FLOAT = 1,
|
|
INT = 2,
|
|
STRING = 3,
|
|
TENSOR = 4,
|
|
GRAPH = 5,
|
|
FLOATS = 6,
|
|
INTS = 7,
|
|
STRINGS = 8,
|
|
TENSORS = 9,
|
|
GRAPHS = 10,
|
|
SPARSE_TENSOR = 11,
|
|
SPARSE_TENSORS = 12,
|
|
}
|
|
|
|
// Shape
|
|
table Shape {
|
|
dim:[Dimension];
|
|
}
|
|
|
|
table Dimension {
|
|
value:DimensionValue;
|
|
denotation:string;
|
|
}
|
|
|
|
enum DimensionValueType : int8 {
|
|
UNKNOWN = 0,
|
|
VALUE = 1,
|
|
PARAM = 2,
|
|
}
|
|
|
|
table DimensionValue {
|
|
dim_type:DimensionValueType;
|
|
dim_value:int64;
|
|
dim_param:string;
|
|
}
|
|
|
|
// Tensor
|
|
enum TensorDataType : int32 {
|
|
UNDEFINED = 0,
|
|
FLOAT = 1,
|
|
UINT8 = 2,
|
|
INT8 = 3,
|
|
UINT16 = 4,
|
|
INT16 = 5,
|
|
INT32 = 6,
|
|
INT64 = 7,
|
|
STRING = 8,
|
|
BOOL = 9,
|
|
FLOAT16 = 10,
|
|
DOUBLE = 11,
|
|
UINT32 = 12,
|
|
UINT64 = 13,
|
|
COMPLEX64 = 14,
|
|
COMPLEX128 = 15,
|
|
BFLOAT16 = 16,
|
|
// Float 8 types. See https://onnx.ai/onnx/technical/float8.html.
|
|
FLOAT8E4M3FN = 17,
|
|
FLOAT8E4M3FNUZ = 18,
|
|
FLOAT8E5M2 = 19,
|
|
FLOAT8E5M2FNUZ = 20,
|
|
}
|
|
|
|
table TensorTypeAndShape {
|
|
elem_type:TensorDataType;
|
|
shape:Shape;
|
|
}
|
|
|
|
table MapType {
|
|
key_type:TensorDataType;
|
|
value_type:TypeInfo;
|
|
}
|
|
|
|
table SequenceType {
|
|
elem_type:TypeInfo;
|
|
}
|
|
|
|
// Node
|
|
enum NodeType : int32 {
|
|
Primitive = 0,
|
|
Fused = 1,
|
|
}
|
|
|
|
struct EdgeEnd {
|
|
node_index:uint32;
|
|
src_arg_index:int32;
|
|
dst_arg_index:int32;
|
|
}
|
|
|
|
table NodeEdge {
|
|
node_index:uint32;
|
|
input_edges:[EdgeEnd];
|
|
output_edges:[EdgeEnd];
|
|
}
|
|
|
|
table Node {
|
|
name:string;
|
|
doc_string:string;
|
|
domain:string;
|
|
since_version:int32;
|
|
|
|
index:uint32;
|
|
op_type:string;
|
|
type:NodeType;
|
|
execution_provider_type:string;
|
|
|
|
inputs:[string];
|
|
outputs:[string];
|
|
attributes:[Attribute];
|
|
|
|
input_arg_counts:[int32];
|
|
implicit_inputs:[string];
|
|
}
|
|
|
|
// ValueInfo
|
|
table ValueInfo {
|
|
name:string;
|
|
doc_string:string;
|
|
type:TypeInfo;
|
|
}
|
|
|
|
// TODO add support of SparseTensor, Opaque if needed
|
|
union TypeInfoValue {
|
|
tensor_type:TensorTypeAndShape,
|
|
sequence_type:SequenceType,
|
|
map_type:MapType,
|
|
}
|
|
|
|
table TypeInfo {
|
|
denotation:string;
|
|
value:TypeInfoValue;
|
|
}
|
|
|
|
// OpSetId
|
|
table OperatorSetId {
|
|
domain:string;
|
|
version:int64;
|
|
}
|
|
|
|
// For simplicity, we will have only two data fields
|
|
// - string_data for string
|
|
// - raw_data for all other types
|
|
table Tensor {
|
|
name:string;
|
|
doc_string:string;
|
|
|
|
dims:[int64];
|
|
data_type:TensorDataType;
|
|
|
|
raw_data:[uint8];
|
|
|
|
// string_data is least used
|
|
string_data:[string];
|
|
}
|
|
|
|
table SparseTensor {
|
|
values:Tensor;
|
|
indices:Tensor;
|
|
dims:[int64];
|
|
}
|
|
|
|
table Attribute {
|
|
name:string;
|
|
doc_string:string;
|
|
|
|
type:AttributeType;
|
|
|
|
f:float32;
|
|
i:int64;
|
|
s:string;
|
|
t:Tensor;
|
|
g:Graph;
|
|
|
|
floats:[float32];
|
|
ints:[int64];
|
|
strings:[string];
|
|
tensors:[Tensor];
|
|
graphs:[Graph];
|
|
}
|
|
|
|
// runtime optimizations
|
|
|
|
/// nodes to consider for a runtime optimization
|
|
/// see corresponding type in onnxruntime/core/graph/runtime_optimization_record.h
|
|
table NodesToOptimizeIndices {
|
|
node_indices:[uint32];
|
|
num_inputs:uint32;
|
|
num_outputs:uint32;
|
|
has_variadic_input:bool;
|
|
has_variadic_output:bool;
|
|
num_variadic_inputs:uint32;
|
|
num_variadic_outputs:uint32;
|
|
}
|
|
|
|
/// deprecated: no longer using kernel def hashes
|
|
table DeprecatedNodeIndexAndKernelDefHash {
|
|
node_index:uint32;
|
|
kernel_def_hash:uint64;
|
|
}
|
|
|
|
/// a single runtime optimization
|
|
/// see corresponding type in onnxruntime/core/graph/runtime_optimization_record.h
|
|
table RuntimeOptimizationRecord {
|
|
action_id:string;
|
|
nodes_to_optimize_indices:NodesToOptimizeIndices;
|
|
produced_nodes:[DeprecatedNodeIndexAndKernelDefHash] (deprecated);
|
|
produced_op_ids:[string];
|
|
}
|
|
|
|
table RuntimeOptimizationRecordContainerEntry {
|
|
optimizer_name:string (key);
|
|
runtime_optimization_records:[RuntimeOptimizationRecord];
|
|
}
|
|
|
|
table RuntimeOptimizations {
|
|
/// mapping from optimizer name to [RuntimeOptimizationRecord]
|
|
records:[RuntimeOptimizationRecordContainerEntry];
|
|
}
|
|
|
|
table Graph {
|
|
initializers:[Tensor];
|
|
|
|
node_args:[ValueInfo];
|
|
nodes:[Node];
|
|
max_node_index:uint32;
|
|
|
|
node_edges:[NodeEdge];
|
|
|
|
inputs:[string];
|
|
outputs:[string];
|
|
sparse_initializers:[SparseTensor];
|
|
|
|
runtime_optimizations:RuntimeOptimizations;
|
|
}
|
|
|
|
table StringStringEntry {
|
|
key:string;
|
|
value:string;
|
|
}
|
|
|
|
table Model {
|
|
ir_version:int64;
|
|
opset_import:[OperatorSetId];
|
|
producer_name:string;
|
|
producer_version:string;
|
|
domain:string;
|
|
model_version:int64;
|
|
doc_string:string;
|
|
|
|
graph:Graph;
|
|
|
|
graph_doc_string:string;
|
|
metadata_props:[StringStringEntry];
|
|
}
|
|
|
|
/// deprecated: no longer using kernel def hashes
|
|
table DeprecatedKernelCreateInfos {
|
|
node_indices:[uint32];
|
|
kernel_def_hashes:[uint64];
|
|
}
|
|
|
|
/// deprecated: no longer using kernel def hashes
|
|
table DeprecatedSubGraphSessionState {
|
|
// graph_id can be used to binary search DeprecatedSubGraphSessionState in
|
|
// DeprecatedSessionState.sub_graph_session_states
|
|
graph_id:string (key);
|
|
|
|
session_state:DeprecatedSessionState;
|
|
}
|
|
|
|
/// deprecated: no longer using kernel def hashes
|
|
table DeprecatedSessionState {
|
|
kernels:DeprecatedKernelCreateInfos;
|
|
sub_graph_session_states:[DeprecatedSubGraphSessionState];
|
|
}
|
|
|
|
enum ArgType : int8 {
|
|
INPUT = 0,
|
|
OUTPUT = 1,
|
|
}
|
|
|
|
table ArgTypeAndIndex {
|
|
arg_type:ArgType;
|
|
index:uint32;
|
|
}
|
|
|
|
table KernelTypeStrArgsEntry {
|
|
kernel_type_str:string (key);
|
|
args:[ArgTypeAndIndex];
|
|
}
|
|
|
|
table OpIdKernelTypeStrArgsEntry {
|
|
op_id:string (key);
|
|
kernel_type_str_args:[KernelTypeStrArgsEntry];
|
|
}
|
|
|
|
table KernelTypeStrResolver {
|
|
op_kernel_type_str_args:[OpIdKernelTypeStrArgsEntry];
|
|
}
|
|
|
|
table InferenceSession {
|
|
// This is the ORT format model version
|
|
// The version number is defined as kOrtModelVersion in <repo root>/onnxruntime/core/flatbuffers/ort_format_version.h
|
|
ort_version:string;
|
|
|
|
model:Model;
|
|
session_state:DeprecatedSessionState (deprecated);
|
|
|
|
kernel_type_str_resolver:KernelTypeStrResolver;
|
|
}
|
|
|
|
root_type InferenceSession;
|
|
file_identifier "ORTM";
|