From 8ee8fdd59b7830bb0edde7c14067a73e57e2735e Mon Sep 17 00:00:00 2001 From: pengwa Date: Sat, 16 Apr 2022 04:42:10 +0800 Subject: [PATCH] Add training api test runner (#10972) * add api test runner * add build flag for training_api * address review comments * some fixes * address more comments * make the build pass by filling in empty implementation * fix more --- cmake/CMakeLists.txt | 5 + cmake/onnxruntime_training.cmake | 26 +- .../test/training_api/test_runner.cc | 301 ++++++++++++++++++ .../orttraining/training_api/interfaces.h | 232 ++++++++++++++ 4 files changed, 562 insertions(+), 2 deletions(-) create mode 100644 orttraining/orttraining/test/training_api/test_runner.cc create mode 100644 orttraining/orttraining/training_api/interfaces.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 1811e99a0d..e5793057a2 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -137,6 +137,7 @@ option(onnxruntime_FUZZ_TEST "Enable Fuzz testing" OFF) option(onnxruntime_ENABLE_NVTX_PROFILE "Enable NVTX profile." OFF) option(onnxruntime_ENABLE_MEMORY_PROFILE "Enable memory profile." OFF) option(onnxruntime_ENABLE_TRAINING "Enable training functionality." OFF) +option(onnxruntime_ENABLE_TRAINING_ON_DEVICE "Enable training on device." OFF) option(onnxruntime_ENABLE_TRAINING_OPS "Include training operators but no training session support." OFF) option(onnxruntime_ENABLE_TRAINING_TORCH_INTEROP "Enable training kernels interop with torch." OFF) option(onnxruntime_ENABLE_TRAINING_E2E_TESTS "Enable training end-to-end tests." OFF) @@ -1811,6 +1812,10 @@ if (onnxruntime_USE_DML) include(dml) endif() +if (onnxruntime_ENABLE_TRAINING_ON_DEVICE) + add_compile_definitions(ENABLE_TRAINING_ON_DEVICE) +endif() + if (onnxruntime_ENABLE_TRAINING_OPS) add_compile_definitions(ENABLE_TRAINING_OPS) endif() diff --git a/cmake/onnxruntime_training.cmake b/cmake/onnxruntime_training.cmake index 67388195b7..097484a41f 100644 --- a/cmake/onnxruntime_training.cmake +++ b/cmake/onnxruntime_training.cmake @@ -18,12 +18,12 @@ file(GLOB_RECURSE onnxruntime_training_srcs ) # This needs to be built in framework.cmake -file(GLOB_RECURSE onnxruntime_training_framework_excude_srcs CONFIGURE_DEPENDS +file(GLOB_RECURSE onnxruntime_training_framework_excluded_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/core/framework/torch/*.h" "${ORTTRAINING_SOURCE_DIR}/core/framework/torch/*.cc" ) -list(REMOVE_ITEM onnxruntime_training_srcs ${onnxruntime_training_framework_excude_srcs}) +list(REMOVE_ITEM onnxruntime_training_srcs ${onnxruntime_training_framework_excluded_srcs}) onnxruntime_add_static_library(onnxruntime_training ${onnxruntime_training_srcs}) add_dependencies(onnxruntime_training onnx tensorboard ${onnxruntime_EXTERNAL_DEPENDENCIES}) @@ -230,4 +230,26 @@ if (onnxruntime_BUILD_UNIT_TESTS) target_link_libraries(onnxruntime_training_gpt2 PRIVATE onnxruntime_training_runner onnxruntime_training ${ONNXRUNTIME_LIBS} ${onnxruntime_EXTERNAL_LIBRARIES}) set_target_properties(onnxruntime_training_gpt2 PROPERTIES FOLDER "ONNXRuntimeTest") + # Training API Tests + # Currently disable it by default for internal development usage. + if (onnxruntime_ENABLE_TRAINING_ON_DEVICE) + file(GLOB_RECURSE training_api_test_runner_src + "${ORTTRAINING_SOURCE_DIR}/test/training_api/*.h" + "${ORTTRAINING_SOURCE_DIR}/test/training_api/*.cc" + ) + onnxruntime_add_executable(onnxruntime_training_api_test_runner ${training_api_test_runner_src}) + + if(UNIX AND NOT APPLE) + if (HAS_NO_MAYBE_UNINITIALIZED) + target_compile_options(onnxruntime_training_api_test_runner PUBLIC "-Wno-maybe-uninitialized") + endif() + endif() + + onnxruntime_add_include_to_target(onnxruntime_training_api_test_runner onnxruntime_common onnx onnx_proto ${PROTOBUF_LIB} onnxruntime_training flatbuffers) + target_include_directories(onnxruntime_training_api_test_runner PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${ORTTRAINING_ROOT} ${MPI_CXX_INCLUDE_DIRS} ${eigen_INCLUDE_DIRS} ${CXXOPTS} ${extra_includes} ${onnxruntime_graph_header} ${onnxruntime_exec_src_dir} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/onnx) + + target_link_libraries(onnxruntime_training_api_test_runner PRIVATE onnxruntime_training ${ONNXRUNTIME_LIBS} ${onnxruntime_EXTERNAL_LIBRARIES}) + set_target_properties(onnxruntime_training_api_test_runner PROPERTIES FOLDER "ONNXRuntimeTest") + endif() + endif() diff --git a/orttraining/orttraining/test/training_api/test_runner.cc b/orttraining/orttraining/test/training_api/test_runner.cc new file mode 100644 index 0000000000..868efb75f8 --- /dev/null +++ b/orttraining/orttraining/test/training_api/test_runner.cc @@ -0,0 +1,301 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "cxxopts.hpp" +#include "core/util/math.h" +#include "core/common/common.h" +#include "core/common/logging/logging.h" +#include "core/common/logging/sinks/clog_sink.h" +#include "core/providers/cpu/cpu_execution_provider.h" +#include "core/session/environment.h" +#include "core/session/inference_session.h" +#include "core/providers/cpu/cpu_provider_factory_creator.h" +#include "orttraining/core/framework/tensorboard/event_writer.h" +#include "orttraining/training_api/interfaces.h" + +using namespace onnxruntime; +using namespace onnxruntime::common; +using namespace onnxruntime::training; +using namespace onnxruntime::training::tensorboard; +using namespace onnxruntime::training::api_test; +using namespace std; + +#ifdef USE_CUDA +namespace onnxruntime { + +std::shared_ptr CreateExecutionProviderFactory_Cuda(const OrtCUDAProviderOptions* provider_options); +std::unique_ptr CreateCUDAPinnedAllocator(int16_t device_id, const char* name); + +} // namespace onnxruntime +#endif + +static SessionOptions session_options; + +struct TestRunnerParameters { + PathString model_training_graph_path; + PathString model_evaluation_graph_path; + PathString optimizer_training_graph_path; + // path to checkpoint to load + PathString checkpoint_to_load_path; + std::string model_name; + + PathString train_data_dir; + PathString test_data_dir; + PathString output_dir; // Output of training, e.g., trained model files. + + size_t train_batch_size; + size_t num_train_epochs; + size_t eval_batch_size; + size_t eval_interval; + size_t checkpoint_interval; + int gradient_accumulation_steps = 1; + + // Allocator to use for allocating inputs from the dataset (optional). + AllocatorPtr input_allocator; + std::unique_ptr provider; +}; + +struct OrtTestRunnerParameters { + logging::Severity log_severity{logging::Severity::kWARNING}; + int vlog_level{-1}; +}; + +Status ParseArguments(int argc, char* argv[], TestRunnerParameters& params, OrtTestRunnerParameters& ort_params) { + cxxopts::Options options("Training API Test", "Main Program to test training C++ APIs."); + // clang-format off + options + .add_options() + ("model_training_graph_path", "The path to the training model to load. ", + cxxopts::value()->default_value("")) + ("model_evaluation_graph_path", "The path to the evaluation model to load. ", + cxxopts::value()->default_value("")) + ("optimizer_training_graph_path", "The path to the optimizer graph to load. ", + cxxopts::value()->default_value("")) + ("checkpoint_to_load_path", + "The path to the checkpoint to load. If not provided, the latest " + "checkpoint in checkpoints_dir, if any, is used.", + cxxopts::value()->default_value("")) + ("model_name", + "The name of the model.", + cxxopts::value()->default_value("model_test")) + + ("train_data_dir", "Input ONNX example files (can be a glob or comma separated).", + cxxopts::value()->default_value("bert_data/128/books_wiki_en_corpus/train")) + ("test_data_dir", "Input ONNX example files (can be a glob or comma separated).", + cxxopts::value()->default_value("bert_data/128/books_wiki_en_corpus/test")) + ("output_dir", "The output directory where the trained model files will be written.", + cxxopts::value()->default_value("")) + + ("train_batch_size", "Total batch size for training.", cxxopts::value()) + ("eval_batch_size", "Total batch size for eval.", cxxopts::value()) + ("num_train_epochs", "Total number of training epochs to perform.", cxxopts::value()->default_value("100")) + ("eval_interval", "Number of training steps before doing evaluation.", cxxopts::value()->default_value("1000")) + ("checkpoint_interval", "Number of training steps before saving checkpoint.", cxxopts::value()->default_value("1000")) + ("gradient_accumulation_steps", "The number of gradient accumulation steps before performing a backward/update pass.", + cxxopts::value()->default_value("1")); + + options + .add_options("ORT configuration") + ("ort_log_severity", "ORT minimum logging severity (see onnxruntime::logging::Severity values)", + cxxopts::value()->default_value("2"/*logging::Severity::kWARNING*/)) + ("ort_vlog_level", "ORT maximum VLOG level (verbose debug logging)", + cxxopts::value()->default_value("-1")); + // clang-format on + + try { + auto flags = options.parse(argc, argv); + + params.model_training_graph_path = ToPathString(flags["model_training_graph_path"].as()); + params.model_evaluation_graph_path = ToPathString(flags["model_evaluation_graph_path"].as()); + params.optimizer_training_graph_path = ToPathString(flags["optimizer_training_graph_path"].as()); + params.checkpoint_to_load_path = ToPathString(flags["checkpoint_to_load_path"].as()); + params.model_name = flags["model_name"].as(); + + params.train_batch_size = flags["train_batch_size"].as(); + if (flags.count("eval_batch_size")) { + params.eval_batch_size = flags["eval_batch_size"].as(); + } else { + params.eval_batch_size = params.train_batch_size; + } + params.num_train_epochs = flags["num_train_epochs"].as(); + params.eval_interval = flags["eval_interval"].as(); + params.checkpoint_interval = flags["checkpoint_interval"].as(); + + params.gradient_accumulation_steps = flags["gradient_accumulation_steps"].as(); + if (params.gradient_accumulation_steps < 1) { + return Status(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid gradient_accumulation_steps parameter: should be >= 1"); + } + + params.train_data_dir = ToPathString(flags["train_data_dir"].as()); + params.test_data_dir = ToPathString(flags["test_data_dir"].as()); + params.output_dir = ToPathString(flags["output_dir"].as()); + if (params.output_dir.empty()) { + printf("No output directory specified. Trained model files will not be saved.\n"); + } + + session_options.use_deterministic_compute = flags["use_deterministic_compute"].as(); + + ort_params.log_severity = static_cast(flags["ort_log_severity"].as()); + ORT_RETURN_IF_NOT( + logging::Severity::kVERBOSE <= ort_params.log_severity && ort_params.log_severity <= logging::Severity::kFATAL, + "Log severity must be in the range [", static_cast(logging::Severity::kVERBOSE), + ", ", static_cast(logging::Severity::kFATAL), "]."); + ort_params.vlog_level = flags["ort_vlog_level"].as(); + } catch (const exception& e) { + const std::string msg = "Failed to parse the command line arguments"; + cerr << msg << ": " << e.what() << "\n" + << options.help() << "\n"; + return Status(ONNXRUNTIME, INVALID_ARGUMENT, msg); + } + + return Status::OK(); +} + +template +static void CreateInputOrtValue(gsl::span dims, + const std::vector& value, + OrtValue* p_ortvalue, + AllocatorPtr alloc = nullptr) { + static CPUExecutionProviderInfo info; + static CPUExecutionProvider cpu_provider(info); + static AllocatorPtr cpu_allocator = cpu_provider.GetAllocator(0, OrtMemTypeDefault); + + TensorShape shape(dims); + assert(shape.Size() == static_cast(value.size())); + auto element_type = DataTypeImpl::GetType(); + auto allocator = alloc ? alloc : cpu_allocator; + auto p_tensor = std::make_unique(element_type, shape, allocator); + + if (value.size() > 0) { + memcpy(p_tensor->MutableDataRaw(), value.data(), p_tensor->SizeInBytes()); + } + + p_ortvalue->Init(p_tensor.release(), + DataTypeImpl::GetType(), + DataTypeImpl::GetType()->GetDeleteFunc()); +} + +std::vector> CreateSyntheticDataLoader(size_t batch_size, + AllocatorPtr alloc = nullptr) { + OrtValue input, positions; + // hard coded each sample to have 4 elements so far. + // todo: we can make it support more generic once we are clear what our offline process graph needed. + CreateInputOrtValue(std::array{4}, std::vector{1, 2, 3, 4}, &input, alloc = alloc); + CreateInputOrtValue(std::array{4}, std::vector{1, 2, 3, 3}, &positions, alloc = alloc); + return std::vector>(batch_size, std::vector{input, positions}); +} + +float GetLossValue(OrtValue& ort_value) { + const Tensor& loss_tensor = ort_value.Get(); + float loss = 0; + if (DataTypeImpl::GetType() == loss_tensor.DataType()) { + loss = *(loss_tensor.template Data()); + } else { + ORT_THROW("loss data type not supported."); + } + return loss; +} + +Status RunTraining(const TestRunnerParameters& params) { + std::string tensorboard_file = params.output_dir + "/tb.event"; + std::shared_ptr tensorboard = std::make_shared(tensorboard_file); + + api_test::utils::CheckpointStates state_dicts; + ORT_ENFORCE(api_test::utils::Ort_Load(params.checkpoint_to_load_path, state_dicts).IsOK()); + + Module module(params.model_training_graph_path, + state_dicts.named_parameters, + params.model_evaluation_graph_path); + + Optimizer optimizer(params.optimizer_training_graph_path, + state_dicts.named_parameters); + +#ifdef USE_CUDA + api_test::utils::SetExecutionProvider(module, optimizer, params.provider.get()); +#endif + + auto scheduler = std::make_unique(optimizer, 0.3333f, 1.0f, 5); + std::vector> + data_loader = CreateSyntheticDataLoader(params.train_batch_size, + params.input_allocator); + + size_t NUM_EPOCHS = params.num_train_epochs; + size_t GRAD_ACC_STEPS = params.gradient_accumulation_steps; + size_t EVAL_STEPS = params.eval_interval; + size_t SAVE_STEPS = params.checkpoint_interval; + std::string tag("train"); + + for (size_t epoch = 0, batch_idx = 0; epoch < NUM_EPOCHS; ++epoch) { + for (auto it = data_loader.begin(); it != data_loader.end(); ++it) { + std::vector& inputs = *it; + std::vector fetches; + ORT_ENFORCE(module.TrainStep(inputs, fetches).IsOK()); + + float loss = GetLossValue(fetches[3]); + tensorboard->AddSummary(std::to_string(loss), batch_idx, tag); + std::cout << "Batch # : " << batch_idx << " Loss: " << loss << std::endl; + + if (batch_idx % GRAD_ACC_STEPS == 0) { + // gradient accumulation steps completed + ORT_ENFORCE(optimizer.Step().IsOK()); + // modify learning rate + ORT_ENFORCE(scheduler->Step().IsOK()); + ORT_ENFORCE(optimizer.ResetGrad().IsOK()); + } + + if (batch_idx % EVAL_STEPS == 0) { + std::vector eval_results; + ORT_ENFORCE(module.EvalStep(inputs, eval_results).IsOK()); + } + + if (batch_idx % SAVE_STEPS == 0) { + // save trained weights + api_test::utils::CheckpointStates state_dicts_to_save; + ORT_ENFORCE(module.GetStateDict(state_dicts_to_save.named_parameters).IsOK()); + ORT_ENFORCE(optimizer.GetStateDict(state_dicts_to_save.optimizer_states).IsOK()); + std::string ckpt_file = params.output_dir + "/ckpt_" + params.model_name + std::to_string(batch_idx); + ORT_ENFORCE(api_test::utils::Ort_Save(state_dicts_to_save, ckpt_file).IsOK()); + } + + batch_idx++; + } + } + + return Status::OK(); +} + +#define RETURN_IF_FAIL(expr) \ + do { \ + auto status = (expr); \ + if ((!status.IsOK())) { \ + printf("Fail: %s \n", status.ErrorMessage().c_str()); \ + return -1; \ + } \ + } while (0); + +int main(int argc, char* argv[]) { + TestRunnerParameters params; + OrtTestRunnerParameters ort_params{}; + RETURN_IF_FAIL(ParseArguments(argc, argv, params, ort_params)); + + // setup logger, be noted: LOGS_DEFAULT must be after logging manager initialization. + string default_logger_id{"Default"}; + logging::LoggingManager default_logging_manager{std::make_unique(), + ort_params.log_severity, + false, + logging::LoggingManager::InstanceType::Default, + &default_logger_id, + ort_params.vlog_level}; +#ifdef USE_CUDA + OrtCUDAProviderOptions provider_options{}; + if (auto factory = CreateExecutionProviderFactory_Cuda(&provider_options)) + params.provider = std::move(factory->CreateProvider()); + + params.input_allocator = CreateCUDAPinnedAllocator(provider_options.device_id, CUDA_PINNED); +#endif + + // start training session + RETURN_IF_FAIL(RunTraining(params)); + return 0; +} diff --git a/orttraining/orttraining/training_api/interfaces.h b/orttraining/orttraining/training_api/interfaces.h new file mode 100644 index 0000000000..1ae67accff --- /dev/null +++ b/orttraining/orttraining/training_api/interfaces.h @@ -0,0 +1,232 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(ENABLE_TRAINING) && defined(ENABLE_TRAINING_ON_DEVICE) + +namespace onnxruntime { +namespace training { +namespace api_test { + +class Parameter { + public: + // create parameter + Parameter(std::string /*name*/, const OrtValue& /*data*/) { + ORT_NOT_IMPLEMENTED("Not implemented."); + } + + // Return the mutable data + OrtValue& data() { return data_; } + std::string name() const { return name_; } + + // Return if trainable. The trainable property of a param + // cannot change over the lifetime of the on-device training + // session since the gradient graph is prebuilt for this setting. + bool requires_grad() const { return requires_grad_; } + + // Return the mutable gradient for trainable parameter + OrtValue& gradient() { return gradient_; } + std::string gradient_name() const { return gradient_name_; } + + // Reset and release the gradient buffer of this Parameter + Status ResetGrad() { + return Status::OK(); + } + // need to set grad but not public api + private: + OrtValue data_; + std::string name_; + + OrtValue gradient_; + std::string gradient_name_; + + // Whether the param is trainable. The optimizer state is + // only created for a trainable param + bool requires_grad_{true}; +}; + +class Module { + public: + // Initialize a module from an ORT inference session with loaded + // training ONNX model and load parameters + Module(const std::string& /*train_model_path_or_bytes*/, + std::unordered_map>& /*parameters*/, + const std::optional& /*eval_model_path_or_bytes*/) { + ORT_NOT_IMPLEMENTED("Not implemented."); + } + + // Return the trainable/nontrainable parameters + std::vector> parameters() const { + return parameters_; + } + std::unordered_map> named_parameters() const { + ORT_NOT_IMPLEMENTED("Not implemented."); + return {}; + } + + // Train Step – does forward and backward computation. The outputs will be the forward’s outputs. Gradients will be accumulated within the Parameter object + Status TrainStep(const std::vector& /*inputs*/, std::vector& /*outputs*/) { + ORT_NOT_IMPLEMENTED("Not implemented."); + return Status::OK(); + } + + // Eval Step – does forward computation. This will use a separate inference session + // and take in a separate inference graph, while sharing the parameters + Status EvalStep(const std::vector& /*inputs*/, std::vector& /*outputs*/) { + ORT_NOT_IMPLEMENTED("Not implemented."); + return Status::OK(); + } + + // Return the states of the module as a map. + Status GetStateDict(const std::unordered_map>& /*module_state_dict*/) { + ORT_NOT_IMPLEMENTED("Not implemented."); + return Status::OK(); + } + + private: + std::unique_ptr train_sess_; + std::unique_ptr eval_sess_; + std::vector> parameters_; +}; + +// Internal state +struct ParameterOptimizerState { + int64_t step_; + float learning_rate_; + // Per param optimizer state. E.g. For Adam and param_0, this would contain + // {“Moment_1_param_0”:, …}, + // It should be noted that the names should only be maintained to correlate with + // the graph inputs for the optimizer graph + std::map states_; +}; + +struct OptimizerState { + // overall state related to optimizer + int64_t step_; + float learning_rate_; + std::unordered_map optimizer_states_; +}; + +class Optimizer { + public: + // Initialize an optimizer module from an ORT inference session with loaded + // training ONNX model For each parameter, initialize the OptimizerState based + // on the graph input’s ValueInfoProto if the parameter doesn’t have it already. + Optimizer(const std::string& /*optim_path_or_bytes*/, + std::unordered_map>& /*parameters*/) { + ORT_NOT_IMPLEMENTED("Not implemented."); + } + + // Reset and release the gradient buffer of all trainable params + Status ResetGrad() { + ORT_NOT_IMPLEMENTED("Not implemented."); + return Status::OK(); + } + + // Optimizer Step. + Status Step() { + ORT_NOT_IMPLEMENTED("Not implemented."); + return Status::OK(); + } + + // Return the states of the optimizer as a map. + Status GetStateDict(const OptimizerState& /*optimizer_state_dict*/) { + ORT_NOT_IMPLEMENTED("Not implemented."); + return Status::OK(); + } + + protected: + int64_t GetStep() const { + ORT_NOT_IMPLEMENTED("Not implemented."); + return 0; + } + Status SetLearningRate(float /*lr*/) { + ORT_NOT_IMPLEMENTED("Not implemented."); + return Status::OK(); + } + + private: + std::unique_ptr optim_sess_; + std::vector> parameters_; + OptimizerState optimizer_state_; +}; + +class LearningRateScheduler { + public: + LearningRateScheduler(const Optimizer& optim) + : optim_(optim) { + ORT_NOT_IMPLEMENTED("Not implemented."); + } + + virtual ~LearningRateScheduler() = default; + + // Modify the current learning rate based on current step + virtual Status Step(/*int64_t step*/) = 0; + + const Optimizer& optim_; +}; + +class LinearScheduler : public LearningRateScheduler { + public: + explicit LinearScheduler(const Optimizer& optim, float start_factor, float end_factor, int64_t total_iters) + : LearningRateScheduler(optim), + start_factor_(start_factor), + end_factor_(end_factor), + total_iters_(total_iters) { + ORT_NOT_IMPLEMENTED("Not implemented."); + } + + // Fetch the step, calculate next value and set lr in optimizer + Status Step(/*int64_t step*/) override { + ORT_NOT_IMPLEMENTED("Not implemented."); + return Status::OK(); + } + + private: + float start_factor_; + float end_factor_; + int64_t total_iters_; +}; + +namespace utils { + +struct CheckpointProperty { + int value; + // Support primitive types like int, float, string leveraging type trait. +}; + +struct CheckpointStates { + CheckpointStates() { + ORT_NOT_IMPLEMENTED("Not implemented."); + } + std::unordered_map> named_parameters; + OptimizerState optimizer_states; + std::unordered_map named_properties; +}; + +// Save properties into a checkpoint property file (with postfix .prop). +Status Ort_Save(CheckpointStates& /*state_dicts*/, const PathString& /*checkpoint_path*/) { + ORT_NOT_IMPLEMENTED("Not implemented."); + return Status::OK(); +} + +// Load properties file having postfix being '.prop'. +Status Ort_Load(const PathString& /*checkpoint_path*/, CheckpointStates& /*state_dicts*/) { + ORT_NOT_IMPLEMENTED("Not implemented."); + return Status::OK(); +} + +/* + module.train_sess.RegisterExecutionProvider(provider); + module.eval_sess.RegisterExecutionProvider(provider); + optimizer.optim_sess.RegisterExecutionProvider(provider); +*/ +void SetExecutionProvider(const Module& /*module*/, const Optimizer& /*optimizer*/, IExecutionProvider* /*provider*/) { + ORT_NOT_IMPLEMENTED("Not implemented."); +} +} // namespace utils + +} // namespace api_test +} // namespace training +} // namespace onnxruntime + +#endif