Introducing TrainingAgent interface to performance training using YieldOp (#6898)

This commit is contained in:
Thiago Crepaldi 2021-03-05 17:03:46 -08:00 committed by GitHub
parent 79f832c682
commit dfc7c18e31
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 285 additions and 192 deletions

View file

@ -13,7 +13,9 @@ file(GLOB_RECURSE onnxruntime_training_srcs
"${ORTTRAINING_SOURCE_DIR}/core/framework/communication/*"
"${ORTTRAINING_SOURCE_DIR}/core/session/*.h"
"${ORTTRAINING_SOURCE_DIR}/core/session/*.cc"
)
"${ORTTRAINING_SOURCE_DIR}/core/agent/*.h"
"${ORTTRAINING_SOURCE_DIR}/core/agent/*.cc"
)
add_library(onnxruntime_training ${onnxruntime_training_srcs})
add_dependencies(onnxruntime_training onnx tensorboard ${onnxruntime_EXTERNAL_DEPENDENCIES})

View file

@ -56,10 +56,6 @@
#include "core/session/custom_ops.h"
#endif
#ifdef ENABLE_TRAINING
#include "orttraining/training_ops/cpu/controlflow/ort_tasks.h"
#endif
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::experimental;
using namespace onnxruntime::common;
@ -379,24 +375,6 @@ InferenceSession::~InferenceSession() {
}
}
#ifdef ENABLE_TRAINING
// TODO: Properly cancel outstanding background tasks
// Following implementation only handle the case where bg_thread is waiting for backward inputs
// Background thread can also be in other states, such as running Forward() or running Backward()
std::vector<int64_t> run_ids;
{
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
for (auto it = bg_threads_.begin(); it != bg_threads_.end(); ++it) {
run_ids.push_back(it->first);
}
}
for (int64_t run_id : run_ids) {
if (!onnxruntime::contrib::OrtTasks::GetInstance().TaskIsCompleted(run_id)) {
CancelBackgroundTask(run_id);
}
}
#endif
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
if (session_activity_started_)
TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity");
@ -1739,114 +1717,6 @@ common::Status InferenceSession::Run(IOBinding& io_binding) {
return Run(run_options, io_binding);
}
#ifdef ENABLE_TRAINING
common::Status InferenceSession::RunInBackgroundAndWaitForYield(const RunOptions& run_options, IOBinding& io_binding,
std::vector<OrtValue>& user_outputs, int64_t& run_id) {
std::promise<void> setup_promise;
std::future<void> setup_future = setup_promise.get_future();
// Passing run_options and io_binding by reference to the bg_thread,
// this is ok because they are ORTModule's member, and they are presistent through forward and backward calls
auto bg_thread = std::thread([this](std::future<void> setup_future, const RunOptions& run_options, IOBinding& io_binding) {
// wait until task is properly setup
setup_future.get();
common::Status status = Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(),
&io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo());
onnxruntime::contrib::OrtTasks::GetInstance().SetStatus(status);
// If forward outputs still hasn't been consumed at this point, i.e. forward function hasn't complete itself
// this indicates that Run() call returned before hitting YieldOp, due to hitting some exception during the forward subgraph execution
// In this case, we need to wake up the foreground thread and pass along the failed status.
// Otherwise, foreground thread will be stuck waiting for forward_outputs.
if (onnxruntime::contrib::OrtTasks::GetInstance().ForwardOutputsIsValid()) {
ORT_ENFORCE(!status.IsOK());
// signal main thread for background thread completion
onnxruntime::contrib::OrtTasks::GetInstance().SetForwardOutputs(status, {});
}
},
std::move(setup_future), std::cref(run_options), std::ref(io_binding));
run_id = std::hash<std::thread::id>()(bg_thread.get_id());
{
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
bg_threads_[run_id] = std::move(bg_thread);
}
onnxruntime::contrib::OrtTasks::GetInstance().CreateBackgroundTask(run_id);
LOGS(*session_logger_, VERBOSE) << "InferenceSession::Forward() call created a task with run_id " << run_id;
// background task is setup, unblock background thread to continue
setup_promise.set_value();
// Wait for data/signal from
// 1. Yield op, if the bg thread sucessfully reached Yield's signal point
// 2. The end of bg thread, if it hit execptions and returned earlier
auto forward_outputs = onnxruntime::contrib::OrtTasks::GetInstance().WaitForForwardOutputs(run_id);
const Status& forward_status = forward_outputs.first;
user_outputs = std::move(forward_outputs.second);
// background thread has completed without hitting Yield Op
if (!forward_status.IsOK()) {
std::thread thread;
{
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
std::swap(thread, bg_threads_[run_id]);
bg_threads_.erase(run_id);
}
ORT_ENFORCE(thread.joinable());
thread.join();
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
return forward_status;
}
return Status::OK();
}
common::Status InferenceSession::ContinueRunInBackground(int64_t run_id, const std::vector<OrtValue>& backward_output_grads) {
LOGS(*session_logger_, VERBOSE) << "Running InferenceSession::Backward() with run_id " << run_id;
// resume background thread
onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, backward_output_grads, false);
Status bg_thread_status = onnxruntime::contrib::OrtTasks::GetInstance().WaitForStatus(run_id);
std::thread bg_thread;
{
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
std::swap(bg_thread, bg_threads_[run_id]);
bg_threads_.erase(run_id);
}
// wait for bg_thread to complete
ORT_ENFORCE(bg_thread.joinable());
bg_thread.join();
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
return bg_thread_status;
}
void InferenceSession::CancelBackgroundTask(int64_t run_id) {
LOGS(*session_logger_, WARNING) << "Canceling background task with run_id " << run_id;
// resume background thread with terminate = true
onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, {}, true);
// wait for bg_thread to complete
std::thread bg_thread;
{
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
std::swap(bg_thread, bg_threads_[run_id]);
bg_threads_.erase(run_id);
}
ORT_ENFORCE(bg_thread.joinable());
bg_thread.join();
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
}
#endif
template <typename T>
void InferenceSession::StartProfiling(const std::basic_string<T>& file_prefix) {
std::basic_ostringstream<T> ss;

View file

@ -6,11 +6,6 @@
#include <string>
#include <unordered_map>
#ifdef ENABLE_TRAINING
#include <thread>
#include <future>
#endif
#include "core/common/common.h"
#include "core/common/logging/logging.h"
#include "core/common/profiler.h"
@ -302,17 +297,6 @@ class InferenceSession {
virtual common::Status Run(const RunOptions& run_options, IOBinding& io_binding) ORT_MUST_USE_RESULT;
common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT;
#ifdef ENABLE_TRAINING
// For ORTModule.forward()
virtual common::Status RunInBackgroundAndWaitForYield(const RunOptions& run_options, IOBinding& io_binding,
std::vector<OrtValue>& user_outputs,
int64_t& run_id) ORT_MUST_USE_RESULT;
// For ORTModule.backward()
common::Status ContinueRunInBackground(int64_t run_id, const std::vector<OrtValue>& backward_output_grads) ORT_MUST_USE_RESULT;
void CancelBackgroundTask(int64_t run_id);
#endif
/**
* @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK.
* @note lifetime of the returned pointer is valid as long as the Session object is live.
@ -677,14 +661,6 @@ class InferenceSession {
// those into new OrtValue instances, at which point we won't free them until the InferenceSession goes away.
std::vector<uint8_t> ort_format_model_bytes_;
#ifdef ENABLE_TRAINING
// mutex for accessing bg_threads_
std::mutex bg_threads_mutex_;
// background threads for RunInBackgroundAndWaitForYield and ContinueRunInBackground
std::unordered_map<int64_t, std::thread> bg_threads_;
#endif
std::shared_ptr<onnxruntime::AllocatorManager> allocator_manager_;
};

View file

@ -228,22 +228,6 @@ class Session:
"""
self._sess.run_with_iobinding(iobinding._iobinding, run_options)
def run_forward(self, iobinding, run_options):
"""
Compute the forward subgraph until it hits the Yield Op.
:param iobinding: the iobinding object that has graph inputs/outputs bind.
:param run_options: See :class:`onnxruntime.RunOptions`.
"""
ortvalues, run_id = self._sess.run_forward(iobinding._iobinding, run_options)
return [OrtValue(ortvalue) for ortvalue in ortvalues], run_id
def run_backward(self, backward_output_grads, run_id):
"""
Resume executing the backward subgraph starting from Yield Op.
:param backward_output_grads: Output gradients for backward.
"""
self._sess.run_backward([ortvalue._ortvalue for ortvalue in backward_output_grads], run_id)
class InferenceSession(Session):
"""

View file

@ -1835,22 +1835,6 @@ including arg name, arg type (contains both type and shape).)pbdoc")
if (!status.IsOK())
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
})
#ifdef ENABLE_TRAINING
.def("run_forward", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions& run_options) -> py::tuple {
std::vector<OrtValue> module_outputs;
int64_t run_id;
Status status = sess->GetSessionHandle()->RunInBackgroundAndWaitForYield(run_options, *io_binding.Get(), module_outputs, run_id);
if (!status.IsOK()) {
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
}
return py::make_tuple(module_outputs, run_id);
})
.def("run_backward", [](PyInferenceSession* sess, const std::vector<OrtValue>& backward_output_grads, int64_t run_id) -> void {
Status status = sess->GetSessionHandle()->ContinueRunInBackground(run_id, backward_output_grads);
if (!status.IsOK())
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
})
#endif
;
py::enum_<onnxruntime::ArenaExtendStrategy>(m, "ArenaExtendStrategy", py::arithmetic())

View file

@ -0,0 +1,136 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/core/agent/training_agent.h"
#include "core/session/IOBinding.h"
#include "orttraining/training_ops/cpu/controlflow/ort_tasks.h"
namespace onnxruntime {
namespace training {
TrainingAgent::TrainingAgent(InferenceSession *session) : inference_session_(session) {}
TrainingAgent::~TrainingAgent() {
// TODO: Properly cancel outstanding background tasks
// Following implementation only handle the case where bg_thread is waiting for backward inputs
// Background thread can also be in other states, such as running Forward() or running Backward()
std::vector<int64_t> run_ids;
{
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
for (auto it = bg_threads_.begin(); it != bg_threads_.end(); ++it) {
run_ids.push_back(it->first);
}
}
for (int64_t run_id : run_ids) {
if (!onnxruntime::contrib::OrtTasks::GetInstance().TaskIsCompleted(run_id)) {
CancelPendingBackwardRun(run_id);
}
}
};
common::Status TrainingAgent::RunForward(const RunOptions& run_options, onnxruntime::IOBinding& io_binding,
std::vector<OrtValue>& user_outputs, int64_t& run_id) {
std::promise<void> setup_promise;
std::future<void> setup_future = setup_promise.get_future();
// Passing run_options and io_binding by reference to the bg_thread,
// this is ok because they are ORTModule's member, and they are presistent through forward and backward calls
auto bg_thread = std::thread([this](std::future<void> setup_future, const RunOptions& run_options, onnxruntime::IOBinding& io_binding) {
// wait until task is properly setup
setup_future.get();
common::Status status = inference_session_->Run(run_options, io_binding);
onnxruntime::contrib::OrtTasks::GetInstance().SetStatus(status);
// If forward outputs still hasn't been consumed at this point, i.e. forward function hasn't complete itself
// this indicates that Run() call returned before hitting YieldOp, due to hitting some exception during the forward subgraph execution
// In this case, we need to wake up the foreground thread and pass along the failed status.
// Otherwise, foreground thread will be stuck waiting for forward_outputs.
if (onnxruntime::contrib::OrtTasks::GetInstance().ForwardOutputsIsValid()) {
ORT_ENFORCE(!status.IsOK());
// signal main thread for background thread completion
onnxruntime::contrib::OrtTasks::GetInstance().SetForwardOutputs(status, {});
}
}, std::move(setup_future), std::cref(run_options), std::ref(io_binding));
run_id = std::hash<std::thread::id>()(bg_thread.get_id());
{
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
bg_threads_[run_id] = std::move(bg_thread);
}
onnxruntime::contrib::OrtTasks::GetInstance().CreateBackgroundTask(run_id);
LOGS(*inference_session_->GetLogger(), VERBOSE) << "InferenceSession::Forward() call created a task with run_id " << run_id;
// background task is setup, unblock background thread to continue
setup_promise.set_value();
// Wait for data/signal from
// 1. Yield op, if the bg thread sucessfully reached Yield's signal point
// 2. The end of bg thread, if it hit execptions and returned earlier
auto forward_outputs = onnxruntime::contrib::OrtTasks::GetInstance().WaitForForwardOutputs(run_id);
const Status& forward_status = forward_outputs.first;
user_outputs = std::move(forward_outputs.second);
// background thread has completed without hitting Yield Op
if (!forward_status.IsOK()) {
std::thread thread;
{
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
std::swap(thread, bg_threads_[run_id]);
bg_threads_.erase(run_id);
}
ORT_ENFORCE(thread.joinable());
thread.join();
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
return forward_status;
}
return Status::OK();
}
common::Status TrainingAgent::RunBackward(int64_t run_id, const std::vector<OrtValue>& backward_output_grads) {
LOGS(*inference_session_->GetLogger(), VERBOSE) << "Running TrainingAgent::Backward() with run_id " << run_id;
// resume background thread
onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, backward_output_grads, false);
Status bg_thread_status = onnxruntime::contrib::OrtTasks::GetInstance().WaitForStatus(run_id);
std::thread bg_thread;
{
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
std::swap(bg_thread, bg_threads_[run_id]);
bg_threads_.erase(run_id);
}
// wait for bg_thread to complete
ORT_ENFORCE(bg_thread.joinable());
bg_thread.join();
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
return bg_thread_status;
}
void TrainingAgent::CancelPendingBackwardRun(int64_t run_id) {
LOGS(*inference_session_->GetLogger(), WARNING) << "Canceling background task with run_id " << run_id;
// resume background thread with terminate = true
onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, {}, true);
// wait for bg_thread to complete
std::thread bg_thread;
{
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
std::swap(bg_thread, bg_threads_[run_id]);
bg_threads_.erase(run_id);
}
ORT_ENFORCE(bg_thread.joinable());
bg_thread.join();
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
}
} // namespace training
} // namespace onnxruntime

View file

@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <thread>
#include <future>
#include "core/common/common.h"
#include "core/common/logging/logging.h"
#include "core/framework/framework_common.h"
#include "core/session/IOBinding.h"
#include "core/session/inference_session.h"
#include "orttraining/training_ops/cpu/controlflow/ort_tasks.h"
namespace onnxruntime {
namespace training {
class IOBinding;
class TrainingAgent {
public:
explicit TrainingAgent(InferenceSession* session);
virtual ~TrainingAgent();
// For ORTModule.forward()
virtual common::Status RunForward(const RunOptions& run_options, onnxruntime::IOBinding& io_binding,
std::vector<OrtValue>& user_outputs,
int64_t& run_id) ORT_MUST_USE_RESULT;
// For ORTModule.backward()
common::Status RunBackward(int64_t run_id, const std::vector<OrtValue>& backward_output_grads) ORT_MUST_USE_RESULT;
void CancelPendingBackwardRun(int64_t run_id);
private:
// mutex for accessing bg_threads_
std::mutex bg_threads_mutex_;
// background threads for RunInBackgroundAndWaitForYield and ContinueRunInBackground
std::unordered_map<int64_t, std::thread> bg_threads_;
// TrainingAgent runs on a InferenceSession under the hood
InferenceSession* inference_session_;
};
} // namespace training
} // namespace onnxruntime

View file

@ -9,6 +9,7 @@
#include "core/session/environment.h"
#include "orttraining/core/session/training_session.h"
#include "orttraining/core/agent/training_agent.h"
#include "orttraining/core/graph/optimizer_config.h"
#include "orttraining/core/framework/communication/mpi/mpi_context.h"
#include "orttraining/core/framework/module_gradient_graph_builder.h"
@ -474,6 +475,28 @@ void addObjectMethodsForTraining(py::module& m) {
return static_cast<PipelineTrainingSession*>(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name);
});
py::class_<TrainingAgent>(m, "TrainingAgent", R"pbdoc(This is the main class used to run a ORTModule model.)pbdoc")
// In Python3, a Python bytes object will be passed to C++ functions that accept std::string or char*
// without any conversion. So this init method can be used for model file path (string) and model content (bytes)
.def(py::init([](PyInferenceSession * session) {
return onnxruntime::make_unique<TrainingAgent>(session->GetSessionHandle());
}))
.def("run_forward", [](TrainingAgent* agent, SessionIOBinding& io_binding, RunOptions& run_options) -> py::tuple {
std::vector<OrtValue> module_outputs;
int64_t run_id;
Status status = agent->RunForward(run_options, *io_binding.Get(), module_outputs, run_id);
if (!status.IsOK()) {
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
}
return py::make_tuple(module_outputs, run_id);
})
.def("run_backward", [](TrainingAgent* agent, const std::vector<OrtValue>& backward_output_grads, int64_t run_id) -> void {
Status status = agent->RunBackward(run_id, backward_output_grads);
if (!status.IsOK())
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
})
;
py::class_<ModuleGradientGraphBuilderConfiguration> module_gradient_graph_builder_config(
m, "ModuleGradientGraphBuilderConfiguration",
R"pbdoc(Configuration information for module gradient graph builder.)pbdoc");
@ -518,5 +541,6 @@ void addObjectMethodsForTraining(py::module& m) {
return module_gradient_graph_builder->GetTrainingGraphInfo();
});
}
} // namespace python
} // namespace onnxruntime

View file

@ -8,5 +8,5 @@ from onnxruntime.capi.training.training_session import TrainingSession
from .orttrainer_options import ORTTrainerOptions
from .orttrainer import ORTTrainer, TrainStepInfo
from . import amp, checkpoint, optim, model_desc_validation
from .ortmodule import ORTModule
from .training_agent import TrainingAgent
from .ortmodule import ORTModule

View file

@ -297,9 +297,9 @@ class ORTModule(torch.nn.Module):
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
session_options.log_severity_level = 2
self._training_session = onnxruntime.InferenceSession(
self._onnx_training.SerializeToString(), session_options, providers=providers, provider_options=provider_options)
self._training_session = onnxruntime.training.TrainingAgent(self._onnx_training.SerializeToString(),
session_options, providers, provider_options)
# Use this global run_options for now
self._run_options = C.RunOptions()

View file

@ -0,0 +1,73 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import onnxruntime
from onnxruntime.capi import _pybind_state as C
from onnxruntime.capi.onnxruntime_inference_collection import IOBinding, OrtValue
from onnxruntime.capi._pybind_state import TrainingAgent as C_TrainingAgent
class TrainingAgent(object):
"""
This is the main class used to run a ORTModule model.
"""
def __init__(self, path_or_bytes, session_options=None, providers=None, provider_options=None):
"""
:param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string
:param sess_options: session options
:param providers: Optional sequence of providers in order of decreasing
precedence. Values can either be provider names or tuples of
(provider name, options dict). If not provided, then all available
providers are used with the default precedence.
:param provider_options: Optional sequence of options dicts corresponding
to the providers listed in 'providers'.
The model type will be inferred unless explicitly set in the SessionOptions.
To explicitly set:
so = onnxruntime.SessionOptions()
so.add_session_config_entry('session.load_model_format', 'ONNX') or
so.add_session_config_entry('session.load_model_format', 'ORT') or
A file extension of '.ort' will be inferred as an ORT format model.
All other filenames are assumed to be ONNX format models.
'providers' can contain either names or names and options. When any options
are given in 'providers', 'provider_options' should not be used.
The list of providers is ordered by precedence. For example ['CUDAExecutionProvider', 'CPUExecutionProvider']
means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider.
"""
self._training_agent = None
self._inference_session = None
self.create_training_agent(path_or_bytes, session_options, providers, provider_options)
def create_training_agent(self, path_or_bytes, session_options, providers, provider_options):
self._inference_session = onnxruntime.InferenceSession(path_or_bytes, session_options,
providers, provider_options)
self._training_agent = C_TrainingAgent(self._inference_session._sess)
def io_binding(self):
"Return an onnxruntime.IOBinding object`."
return IOBinding(self._inference_session)
def run_forward(self, iobinding, run_options):
"""
Compute the forward subgraph until it hits the Yield Op.
:param iobinding: the iobinding object that has graph inputs/outputs bind.
:param run_options: See :class:`onnxruntime.RunOptions`.
"""
ortvalues, run_id = self._training_agent.run_forward(iobinding._iobinding, run_options)
return [OrtValue(ortvalue) for ortvalue in ortvalues], run_id
def run_backward(self, backward_output_grads, run_id):
"""
Resume executing the backward subgraph starting from Yield Op.
:param backward_output_grads: Output gradients for backward.
"""
self._training_agent.run_backward([ortvalue._ortvalue for ortvalue in backward_output_grads], run_id)