mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Introducing TrainingAgent interface to performance training using YieldOp (#6898)
This commit is contained in:
parent
79f832c682
commit
dfc7c18e31
11 changed files with 285 additions and 192 deletions
|
|
@ -13,7 +13,9 @@ file(GLOB_RECURSE onnxruntime_training_srcs
|
|||
"${ORTTRAINING_SOURCE_DIR}/core/framework/communication/*"
|
||||
"${ORTTRAINING_SOURCE_DIR}/core/session/*.h"
|
||||
"${ORTTRAINING_SOURCE_DIR}/core/session/*.cc"
|
||||
)
|
||||
"${ORTTRAINING_SOURCE_DIR}/core/agent/*.h"
|
||||
"${ORTTRAINING_SOURCE_DIR}/core/agent/*.cc"
|
||||
)
|
||||
|
||||
add_library(onnxruntime_training ${onnxruntime_training_srcs})
|
||||
add_dependencies(onnxruntime_training onnx tensorboard ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
|
|
|
|||
|
|
@ -56,10 +56,6 @@
|
|||
#include "core/session/custom_ops.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
#include "orttraining/training_ops/cpu/controlflow/ort_tasks.h"
|
||||
#endif
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::experimental;
|
||||
using namespace onnxruntime::common;
|
||||
|
|
@ -379,24 +375,6 @@ InferenceSession::~InferenceSession() {
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
// TODO: Properly cancel outstanding background tasks
|
||||
// Following implementation only handle the case where bg_thread is waiting for backward inputs
|
||||
// Background thread can also be in other states, such as running Forward() or running Backward()
|
||||
std::vector<int64_t> run_ids;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
|
||||
for (auto it = bg_threads_.begin(); it != bg_threads_.end(); ++it) {
|
||||
run_ids.push_back(it->first);
|
||||
}
|
||||
}
|
||||
for (int64_t run_id : run_ids) {
|
||||
if (!onnxruntime::contrib::OrtTasks::GetInstance().TaskIsCompleted(run_id)) {
|
||||
CancelBackgroundTask(run_id);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
if (session_activity_started_)
|
||||
TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity");
|
||||
|
|
@ -1739,114 +1717,6 @@ common::Status InferenceSession::Run(IOBinding& io_binding) {
|
|||
return Run(run_options, io_binding);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
common::Status InferenceSession::RunInBackgroundAndWaitForYield(const RunOptions& run_options, IOBinding& io_binding,
|
||||
std::vector<OrtValue>& user_outputs, int64_t& run_id) {
|
||||
std::promise<void> setup_promise;
|
||||
std::future<void> setup_future = setup_promise.get_future();
|
||||
|
||||
// Passing run_options and io_binding by reference to the bg_thread,
|
||||
// this is ok because they are ORTModule's member, and they are presistent through forward and backward calls
|
||||
auto bg_thread = std::thread([this](std::future<void> setup_future, const RunOptions& run_options, IOBinding& io_binding) {
|
||||
// wait until task is properly setup
|
||||
setup_future.get();
|
||||
|
||||
common::Status status = Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(),
|
||||
&io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo());
|
||||
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().SetStatus(status);
|
||||
|
||||
// If forward outputs still hasn't been consumed at this point, i.e. forward function hasn't complete itself
|
||||
// this indicates that Run() call returned before hitting YieldOp, due to hitting some exception during the forward subgraph execution
|
||||
// In this case, we need to wake up the foreground thread and pass along the failed status.
|
||||
// Otherwise, foreground thread will be stuck waiting for forward_outputs.
|
||||
if (onnxruntime::contrib::OrtTasks::GetInstance().ForwardOutputsIsValid()) {
|
||||
ORT_ENFORCE(!status.IsOK());
|
||||
// signal main thread for background thread completion
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().SetForwardOutputs(status, {});
|
||||
}
|
||||
},
|
||||
std::move(setup_future), std::cref(run_options), std::ref(io_binding));
|
||||
|
||||
run_id = std::hash<std::thread::id>()(bg_thread.get_id());
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
|
||||
bg_threads_[run_id] = std::move(bg_thread);
|
||||
}
|
||||
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().CreateBackgroundTask(run_id);
|
||||
|
||||
LOGS(*session_logger_, VERBOSE) << "InferenceSession::Forward() call created a task with run_id " << run_id;
|
||||
|
||||
// background task is setup, unblock background thread to continue
|
||||
setup_promise.set_value();
|
||||
|
||||
// Wait for data/signal from
|
||||
// 1. Yield op, if the bg thread sucessfully reached Yield's signal point
|
||||
// 2. The end of bg thread, if it hit execptions and returned earlier
|
||||
auto forward_outputs = onnxruntime::contrib::OrtTasks::GetInstance().WaitForForwardOutputs(run_id);
|
||||
const Status& forward_status = forward_outputs.first;
|
||||
user_outputs = std::move(forward_outputs.second);
|
||||
|
||||
// background thread has completed without hitting Yield Op
|
||||
if (!forward_status.IsOK()) {
|
||||
std::thread thread;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
|
||||
std::swap(thread, bg_threads_[run_id]);
|
||||
bg_threads_.erase(run_id);
|
||||
}
|
||||
ORT_ENFORCE(thread.joinable());
|
||||
thread.join();
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
|
||||
return forward_status;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status InferenceSession::ContinueRunInBackground(int64_t run_id, const std::vector<OrtValue>& backward_output_grads) {
|
||||
LOGS(*session_logger_, VERBOSE) << "Running InferenceSession::Backward() with run_id " << run_id;
|
||||
|
||||
// resume background thread
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, backward_output_grads, false);
|
||||
|
||||
Status bg_thread_status = onnxruntime::contrib::OrtTasks::GetInstance().WaitForStatus(run_id);
|
||||
|
||||
std::thread bg_thread;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
|
||||
std::swap(bg_thread, bg_threads_[run_id]);
|
||||
bg_threads_.erase(run_id);
|
||||
}
|
||||
|
||||
// wait for bg_thread to complete
|
||||
ORT_ENFORCE(bg_thread.joinable());
|
||||
bg_thread.join();
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
|
||||
|
||||
return bg_thread_status;
|
||||
}
|
||||
|
||||
void InferenceSession::CancelBackgroundTask(int64_t run_id) {
|
||||
LOGS(*session_logger_, WARNING) << "Canceling background task with run_id " << run_id;
|
||||
|
||||
// resume background thread with terminate = true
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, {}, true);
|
||||
|
||||
// wait for bg_thread to complete
|
||||
std::thread bg_thread;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
|
||||
std::swap(bg_thread, bg_threads_[run_id]);
|
||||
bg_threads_.erase(run_id);
|
||||
}
|
||||
ORT_ENFORCE(bg_thread.joinable());
|
||||
bg_thread.join();
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
void InferenceSession::StartProfiling(const std::basic_string<T>& file_prefix) {
|
||||
std::basic_ostringstream<T> ss;
|
||||
|
|
|
|||
|
|
@ -6,11 +6,6 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
#include <thread>
|
||||
#include <future>
|
||||
#endif
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/profiler.h"
|
||||
|
|
@ -302,17 +297,6 @@ class InferenceSession {
|
|||
virtual common::Status Run(const RunOptions& run_options, IOBinding& io_binding) ORT_MUST_USE_RESULT;
|
||||
common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT;
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
// For ORTModule.forward()
|
||||
virtual common::Status RunInBackgroundAndWaitForYield(const RunOptions& run_options, IOBinding& io_binding,
|
||||
std::vector<OrtValue>& user_outputs,
|
||||
int64_t& run_id) ORT_MUST_USE_RESULT;
|
||||
|
||||
// For ORTModule.backward()
|
||||
common::Status ContinueRunInBackground(int64_t run_id, const std::vector<OrtValue>& backward_output_grads) ORT_MUST_USE_RESULT;
|
||||
|
||||
void CancelBackgroundTask(int64_t run_id);
|
||||
#endif
|
||||
/**
|
||||
* @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK.
|
||||
* @note lifetime of the returned pointer is valid as long as the Session object is live.
|
||||
|
|
@ -677,14 +661,6 @@ class InferenceSession {
|
|||
// those into new OrtValue instances, at which point we won't free them until the InferenceSession goes away.
|
||||
std::vector<uint8_t> ort_format_model_bytes_;
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
// mutex for accessing bg_threads_
|
||||
std::mutex bg_threads_mutex_;
|
||||
|
||||
// background threads for RunInBackgroundAndWaitForYield and ContinueRunInBackground
|
||||
std::unordered_map<int64_t, std::thread> bg_threads_;
|
||||
#endif
|
||||
|
||||
std::shared_ptr<onnxruntime::AllocatorManager> allocator_manager_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -228,22 +228,6 @@ class Session:
|
|||
"""
|
||||
self._sess.run_with_iobinding(iobinding._iobinding, run_options)
|
||||
|
||||
def run_forward(self, iobinding, run_options):
|
||||
"""
|
||||
Compute the forward subgraph until it hits the Yield Op.
|
||||
:param iobinding: the iobinding object that has graph inputs/outputs bind.
|
||||
:param run_options: See :class:`onnxruntime.RunOptions`.
|
||||
"""
|
||||
ortvalues, run_id = self._sess.run_forward(iobinding._iobinding, run_options)
|
||||
return [OrtValue(ortvalue) for ortvalue in ortvalues], run_id
|
||||
|
||||
def run_backward(self, backward_output_grads, run_id):
|
||||
"""
|
||||
Resume executing the backward subgraph starting from Yield Op.
|
||||
:param backward_output_grads: Output gradients for backward.
|
||||
"""
|
||||
self._sess.run_backward([ortvalue._ortvalue for ortvalue in backward_output_grads], run_id)
|
||||
|
||||
|
||||
class InferenceSession(Session):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1835,22 +1835,6 @@ including arg name, arg type (contains both type and shape).)pbdoc")
|
|||
if (!status.IsOK())
|
||||
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
|
||||
})
|
||||
#ifdef ENABLE_TRAINING
|
||||
.def("run_forward", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions& run_options) -> py::tuple {
|
||||
std::vector<OrtValue> module_outputs;
|
||||
int64_t run_id;
|
||||
Status status = sess->GetSessionHandle()->RunInBackgroundAndWaitForYield(run_options, *io_binding.Get(), module_outputs, run_id);
|
||||
if (!status.IsOK()) {
|
||||
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
|
||||
}
|
||||
return py::make_tuple(module_outputs, run_id);
|
||||
})
|
||||
.def("run_backward", [](PyInferenceSession* sess, const std::vector<OrtValue>& backward_output_grads, int64_t run_id) -> void {
|
||||
Status status = sess->GetSessionHandle()->ContinueRunInBackground(run_id, backward_output_grads);
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
|
||||
})
|
||||
#endif
|
||||
;
|
||||
|
||||
py::enum_<onnxruntime::ArenaExtendStrategy>(m, "ArenaExtendStrategy", py::arithmetic())
|
||||
|
|
|
|||
136
orttraining/orttraining/core/agent/training_agent.cc
Normal file
136
orttraining/orttraining/core/agent/training_agent.cc
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/core/agent/training_agent.h"
|
||||
#include "core/session/IOBinding.h"
|
||||
#include "orttraining/training_ops/cpu/controlflow/ort_tasks.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
TrainingAgent::TrainingAgent(InferenceSession *session) : inference_session_(session) {}
|
||||
|
||||
TrainingAgent::~TrainingAgent() {
|
||||
// TODO: Properly cancel outstanding background tasks
|
||||
// Following implementation only handle the case where bg_thread is waiting for backward inputs
|
||||
// Background thread can also be in other states, such as running Forward() or running Backward()
|
||||
std::vector<int64_t> run_ids;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
|
||||
for (auto it = bg_threads_.begin(); it != bg_threads_.end(); ++it) {
|
||||
run_ids.push_back(it->first);
|
||||
}
|
||||
}
|
||||
for (int64_t run_id : run_ids) {
|
||||
if (!onnxruntime::contrib::OrtTasks::GetInstance().TaskIsCompleted(run_id)) {
|
||||
CancelPendingBackwardRun(run_id);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
common::Status TrainingAgent::RunForward(const RunOptions& run_options, onnxruntime::IOBinding& io_binding,
|
||||
std::vector<OrtValue>& user_outputs, int64_t& run_id) {
|
||||
std::promise<void> setup_promise;
|
||||
std::future<void> setup_future = setup_promise.get_future();
|
||||
|
||||
// Passing run_options and io_binding by reference to the bg_thread,
|
||||
// this is ok because they are ORTModule's member, and they are presistent through forward and backward calls
|
||||
auto bg_thread = std::thread([this](std::future<void> setup_future, const RunOptions& run_options, onnxruntime::IOBinding& io_binding) {
|
||||
// wait until task is properly setup
|
||||
setup_future.get();
|
||||
|
||||
common::Status status = inference_session_->Run(run_options, io_binding);
|
||||
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().SetStatus(status);
|
||||
|
||||
// If forward outputs still hasn't been consumed at this point, i.e. forward function hasn't complete itself
|
||||
// this indicates that Run() call returned before hitting YieldOp, due to hitting some exception during the forward subgraph execution
|
||||
// In this case, we need to wake up the foreground thread and pass along the failed status.
|
||||
// Otherwise, foreground thread will be stuck waiting for forward_outputs.
|
||||
if (onnxruntime::contrib::OrtTasks::GetInstance().ForwardOutputsIsValid()) {
|
||||
ORT_ENFORCE(!status.IsOK());
|
||||
// signal main thread for background thread completion
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().SetForwardOutputs(status, {});
|
||||
}
|
||||
}, std::move(setup_future), std::cref(run_options), std::ref(io_binding));
|
||||
|
||||
run_id = std::hash<std::thread::id>()(bg_thread.get_id());
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
|
||||
bg_threads_[run_id] = std::move(bg_thread);
|
||||
}
|
||||
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().CreateBackgroundTask(run_id);
|
||||
|
||||
LOGS(*inference_session_->GetLogger(), VERBOSE) << "InferenceSession::Forward() call created a task with run_id " << run_id;
|
||||
|
||||
// background task is setup, unblock background thread to continue
|
||||
setup_promise.set_value();
|
||||
|
||||
// Wait for data/signal from
|
||||
// 1. Yield op, if the bg thread sucessfully reached Yield's signal point
|
||||
// 2. The end of bg thread, if it hit execptions and returned earlier
|
||||
auto forward_outputs = onnxruntime::contrib::OrtTasks::GetInstance().WaitForForwardOutputs(run_id);
|
||||
const Status& forward_status = forward_outputs.first;
|
||||
user_outputs = std::move(forward_outputs.second);
|
||||
|
||||
// background thread has completed without hitting Yield Op
|
||||
if (!forward_status.IsOK()) {
|
||||
std::thread thread;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
|
||||
std::swap(thread, bg_threads_[run_id]);
|
||||
bg_threads_.erase(run_id);
|
||||
}
|
||||
ORT_ENFORCE(thread.joinable());
|
||||
thread.join();
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
|
||||
return forward_status;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status TrainingAgent::RunBackward(int64_t run_id, const std::vector<OrtValue>& backward_output_grads) {
|
||||
LOGS(*inference_session_->GetLogger(), VERBOSE) << "Running TrainingAgent::Backward() with run_id " << run_id;
|
||||
|
||||
// resume background thread
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, backward_output_grads, false);
|
||||
|
||||
Status bg_thread_status = onnxruntime::contrib::OrtTasks::GetInstance().WaitForStatus(run_id);
|
||||
|
||||
std::thread bg_thread;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
|
||||
std::swap(bg_thread, bg_threads_[run_id]);
|
||||
bg_threads_.erase(run_id);
|
||||
}
|
||||
|
||||
// wait for bg_thread to complete
|
||||
ORT_ENFORCE(bg_thread.joinable());
|
||||
bg_thread.join();
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
|
||||
|
||||
return bg_thread_status;
|
||||
}
|
||||
|
||||
void TrainingAgent::CancelPendingBackwardRun(int64_t run_id) {
|
||||
LOGS(*inference_session_->GetLogger(), WARNING) << "Canceling background task with run_id " << run_id;
|
||||
|
||||
// resume background thread with terminate = true
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, {}, true);
|
||||
|
||||
// wait for bg_thread to complete
|
||||
std::thread bg_thread;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
|
||||
std::swap(bg_thread, bg_threads_[run_id]);
|
||||
bg_threads_.erase(run_id);
|
||||
}
|
||||
ORT_ENFORCE(bg_thread.joinable());
|
||||
bg_thread.join();
|
||||
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
44
orttraining/orttraining/core/agent/training_agent.h
Normal file
44
orttraining/orttraining/core/agent/training_agent.h
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <thread>
|
||||
#include <future>
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/framework_common.h"
|
||||
#include "core/session/IOBinding.h"
|
||||
#include "core/session/inference_session.h"
|
||||
#include "orttraining/training_ops/cpu/controlflow/ort_tasks.h"
|
||||
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
class IOBinding;
|
||||
|
||||
class TrainingAgent {
|
||||
|
||||
public:
|
||||
explicit TrainingAgent(InferenceSession* session);
|
||||
virtual ~TrainingAgent();
|
||||
// For ORTModule.forward()
|
||||
virtual common::Status RunForward(const RunOptions& run_options, onnxruntime::IOBinding& io_binding,
|
||||
std::vector<OrtValue>& user_outputs,
|
||||
int64_t& run_id) ORT_MUST_USE_RESULT;
|
||||
// For ORTModule.backward()
|
||||
common::Status RunBackward(int64_t run_id, const std::vector<OrtValue>& backward_output_grads) ORT_MUST_USE_RESULT;
|
||||
void CancelPendingBackwardRun(int64_t run_id);
|
||||
|
||||
private:
|
||||
// mutex for accessing bg_threads_
|
||||
std::mutex bg_threads_mutex_;
|
||||
// background threads for RunInBackgroundAndWaitForYield and ContinueRunInBackground
|
||||
std::unordered_map<int64_t, std::thread> bg_threads_;
|
||||
// TrainingAgent runs on a InferenceSession under the hood
|
||||
InferenceSession* inference_session_;
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
#include "core/session/environment.h"
|
||||
#include "orttraining/core/session/training_session.h"
|
||||
#include "orttraining/core/agent/training_agent.h"
|
||||
#include "orttraining/core/graph/optimizer_config.h"
|
||||
#include "orttraining/core/framework/communication/mpi/mpi_context.h"
|
||||
#include "orttraining/core/framework/module_gradient_graph_builder.h"
|
||||
|
|
@ -474,6 +475,28 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
return static_cast<PipelineTrainingSession*>(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name);
|
||||
});
|
||||
|
||||
py::class_<TrainingAgent>(m, "TrainingAgent", R"pbdoc(This is the main class used to run a ORTModule model.)pbdoc")
|
||||
// In Python3, a Python bytes object will be passed to C++ functions that accept std::string or char*
|
||||
// without any conversion. So this init method can be used for model file path (string) and model content (bytes)
|
||||
.def(py::init([](PyInferenceSession * session) {
|
||||
return onnxruntime::make_unique<TrainingAgent>(session->GetSessionHandle());
|
||||
}))
|
||||
.def("run_forward", [](TrainingAgent* agent, SessionIOBinding& io_binding, RunOptions& run_options) -> py::tuple {
|
||||
std::vector<OrtValue> module_outputs;
|
||||
int64_t run_id;
|
||||
Status status = agent->RunForward(run_options, *io_binding.Get(), module_outputs, run_id);
|
||||
if (!status.IsOK()) {
|
||||
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
|
||||
}
|
||||
return py::make_tuple(module_outputs, run_id);
|
||||
})
|
||||
.def("run_backward", [](TrainingAgent* agent, const std::vector<OrtValue>& backward_output_grads, int64_t run_id) -> void {
|
||||
Status status = agent->RunBackward(run_id, backward_output_grads);
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
|
||||
})
|
||||
;
|
||||
|
||||
py::class_<ModuleGradientGraphBuilderConfiguration> module_gradient_graph_builder_config(
|
||||
m, "ModuleGradientGraphBuilderConfiguration",
|
||||
R"pbdoc(Configuration information for module gradient graph builder.)pbdoc");
|
||||
|
|
@ -518,5 +541,6 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
return module_gradient_graph_builder->GetTrainingGraphInfo();
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -8,5 +8,5 @@ from onnxruntime.capi.training.training_session import TrainingSession
|
|||
from .orttrainer_options import ORTTrainerOptions
|
||||
from .orttrainer import ORTTrainer, TrainStepInfo
|
||||
from . import amp, checkpoint, optim, model_desc_validation
|
||||
|
||||
from .ortmodule import ORTModule
|
||||
from .training_agent import TrainingAgent
|
||||
from .ortmodule import ORTModule
|
||||
|
|
|
|||
|
|
@ -297,9 +297,9 @@ class ORTModule(torch.nn.Module):
|
|||
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
|
||||
session_options.log_severity_level = 2
|
||||
|
||||
self._training_session = onnxruntime.InferenceSession(
|
||||
self._onnx_training.SerializeToString(), session_options, providers=providers, provider_options=provider_options)
|
||||
|
||||
self._training_session = onnxruntime.training.TrainingAgent(self._onnx_training.SerializeToString(),
|
||||
session_options, providers, provider_options)
|
||||
|
||||
# Use this global run_options for now
|
||||
self._run_options = C.RunOptions()
|
||||
|
||||
|
|
|
|||
73
orttraining/orttraining/python/training/training_agent.py
Normal file
73
orttraining/orttraining/python/training/training_agent.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.capi.onnxruntime_inference_collection import IOBinding, OrtValue
|
||||
from onnxruntime.capi._pybind_state import TrainingAgent as C_TrainingAgent
|
||||
|
||||
|
||||
class TrainingAgent(object):
|
||||
"""
|
||||
This is the main class used to run a ORTModule model.
|
||||
"""
|
||||
|
||||
def __init__(self, path_or_bytes, session_options=None, providers=None, provider_options=None):
|
||||
"""
|
||||
:param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string
|
||||
:param sess_options: session options
|
||||
:param providers: Optional sequence of providers in order of decreasing
|
||||
precedence. Values can either be provider names or tuples of
|
||||
(provider name, options dict). If not provided, then all available
|
||||
providers are used with the default precedence.
|
||||
:param provider_options: Optional sequence of options dicts corresponding
|
||||
to the providers listed in 'providers'.
|
||||
|
||||
The model type will be inferred unless explicitly set in the SessionOptions.
|
||||
To explicitly set:
|
||||
so = onnxruntime.SessionOptions()
|
||||
so.add_session_config_entry('session.load_model_format', 'ONNX') or
|
||||
so.add_session_config_entry('session.load_model_format', 'ORT') or
|
||||
|
||||
A file extension of '.ort' will be inferred as an ORT format model.
|
||||
All other filenames are assumed to be ONNX format models.
|
||||
|
||||
'providers' can contain either names or names and options. When any options
|
||||
are given in 'providers', 'provider_options' should not be used.
|
||||
|
||||
The list of providers is ordered by precedence. For example ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.
|
||||
"""
|
||||
|
||||
self._training_agent = None
|
||||
self._inference_session = None
|
||||
|
||||
self.create_training_agent(path_or_bytes, session_options, providers, provider_options)
|
||||
|
||||
|
||||
def create_training_agent(self, path_or_bytes, session_options, providers, provider_options):
|
||||
self._inference_session = onnxruntime.InferenceSession(path_or_bytes, session_options,
|
||||
providers, provider_options)
|
||||
self._training_agent = C_TrainingAgent(self._inference_session._sess)
|
||||
|
||||
def io_binding(self):
|
||||
"Return an onnxruntime.IOBinding object`."
|
||||
return IOBinding(self._inference_session)
|
||||
|
||||
def run_forward(self, iobinding, run_options):
|
||||
"""
|
||||
Compute the forward subgraph until it hits the Yield Op.
|
||||
:param iobinding: the iobinding object that has graph inputs/outputs bind.
|
||||
:param run_options: See :class:`onnxruntime.RunOptions`.
|
||||
"""
|
||||
ortvalues, run_id = self._training_agent.run_forward(iobinding._iobinding, run_options)
|
||||
return [OrtValue(ortvalue) for ortvalue in ortvalues], run_id
|
||||
|
||||
def run_backward(self, backward_output_grads, run_id):
|
||||
"""
|
||||
Resume executing the backward subgraph starting from Yield Op.
|
||||
:param backward_output_grads: Output gradients for backward.
|
||||
"""
|
||||
self._training_agent.run_backward([ortvalue._ortvalue for ortvalue in backward_output_grads], run_id)
|
||||
Loading…
Reference in a new issue