mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
### Description Allows to specify fixed size for dynamic input of a model. resolves #16707 Pending test
387 lines
19 KiB
C
387 lines
19 KiB
C
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
// NOTE: This file contains declarations of exported functions as WebAssembly API.
|
|
// Unlike a normal C-API, the purpose of this API is to make emcc to generate correct exports for the WebAssembly. The
|
|
// macro "EMSCRIPTEN_KEEPALIVE" helps the compiler to mark the function as an exported funtion of the WebAssembly
|
|
// module. Users are expected to consume those functions from JavaScript side.
|
|
|
|
#pragma once
|
|
|
|
#include <emscripten.h>
|
|
|
|
#include <stddef.h>
|
|
|
|
struct OrtSession;
|
|
using ort_session_handle_t = OrtSession*;
|
|
|
|
struct OrtSessionOptions;
|
|
using ort_session_options_handle_t = OrtSessionOptions*;
|
|
|
|
struct OrtRunOptions;
|
|
using ort_run_options_handle_t = OrtRunOptions*;
|
|
|
|
struct OrtValue;
|
|
using ort_tensor_handle_t = OrtValue*;
|
|
|
|
#ifdef ENABLE_TRAINING_APIS
|
|
struct OrtTrainingSession;
|
|
using ort_training_session_handle_t = OrtTrainingSession*;
|
|
|
|
struct OrtCheckpointState;
|
|
using ort_training_checkpoint_handle_t = OrtCheckpointState*;
|
|
#endif
|
|
|
|
extern "C" {
|
|
|
|
/**
|
|
* perform global initialization. should be called only once.
|
|
* @param num_threads number of total threads to use.
|
|
* @param logging_level default logging level.
|
|
* @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
|
|
*/
|
|
int EMSCRIPTEN_KEEPALIVE OrtInit(int num_threads, int logging_level);
|
|
|
|
/**
|
|
* get the last error.
|
|
* @param error_code [out] a pointer to accept the error code.
|
|
* @param error_message [out] a pointer to accept the error message. The message buffer is only available before any ORT API is called.
|
|
*/
|
|
void EMSCRIPTEN_KEEPALIVE OrtGetLastError(int* error_code, const char** error_message);
|
|
|
|
/**
|
|
* create an instance of ORT session options.
|
|
* assume that all enum type parameters, such as graph_optimization_level, execution_mode, and log_severity_level,
|
|
* are checked and set properly at JavaScript.
|
|
* @param graph_optimization_level disabled, basic, extended, or enable all
|
|
* @param enable_cpu_mem_arena enable or disable cpu memory arena
|
|
* @param enable_mem_pattern enable or disable memory pattern
|
|
* @param execution_mode sequential or parallel execution mode
|
|
* @param enable_profiling enable or disable profiling.
|
|
* @param profile_file_prefix file prefix for profiling data. it's a no-op and for a future use.
|
|
* @param log_id logger id for session output
|
|
* @param log_severity_level verbose, info, warning, error or fatal
|
|
* @param log_verbosity_level vlog level
|
|
* @param optimized_model_filepath filepath of the optimized model to dump.
|
|
* @returns a session option handle. Caller must release it after use by calling OrtReleaseSessionOptions().
|
|
*/
|
|
ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions(size_t graph_optimization_level,
|
|
bool enable_cpu_mem_arena,
|
|
bool enable_mem_pattern,
|
|
size_t execution_mode,
|
|
bool enable_profiling,
|
|
const char* profile_file_prefix,
|
|
const char* log_id,
|
|
size_t log_severity_level,
|
|
size_t log_verbosity_level,
|
|
const char* optimized_model_filepath);
|
|
|
|
/**
|
|
* append an execution provider for a session.
|
|
* @param name the name of the execution provider
|
|
* @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
|
|
*/
|
|
int EMSCRIPTEN_KEEPALIVE OrtAppendExecutionProvider(ort_session_options_handle_t session_options,
|
|
const char* name);
|
|
|
|
/**
|
|
* add a free dimension override for one dimension of a session's input.
|
|
*/
|
|
int EMSCRIPTEN_KEEPALIVE OrtAddFreeDimensionOverride(ort_session_options_handle_t session_options,
|
|
const char* dim_param_name,
|
|
int dim_value);
|
|
|
|
/**
|
|
* store configurations for a session.
|
|
* @param session_options a handle to session options created by OrtCreateSessionOptions
|
|
* @param config_key configuration keys and value formats are defined in
|
|
* include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
|
|
* @param config_value value for config_key
|
|
* @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
|
|
*/
|
|
int EMSCRIPTEN_KEEPALIVE OrtAddSessionConfigEntry(ort_session_options_handle_t session_options,
|
|
const char* config_key,
|
|
const char* config_value);
|
|
|
|
/**
|
|
* release the specified ORT session options.
|
|
*/
|
|
void EMSCRIPTEN_KEEPALIVE OrtReleaseSessionOptions(ort_session_options_handle_t session_options);
|
|
|
|
/**
|
|
* create an instance of ORT session.
|
|
* @param data a pointer to a buffer that contains the ONNX or ORT format model.
|
|
* @param data_length the size of the buffer in bytes.
|
|
* @returns an ORT session handle. Caller must release it after use by calling OrtReleaseSession().
|
|
*/
|
|
ort_session_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSession(void* data,
|
|
size_t data_length,
|
|
ort_session_options_handle_t session_options);
|
|
|
|
/**
|
|
* release the specified ORT session.
|
|
*/
|
|
void EMSCRIPTEN_KEEPALIVE OrtReleaseSession(ort_session_handle_t session);
|
|
|
|
/**
|
|
* get model's input count and output count.
|
|
* @param session handle of the specified session
|
|
* @param input_count [out] a pointer to a size_t variable to accept input_count.
|
|
* @param output_count [out] a pointer to a size_t variable to accept output_count.
|
|
* @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
|
|
*/
|
|
int EMSCRIPTEN_KEEPALIVE OrtGetInputOutputCount(ort_session_handle_t session,
|
|
size_t* input_count,
|
|
size_t* output_count);
|
|
|
|
/**
|
|
* get the model's input name.
|
|
* @param session handle of the specified session
|
|
* @param index the input index
|
|
* @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by
|
|
* calling OrtFree().
|
|
*/
|
|
char* EMSCRIPTEN_KEEPALIVE OrtGetInputName(ort_session_handle_t session, size_t index);
|
|
/**
|
|
* get the model's output name.
|
|
* @param session handle of the specified session
|
|
* @param index the output index
|
|
* @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by
|
|
* calling OrtFree().
|
|
*/
|
|
char* EMSCRIPTEN_KEEPALIVE OrtGetOutputName(ort_session_handle_t session, size_t index);
|
|
|
|
/**
|
|
* free the specified buffer.
|
|
* @param ptr a pointer to the buffer.
|
|
*/
|
|
void EMSCRIPTEN_KEEPALIVE OrtFree(void* ptr);
|
|
|
|
/**
|
|
* create an instance of ORT tensor.
|
|
* @param data_type data type defined in enum ONNXTensorElementDataType.
|
|
* @param data for numeric tensor: a pointer to the tensor data buffer. for string tensor: a pointer to a C-Style null terminated string array.
|
|
* @param data_length size of the buffer 'data' in bytes.
|
|
* @param dims a pointer to an array of dims. the array should contain (dims_length) element(s).
|
|
* @param dims_length the length of the tensor's dimension
|
|
* @returns a tensor handle. Caller must release it after use by calling OrtReleaseTensor().
|
|
*/
|
|
ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length);
|
|
|
|
/**
|
|
* get type, shape info and data of the specified tensor.
|
|
* @param tensor handle of the tensor.
|
|
* @param data_type [out] specify the memory to write data type
|
|
* @param data [out] specify the memory to write the tensor data. for string tensor: an array of C-Style null terminated string.
|
|
* @param dims [out] specify the memory to write address of the buffer containing value of each dimension.
|
|
* @param dims_length [out] specify the memory to write dims length
|
|
* @remarks following temporary buffers are allocated during the call. Caller must release the buffers after use by calling OrtFree():
|
|
* 'dims' (for all types of tensor), 'data' (only for string tensor)
|
|
* @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
|
|
*/
|
|
int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, int* data_type, void** data, size_t** dims, size_t* dims_length);
|
|
|
|
/**
|
|
* release the specified tensor.
|
|
*/
|
|
void EMSCRIPTEN_KEEPALIVE OrtReleaseTensor(ort_tensor_handle_t tensor);
|
|
|
|
/**
|
|
* create an instance of ORT run options.
|
|
* @param log_severity_level verbose, info, warning, error or fatal
|
|
* @param log_verbosity_level vlog level
|
|
* @param terminate if true, all incomplete OrtRun calls will exit as soon as possible
|
|
* @param tag tag for this run
|
|
* @returns a run option handle. Caller must release it after use by calling OrtReleaseRunOptions().
|
|
*/
|
|
ort_run_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateRunOptions(size_t log_severity_level,
|
|
size_t log_verbosity_level,
|
|
bool terminate,
|
|
const char* tag);
|
|
|
|
/**
|
|
* set a single run configuration entry
|
|
* @param run_options a handle to run options created by OrtCreateRunOptions
|
|
* @param config_key configuration keys and value formats are defined in
|
|
* include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h
|
|
* @param config_value value for config_key
|
|
* @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
|
|
*/
|
|
int EMSCRIPTEN_KEEPALIVE OrtAddRunConfigEntry(ort_run_options_handle_t run_options,
|
|
const char* config_key,
|
|
const char* config_value);
|
|
|
|
/**
|
|
* release the specified ORT run options.
|
|
*/
|
|
void EMSCRIPTEN_KEEPALIVE OrtReleaseRunOptions(ort_run_options_handle_t run_options);
|
|
|
|
/**
|
|
* inference the model.
|
|
* @param session handle of the specified session
|
|
* @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
|
|
*/
|
|
int EMSCRIPTEN_KEEPALIVE OrtRun(ort_session_handle_t session,
|
|
const char** input_names,
|
|
const ort_tensor_handle_t* inputs,
|
|
size_t input_count,
|
|
const char** output_names,
|
|
size_t output_count,
|
|
ort_tensor_handle_t* outputs,
|
|
ort_run_options_handle_t run_options);
|
|
|
|
/**
|
|
* end profiling.
|
|
* @param session handle of the specified session
|
|
* @returns a pointer to a buffer which contains C-style string of profile filename.
|
|
* Caller must release the C style string after use by calling OrtFree().
|
|
*/
|
|
char* EMSCRIPTEN_KEEPALIVE OrtEndProfiling(ort_session_handle_t session);
|
|
|
|
// Training API Section
|
|
|
|
#ifdef ENABLE_TRAINING_APIS
|
|
/**
|
|
* @brief Load the checkpoint for training.
|
|
*
|
|
* @param checkpoint_data_buffer pointer to a buffer containing the CheckpointState
|
|
* @param checkpoint_size size of the CheckpointState in bytes
|
|
* @return ort_training_checkpoint_handle_t
|
|
*/
|
|
ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint(void* checkpoint_data_buffer, size_t checkpoint_size);
|
|
|
|
/**
|
|
* @brief Release the specified ORT training checkpoint state.
|
|
*
|
|
* @param training_checkpoint_state_handle handle for the CheckpointState
|
|
*/
|
|
void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_handle_t training_checkpoint_state_handle);
|
|
|
|
/**
|
|
* Creates an instance of a training session that can be used to begin or resume training from a given checkpoint state
|
|
* for the given onnx models.
|
|
* @param options Session options that the user can customize for this training session.
|
|
* @param training_checkpoint_state_handle Training states that the training session uses as a starting point for training.
|
|
* @param train_model pointer to a buffer containing the ONNX training model
|
|
* @param train_size size of the train_model buffer in bytes
|
|
* @param eval_model pointer to a buffer containing the ONNX evaluation model
|
|
* @param eval_size size of the eval_model buffer in bytes
|
|
* @param optimizer_model pointer to a buffer containing the ONNX optimizer model
|
|
* @param optimizer_size size of the optimizer_model buffer in bytes
|
|
* @return a handle of the ORT training session
|
|
*
|
|
*/
|
|
ort_training_session_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingCreateSession(ort_session_options_handle_t options,
|
|
ort_training_checkpoint_handle_t training_checkpoint_state_handle,
|
|
void* train_model,
|
|
size_t train_size,
|
|
void* eval_model,
|
|
size_t eval_size,
|
|
void* optimizer_model,
|
|
size_t optimizer_size);
|
|
|
|
/**
|
|
* Resets the gradients of all trainable parameters to zero for the specified TrainingSession
|
|
* @param training_handle handle of the training session
|
|
* @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
|
|
*/
|
|
int EMSCRIPTEN_KEEPALIVE OrtTrainingLazyResetGrad(ort_training_session_handle_t training_handle);
|
|
|
|
/**
|
|
* @brief Run a single training step.
|
|
*
|
|
* @param training_handle session handle of the specified session
|
|
* @param inputs user inputs to the training model
|
|
* @param input_count number of user inputs to the training model
|
|
* @param outputs [out] user outputs computed by train step
|
|
* @param output_count [out] number of user outputs expected from this train step
|
|
* @param run_options handle of the run options
|
|
* @return int ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
|
|
*/
|
|
int EMSCRIPTEN_KEEPALIVE OrtTrainingRunTrainStep(ort_training_session_handle_t training_handle,
|
|
ort_tensor_handle_t* inputs, size_t input_count,
|
|
ort_tensor_handle_t* outputs,
|
|
size_t output_count,
|
|
ort_run_options_handle_t run_options = nullptr);
|
|
|
|
/**
|
|
* Performs weight updates for the trainable parameters in the given training session using the optimizer model.
|
|
* @param training_handle handle of the training session
|
|
* @param run_options optional parameter of run options for this training step
|
|
* @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
|
|
*/
|
|
int EMSCRIPTEN_KEEPALIVE OrtTrainingOptimizerStep(ort_training_session_handle_t training_handle,
|
|
ort_run_options_handle_t run_options = nullptr);
|
|
|
|
/**
|
|
* Computs outputs for the eval model associated with the given training session.
|
|
* @param training_handle handle of the training session
|
|
* @param options run options for this eval step
|
|
* @param input_count number of user inputs to the eval model
|
|
* @param inputs the user inputs to the eval model
|
|
* @param output_count [out] number of user outputs expected from this eval step
|
|
* @param outputs [out] user outputs computed by the eval step
|
|
* @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
|
|
*/
|
|
int EMSCRIPTEN_KEEPALIVE OrtTrainingEvalStep(ort_training_session_handle_t training_handle,
|
|
ort_tensor_handle_t* inputs,
|
|
size_t input_count,
|
|
ort_tensor_handle_t* outputs,
|
|
size_t output_count,
|
|
ort_run_options_handle_t options = nullptr);
|
|
|
|
/**
|
|
* Retrieves the size of all parameters for the training state.
|
|
* When the trainable_only argument is true, the size is calculated for trainable params only.
|
|
*
|
|
* @param training_handle handle of the training session
|
|
* @param param_size [out] size of all parameter elements
|
|
* @param trainable_only skips non-trainable parameters when true.
|
|
* @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
|
|
*/
|
|
int EMSCRIPTEN_KEEPALIVE OrtTrainingGetParametersSize(ort_training_session_handle_t training_handle,
|
|
size_t* param_size,
|
|
bool trainable_only);
|
|
|
|
/**
|
|
* Copy all parameters to a contiguous buffer held by the argument parameters_buffer
|
|
*
|
|
* User is responsible for allocating and freeing resources used by the parameters_buffer.
|
|
* Parameter ordering is preserved.
|
|
*
|
|
* @param training_handle handle of the training session
|
|
* @param parameters_buffer [out] pre-allocated OrtValue buffer to copy onto. Must be same size as results of
|
|
* GetParametersSize api call
|
|
* @param parameter_count number of parameters expected in the parameters_buffer
|
|
* @param trainable_only whether to skip non-trainable parameters
|
|
* @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
|
|
*/
|
|
int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersToBuffer(ort_training_session_handle_t training_handle,
|
|
ort_tensor_handle_t parameters_buffer,
|
|
size_t parameter_count,
|
|
bool trainable_only);
|
|
|
|
/**
|
|
* Copy parameters values from given contiguous buffer held by parameters_buffer to the training state.
|
|
* Parameter ordering is preserved.
|
|
* @param training_handle handle of the training session
|
|
* @param parameters_buffer OrtValue buffer to copy from. Must be same size as results of
|
|
* GetParametersSize api call
|
|
* @param parameter_count number of parameters expected in the parameters_buffer
|
|
* @param trainable_only whether to skip non-trainable parameters
|
|
* @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message.
|
|
*/
|
|
int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_session_handle_t training_handle,
|
|
ort_tensor_handle_t parameters_buffer,
|
|
size_t parameter_count,
|
|
bool trainable_only);
|
|
|
|
/**
|
|
* @brief Release the specified ORT training session.
|
|
*
|
|
* @param training_session_handle handle of the training session
|
|
*/
|
|
void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_session_handle);
|
|
|
|
#endif
|
|
};
|