From dfc7c18e3104df7e9ed8eb4f8e012b13bbcbafa8 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Fri, 5 Mar 2021 17:03:46 -0800 Subject: [PATCH] Introducing TrainingAgent interface to performance training using YieldOp (#6898) --- cmake/onnxruntime_training.cmake | 4 +- onnxruntime/core/session/inference_session.cc | 130 ----------------- onnxruntime/core/session/inference_session.h | 24 ---- .../onnxruntime_inference_collection.py | 16 --- .../python/onnxruntime_pybind_state.cc | 16 --- .../orttraining/core/agent/training_agent.cc | 136 ++++++++++++++++++ .../orttraining/core/agent/training_agent.h | 44 ++++++ .../python/orttraining_pybind_state.cc | 24 ++++ .../orttraining/python/training/__init__.py | 4 +- .../orttraining/python/training/ortmodule.py | 6 +- .../python/training/training_agent.py | 73 ++++++++++ 11 files changed, 285 insertions(+), 192 deletions(-) create mode 100644 orttraining/orttraining/core/agent/training_agent.cc create mode 100644 orttraining/orttraining/core/agent/training_agent.h create mode 100644 orttraining/orttraining/python/training/training_agent.py diff --git a/cmake/onnxruntime_training.cmake b/cmake/onnxruntime_training.cmake index 128b66d2fd..0e03d79bb0 100644 --- a/cmake/onnxruntime_training.cmake +++ b/cmake/onnxruntime_training.cmake @@ -13,7 +13,9 @@ file(GLOB_RECURSE onnxruntime_training_srcs "${ORTTRAINING_SOURCE_DIR}/core/framework/communication/*" "${ORTTRAINING_SOURCE_DIR}/core/session/*.h" "${ORTTRAINING_SOURCE_DIR}/core/session/*.cc" -) + "${ORTTRAINING_SOURCE_DIR}/core/agent/*.h" + "${ORTTRAINING_SOURCE_DIR}/core/agent/*.cc" + ) add_library(onnxruntime_training ${onnxruntime_training_srcs}) add_dependencies(onnxruntime_training onnx tensorboard ${onnxruntime_EXTERNAL_DEPENDENCIES}) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 880e3d8e86..e18d89068e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -56,10 +56,6 @@ #include "core/session/custom_ops.h" #endif -#ifdef ENABLE_TRAINING -#include "orttraining/training_ops/cpu/controlflow/ort_tasks.h" -#endif - using namespace ONNX_NAMESPACE; using namespace onnxruntime::experimental; using namespace onnxruntime::common; @@ -379,24 +375,6 @@ InferenceSession::~InferenceSession() { } } -#ifdef ENABLE_TRAINING - // TODO: Properly cancel outstanding background tasks - // Following implementation only handle the case where bg_thread is waiting for backward inputs - // Background thread can also be in other states, such as running Forward() or running Backward() - std::vector run_ids; - { - std::lock_guard lock(bg_threads_mutex_); - for (auto it = bg_threads_.begin(); it != bg_threads_.end(); ++it) { - run_ids.push_back(it->first); - } - } - for (int64_t run_id : run_ids) { - if (!onnxruntime::contrib::OrtTasks::GetInstance().TaskIsCompleted(run_id)) { - CancelBackgroundTask(run_id); - } - } -#endif - #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT if (session_activity_started_) TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity"); @@ -1739,114 +1717,6 @@ common::Status InferenceSession::Run(IOBinding& io_binding) { return Run(run_options, io_binding); } -#ifdef ENABLE_TRAINING -common::Status InferenceSession::RunInBackgroundAndWaitForYield(const RunOptions& run_options, IOBinding& io_binding, - std::vector& user_outputs, int64_t& run_id) { - std::promise setup_promise; - std::future setup_future = setup_promise.get_future(); - - // Passing run_options and io_binding by reference to the bg_thread, - // this is ok because they are ORTModule's member, and they are presistent through forward and backward calls - auto bg_thread = std::thread([this](std::future setup_future, const RunOptions& run_options, IOBinding& io_binding) { - // wait until task is properly setup - setup_future.get(); - - common::Status status = Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(), - &io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo()); - - onnxruntime::contrib::OrtTasks::GetInstance().SetStatus(status); - - // If forward outputs still hasn't been consumed at this point, i.e. forward function hasn't complete itself - // this indicates that Run() call returned before hitting YieldOp, due to hitting some exception during the forward subgraph execution - // In this case, we need to wake up the foreground thread and pass along the failed status. - // Otherwise, foreground thread will be stuck waiting for forward_outputs. - if (onnxruntime::contrib::OrtTasks::GetInstance().ForwardOutputsIsValid()) { - ORT_ENFORCE(!status.IsOK()); - // signal main thread for background thread completion - onnxruntime::contrib::OrtTasks::GetInstance().SetForwardOutputs(status, {}); - } - }, - std::move(setup_future), std::cref(run_options), std::ref(io_binding)); - - run_id = std::hash()(bg_thread.get_id()); - { - std::lock_guard lock(bg_threads_mutex_); - bg_threads_[run_id] = std::move(bg_thread); - } - - onnxruntime::contrib::OrtTasks::GetInstance().CreateBackgroundTask(run_id); - - LOGS(*session_logger_, VERBOSE) << "InferenceSession::Forward() call created a task with run_id " << run_id; - - // background task is setup, unblock background thread to continue - setup_promise.set_value(); - - // Wait for data/signal from - // 1. Yield op, if the bg thread sucessfully reached Yield's signal point - // 2. The end of bg thread, if it hit execptions and returned earlier - auto forward_outputs = onnxruntime::contrib::OrtTasks::GetInstance().WaitForForwardOutputs(run_id); - const Status& forward_status = forward_outputs.first; - user_outputs = std::move(forward_outputs.second); - - // background thread has completed without hitting Yield Op - if (!forward_status.IsOK()) { - std::thread thread; - { - std::lock_guard lock(bg_threads_mutex_); - std::swap(thread, bg_threads_[run_id]); - bg_threads_.erase(run_id); - } - ORT_ENFORCE(thread.joinable()); - thread.join(); - onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id); - return forward_status; - } - - return Status::OK(); -} - -common::Status InferenceSession::ContinueRunInBackground(int64_t run_id, const std::vector& backward_output_grads) { - LOGS(*session_logger_, VERBOSE) << "Running InferenceSession::Backward() with run_id " << run_id; - - // resume background thread - onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, backward_output_grads, false); - - Status bg_thread_status = onnxruntime::contrib::OrtTasks::GetInstance().WaitForStatus(run_id); - - std::thread bg_thread; - { - std::lock_guard lock(bg_threads_mutex_); - std::swap(bg_thread, bg_threads_[run_id]); - bg_threads_.erase(run_id); - } - - // wait for bg_thread to complete - ORT_ENFORCE(bg_thread.joinable()); - bg_thread.join(); - onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id); - - return bg_thread_status; -} - -void InferenceSession::CancelBackgroundTask(int64_t run_id) { - LOGS(*session_logger_, WARNING) << "Canceling background task with run_id " << run_id; - - // resume background thread with terminate = true - onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, {}, true); - - // wait for bg_thread to complete - std::thread bg_thread; - { - std::lock_guard lock(bg_threads_mutex_); - std::swap(bg_thread, bg_threads_[run_id]); - bg_threads_.erase(run_id); - } - ORT_ENFORCE(bg_thread.joinable()); - bg_thread.join(); - onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id); -} -#endif - template void InferenceSession::StartProfiling(const std::basic_string& file_prefix) { std::basic_ostringstream ss; diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index bb5e6e91d7..d81fa49618 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -6,11 +6,6 @@ #include #include -#ifdef ENABLE_TRAINING -#include -#include -#endif - #include "core/common/common.h" #include "core/common/logging/logging.h" #include "core/common/profiler.h" @@ -302,17 +297,6 @@ class InferenceSession { virtual common::Status Run(const RunOptions& run_options, IOBinding& io_binding) ORT_MUST_USE_RESULT; common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT; -#ifdef ENABLE_TRAINING - // For ORTModule.forward() - virtual common::Status RunInBackgroundAndWaitForYield(const RunOptions& run_options, IOBinding& io_binding, - std::vector& user_outputs, - int64_t& run_id) ORT_MUST_USE_RESULT; - - // For ORTModule.backward() - common::Status ContinueRunInBackground(int64_t run_id, const std::vector& backward_output_grads) ORT_MUST_USE_RESULT; - - void CancelBackgroundTask(int64_t run_id); -#endif /** * @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK. * @note lifetime of the returned pointer is valid as long as the Session object is live. @@ -677,14 +661,6 @@ class InferenceSession { // those into new OrtValue instances, at which point we won't free them until the InferenceSession goes away. std::vector ort_format_model_bytes_; -#ifdef ENABLE_TRAINING - // mutex for accessing bg_threads_ - std::mutex bg_threads_mutex_; - - // background threads for RunInBackgroundAndWaitForYield and ContinueRunInBackground - std::unordered_map bg_threads_; -#endif - std::shared_ptr allocator_manager_; }; diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 6bfe7a0a72..c0ced40ed4 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -228,22 +228,6 @@ class Session: """ self._sess.run_with_iobinding(iobinding._iobinding, run_options) - def run_forward(self, iobinding, run_options): - """ - Compute the forward subgraph until it hits the Yield Op. - :param iobinding: the iobinding object that has graph inputs/outputs bind. - :param run_options: See :class:`onnxruntime.RunOptions`. - """ - ortvalues, run_id = self._sess.run_forward(iobinding._iobinding, run_options) - return [OrtValue(ortvalue) for ortvalue in ortvalues], run_id - - def run_backward(self, backward_output_grads, run_id): - """ - Resume executing the backward subgraph starting from Yield Op. - :param backward_output_grads: Output gradients for backward. - """ - self._sess.run_backward([ortvalue._ortvalue for ortvalue in backward_output_grads], run_id) - class InferenceSession(Session): """ diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index e7c71ad84c..445f2a76ef 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1835,22 +1835,6 @@ including arg name, arg type (contains both type and shape).)pbdoc") if (!status.IsOK()) throw std::runtime_error("Error in execution: " + status.ErrorMessage()); }) -#ifdef ENABLE_TRAINING - .def("run_forward", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions& run_options) -> py::tuple { - std::vector module_outputs; - int64_t run_id; - Status status = sess->GetSessionHandle()->RunInBackgroundAndWaitForYield(run_options, *io_binding.Get(), module_outputs, run_id); - if (!status.IsOK()) { - throw std::runtime_error("Error in execution: " + status.ErrorMessage()); - } - return py::make_tuple(module_outputs, run_id); - }) - .def("run_backward", [](PyInferenceSession* sess, const std::vector& backward_output_grads, int64_t run_id) -> void { - Status status = sess->GetSessionHandle()->ContinueRunInBackground(run_id, backward_output_grads); - if (!status.IsOK()) - throw std::runtime_error("Error in execution: " + status.ErrorMessage()); - }) -#endif ; py::enum_(m, "ArenaExtendStrategy", py::arithmetic()) diff --git a/orttraining/orttraining/core/agent/training_agent.cc b/orttraining/orttraining/core/agent/training_agent.cc new file mode 100644 index 0000000000..1d1a978b86 --- /dev/null +++ b/orttraining/orttraining/core/agent/training_agent.cc @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/core/agent/training_agent.h" +#include "core/session/IOBinding.h" +#include "orttraining/training_ops/cpu/controlflow/ort_tasks.h" + +namespace onnxruntime { +namespace training { + +TrainingAgent::TrainingAgent(InferenceSession *session) : inference_session_(session) {} + +TrainingAgent::~TrainingAgent() { + // TODO: Properly cancel outstanding background tasks + // Following implementation only handle the case where bg_thread is waiting for backward inputs + // Background thread can also be in other states, such as running Forward() or running Backward() + std::vector run_ids; + { + std::lock_guard lock(bg_threads_mutex_); + for (auto it = bg_threads_.begin(); it != bg_threads_.end(); ++it) { + run_ids.push_back(it->first); + } + } + for (int64_t run_id : run_ids) { + if (!onnxruntime::contrib::OrtTasks::GetInstance().TaskIsCompleted(run_id)) { + CancelPendingBackwardRun(run_id); + } + } +}; + +common::Status TrainingAgent::RunForward(const RunOptions& run_options, onnxruntime::IOBinding& io_binding, + std::vector& user_outputs, int64_t& run_id) { + std::promise setup_promise; + std::future setup_future = setup_promise.get_future(); + + // Passing run_options and io_binding by reference to the bg_thread, + // this is ok because they are ORTModule's member, and they are presistent through forward and backward calls + auto bg_thread = std::thread([this](std::future setup_future, const RunOptions& run_options, onnxruntime::IOBinding& io_binding) { + // wait until task is properly setup + setup_future.get(); + + common::Status status = inference_session_->Run(run_options, io_binding); + + onnxruntime::contrib::OrtTasks::GetInstance().SetStatus(status); + + // If forward outputs still hasn't been consumed at this point, i.e. forward function hasn't complete itself + // this indicates that Run() call returned before hitting YieldOp, due to hitting some exception during the forward subgraph execution + // In this case, we need to wake up the foreground thread and pass along the failed status. + // Otherwise, foreground thread will be stuck waiting for forward_outputs. + if (onnxruntime::contrib::OrtTasks::GetInstance().ForwardOutputsIsValid()) { + ORT_ENFORCE(!status.IsOK()); + // signal main thread for background thread completion + onnxruntime::contrib::OrtTasks::GetInstance().SetForwardOutputs(status, {}); + } + }, std::move(setup_future), std::cref(run_options), std::ref(io_binding)); + + run_id = std::hash()(bg_thread.get_id()); + { + std::lock_guard lock(bg_threads_mutex_); + bg_threads_[run_id] = std::move(bg_thread); + } + + onnxruntime::contrib::OrtTasks::GetInstance().CreateBackgroundTask(run_id); + + LOGS(*inference_session_->GetLogger(), VERBOSE) << "InferenceSession::Forward() call created a task with run_id " << run_id; + + // background task is setup, unblock background thread to continue + setup_promise.set_value(); + + // Wait for data/signal from + // 1. Yield op, if the bg thread sucessfully reached Yield's signal point + // 2. The end of bg thread, if it hit execptions and returned earlier + auto forward_outputs = onnxruntime::contrib::OrtTasks::GetInstance().WaitForForwardOutputs(run_id); + const Status& forward_status = forward_outputs.first; + user_outputs = std::move(forward_outputs.second); + + // background thread has completed without hitting Yield Op + if (!forward_status.IsOK()) { + std::thread thread; + { + std::lock_guard lock(bg_threads_mutex_); + std::swap(thread, bg_threads_[run_id]); + bg_threads_.erase(run_id); + } + ORT_ENFORCE(thread.joinable()); + thread.join(); + onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id); + return forward_status; + } + + return Status::OK(); +} + +common::Status TrainingAgent::RunBackward(int64_t run_id, const std::vector& backward_output_grads) { + LOGS(*inference_session_->GetLogger(), VERBOSE) << "Running TrainingAgent::Backward() with run_id " << run_id; + + // resume background thread + onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, backward_output_grads, false); + + Status bg_thread_status = onnxruntime::contrib::OrtTasks::GetInstance().WaitForStatus(run_id); + + std::thread bg_thread; + { + std::lock_guard lock(bg_threads_mutex_); + std::swap(bg_thread, bg_threads_[run_id]); + bg_threads_.erase(run_id); + } + + // wait for bg_thread to complete + ORT_ENFORCE(bg_thread.joinable()); + bg_thread.join(); + onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id); + + return bg_thread_status; +} + +void TrainingAgent::CancelPendingBackwardRun(int64_t run_id) { + LOGS(*inference_session_->GetLogger(), WARNING) << "Canceling background task with run_id " << run_id; + + // resume background thread with terminate = true + onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, {}, true); + + // wait for bg_thread to complete + std::thread bg_thread; + { + std::lock_guard lock(bg_threads_mutex_); + std::swap(bg_thread, bg_threads_[run_id]); + bg_threads_.erase(run_id); + } + ORT_ENFORCE(bg_thread.joinable()); + bg_thread.join(); + onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id); +} + +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/agent/training_agent.h b/orttraining/orttraining/core/agent/training_agent.h new file mode 100644 index 0000000000..8918650705 --- /dev/null +++ b/orttraining/orttraining/core/agent/training_agent.h @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/common/common.h" +#include "core/common/logging/logging.h" +#include "core/framework/framework_common.h" +#include "core/session/IOBinding.h" +#include "core/session/inference_session.h" +#include "orttraining/training_ops/cpu/controlflow/ort_tasks.h" + + +namespace onnxruntime { +namespace training { +class IOBinding; + +class TrainingAgent { + + public: + explicit TrainingAgent(InferenceSession* session); + virtual ~TrainingAgent(); + // For ORTModule.forward() + virtual common::Status RunForward(const RunOptions& run_options, onnxruntime::IOBinding& io_binding, + std::vector& user_outputs, + int64_t& run_id) ORT_MUST_USE_RESULT; + // For ORTModule.backward() + common::Status RunBackward(int64_t run_id, const std::vector& backward_output_grads) ORT_MUST_USE_RESULT; + void CancelPendingBackwardRun(int64_t run_id); + + private: + // mutex for accessing bg_threads_ + std::mutex bg_threads_mutex_; + // background threads for RunInBackgroundAndWaitForYield and ContinueRunInBackground + std::unordered_map bg_threads_; + // TrainingAgent runs on a InferenceSession under the hood + InferenceSession* inference_session_; +}; + +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 02b3cae098..b241a7db58 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -9,6 +9,7 @@ #include "core/session/environment.h" #include "orttraining/core/session/training_session.h" +#include "orttraining/core/agent/training_agent.h" #include "orttraining/core/graph/optimizer_config.h" #include "orttraining/core/framework/communication/mpi/mpi_context.h" #include "orttraining/core/framework/module_gradient_graph_builder.h" @@ -474,6 +475,28 @@ void addObjectMethodsForTraining(py::module& m) { return static_cast(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name); }); +py::class_(m, "TrainingAgent", R"pbdoc(This is the main class used to run a ORTModule model.)pbdoc") + // In Python3, a Python bytes object will be passed to C++ functions that accept std::string or char* + // without any conversion. So this init method can be used for model file path (string) and model content (bytes) + .def(py::init([](PyInferenceSession * session) { + return onnxruntime::make_unique(session->GetSessionHandle()); + })) + .def("run_forward", [](TrainingAgent* agent, SessionIOBinding& io_binding, RunOptions& run_options) -> py::tuple { + std::vector module_outputs; + int64_t run_id; + Status status = agent->RunForward(run_options, *io_binding.Get(), module_outputs, run_id); + if (!status.IsOK()) { + throw std::runtime_error("Error in execution: " + status.ErrorMessage()); + } + return py::make_tuple(module_outputs, run_id); + }) + .def("run_backward", [](TrainingAgent* agent, const std::vector& backward_output_grads, int64_t run_id) -> void { + Status status = agent->RunBackward(run_id, backward_output_grads); + if (!status.IsOK()) + throw std::runtime_error("Error in execution: " + status.ErrorMessage()); + }) + ; + py::class_ module_gradient_graph_builder_config( m, "ModuleGradientGraphBuilderConfiguration", R"pbdoc(Configuration information for module gradient graph builder.)pbdoc"); @@ -518,5 +541,6 @@ void addObjectMethodsForTraining(py::module& m) { return module_gradient_graph_builder->GetTrainingGraphInfo(); }); } + } // namespace python } // namespace onnxruntime diff --git a/orttraining/orttraining/python/training/__init__.py b/orttraining/orttraining/python/training/__init__.py index 179e4c9f13..e8f5eff157 100644 --- a/orttraining/orttraining/python/training/__init__.py +++ b/orttraining/orttraining/python/training/__init__.py @@ -8,5 +8,5 @@ from onnxruntime.capi.training.training_session import TrainingSession from .orttrainer_options import ORTTrainerOptions from .orttrainer import ORTTrainer, TrainStepInfo from . import amp, checkpoint, optim, model_desc_validation - -from .ortmodule import ORTModule \ No newline at end of file +from .training_agent import TrainingAgent +from .ortmodule import ORTModule diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index de922ce286..812618f446 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -297,9 +297,9 @@ class ORTModule(torch.nn.Module): # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. session_options.log_severity_level = 2 - self._training_session = onnxruntime.InferenceSession( - self._onnx_training.SerializeToString(), session_options, providers=providers, provider_options=provider_options) - + self._training_session = onnxruntime.training.TrainingAgent(self._onnx_training.SerializeToString(), + session_options, providers, provider_options) + # Use this global run_options for now self._run_options = C.RunOptions() diff --git a/orttraining/orttraining/python/training/training_agent.py b/orttraining/orttraining/python/training/training_agent.py new file mode 100644 index 0000000000..599fcdb5a0 --- /dev/null +++ b/orttraining/orttraining/python/training/training_agent.py @@ -0,0 +1,73 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import onnxruntime +from onnxruntime.capi import _pybind_state as C +from onnxruntime.capi.onnxruntime_inference_collection import IOBinding, OrtValue +from onnxruntime.capi._pybind_state import TrainingAgent as C_TrainingAgent + + +class TrainingAgent(object): + """ + This is the main class used to run a ORTModule model. + """ + + def __init__(self, path_or_bytes, session_options=None, providers=None, provider_options=None): + """ + :param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string + :param sess_options: session options + :param providers: Optional sequence of providers in order of decreasing + precedence. Values can either be provider names or tuples of + (provider name, options dict). If not provided, then all available + providers are used with the default precedence. + :param provider_options: Optional sequence of options dicts corresponding + to the providers listed in 'providers'. + + The model type will be inferred unless explicitly set in the SessionOptions. + To explicitly set: + so = onnxruntime.SessionOptions() + so.add_session_config_entry('session.load_model_format', 'ONNX') or + so.add_session_config_entry('session.load_model_format', 'ORT') or + + A file extension of '.ort' will be inferred as an ORT format model. + All other filenames are assumed to be ONNX format models. + + 'providers' can contain either names or names and options. When any options + are given in 'providers', 'provider_options' should not be used. + + The list of providers is ordered by precedence. For example ['CUDAExecutionProvider', 'CPUExecutionProvider'] + means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider. + """ + + self._training_agent = None + self._inference_session = None + + self.create_training_agent(path_or_bytes, session_options, providers, provider_options) + + + def create_training_agent(self, path_or_bytes, session_options, providers, provider_options): + self._inference_session = onnxruntime.InferenceSession(path_or_bytes, session_options, + providers, provider_options) + self._training_agent = C_TrainingAgent(self._inference_session._sess) + + def io_binding(self): + "Return an onnxruntime.IOBinding object`." + return IOBinding(self._inference_session) + + def run_forward(self, iobinding, run_options): + """ + Compute the forward subgraph until it hits the Yield Op. + :param iobinding: the iobinding object that has graph inputs/outputs bind. + :param run_options: See :class:`onnxruntime.RunOptions`. + """ + ortvalues, run_id = self._training_agent.run_forward(iobinding._iobinding, run_options) + return [OrtValue(ortvalue) for ortvalue in ortvalues], run_id + + def run_backward(self, backward_output_grads, run_id): + """ + Resume executing the backward subgraph starting from Yield Op. + :param backward_output_grads: Output gradients for backward. + """ + self._training_agent.run_backward([ortvalue._ortvalue for ortvalue in backward_output_grads], run_id)