diff --git a/onnxruntime/core/providers/cpu/controlflow/message_queue.h b/onnxruntime/core/providers/cpu/controlflow/message_queue.h deleted file mode 100644 index c743950056..0000000000 --- a/onnxruntime/core/providers/cpu/controlflow/message_queue.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "core/common/common.h" -#include "core/framework/ml_value.h" - -namespace onnxruntime { -namespace contrib { - -class OrtMessageQueue final { - public: - static OrtMessageQueue& GetInstance() { - static OrtMessageQueue instance_; - return instance_; - } - - void Push(const OrtValue& ort_value) { ort_values.emplace(ort_value); } - OrtValue Pop() { - OrtValue ort_value = ort_values.front(); - ort_values.pop(); - return ort_value; - } - - void PopAll(std::vector& results) { - while (!ort_values.empty()) { - OrtValue ort_value = ort_values.front(); - ort_values.pop(); - results.emplace_back(ort_value); - } - } - - private: - OrtMessageQueue() = default; - ~OrtMessageQueue() = default; - OrtMessageQueue(const OrtMessageQueue&) = delete; - OrtMessageQueue& operator=(const OrtMessageQueue&) = delete; - - std::queue ort_values; -}; - -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index ef207af880..880e3d8e86 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -57,8 +57,7 @@ #endif #ifdef ENABLE_TRAINING -#include "core/providers/cpu/controlflow/event_pool.h" -#include "core/providers/cpu/controlflow/message_queue.h" +#include "orttraining/training_ops/cpu/controlflow/ort_tasks.h" #endif using namespace ONNX_NAMESPACE; @@ -381,12 +380,20 @@ InferenceSession::~InferenceSession() { } #ifdef ENABLE_TRAINING - // TODO: find a better way to terminate the background thread - // backward is not completed yet, set terminate_flag to True - if (task_.bg_thread_future_.valid()) { - *(task_.terminate_flag_) = true; - Status s = ContinueRunInBackground({}); - ORT_UNUSED_PARAMETER(s); + // TODO: Properly cancel outstanding background tasks + // Following implementation only handle the case where bg_thread is waiting for backward inputs + // Background thread can also be in other states, such as running Forward() or running Backward() + std::vector run_ids; + { + std::lock_guard lock(bg_threads_mutex_); + for (auto it = bg_threads_.begin(); it != bg_threads_.end(); ++it) { + run_ids.push_back(it->first); + } + } + for (int64_t run_id : run_ids) { + if (!onnxruntime::contrib::OrtTasks::GetInstance().TaskIsCompleted(run_id)) { + CancelBackgroundTask(run_id); + } } #endif @@ -1733,59 +1740,111 @@ common::Status InferenceSession::Run(IOBinding& io_binding) { } #ifdef ENABLE_TRAINING -common::Status InferenceSession::RunInBackgroundAndWaitForYield(RunOptions& run_options, IOBinding& io_binding, - std::vector& user_outputs) { - const int64_t main_thread_event_id = 0; - onnxruntime::contrib::OrtEventPool::GetInstance().ResetEvent(0); +common::Status InferenceSession::RunInBackgroundAndWaitForYield(const RunOptions& run_options, IOBinding& io_binding, + std::vector& user_outputs, int64_t& run_id) { + std::promise setup_promise; + std::future setup_future = setup_promise.get_future(); - task_.terminate_flag_ = &(run_options.terminate); - task_.bg_thread_promise_ = std::promise(); - task_.bg_thread_future_ = task_.bg_thread_promise_.get_future(); - task_.bg_thread_ = std::thread([&](std::promise result_promise) { - common::Status s = Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(), - &io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo()); + // Passing run_options and io_binding by reference to the bg_thread, + // this is ok because they are ORTModule's member, and they are presistent through forward and backward calls + auto bg_thread = std::thread([this](std::future setup_future, const RunOptions& run_options, IOBinding& io_binding) { + // wait until task is properly setup + setup_future.get(); - result_promise.set_value(s); + common::Status status = Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(), + &io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo()); - // signal main thread for background thread completion - const int64_t main_thread_event_id = 0; - onnxruntime::contrib::OrtEventPool::GetInstance().SignalEvent(main_thread_event_id); + onnxruntime::contrib::OrtTasks::GetInstance().SetStatus(status); + + // If forward outputs still hasn't been consumed at this point, i.e. forward function hasn't complete itself + // this indicates that Run() call returned before hitting YieldOp, due to hitting some exception during the forward subgraph execution + // In this case, we need to wake up the foreground thread and pass along the failed status. + // Otherwise, foreground thread will be stuck waiting for forward_outputs. + if (onnxruntime::contrib::OrtTasks::GetInstance().ForwardOutputsIsValid()) { + ORT_ENFORCE(!status.IsOK()); + // signal main thread for background thread completion + onnxruntime::contrib::OrtTasks::GetInstance().SetForwardOutputs(status, {}); + } }, - std::move(task_.bg_thread_promise_)); + std::move(setup_future), std::cref(run_options), std::ref(io_binding)); - // Wait for events from - // 1. Yield op, if the bg thread sucessfully reached Yield's signal point - // 2. The end of bg thread, if it hit execptions and returned earlier - onnxruntime::contrib::OrtEventPool::GetInstance().WaitAndResetEvent(main_thread_event_id); - - // background thread has completed without hitting Yield Op - if (task_.bg_thread_future_.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready) { - Status bg_thread_status = task_.bg_thread_future_.get(); - task_.bg_thread_.join(); - return bg_thread_status; + run_id = std::hash()(bg_thread.get_id()); + { + std::lock_guard lock(bg_threads_mutex_); + bg_threads_[run_id] = std::move(bg_thread); + } + + onnxruntime::contrib::OrtTasks::GetInstance().CreateBackgroundTask(run_id); + + LOGS(*session_logger_, VERBOSE) << "InferenceSession::Forward() call created a task with run_id " << run_id; + + // background task is setup, unblock background thread to continue + setup_promise.set_value(); + + // Wait for data/signal from + // 1. Yield op, if the bg thread sucessfully reached Yield's signal point + // 2. The end of bg thread, if it hit execptions and returned earlier + auto forward_outputs = onnxruntime::contrib::OrtTasks::GetInstance().WaitForForwardOutputs(run_id); + const Status& forward_status = forward_outputs.first; + user_outputs = std::move(forward_outputs.second); + + // background thread has completed without hitting Yield Op + if (!forward_status.IsOK()) { + std::thread thread; + { + std::lock_guard lock(bg_threads_mutex_); + std::swap(thread, bg_threads_[run_id]); + bg_threads_.erase(run_id); + } + ORT_ENFORCE(thread.joinable()); + thread.join(); + onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id); + return forward_status; } - onnxruntime::contrib::OrtMessageQueue::GetInstance().PopAll(user_outputs); return Status::OK(); } -common::Status InferenceSession::ContinueRunInBackground(const std::vector& backward_output_grads) { - for (const auto& ort_value : backward_output_grads) { - onnxruntime::contrib::OrtMessageQueue::GetInstance().Push(ort_value); - } +common::Status InferenceSession::ContinueRunInBackground(int64_t run_id, const std::vector& backward_output_grads) { + LOGS(*session_logger_, VERBOSE) << "Running InferenceSession::Backward() with run_id " << run_id; // resume background thread - const int64_t background_thread_event_id = 1; - onnxruntime::contrib::OrtEventPool::GetInstance().SignalEvent(background_thread_event_id); + onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, backward_output_grads, false); - Status bg_thread_status = task_.bg_thread_future_.get(); - // wait for bg_thread to complete - if (task_.bg_thread_.joinable()) { - task_.bg_thread_.join(); + Status bg_thread_status = onnxruntime::contrib::OrtTasks::GetInstance().WaitForStatus(run_id); + + std::thread bg_thread; + { + std::lock_guard lock(bg_threads_mutex_); + std::swap(bg_thread, bg_threads_[run_id]); + bg_threads_.erase(run_id); } + // wait for bg_thread to complete + ORT_ENFORCE(bg_thread.joinable()); + bg_thread.join(); + onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id); + return bg_thread_status; } + +void InferenceSession::CancelBackgroundTask(int64_t run_id) { + LOGS(*session_logger_, WARNING) << "Canceling background task with run_id " << run_id; + + // resume background thread with terminate = true + onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, {}, true); + + // wait for bg_thread to complete + std::thread bg_thread; + { + std::lock_guard lock(bg_threads_mutex_); + std::swap(bg_thread, bg_threads_[run_id]); + bg_threads_.erase(run_id); + } + ORT_ENFORCE(bg_thread.joinable()); + bg_thread.join(); + onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id); +} #endif template diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index ee3e0fba38..bb5e6e91d7 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -304,13 +304,15 @@ class InferenceSession { #ifdef ENABLE_TRAINING // For ORTModule.forward() - virtual common::Status RunInBackgroundAndWaitForYield(RunOptions& run_options, IOBinding& io_binding, - std::vector& user_outputs) ORT_MUST_USE_RESULT; + virtual common::Status RunInBackgroundAndWaitForYield(const RunOptions& run_options, IOBinding& io_binding, + std::vector& user_outputs, + int64_t& run_id) ORT_MUST_USE_RESULT; // For ORTModule.backward() - common::Status ContinueRunInBackground(const std::vector& backward_output_grads) ORT_MUST_USE_RESULT; -#endif + common::Status ContinueRunInBackground(int64_t run_id, const std::vector& backward_output_grads) ORT_MUST_USE_RESULT; + void CancelBackgroundTask(int64_t run_id); +#endif /** * @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK. * @note lifetime of the returned pointer is valid as long as the Session object is live. @@ -676,18 +678,16 @@ class InferenceSession { std::vector ort_format_model_bytes_; #ifdef ENABLE_TRAINING - // background thread for RunInBackgroundAndWaitForYield - struct Task { - std::thread bg_thread_; - std::promise bg_thread_promise_; - std::future bg_thread_future_; - bool* terminate_flag_ = nullptr; - } task_; + // mutex for accessing bg_threads_ + std::mutex bg_threads_mutex_; + + // background threads for RunInBackgroundAndWaitForYield and ContinueRunInBackground + std::unordered_map bg_threads_; #endif + std::shared_ptr allocator_manager_; }; - struct SessionIOBinding { public: SessionIOBinding(InferenceSession* session); diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index a9567c7388..6bfe7a0a72 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -234,14 +234,15 @@ class Session: :param iobinding: the iobinding object that has graph inputs/outputs bind. :param run_options: See :class:`onnxruntime.RunOptions`. """ - return [OrtValue(ortvalue) for ortvalue in self._sess.run_forward(iobinding._iobinding, run_options)] + ortvalues, run_id = self._sess.run_forward(iobinding._iobinding, run_options) + return [OrtValue(ortvalue) for ortvalue in ortvalues], run_id - def run_backward(self, backward_output_grads): + def run_backward(self, backward_output_grads, run_id): """ Resume executing the backward subgraph starting from Yield Op. :param backward_output_grads: Output gradients for backward. """ - self._sess.run_backward([ortvalue._ortvalue for ortvalue in backward_output_grads]) + self._sess.run_backward([ortvalue._ortvalue for ortvalue in backward_output_grads], run_id) class InferenceSession(Session): diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index e4d8b9aa65..e7c71ad84c 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1222,8 +1222,7 @@ void addObjectMethods(py::module& m, Environment& env) { return ml_value; }) #ifdef ENABLE_TRAINING - .def_static("ortvalue_from_data_ptr", [](std::vector& shape, py::object& element_type, - OrtDevice& device, int64_t data_ptr) { + .def_static("ortvalue_from_data_ptr", [](std::vector& shape, py::object& element_type, OrtDevice& device, int64_t data_ptr) { ORT_ENFORCE(data_ptr != 0, "Pointer to data memory is invalid"); PyArray_Descr* dtype; if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) { @@ -1313,7 +1312,7 @@ void addObjectMethods(py::module& m, Environment& env) { PyCapsule_New(dlmanaged_tensor, "dltensor", dlpack_capsule_destructor)); }) #endif -; + ; py::class_ session_io_binding(m, "SessionIOBinding"); session_io_binding @@ -1837,22 +1836,22 @@ including arg name, arg type (contains both type and shape).)pbdoc") throw std::runtime_error("Error in execution: " + status.ErrorMessage()); }) #ifdef ENABLE_TRAINING - .def("run_forward", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions& run_options) -> std::vector { + .def("run_forward", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions& run_options) -> py::tuple { std::vector module_outputs; - Status status = sess->GetSessionHandle()->RunInBackgroundAndWaitForYield(run_options, *io_binding.Get(), module_outputs); + int64_t run_id; + Status status = sess->GetSessionHandle()->RunInBackgroundAndWaitForYield(run_options, *io_binding.Get(), module_outputs, run_id); if (!status.IsOK()) { throw std::runtime_error("Error in execution: " + status.ErrorMessage()); } - - return module_outputs; + return py::make_tuple(module_outputs, run_id); }) - .def("run_backward", [](PyInferenceSession* sess, const std::vector& backward_output_grads) -> void { - Status status = sess->GetSessionHandle()->ContinueRunInBackground(backward_output_grads); + .def("run_backward", [](PyInferenceSession* sess, const std::vector& backward_output_grads, int64_t run_id) -> void { + Status status = sess->GetSessionHandle()->ContinueRunInBackground(run_id, backward_output_grads); if (!status.IsOK()) throw std::runtime_error("Error in execution: " + status.ErrorMessage()); }) #endif -; + ; py::enum_(m, "ArenaExtendStrategy", py::arithmetic()) .value("kNextPowerOfTwo", onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo) diff --git a/orttraining/orttraining/python/training/ortmodule.py b/orttraining/orttraining/python/training/ortmodule.py index cbbfc9d4d3..7224ef301d 100644 --- a/orttraining/orttraining/python/training/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule.py @@ -244,8 +244,10 @@ class ORTModule(torch.nn.Module): _create_iobinding(self._training_io_binding, inputs, self._onnx_training, self._device) # Run and return module outputs. - user_outputs = tuple(_ort_output_to_torch_tensor(forward_output) \ - for forward_output in self._training_session.run_forward(self._training_io_binding, self._run_options)) + forward_outputs, run_id = self._training_session.run_forward(self._training_io_binding, self._run_options) + user_outputs = tuple(_ort_output_to_torch_tensor(forward_output) for forward_output in forward_outputs) + ctx.run_id = run_id + return user_outputs @staticmethod @@ -261,7 +263,8 @@ class ORTModule(torch.nn.Module): grad_output.dtype), grad_output.device.type, _utils.get_device_index(grad_output.device), grad_output.data_ptr())) # Run and get results - self._training_session.run_backward(backward_grad_output_ortvalue) + run_id = ctx.run_id + self._training_session.run_backward(backward_grad_output_ortvalue, run_id) backward_outputs = self._training_io_binding.get_outputs() # Return input and initializer gradients diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py index 48d2f531e0..1de1b10bba 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py @@ -26,29 +26,10 @@ def parse_arguments(): def run_ortmodule_api_tests(cwd, log): log.debug('Running: ORTModule-API tests') - class TestNameCollecterPlugin: - def __init__(self): - self.collected = set() + command = [sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_ortmodule_api.py'] - def pytest_collection_modifyitems(self, items): - for item in items: - print('item.name: ', item.name) - self.collected.add(item.name) + run_subprocess(command, cwd=cwd, log=log).check_returncode() - import os - import pytest - plugin = TestNameCollecterPlugin() - print(cwd) - test_script_filename = os.path.join("orttraining_test_ortmodule_api.py") - pytest.main(['--collect-only', test_script_filename], plugins=[plugin]) - - # TODO: FIX THIS! - # Running tests in a loop one after another, - # because ORTModule doesn't support multiple run call at the same time - for test_name in plugin.collected: - run_subprocess([ - sys.executable, '-m', 'pytest', '-sv', - 'orttraining_test_ortmodule_api.py' + '::' + test_name], cwd=cwd).check_returncode() def run_ortmodule_poc_net(cwd, log, no_cuda, data_dir): log.debug('Running: ORTModule POCNet for MNIST with --no-cuda arg {}.'.format(no_cuda)) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 5e9bf4520a..3df281e384 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -7,6 +7,7 @@ import torch from transformers import AutoConfig, BertForSequenceClassification from transformers.modeling_outputs import SequenceClassifierOutput import pytest +from time import sleep import warnings from unittest.mock import patch from collections import OrderedDict @@ -103,6 +104,13 @@ class NeuralNetSimplePositionalAndKeywordArguments(torch.nn.Module): return torch.mean(self.a) + 3 * y return torch.mean(self.a) + x +# TODO: This is a workaround for the problem that pytest is still cleaning up the previous test +# while the next task already start. +@pytest.fixture(autouse=True) +def run_before_tests(): + # wait for 50ms before starting the next test + sleep(0.05) + def _get_bert_for_sequence_classification_model(device, output_attentions = False, \ output_hidden_states = False, return_dict = True): """Returns the BertForSequenceClassification pretrained model""" @@ -341,7 +349,6 @@ def test_model_and_input_without_device(): # x = torch.randn(N, D_in, device=device) # y = model(x) -# TODO: Re-enable this Test when .to(), .cpu() and .cuda() are fixed @pytest.mark.parametrize("device", ['cuda', 'cpu']) def test_input_requires_grad_saved(device): N, D_in, H, D_out = 32, 784, 500, 10 @@ -363,17 +370,154 @@ def test_input_requires_grad_backward_creates_input_grad(device): s.backward() assert x.grad is not None -# TODO: Re-enable this Test when .to(), .cpu() and .cuda() are fixed -# @pytest.mark.parametrize("device", ['cuda', 'cpu']) -# def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder(device): -# N, D_in, H, D_out = 32, 784, 500, 10 -# model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) -# model = ORTModule(model) -# x = torch.randn(N, D_in, device=device, requires_grad=True) -# model(x.data) -# module_gradient_graph_builder = model._module_gradient_graph_builder -# model(x) -# assert module_gradient_graph_builder != model._module_gradient_graph_builder +def test_multiple_forward_only_calls(): + N, D_in, H, D_out = 32, 784, 500, 10 + model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') + model = ORTModule(model) + for step in range(10): + x = torch.randn(N, D_in, device='cuda', requires_grad=False) + prediction1 = model(x) + +def test_multiple_overlapping_forward_backward_calls(): + N, D_in, H, D_out = 32, 784, 500, 10 + model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') + model = ORTModule(model) + + for step in range(10): + x1 = torch.randn(N, D_in, device='cuda', requires_grad=True) + x2 = torch.randn(N, D_in, device='cuda', requires_grad=True) + assert x1.grad is None and x2.grad is None + + prediction1 = model(x1) + s1 = prediction1.sum() + + prediction2 = model(x2) + s2 = prediction2.sum() + + s1.backward() + s2.backward() + assert x1.grad is not None and x2.grad is not None + +def test_multiple_ortmodules_training(): + N, D_in, H, D_out = 32, 784, 500, 10 + model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') + model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') + model1 = ORTModule(model1) + model2 = ORTModule(model2) + + for step in range(10): + x1 = torch.randn(N, D_in, device='cuda', requires_grad=True) + x2 = torch.randn(N, D_in, device='cuda', requires_grad=True) + assert x1.grad is None and x2.grad is None + + prediction1 = model1(x1) + s1 = prediction1.sum() + + prediction2 = model2(x2) + s2 = prediction2.sum() + + s1.backward() + s2.backward() + + assert x1.grad is not None and x2.grad is not None + for param in model1.parameters(): + assert param.grad is not None + param.grad = None + for param in model2.parameters(): + assert param.grad is not None + param.grad = None + +def test_multiple_ortmodules_common_backbone_training(): + N, D_in, H, D_out = 32, 64, 500, 64 + model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') + model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') + model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') + # model is the common backbone shared by model1 and model2 + model = ORTModule(model) + model1 = ORTModule(model1) + model2 = ORTModule(model2) + + for step in range(10): + x1 = torch.randn(N, D_in, device='cuda', requires_grad=True) + x2 = torch.randn(N, D_in, device='cuda', requires_grad=True) + assert x1.grad is None and x2.grad is None + + prediction1 = model1(model(x1)) + s1 = prediction1.sum() + s1.backward() + + prediction2 = model2(model(x2)) + s2 = prediction2.sum() + s2.backward() + + assert x1.grad is not None and x2.grad is not None + for param in model.parameters(): + assert param.grad is not None + param.grad = None + for param in model1.parameters(): + assert param.grad is not None + param.grad = None + for param in model2.parameters(): + assert param.grad is not None + param.grad = None + +def test_multiple_chained_ortmodules_training(): + N, D_in, H, D_out = 32, 128, 500, 128 + model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') + model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') + model1 = ORTModule(model1) + model2 = ORTModule(model2) + + all_params = list(model1.parameters()) + list(model2.parameters()) + + for step in range(10): + x = torch.randn(N, D_in, device='cuda', requires_grad=True) + output1 = model1(x) + output2 = model2(output1) + s = output2.sum() + s.backward() + + assert x.grad is not None + for param in all_params: + assert param.grad is not None + param.grad = None + +def test_mixed_nnmodule_ortmodules_training(): + N, D_in, H, D_out = 32, 128, 500, 128 + model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') + model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') + model3 = NeuralNetMultiplePositionalArguments(D_in, H, D_out).to('cuda') + model1 = ORTModule(model1) + # model2 is intentionally left as nn.module + model3 = ORTModule(model3) + + all_params = list(model1.parameters()) + list(model2.parameters()) + list(model3.parameters()) + + for step in range(10): + x1 = torch.randn(N, D_in, device='cuda', requires_grad=True) + x2 = torch.randn(N, D_in, device='cuda', requires_grad=True) + + a1 = model1(x1) + a2 = model2(x2) + a3 = model3(torch.sin(a1), torch.cos(a2)) + loss = a3.sum() + loss.backward() + + assert x1.grad is not None and x2.grad is not None + for param in all_params: + assert param.grad is not None + param.grad = None + +@pytest.mark.parametrize("device", ['cuda', 'cpu']) +def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder(device): + N, D_in, H, D_out = 32, 784, 500, 10 + model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + model = ORTModule(model) + x = torch.randn(N, D_in, device=device, requires_grad=True) + model(x.data) + module_gradient_graph_builder = model._module_gradient_graph_builder + model(x) + assert module_gradient_graph_builder != model._module_gradient_graph_builder def test_gpu_reserved_memory_with_torch_no_grad(): device = 'cuda' diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/ort_tasks.cc b/orttraining/orttraining/training_ops/cpu/controlflow/ort_tasks.cc new file mode 100644 index 0000000000..767bede9da --- /dev/null +++ b/orttraining/orttraining/training_ops/cpu/controlflow/ort_tasks.cc @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ort_tasks.h" + +namespace onnxruntime { +namespace contrib { + +void OrtTasks::CreateBackgroundTask(int64_t run_id) { + std::lock_guard lock(mutex_); + ORT_ENFORCE(bg_tasks_.find(run_id) == bg_tasks_.end()); + bg_tasks_.insert(std::make_pair(run_id, std::make_unique())); +} + +void OrtTasks::RemoveTask(int64_t run_id) { + std::lock_guard lock(mutex_); + auto iter = bg_tasks_.find(run_id); + ORT_ENFORCE(iter != bg_tasks_.end()); + bg_tasks_.erase(iter); +} + +void OrtTasks::SetForwardOutputs(Status s, const std::vector& forward_outputs) { + int64_t run_id = hasher_(std::this_thread::get_id()); + + std::lock_guard lock(mutex_); + auto iter = bg_tasks_.find(run_id); + ORT_ENFORCE(iter != bg_tasks_.end()); + iter->second->forward_output_promise_.set_value(std::make_pair(s, forward_outputs)); +} + +ForwardReturnType OrtTasks::WaitForForwardOutputs(int64_t run_id) { + OrtTasks::Task* task; + { + std::lock_guard lock(mutex_); + auto iter = bg_tasks_.find(run_id); + ORT_ENFORCE(iter != bg_tasks_.end()); + task = (*iter).second.get(); + } + return task->forward_output_future_.get(); +} + +bool OrtTasks::ForwardOutputsIsValid() { + int64_t run_id = hasher_(std::this_thread::get_id()); + + std::lock_guard lock(mutex_); + auto iter = bg_tasks_.find(run_id); + ORT_ENFORCE(iter != bg_tasks_.end()); + return iter->second->forward_output_future_.valid(); +} + +void OrtTasks::SetBackwardInputs(int64_t run_id, const std::vector& backward_inputs, bool terminate) { + std::lock_guard lock(mutex_); + auto iter = bg_tasks_.find(run_id); + ORT_ENFORCE(iter != bg_tasks_.end()); + iter->second->backward_input_promise_.set_value(std::make_pair(terminate, backward_inputs)); +} + +BackwardReturnType OrtTasks::WaitForBackwardInputs() { + int64_t run_id = hasher_(std::this_thread::get_id()); + OrtTasks::Task* task; + { + std::lock_guard lock(mutex_); + auto iter = bg_tasks_.find(run_id); + ORT_ENFORCE(iter != bg_tasks_.end()); + task = (*iter).second.get(); + } + return task->backward_input_future_.get(); +} + +void OrtTasks::SetStatus(const Status& status) { + int64_t run_id = hasher_(std::this_thread::get_id()); + + std::lock_guard lock(mutex_); + auto iter = bg_tasks_.find(run_id); + ORT_ENFORCE(iter != bg_tasks_.end()); + iter->second->status_promise_.set_value(status); +} + +bool OrtTasks::TaskIsCompleted(int64_t run_id) { + std::lock_guard lock(mutex_); + auto iter = bg_tasks_.find(run_id); + ORT_ENFORCE(iter != bg_tasks_.end()); + // if status_future has been invalidated, the task is completed + return !iter->second->status_future_.valid(); +} + +Status OrtTasks::WaitForStatus(int64_t run_id) { + OrtTasks::Task* task; + { + std::lock_guard lock(mutex_); + auto iter = bg_tasks_.find(run_id); + ORT_ENFORCE(iter != bg_tasks_.end()); + task = (*iter).second.get(); + } + return task->status_future_.get(); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/ort_tasks.h b/orttraining/orttraining/training_ops/cpu/controlflow/ort_tasks.h new file mode 100644 index 0000000000..bb78e0a935 --- /dev/null +++ b/orttraining/orttraining/training_ops/cpu/controlflow/ort_tasks.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/ml_value.h" + +#include +#include + + +namespace onnxruntime { +namespace contrib { + +typedef std::pair> ForwardReturnType; +typedef std::pair> BackwardReturnType; + +class OrtTasks final { + public: + static OrtTasks& GetInstance() { + static OrtTasks* instance_ = new OrtTasks; + return *instance_; + } + + void CreateBackgroundTask(int64_t run_id); + void RemoveTask(int64_t run_id); + + void SetForwardOutputs(Status s, const std::vector& forward_outputs); + ForwardReturnType WaitForForwardOutputs(int64_t run_id); + bool ForwardOutputsIsValid(); + + void SetBackwardInputs(int64_t run_id, const std::vector& backward_inputs, bool terminate); + BackwardReturnType WaitForBackwardInputs(); + + void SetStatus(const Status& status); + Status WaitForStatus(int64_t run_id); + bool TaskIsCompleted(int64_t run_id); + + private: + OrtTasks() = default; + ~OrtTasks() = default; + OrtTasks(const OrtTasks&) = delete; + OrtTasks& operator=(const OrtTasks&) = delete; + + struct Task { + std::promise forward_output_promise_; + std::future forward_output_future_ = forward_output_promise_.get_future(); + + std::promise backward_input_promise_; + std::future backward_input_future_ = backward_input_promise_.get_future(); + + std::promise status_promise_; + std::future status_future_ = status_promise_.get_future(); + }; + + std::hash hasher_; + mutable std::mutex mutex_; + std::unordered_map> bg_tasks_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/controlflow/yield.cc b/orttraining/orttraining/training_ops/cpu/controlflow/yield.cc index 5fdc64f5ae..cdbd90ceab 100644 --- a/orttraining/orttraining/training_ops/cpu/controlflow/yield.cc +++ b/orttraining/orttraining/training_ops/cpu/controlflow/yield.cc @@ -2,8 +2,7 @@ // Licensed under the MIT License. #include "orttraining/training_ops/cpu/controlflow/yield.h" -#include "core/providers/cpu/controlflow/event_pool.h" -#include "core/providers/cpu/controlflow/message_queue.h" +#include "orttraining/training_ops/cpu/controlflow/ort_tasks.h" #include "core/framework/op_kernel_context_internal.h" namespace onnxruntime { @@ -21,27 +20,26 @@ ONNX_OPERATOR_KERNEL_EX( Status YieldOp::Compute(OpKernelContext* ctx) const { auto* ctx_internal = static_cast(ctx); - for (int i_in = 0; i_in < ctx->InputCount(); ++i_in) { - onnxruntime::contrib::OrtMessageQueue::GetInstance().Push(*ctx_internal->GetInputMLValue(i_in)); + + std::vector forward_outputs; + forward_outputs.reserve(ctx->InputCount()); + for (int i = 0; i < ctx->InputCount(); ++i) { + forward_outputs.push_back(*ctx_internal->GetInputMLValue(i)); } - // Reset background event before returning to main thread - const int64_t background_thread_event_id = 1; - onnxruntime::contrib::OrtEventPool::GetInstance().ResetEvent(background_thread_event_id); + // return forward output and single that FW graph is completed + OrtTasks::GetInstance().SetForwardOutputs(Status::OK(), forward_outputs); - // single event for InferenceSession::RunInBackgroundAndWaitForYield() that FW graph is done - const int64_t main_thread_event_id = 0; - OrtEventPool::GetInstance().SignalEvent(main_thread_event_id); + // wait for data from SetBackwardInputs() to continue executing the BW graph + auto backward_inputs = OrtTasks::GetInstance().WaitForBackwardInputs(); + bool terminate = backward_inputs.first; - // wait for event from InferenceSession::ContinueRunInBackground() to continue the BW graph - OrtEventPool::GetInstance().WaitAndResetEvent(background_thread_event_id); - - if (ctx_internal->GetTerminateFlag()) { - LOGS(ctx->Logger(), WARNING) << "Resumed executing backward subgraph, terminate_flag is set to true."; + if (terminate) { + ORT_THROW("Terminating backward run, since the terminate is set to true."); } else { - // Get output grad from somewhere and prepare Op outputs. - for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) { - ctx_internal->SetOutputMLValue(i_out, OrtMessageQueue::GetInstance().Pop()); + ORT_ENFORCE(backward_inputs.second.size() == static_cast(ctx->OutputCount())); + for (int i = 0; i < ctx->OutputCount(); ++i) { + ctx_internal->SetOutputMLValue(i, backward_inputs.second[i]); } }