Rewrite ORTModule background task coordination (#6700)

* Introduce OrtTasks to replace EventPool

* return run_id to frontend

* pass run_id to backward

* OrtTasks support multiple bg_events

* make message_queue a member of orttask

* Replace MessageQueue with std::promise

* Move status_promise into Task

* Move terminate flag into Task

* Reenable previously disabled UTs

* Add unit tests

* Replace condition variables with std::promise

* Move to CreateBackgroundTask in the main thread

* return status and output in forward_future

* use throw for terminating background thread

* cleanup tasks at destructor

* reenable test_mixed_nnmodule_ortmodules_training

* add mutex for ORTTasks functions

* add mutex for bg_threads

* delay tests before start

* add ut for multi-task common backbone

Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
Sherlock 2021-02-24 18:00:25 -08:00 committed by GitHub
parent 7ce4075bbd
commit 8e200e13fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 470 additions and 170 deletions

View file

@ -1,47 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <queue>
#include <vector>
#include "core/common/common.h"
#include "core/framework/ml_value.h"
namespace onnxruntime {
namespace contrib {
class OrtMessageQueue final {
public:
static OrtMessageQueue& GetInstance() {
static OrtMessageQueue instance_;
return instance_;
}
void Push(const OrtValue& ort_value) { ort_values.emplace(ort_value); }
OrtValue Pop() {
OrtValue ort_value = ort_values.front();
ort_values.pop();
return ort_value;
}
void PopAll(std::vector<OrtValue>& results) {
while (!ort_values.empty()) {
OrtValue ort_value = ort_values.front();
ort_values.pop();
results.emplace_back(ort_value);
}
}
private:
OrtMessageQueue() = default;
~OrtMessageQueue() = default;
OrtMessageQueue(const OrtMessageQueue&) = delete;
OrtMessageQueue& operator=(const OrtMessageQueue&) = delete;
std::queue<OrtValue> ort_values;
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -57,8 +57,7 @@
#endif
#ifdef ENABLE_TRAINING
#include "core/providers/cpu/controlflow/event_pool.h"
#include "core/providers/cpu/controlflow/message_queue.h"
#include "orttraining/training_ops/cpu/controlflow/ort_tasks.h"
#endif
using namespace ONNX_NAMESPACE;
@ -381,12 +380,20 @@ InferenceSession::~InferenceSession() {
}
#ifdef ENABLE_TRAINING
// TODO: find a better way to terminate the background thread
// backward is not completed yet, set terminate_flag to True
if (task_.bg_thread_future_.valid()) {
*(task_.terminate_flag_) = true;
Status s = ContinueRunInBackground({});
ORT_UNUSED_PARAMETER(s);
// TODO: Properly cancel outstanding background tasks
// Following implementation only handle the case where bg_thread is waiting for backward inputs
// Background thread can also be in other states, such as running Forward() or running Backward()
std::vector<int64_t> run_ids;
{
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
for (auto it = bg_threads_.begin(); it != bg_threads_.end(); ++it) {
run_ids.push_back(it->first);
}
}
for (int64_t run_id : run_ids) {
if (!onnxruntime::contrib::OrtTasks::GetInstance().TaskIsCompleted(run_id)) {
CancelBackgroundTask(run_id);
}
}
#endif
@ -1733,59 +1740,111 @@ common::Status InferenceSession::Run(IOBinding& io_binding) {
}
#ifdef ENABLE_TRAINING
common::Status InferenceSession::RunInBackgroundAndWaitForYield(RunOptions& run_options, IOBinding& io_binding,
std::vector<OrtValue>& user_outputs) {
const int64_t main_thread_event_id = 0;
onnxruntime::contrib::OrtEventPool::GetInstance().ResetEvent(0);
common::Status InferenceSession::RunInBackgroundAndWaitForYield(const RunOptions& run_options, IOBinding& io_binding,
std::vector<OrtValue>& user_outputs, int64_t& run_id) {
std::promise<void> setup_promise;
std::future<void> setup_future = setup_promise.get_future();
task_.terminate_flag_ = &(run_options.terminate);
task_.bg_thread_promise_ = std::promise<Status>();
task_.bg_thread_future_ = task_.bg_thread_promise_.get_future();
task_.bg_thread_ = std::thread([&](std::promise<common::Status> result_promise) {
common::Status s = Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(),
&io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo());
// Passing run_options and io_binding by reference to the bg_thread,
// this is ok because they are ORTModule's member, and they are presistent through forward and backward calls
auto bg_thread = std::thread([this](std::future<void> setup_future, const RunOptions& run_options, IOBinding& io_binding) {
// wait until task is properly setup
setup_future.get();
result_promise.set_value(s);
common::Status status = Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(),
&io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo());
// signal main thread for background thread completion
const int64_t main_thread_event_id = 0;
onnxruntime::contrib::OrtEventPool::GetInstance().SignalEvent(main_thread_event_id);
onnxruntime::contrib::OrtTasks::GetInstance().SetStatus(status);
// If forward outputs still hasn't been consumed at this point, i.e. forward function hasn't complete itself
// this indicates that Run() call returned before hitting YieldOp, due to hitting some exception during the forward subgraph execution
// In this case, we need to wake up the foreground thread and pass along the failed status.
// Otherwise, foreground thread will be stuck waiting for forward_outputs.
if (onnxruntime::contrib::OrtTasks::GetInstance().ForwardOutputsIsValid()) {
ORT_ENFORCE(!status.IsOK());
// signal main thread for background thread completion
onnxruntime::contrib::OrtTasks::GetInstance().SetForwardOutputs(status, {});
}
},
std::move(task_.bg_thread_promise_));
std::move(setup_future), std::cref(run_options), std::ref(io_binding));
// Wait for events from
// 1. Yield op, if the bg thread sucessfully reached Yield's signal point
// 2. The end of bg thread, if it hit execptions and returned earlier
onnxruntime::contrib::OrtEventPool::GetInstance().WaitAndResetEvent(main_thread_event_id);
// background thread has completed without hitting Yield Op
if (task_.bg_thread_future_.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready) {
Status bg_thread_status = task_.bg_thread_future_.get();
task_.bg_thread_.join();
return bg_thread_status;
run_id = std::hash<std::thread::id>()(bg_thread.get_id());
{
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
bg_threads_[run_id] = std::move(bg_thread);
}
onnxruntime::contrib::OrtTasks::GetInstance().CreateBackgroundTask(run_id);
LOGS(*session_logger_, VERBOSE) << "InferenceSession::Forward() call created a task with run_id " << run_id;
// background task is setup, unblock background thread to continue
setup_promise.set_value();
// Wait for data/signal from
// 1. Yield op, if the bg thread sucessfully reached Yield's signal point
// 2. The end of bg thread, if it hit execptions and returned earlier
auto forward_outputs = onnxruntime::contrib::OrtTasks::GetInstance().WaitForForwardOutputs(run_id);
const Status& forward_status = forward_outputs.first;
user_outputs = std::move(forward_outputs.second);
// background thread has completed without hitting Yield Op
if (!forward_status.IsOK()) {
std::thread thread;
{
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
std::swap(thread, bg_threads_[run_id]);
bg_threads_.erase(run_id);
}
ORT_ENFORCE(thread.joinable());
thread.join();
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
return forward_status;
}
onnxruntime::contrib::OrtMessageQueue::GetInstance().PopAll(user_outputs);
return Status::OK();
}
common::Status InferenceSession::ContinueRunInBackground(const std::vector<OrtValue>& backward_output_grads) {
for (const auto& ort_value : backward_output_grads) {
onnxruntime::contrib::OrtMessageQueue::GetInstance().Push(ort_value);
}
common::Status InferenceSession::ContinueRunInBackground(int64_t run_id, const std::vector<OrtValue>& backward_output_grads) {
LOGS(*session_logger_, VERBOSE) << "Running InferenceSession::Backward() with run_id " << run_id;
// resume background thread
const int64_t background_thread_event_id = 1;
onnxruntime::contrib::OrtEventPool::GetInstance().SignalEvent(background_thread_event_id);
onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, backward_output_grads, false);
Status bg_thread_status = task_.bg_thread_future_.get();
// wait for bg_thread to complete
if (task_.bg_thread_.joinable()) {
task_.bg_thread_.join();
Status bg_thread_status = onnxruntime::contrib::OrtTasks::GetInstance().WaitForStatus(run_id);
std::thread bg_thread;
{
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
std::swap(bg_thread, bg_threads_[run_id]);
bg_threads_.erase(run_id);
}
// wait for bg_thread to complete
ORT_ENFORCE(bg_thread.joinable());
bg_thread.join();
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
return bg_thread_status;
}
void InferenceSession::CancelBackgroundTask(int64_t run_id) {
LOGS(*session_logger_, WARNING) << "Canceling background task with run_id " << run_id;
// resume background thread with terminate = true
onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, {}, true);
// wait for bg_thread to complete
std::thread bg_thread;
{
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
std::swap(bg_thread, bg_threads_[run_id]);
bg_threads_.erase(run_id);
}
ORT_ENFORCE(bg_thread.joinable());
bg_thread.join();
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
}
#endif
template <typename T>

View file

@ -304,13 +304,15 @@ class InferenceSession {
#ifdef ENABLE_TRAINING
// For ORTModule.forward()
virtual common::Status RunInBackgroundAndWaitForYield(RunOptions& run_options, IOBinding& io_binding,
std::vector<OrtValue>& user_outputs) ORT_MUST_USE_RESULT;
virtual common::Status RunInBackgroundAndWaitForYield(const RunOptions& run_options, IOBinding& io_binding,
std::vector<OrtValue>& user_outputs,
int64_t& run_id) ORT_MUST_USE_RESULT;
// For ORTModule.backward()
common::Status ContinueRunInBackground(const std::vector<OrtValue>& backward_output_grads) ORT_MUST_USE_RESULT;
#endif
common::Status ContinueRunInBackground(int64_t run_id, const std::vector<OrtValue>& backward_output_grads) ORT_MUST_USE_RESULT;
void CancelBackgroundTask(int64_t run_id);
#endif
/**
* @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK.
* @note lifetime of the returned pointer is valid as long as the Session object is live.
@ -676,18 +678,16 @@ class InferenceSession {
std::vector<uint8_t> ort_format_model_bytes_;
#ifdef ENABLE_TRAINING
// background thread for RunInBackgroundAndWaitForYield
struct Task {
std::thread bg_thread_;
std::promise<Status> bg_thread_promise_;
std::future<Status> bg_thread_future_;
bool* terminate_flag_ = nullptr;
} task_;
// mutex for accessing bg_threads_
std::mutex bg_threads_mutex_;
// background threads for RunInBackgroundAndWaitForYield and ContinueRunInBackground
std::unordered_map<int64_t, std::thread> bg_threads_;
#endif
std::shared_ptr<onnxruntime::AllocatorManager> allocator_manager_;
};
struct SessionIOBinding {
public:
SessionIOBinding(InferenceSession* session);

View file

@ -234,14 +234,15 @@ class Session:
:param iobinding: the iobinding object that has graph inputs/outputs bind.
:param run_options: See :class:`onnxruntime.RunOptions`.
"""
return [OrtValue(ortvalue) for ortvalue in self._sess.run_forward(iobinding._iobinding, run_options)]
ortvalues, run_id = self._sess.run_forward(iobinding._iobinding, run_options)
return [OrtValue(ortvalue) for ortvalue in ortvalues], run_id
def run_backward(self, backward_output_grads):
def run_backward(self, backward_output_grads, run_id):
"""
Resume executing the backward subgraph starting from Yield Op.
:param backward_output_grads: Output gradients for backward.
"""
self._sess.run_backward([ortvalue._ortvalue for ortvalue in backward_output_grads])
self._sess.run_backward([ortvalue._ortvalue for ortvalue in backward_output_grads], run_id)
class InferenceSession(Session):

View file

@ -1222,8 +1222,7 @@ void addObjectMethods(py::module& m, Environment& env) {
return ml_value;
})
#ifdef ENABLE_TRAINING
.def_static("ortvalue_from_data_ptr", [](std::vector<int64_t>& shape, py::object& element_type,
OrtDevice& device, int64_t data_ptr) {
.def_static("ortvalue_from_data_ptr", [](std::vector<int64_t>& shape, py::object& element_type, OrtDevice& device, int64_t data_ptr) {
ORT_ENFORCE(data_ptr != 0, "Pointer to data memory is invalid");
PyArray_Descr* dtype;
if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) {
@ -1313,7 +1312,7 @@ void addObjectMethods(py::module& m, Environment& env) {
PyCapsule_New(dlmanaged_tensor, "dltensor", dlpack_capsule_destructor));
})
#endif
;
;
py::class_<SessionIOBinding> session_io_binding(m, "SessionIOBinding");
session_io_binding
@ -1837,22 +1836,22 @@ including arg name, arg type (contains both type and shape).)pbdoc")
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
})
#ifdef ENABLE_TRAINING
.def("run_forward", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions& run_options) -> std::vector<OrtValue> {
.def("run_forward", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions& run_options) -> py::tuple {
std::vector<OrtValue> module_outputs;
Status status = sess->GetSessionHandle()->RunInBackgroundAndWaitForYield(run_options, *io_binding.Get(), module_outputs);
int64_t run_id;
Status status = sess->GetSessionHandle()->RunInBackgroundAndWaitForYield(run_options, *io_binding.Get(), module_outputs, run_id);
if (!status.IsOK()) {
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
}
return module_outputs;
return py::make_tuple(module_outputs, run_id);
})
.def("run_backward", [](PyInferenceSession* sess, const std::vector<OrtValue>& backward_output_grads) -> void {
Status status = sess->GetSessionHandle()->ContinueRunInBackground(backward_output_grads);
.def("run_backward", [](PyInferenceSession* sess, const std::vector<OrtValue>& backward_output_grads, int64_t run_id) -> void {
Status status = sess->GetSessionHandle()->ContinueRunInBackground(run_id, backward_output_grads);
if (!status.IsOK())
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
})
#endif
;
;
py::enum_<onnxruntime::ArenaExtendStrategy>(m, "ArenaExtendStrategy", py::arithmetic())
.value("kNextPowerOfTwo", onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo)

View file

@ -244,8 +244,10 @@ class ORTModule(torch.nn.Module):
_create_iobinding(self._training_io_binding, inputs, self._onnx_training, self._device)
# Run and return module outputs.
user_outputs = tuple(_ort_output_to_torch_tensor(forward_output) \
for forward_output in self._training_session.run_forward(self._training_io_binding, self._run_options))
forward_outputs, run_id = self._training_session.run_forward(self._training_io_binding, self._run_options)
user_outputs = tuple(_ort_output_to_torch_tensor(forward_output) for forward_output in forward_outputs)
ctx.run_id = run_id
return user_outputs
@staticmethod
@ -261,7 +263,8 @@ class ORTModule(torch.nn.Module):
grad_output.dtype), grad_output.device.type, _utils.get_device_index(grad_output.device), grad_output.data_ptr()))
# Run and get results
self._training_session.run_backward(backward_grad_output_ortvalue)
run_id = ctx.run_id
self._training_session.run_backward(backward_grad_output_ortvalue, run_id)
backward_outputs = self._training_io_binding.get_outputs()
# Return input and initializer gradients

View file

@ -26,29 +26,10 @@ def parse_arguments():
def run_ortmodule_api_tests(cwd, log):
log.debug('Running: ORTModule-API tests')
class TestNameCollecterPlugin:
def __init__(self):
self.collected = set()
command = [sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_ortmodule_api.py']
def pytest_collection_modifyitems(self, items):
for item in items:
print('item.name: ', item.name)
self.collected.add(item.name)
run_subprocess(command, cwd=cwd, log=log).check_returncode()
import os
import pytest
plugin = TestNameCollecterPlugin()
print(cwd)
test_script_filename = os.path.join("orttraining_test_ortmodule_api.py")
pytest.main(['--collect-only', test_script_filename], plugins=[plugin])
# TODO: FIX THIS!
# Running tests in a loop one after another,
# because ORTModule doesn't support multiple run call at the same time
for test_name in plugin.collected:
run_subprocess([
sys.executable, '-m', 'pytest', '-sv',
'orttraining_test_ortmodule_api.py' + '::' + test_name], cwd=cwd).check_returncode()
def run_ortmodule_poc_net(cwd, log, no_cuda, data_dir):
log.debug('Running: ORTModule POCNet for MNIST with --no-cuda arg {}.'.format(no_cuda))

View file

@ -7,6 +7,7 @@ import torch
from transformers import AutoConfig, BertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput
import pytest
from time import sleep
import warnings
from unittest.mock import patch
from collections import OrderedDict
@ -103,6 +104,13 @@ class NeuralNetSimplePositionalAndKeywordArguments(torch.nn.Module):
return torch.mean(self.a) + 3 * y
return torch.mean(self.a) + x
# TODO: This is a workaround for the problem that pytest is still cleaning up the previous test
# while the next task already start.
@pytest.fixture(autouse=True)
def run_before_tests():
# wait for 50ms before starting the next test
sleep(0.05)
def _get_bert_for_sequence_classification_model(device, output_attentions = False, \
output_hidden_states = False, return_dict = True):
"""Returns the BertForSequenceClassification pretrained model"""
@ -341,7 +349,6 @@ def test_model_and_input_without_device():
# x = torch.randn(N, D_in, device=device)
# y = model(x)
# TODO: Re-enable this Test when .to(), .cpu() and .cuda() are fixed
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
def test_input_requires_grad_saved(device):
N, D_in, H, D_out = 32, 784, 500, 10
@ -363,17 +370,154 @@ def test_input_requires_grad_backward_creates_input_grad(device):
s.backward()
assert x.grad is not None
# TODO: Re-enable this Test when .to(), .cpu() and .cuda() are fixed
# @pytest.mark.parametrize("device", ['cuda', 'cpu'])
# def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder(device):
# N, D_in, H, D_out = 32, 784, 500, 10
# model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
# model = ORTModule(model)
# x = torch.randn(N, D_in, device=device, requires_grad=True)
# model(x.data)
# module_gradient_graph_builder = model._module_gradient_graph_builder
# model(x)
# assert module_gradient_graph_builder != model._module_gradient_graph_builder
def test_multiple_forward_only_calls():
N, D_in, H, D_out = 32, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model = ORTModule(model)
for step in range(10):
x = torch.randn(N, D_in, device='cuda', requires_grad=False)
prediction1 = model(x)
def test_multiple_overlapping_forward_backward_calls():
N, D_in, H, D_out = 32, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model = ORTModule(model)
for step in range(10):
x1 = torch.randn(N, D_in, device='cuda', requires_grad=True)
x2 = torch.randn(N, D_in, device='cuda', requires_grad=True)
assert x1.grad is None and x2.grad is None
prediction1 = model(x1)
s1 = prediction1.sum()
prediction2 = model(x2)
s2 = prediction2.sum()
s1.backward()
s2.backward()
assert x1.grad is not None and x2.grad is not None
def test_multiple_ortmodules_training():
N, D_in, H, D_out = 32, 784, 500, 10
model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model1 = ORTModule(model1)
model2 = ORTModule(model2)
for step in range(10):
x1 = torch.randn(N, D_in, device='cuda', requires_grad=True)
x2 = torch.randn(N, D_in, device='cuda', requires_grad=True)
assert x1.grad is None and x2.grad is None
prediction1 = model1(x1)
s1 = prediction1.sum()
prediction2 = model2(x2)
s2 = prediction2.sum()
s1.backward()
s2.backward()
assert x1.grad is not None and x2.grad is not None
for param in model1.parameters():
assert param.grad is not None
param.grad = None
for param in model2.parameters():
assert param.grad is not None
param.grad = None
def test_multiple_ortmodules_common_backbone_training():
N, D_in, H, D_out = 32, 64, 500, 64
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
# model is the common backbone shared by model1 and model2
model = ORTModule(model)
model1 = ORTModule(model1)
model2 = ORTModule(model2)
for step in range(10):
x1 = torch.randn(N, D_in, device='cuda', requires_grad=True)
x2 = torch.randn(N, D_in, device='cuda', requires_grad=True)
assert x1.grad is None and x2.grad is None
prediction1 = model1(model(x1))
s1 = prediction1.sum()
s1.backward()
prediction2 = model2(model(x2))
s2 = prediction2.sum()
s2.backward()
assert x1.grad is not None and x2.grad is not None
for param in model.parameters():
assert param.grad is not None
param.grad = None
for param in model1.parameters():
assert param.grad is not None
param.grad = None
for param in model2.parameters():
assert param.grad is not None
param.grad = None
def test_multiple_chained_ortmodules_training():
N, D_in, H, D_out = 32, 128, 500, 128
model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model1 = ORTModule(model1)
model2 = ORTModule(model2)
all_params = list(model1.parameters()) + list(model2.parameters())
for step in range(10):
x = torch.randn(N, D_in, device='cuda', requires_grad=True)
output1 = model1(x)
output2 = model2(output1)
s = output2.sum()
s.backward()
assert x.grad is not None
for param in all_params:
assert param.grad is not None
param.grad = None
def test_mixed_nnmodule_ortmodules_training():
N, D_in, H, D_out = 32, 128, 500, 128
model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda')
model3 = NeuralNetMultiplePositionalArguments(D_in, H, D_out).to('cuda')
model1 = ORTModule(model1)
# model2 is intentionally left as nn.module
model3 = ORTModule(model3)
all_params = list(model1.parameters()) + list(model2.parameters()) + list(model3.parameters())
for step in range(10):
x1 = torch.randn(N, D_in, device='cuda', requires_grad=True)
x2 = torch.randn(N, D_in, device='cuda', requires_grad=True)
a1 = model1(x1)
a2 = model2(x2)
a3 = model3(torch.sin(a1), torch.cos(a2))
loss = a3.sum()
loss.backward()
assert x1.grad is not None and x2.grad is not None
for param in all_params:
assert param.grad is not None
param.grad = None
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder(device):
N, D_in, H, D_out = 32, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
model = ORTModule(model)
x = torch.randn(N, D_in, device=device, requires_grad=True)
model(x.data)
module_gradient_graph_builder = model._module_gradient_graph_builder
model(x)
assert module_gradient_graph_builder != model._module_gradient_graph_builder
def test_gpu_reserved_memory_with_torch_no_grad():
device = 'cuda'

View file

@ -0,0 +1,99 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "ort_tasks.h"
namespace onnxruntime {
namespace contrib {
void OrtTasks::CreateBackgroundTask(int64_t run_id) {
std::lock_guard<std::mutex> lock(mutex_);
ORT_ENFORCE(bg_tasks_.find(run_id) == bg_tasks_.end());
bg_tasks_.insert(std::make_pair(run_id, std::make_unique<Task>()));
}
void OrtTasks::RemoveTask(int64_t run_id) {
std::lock_guard<std::mutex> lock(mutex_);
auto iter = bg_tasks_.find(run_id);
ORT_ENFORCE(iter != bg_tasks_.end());
bg_tasks_.erase(iter);
}
void OrtTasks::SetForwardOutputs(Status s, const std::vector<OrtValue>& forward_outputs) {
int64_t run_id = hasher_(std::this_thread::get_id());
std::lock_guard<std::mutex> lock(mutex_);
auto iter = bg_tasks_.find(run_id);
ORT_ENFORCE(iter != bg_tasks_.end());
iter->second->forward_output_promise_.set_value(std::make_pair(s, forward_outputs));
}
ForwardReturnType OrtTasks::WaitForForwardOutputs(int64_t run_id) {
OrtTasks::Task* task;
{
std::lock_guard<std::mutex> lock(mutex_);
auto iter = bg_tasks_.find(run_id);
ORT_ENFORCE(iter != bg_tasks_.end());
task = (*iter).second.get();
}
return task->forward_output_future_.get();
}
bool OrtTasks::ForwardOutputsIsValid() {
int64_t run_id = hasher_(std::this_thread::get_id());
std::lock_guard<std::mutex> lock(mutex_);
auto iter = bg_tasks_.find(run_id);
ORT_ENFORCE(iter != bg_tasks_.end());
return iter->second->forward_output_future_.valid();
}
void OrtTasks::SetBackwardInputs(int64_t run_id, const std::vector<OrtValue>& backward_inputs, bool terminate) {
std::lock_guard<std::mutex> lock(mutex_);
auto iter = bg_tasks_.find(run_id);
ORT_ENFORCE(iter != bg_tasks_.end());
iter->second->backward_input_promise_.set_value(std::make_pair(terminate, backward_inputs));
}
BackwardReturnType OrtTasks::WaitForBackwardInputs() {
int64_t run_id = hasher_(std::this_thread::get_id());
OrtTasks::Task* task;
{
std::lock_guard<std::mutex> lock(mutex_);
auto iter = bg_tasks_.find(run_id);
ORT_ENFORCE(iter != bg_tasks_.end());
task = (*iter).second.get();
}
return task->backward_input_future_.get();
}
void OrtTasks::SetStatus(const Status& status) {
int64_t run_id = hasher_(std::this_thread::get_id());
std::lock_guard<std::mutex> lock(mutex_);
auto iter = bg_tasks_.find(run_id);
ORT_ENFORCE(iter != bg_tasks_.end());
iter->second->status_promise_.set_value(status);
}
bool OrtTasks::TaskIsCompleted(int64_t run_id) {
std::lock_guard<std::mutex> lock(mutex_);
auto iter = bg_tasks_.find(run_id);
ORT_ENFORCE(iter != bg_tasks_.end());
// if status_future has been invalidated, the task is completed
return !iter->second->status_future_.valid();
}
Status OrtTasks::WaitForStatus(int64_t run_id) {
OrtTasks::Task* task;
{
std::lock_guard<std::mutex> lock(mutex_);
auto iter = bg_tasks_.find(run_id);
ORT_ENFORCE(iter != bg_tasks_.end());
task = (*iter).second.get();
}
return task->status_future_.get();
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/ml_value.h"
#include <mutex>
#include <future>
namespace onnxruntime {
namespace contrib {
typedef std::pair<Status, std::vector<OrtValue>> ForwardReturnType;
typedef std::pair<bool, std::vector<OrtValue>> BackwardReturnType;
class OrtTasks final {
public:
static OrtTasks& GetInstance() {
static OrtTasks* instance_ = new OrtTasks;
return *instance_;
}
void CreateBackgroundTask(int64_t run_id);
void RemoveTask(int64_t run_id);
void SetForwardOutputs(Status s, const std::vector<OrtValue>& forward_outputs);
ForwardReturnType WaitForForwardOutputs(int64_t run_id);
bool ForwardOutputsIsValid();
void SetBackwardInputs(int64_t run_id, const std::vector<OrtValue>& backward_inputs, bool terminate);
BackwardReturnType WaitForBackwardInputs();
void SetStatus(const Status& status);
Status WaitForStatus(int64_t run_id);
bool TaskIsCompleted(int64_t run_id);
private:
OrtTasks() = default;
~OrtTasks() = default;
OrtTasks(const OrtTasks&) = delete;
OrtTasks& operator=(const OrtTasks&) = delete;
struct Task {
std::promise<ForwardReturnType> forward_output_promise_;
std::future<ForwardReturnType> forward_output_future_ = forward_output_promise_.get_future();
std::promise<BackwardReturnType> backward_input_promise_;
std::future<BackwardReturnType> backward_input_future_ = backward_input_promise_.get_future();
std::promise<Status> status_promise_;
std::future<Status> status_future_ = status_promise_.get_future();
};
std::hash<std::thread::id> hasher_;
mutable std::mutex mutex_;
std::unordered_map<int64_t, std::unique_ptr<Task>> bg_tasks_;
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -2,8 +2,7 @@
// Licensed under the MIT License.
#include "orttraining/training_ops/cpu/controlflow/yield.h"
#include "core/providers/cpu/controlflow/event_pool.h"
#include "core/providers/cpu/controlflow/message_queue.h"
#include "orttraining/training_ops/cpu/controlflow/ort_tasks.h"
#include "core/framework/op_kernel_context_internal.h"
namespace onnxruntime {
@ -21,27 +20,26 @@ ONNX_OPERATOR_KERNEL_EX(
Status YieldOp::Compute(OpKernelContext* ctx) const {
auto* ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
for (int i_in = 0; i_in < ctx->InputCount(); ++i_in) {
onnxruntime::contrib::OrtMessageQueue::GetInstance().Push(*ctx_internal->GetInputMLValue(i_in));
std::vector<OrtValue> forward_outputs;
forward_outputs.reserve(ctx->InputCount());
for (int i = 0; i < ctx->InputCount(); ++i) {
forward_outputs.push_back(*ctx_internal->GetInputMLValue(i));
}
// Reset background event before returning to main thread
const int64_t background_thread_event_id = 1;
onnxruntime::contrib::OrtEventPool::GetInstance().ResetEvent(background_thread_event_id);
// return forward output and single that FW graph is completed
OrtTasks::GetInstance().SetForwardOutputs(Status::OK(), forward_outputs);
// single event for InferenceSession::RunInBackgroundAndWaitForYield() that FW graph is done
const int64_t main_thread_event_id = 0;
OrtEventPool::GetInstance().SignalEvent(main_thread_event_id);
// wait for data from SetBackwardInputs() to continue executing the BW graph
auto backward_inputs = OrtTasks::GetInstance().WaitForBackwardInputs();
bool terminate = backward_inputs.first;
// wait for event from InferenceSession::ContinueRunInBackground() to continue the BW graph
OrtEventPool::GetInstance().WaitAndResetEvent(background_thread_event_id);
if (ctx_internal->GetTerminateFlag()) {
LOGS(ctx->Logger(), WARNING) << "Resumed executing backward subgraph, terminate_flag is set to true.";
if (terminate) {
ORT_THROW("Terminating backward run, since the terminate is set to true.");
} else {
// Get output grad from somewhere and prepare Op outputs.
for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) {
ctx_internal->SetOutputMLValue(i_out, OrtMessageQueue::GetInstance().Pop());
ORT_ENFORCE(backward_inputs.second.size() == static_cast<size_t>(ctx->OutputCount()));
for (int i = 0; i < ctx->OutputCount(); ++i) {
ctx_internal->SetOutputMLValue(i, backward_inputs.second[i]);
}
}