onnxruntime/onnxruntime/core/flatbuffers/schema/ort.fbs
Baiju Meswani 10ba1e270c
Minimal Build for On-Device Training (#16326)
🛠️ __Changes in this pull request:__

This pull request introduces two significant changes to the project:

- Changing on device training checkpoint format: The current
implementation stores the on device training checkpoint as a sequence of
tensors in multiple files inside a checkpoint folder, which can be
inefficient in terms of storage and performance. In this PR, I have
modified the checkpoint format to utilize the flatbuffer table to save
the checkpoint to a single file, providing a more compact and efficient
representation. The changes around this are twofold:
- Add the checkpoint flatbuffer schema that will generate the necessary
checkpoint source files.
- Update the checkpoint saving and loading functionality to use the new
format.

- Adding support for onnxruntime minimal build: To support scenarios
where binary size is a constraint, I made changes to ensure that the
training build can work well with the minimal build.

🔍 __Open Issues:__
- In order to extract the optimizer type, the existing implementation
re-loaded the onnx optimizer model and parsed it. This is no longer
possible, since the model format can either be onnx or ort. One idea is
to do the same for ort format optimizer model. This needs some
investigation.
- Changes to the offline tooling to generate ort format training
artifacts.
- End-to-end training example showcasing the use of the minimal training
build.
- Add support for export model for inferencing in a minimal build.
2023-06-22 12:27:23 -07:00

320 lines
5.8 KiB
Text

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
namespace onnxruntime.fbs;
// Attribute
enum AttributeType : int32 {
UNDEFINED = 0,
FLOAT = 1,
INT = 2,
STRING = 3,
TENSOR = 4,
GRAPH = 5,
FLOATS = 6,
INTS = 7,
STRINGS = 8,
TENSORS = 9,
GRAPHS = 10,
SPARSE_TENSOR = 11,
SPARSE_TENSORS = 12,
}
// Shape
table Shape {
dim:[Dimension];
}
table Dimension {
value:DimensionValue;
denotation:string;
}
enum DimensionValueType : int8 {
UNKNOWN = 0,
VALUE = 1,
PARAM = 2,
}
table DimensionValue {
dim_type:DimensionValueType;
dim_value:int64;
dim_param:string;
}
// Tensor
enum TensorDataType : int32 {
UNDEFINED = 0,
FLOAT = 1,
UINT8 = 2,
INT8 = 3,
UINT16 = 4,
INT16 = 5,
INT32 = 6,
INT64 = 7,
STRING = 8,
BOOL = 9,
FLOAT16 = 10,
DOUBLE = 11,
UINT32 = 12,
UINT64 = 13,
COMPLEX64 = 14,
COMPLEX128 = 15,
BFLOAT16 = 16,
// Float 8 types. See https://onnx.ai/onnx/technical/float8.html.
FLOAT8E4M3FN = 17,
FLOAT8E4M3FNUZ = 18,
FLOAT8E5M2 = 19,
FLOAT8E5M2FNUZ = 20,
}
table TensorTypeAndShape {
elem_type:TensorDataType;
shape:Shape;
}
table MapType {
key_type:TensorDataType;
value_type:TypeInfo;
}
table SequenceType {
elem_type:TypeInfo;
}
// Node
enum NodeType : int32 {
Primitive = 0,
Fused = 1,
}
struct EdgeEnd {
node_index:uint32;
src_arg_index:int32;
dst_arg_index:int32;
}
table NodeEdge {
node_index:uint32;
input_edges:[EdgeEnd];
output_edges:[EdgeEnd];
}
table Node {
name:string;
doc_string:string;
domain:string;
since_version:int32;
index:uint32;
op_type:string;
type:NodeType;
execution_provider_type:string;
inputs:[string];
outputs:[string];
attributes:[Attribute];
input_arg_counts:[int32];
implicit_inputs:[string];
}
// ValueInfo
table ValueInfo {
name:string;
doc_string:string;
type:TypeInfo;
}
// TODO add support of SparseTensor, Opaque if needed
union TypeInfoValue {
tensor_type:TensorTypeAndShape,
sequence_type:SequenceType,
map_type:MapType,
}
table TypeInfo {
denotation:string;
value:TypeInfoValue;
}
// OpSetId
table OperatorSetId {
domain:string;
version:int64;
}
// For simplicity, we will have only two data fields
// - string_data for string
// - raw_data for all other types
table Tensor {
name:string;
doc_string:string;
dims:[int64];
data_type:TensorDataType;
raw_data:[uint8];
// string_data is least used
string_data:[string];
}
table SparseTensor {
values:Tensor;
indices:Tensor;
dims:[int64];
}
table Attribute {
name:string;
doc_string:string;
type:AttributeType;
f:float32;
i:int64;
s:string;
t:Tensor;
g:Graph;
floats:[float32];
ints:[int64];
strings:[string];
tensors:[Tensor];
graphs:[Graph];
}
// runtime optimizations
/// nodes to consider for a runtime optimization
/// see corresponding type in onnxruntime/core/graph/runtime_optimization_record.h
table NodesToOptimizeIndices {
node_indices:[uint32];
num_inputs:uint32;
num_outputs:uint32;
has_variadic_input:bool;
has_variadic_output:bool;
num_variadic_inputs:uint32;
num_variadic_outputs:uint32;
}
/// deprecated: no longer using kernel def hashes
table DeprecatedNodeIndexAndKernelDefHash {
node_index:uint32;
kernel_def_hash:uint64;
}
/// a single runtime optimization
/// see corresponding type in onnxruntime/core/graph/runtime_optimization_record.h
table RuntimeOptimizationRecord {
action_id:string;
nodes_to_optimize_indices:NodesToOptimizeIndices;
produced_nodes:[DeprecatedNodeIndexAndKernelDefHash] (deprecated);
produced_op_ids:[string];
}
table RuntimeOptimizationRecordContainerEntry {
optimizer_name:string (key);
runtime_optimization_records:[RuntimeOptimizationRecord];
}
table RuntimeOptimizations {
/// mapping from optimizer name to [RuntimeOptimizationRecord]
records:[RuntimeOptimizationRecordContainerEntry];
}
table Graph {
initializers:[Tensor];
node_args:[ValueInfo];
nodes:[Node];
max_node_index:uint32;
node_edges:[NodeEdge];
inputs:[string];
outputs:[string];
sparse_initializers:[SparseTensor];
runtime_optimizations:RuntimeOptimizations;
}
table StringStringEntry {
key:string;
value:string;
}
table Model {
ir_version:int64;
opset_import:[OperatorSetId];
producer_name:string;
producer_version:string;
domain:string;
model_version:int64;
doc_string:string;
graph:Graph;
graph_doc_string:string;
metadata_props:[StringStringEntry];
}
/// deprecated: no longer using kernel def hashes
table DeprecatedKernelCreateInfos {
node_indices:[uint32];
kernel_def_hashes:[uint64];
}
/// deprecated: no longer using kernel def hashes
table DeprecatedSubGraphSessionState {
// graph_id can be used to binary search DeprecatedSubGraphSessionState in
// DeprecatedSessionState.sub_graph_session_states
graph_id:string (key);
session_state:DeprecatedSessionState;
}
/// deprecated: no longer using kernel def hashes
table DeprecatedSessionState {
kernels:DeprecatedKernelCreateInfos;
sub_graph_session_states:[DeprecatedSubGraphSessionState];
}
enum ArgType : int8 {
INPUT = 0,
OUTPUT = 1,
}
table ArgTypeAndIndex {
arg_type:ArgType;
index:uint32;
}
table KernelTypeStrArgsEntry {
kernel_type_str:string (key);
args:[ArgTypeAndIndex];
}
table OpIdKernelTypeStrArgsEntry {
op_id:string (key);
kernel_type_str_args:[KernelTypeStrArgsEntry];
}
table KernelTypeStrResolver {
op_kernel_type_str_args:[OpIdKernelTypeStrArgsEntry];
}
table InferenceSession {
// This is the ORT format model version
// The version number is defined as kOrtModelVersion in <repo root>/onnxruntime/core/flatbuffers/ort_format_version.h
ort_version:string;
model:Model;
session_state:DeprecatedSessionState (deprecated);
kernel_type_str_resolver:KernelTypeStrResolver;
}
root_type InferenceSession;
file_identifier "ORTM";