OrtModule v0.21 (#6395)

* ortmodule v0.2

* use pt module for eval

* get user outputs in yield op

* pass output grads to yield output without copy

* Disable mem_pattern for ORTModule

* Avoid allocating output buffer for Yield op

* Change to WaitAndReset to avoid overriding signal

* remove unnecessory signal/wait at the end of bg thread

* Return Session.Run result as a std::future

* export model with torch.no_grad()

* Handle bg thread's early return in Forward call

* Removed duplicated Yield kernel

* Silence "CUDA kernel missing log"

* Add missing transforms, clear iobinding (#6532)

* revert ortmodule.py to a working state first

* Apply ortmodule.py change from dev branch

* Rename to YieldOp

Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
Co-authored-by: ashbhandare <ash.bhandare@gmail.com>
Co-authored-by: Sherlock <baihan.huang@gmail.com>
This commit is contained in:
Vincent Wang 2021-02-11 05:27:15 +08:00 committed by GitHub
parent bc0d04bf07
commit eec602e48a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 761 additions and 516 deletions

View file

@ -251,6 +251,7 @@ class OpKernelContext {
const OrtValue* GetInputMLValue(int index) const;
const OrtValue* GetImplicitInputMLValue(int index) const;
OrtValue* GetOutputMLValue(int index);
Status SetOutputMLValue(int index, const OrtValue& ort_value);
// Creates the OrtValue* based on the shape, if it does not exist
// The parameter nnz is used only for sparse-tensors and indicates the

View file

@ -45,6 +45,16 @@ OrtValue* IExecutionFrame::GetMutableNodeInputOrOutputMLValue(int index) {
return const_cast<OrtValue*>(GetNodeInputOrOutputMLValue(index));
}
Status IExecutionFrame::SetOutputMLValue(int index, const OrtValue& ort_value) {
int ort_value_idx = GetNodeIdxToMLValueIdx(index);
if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast<size_t>(ort_value_idx) >= all_values_size_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx);
}
all_values_[ort_value_idx] = ort_value;
return Status::OK();
}
// TO DO: make it thread safe
// This method is not thread safe!
// Return S_OK and nullptr if index map to an value that is an unused optional input/output

View file

@ -49,6 +49,9 @@ class IExecutionFrame {
const OrtValue* GetNodeInputOrOutputMLValue(int index) const;
OrtValue* GetMutableNodeInputOrOutputMLValue(int index);
// Override the index-th output with ort_value
Status SetOutputMLValue(int index, const OrtValue& ort_value);
// TO DO: make it thread safe
// This method is not thread safe!
// Return S_OK and nullptr if index map to an value that is an unused optional input/output

View file

@ -175,4 +175,13 @@ OrtValue* OpKernelContext::GetOutputMLValue(int index) {
return execution_frame_->GetMutableNodeInputOrOutputMLValue(output_arg_index);
}
Status OpKernelContext::SetOutputMLValue(int index, const OrtValue& ort_value) {
if (index < 0 || index >= OutputCount()) {
return Status(common::ONNXRUNTIME, common::FAIL, "Index out of range.");
}
auto output_arg_index = GetOutputArgIndex(index);
return execution_frame_->SetOutputMLValue(output_arg_index, ort_value);
}
} // namespace onnxruntime

View file

@ -53,6 +53,10 @@ class OpKernelContextInternal : public OpKernelContext {
return OpKernelContext::GetOutputMLValue(index);
}
void SetOutputMLValue(int index, const OrtValue& ort_value) {
ORT_ENFORCE(OpKernelContext::SetOutputMLValue(index, ort_value).IsOK());
}
OrtValue* OutputMLValue(int index, const TensorShape& shape) {
return OpKernelContext::OutputMLValue(index, shape);
}

View file

@ -37,6 +37,8 @@
#include "core/optimizer/skip_layer_norm_fusion.h"
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/unsqueeze_elimination.h"
#include "core/optimizer/matmul_transpose_fusion.h"
#include "orttraining/core/optimizer/bias_dropout_fusion.h"
namespace onnxruntime {
class IExecutionProvider;
@ -144,13 +146,17 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
transformers.emplace_back(onnxruntime::make_unique<ConvActivationFusion>(cpu_cuda_acl_armnn_execution_providers));
std::unordered_set<std::string> cpu_cuda_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider};
const std::unordered_set<std::string> cuda_execution_providers = {onnxruntime::kCudaExecutionProvider};
const std::unordered_set<std::string> cpu_cuda_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider};
transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<LayerNormFusion>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<SimplifiedLayerNormFusion>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<AttentionFusion>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<EmbedLayerNormFusion>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<BiasDropoutFusion>(cuda_execution_providers));
// TODO: This should be combined with MatMulScaleFusion and deprecate MatmulTransposeFusion
transformers.emplace_back(onnxruntime::make_unique<MatmulTransposeFusion>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<BiasGeluFusion>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<BiasSoftmaxFusion>(cpu_cuda_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<SkipLayerNormFusion>(cpu_cuda_execution_providers));

View file

@ -1906,7 +1906,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
// none of the provided registries has a CUDA kernel for this node
if (cuda_kernel_def == nullptr) {
LOGS_DEFAULT(WARNING) << "CUDA kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name();
LOGS_DEFAULT(INFO) << "CUDA kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name();
continue;
}

View file

@ -52,6 +52,8 @@
#include "core/util/protobuf_parsing_utils.h"
#include "core/util/thread_utils.h"
#include "orttraining/training_ops/cpu/controlflow/event_pool.h"
#include "orttraining/training_ops/cpu/controlflow/message_queue.h"
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::experimental;
@ -372,6 +374,14 @@ InferenceSession::~InferenceSession() {
}
}
// TODO: find a better way to terminate the background thread
// backward is not completed yet, set terminate_flag to True
if (task_.bg_thread_future_.valid()) {
*(task_.terminate_flag_) = true;
Status s = ContinueRunInBackground({});
ORT_UNUSED_PARAMETER(s);
}
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
if (session_activity_started_)
TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity");
@ -1709,6 +1719,60 @@ common::Status InferenceSession::Run(IOBinding& io_binding) {
return Run(run_options, io_binding);
}
common::Status InferenceSession::RunInBackgroundAndWaitForYield(RunOptions& run_options, IOBinding& io_binding,
std::vector<OrtValue>& user_outputs) {
const int64_t main_thread_event_id = 0;
onnxruntime::contrib::OrtEventPool::GetInstance().ResetEvent(0);
task_.terminate_flag_ = &(run_options.terminate);
task_.bg_thread_promise_ = std::promise<Status>();
task_.bg_thread_future_ = task_.bg_thread_promise_.get_future();
task_.bg_thread_ = std::thread([&](std::promise<common::Status> result_promise) {
common::Status s = Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(),
&io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo());
result_promise.set_value(s);
// signal main thread for background thread completion
const int64_t main_thread_event_id = 0;
onnxruntime::contrib::OrtEventPool::GetInstance().SignalEvent(main_thread_event_id);
},
std::move(task_.bg_thread_promise_));
// Wait for events from
// 1. Yield op, if the bg thread sucessfully reached Yield's signal point
// 2. The end of bg thread, if it hit execptions and returned earlier
onnxruntime::contrib::OrtEventPool::GetInstance().WaitAndResetEvent(main_thread_event_id);
// background thread has completed without hitting Yield Op
if (task_.bg_thread_future_.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready) {
Status bg_thread_status = task_.bg_thread_future_.get();
task_.bg_thread_.join();
return bg_thread_status;
}
onnxruntime::contrib::OrtMessageQueue::GetInstance().PopAll(user_outputs);
return Status::OK();
}
common::Status InferenceSession::ContinueRunInBackground(const std::vector<OrtValue>& backward_output_grads) {
for (const auto& ort_value : backward_output_grads) {
onnxruntime::contrib::OrtMessageQueue::GetInstance().Push(ort_value);
}
// resume background thread
const int64_t background_thread_event_id = 1;
onnxruntime::contrib::OrtEventPool::GetInstance().SignalEvent(background_thread_event_id);
Status bg_thread_status = task_.bg_thread_future_.get();
// wait for bg_thread to complete
if (task_.bg_thread_.joinable()) {
task_.bg_thread_.join();
}
return bg_thread_status;
}
template <typename T>
void InferenceSession::StartProfiling(const std::basic_string<T>& file_prefix) {
std::basic_ostringstream<T> ss;

View file

@ -5,6 +5,8 @@
#include <string>
#include <unordered_map>
#include <thread>
#include <future>
#include "core/common/common.h"
#include "core/common/logging/logging.h"
@ -294,6 +296,13 @@ class InferenceSession {
virtual common::Status Run(const RunOptions& run_options, IOBinding& io_binding) ORT_MUST_USE_RESULT;
common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT;
// For ORTModule.forward()
virtual common::Status RunInBackgroundAndWaitForYield(RunOptions& run_options, IOBinding& io_binding,
std::vector<OrtValue>& user_outputs) ORT_MUST_USE_RESULT;
// For ORTModule.backward()
common::Status ContinueRunInBackground(const std::vector<OrtValue>& backward_output_grads) ORT_MUST_USE_RESULT;
/**
* @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK.
* @note lifetime of the returned pointer is valid as long as the Session object is live.
@ -654,7 +663,15 @@ class InferenceSession {
// Longer term we may want to directly refer to offsets in this buffer for initializers so we don't need to copy
// those into new OrtValue instances, at which point we won't free them until the InferenceSession goes away.
std::vector<uint8_t> ort_format_model_bytes_;
// background thread for RunInBackgroundAndWaitForYield
struct Task {
std::thread bg_thread_;
std::promise<Status> bg_thread_promise_;
std::future<Status> bg_thread_future_;
bool* terminate_flag_ = nullptr;
} task_;
std::shared_ptr<onnxruntime::AllocatorManager> allocator_manager_;
};

View file

@ -1,20 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/ml_value.h"
#include "python/dlpack.h"
// this convertor will:
// 1) take a OrtValue object and wrap it in the DLPack tensor
namespace onnxruntime {
namespace python {
DLManagedTensor* ort_value_to_dlpack(const OrtValue& ml_value);
DLDataType get_dlpack_data_type(const OrtValue& ml_value);
DLContext get_dlpack_context(const OrtValue& ml_value, const int64_t& device_id);
} // namespace python
} // namespace onnxruntime

View file

@ -1,16 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/dl_convertor.h"
#include "python/dlpack_convertor.h"
namespace onnxruntime {
namespace python {
DLDataType get_dlpack_data_type(const OrtValue& ml_value) {
ORT_ENFORCE(ml_value.IsTensor(), "Only OrtValues that are Tensors are currently supported");
DLDataType get_dlpack_data_type(const OrtValue& ort_value) {
ORT_ENFORCE(ort_value.IsTensor(), "Only tensor-type OrtValues are supported");
DLDataType dtype;
dtype.lanes = 1;
const Tensor& tensor = ml_value.Get<Tensor>();
const Tensor& tensor = ort_value.Get<Tensor>();
switch (tensor.GetElementType()) {
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
dtype.code = DLDataTypeCode::kDLFloat;
@ -73,11 +73,11 @@ DLDataType get_dlpack_data_type(const OrtValue& ml_value) {
return dtype;
}
DLContext get_dlpack_context(const OrtValue& ml_value, const int64_t& device_id) {
ORT_ENFORCE(ml_value.IsTensor(), "Only OrtValues that are Tensors are currently supported");
DLContext get_dlpack_context(const OrtValue& ort_value, const int64_t& device_id) {
ORT_ENFORCE(ort_value.IsTensor(), "Only OrtValues that are Tensors are currently supported");
DLContext ctx;
ctx.device_id = device_id;
const Tensor& tensor = ml_value.Get<Tensor>();
ctx.device_id = static_cast<int>(device_id);
const Tensor& tensor = ort_value.Get<Tensor>();
const auto& location = tensor.Location();
switch (location.device.Type()) {
case OrtDevice::CPU:
@ -102,17 +102,17 @@ void deleter(DLManagedTensor* arg) { delete static_cast<OrtDLManagedTensor*>(arg
// This function returns a shared_ptr to memory managed DLpack tensor
// constructed out of OrtValue.
DLManagedTensor* ort_value_to_dlpack(const OrtValue& ml_value) {
ORT_ENFORCE(ml_value.IsTensor(), "Only OrtValues that are Tensors are currently supported");
DLManagedTensor* ort_value_to_dlpack(const OrtValue& ort_value) {
ORT_ENFORCE(ort_value.IsTensor(), "Only tensor type OrtValues are supported");
OrtDLManagedTensor* ort_dlmanaged_tensor(new OrtDLManagedTensor);
const Tensor& tensor = ml_value.Get<Tensor>();
ort_dlmanaged_tensor->handle = ml_value;
const Tensor& tensor = ort_value.Get<Tensor>();
ort_dlmanaged_tensor->handle = ort_value;
ort_dlmanaged_tensor->tensor.manager_ctx = ort_dlmanaged_tensor;
ort_dlmanaged_tensor->tensor.deleter = &deleter;
ort_dlmanaged_tensor->tensor.dl_tensor.data = const_cast<void*>(tensor.DataRaw());
ort_dlmanaged_tensor->tensor.dl_tensor.ctx = get_dlpack_context(ml_value, tensor.Location().device.Id());
ort_dlmanaged_tensor->tensor.dl_tensor.ndim = tensor.Shape().NumDimensions();
ort_dlmanaged_tensor->tensor.dl_tensor.dtype = get_dlpack_data_type(ml_value);
ort_dlmanaged_tensor->tensor.dl_tensor.ctx = get_dlpack_context(ort_value, tensor.Location().device.Id());
ort_dlmanaged_tensor->tensor.dl_tensor.ndim = static_cast<int>(tensor.Shape().NumDimensions());
ort_dlmanaged_tensor->tensor.dl_tensor.dtype = get_dlpack_data_type(ort_value);
ort_dlmanaged_tensor->tensor.dl_tensor.shape =
tensor.Shape().NumDimensions() > 0 ? const_cast<int64_t*>(&tensor.Shape()[0]) : nullptr;
ort_dlmanaged_tensor->tensor.dl_tensor.strides = nullptr;

View file

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/ml_value.h"
#include "python/dlpack.h"
// This convertor will take an OrtValue and wrap it as a DLPack tensor
namespace onnxruntime {
namespace python {
DLManagedTensor* ort_value_to_dlpack(const OrtValue& ort_value);
DLDataType get_dlpack_data_type(const OrtValue& ort_value);
DLContext get_dlpack_context(const OrtValue& ort_value, const int64_t& device_id);
} // namespace python
} // namespace onnxruntime

View file

@ -228,6 +228,21 @@ class Session:
"""
self._sess.run_with_iobinding(iobinding._iobinding, run_options)
def run_forward(self, iobinding, run_options):
"""
Compute the forward subgraph until it hits the Yield Op.
:param iobinding: the iobinding object that has graph inputs/outputs bind.
:param run_options: See :class:`onnxruntime.RunOptions`.
"""
return [OrtValue(ortvalue) for ortvalue in self._sess.run_forward(iobinding._iobinding, run_options)]
def run_backward(self, backward_output_grads):
"""
Resume executing the backward subgraph starting from Yield Op.
:param backward_output_grads: Output gradients for backward.
"""
self._sess.run_backward([ortvalue._ortvalue for ortvalue in backward_output_grads])
class InferenceSession(Session):
"""
@ -486,6 +501,23 @@ class OrtValue:
return OrtValue(C.OrtValue.ortvalue_from_shape_and_type(shape, element_type,
C.OrtDevice(get_ort_device_type(device_type), C.OrtDevice.default_memory(), device_id)))
@staticmethod
def ortvalue_from_data_ptr(shape=None, element_type=None, device_type='cpu', device_id=0, buffer_ptr=None):
'''
Factory method to construct an OrtValue (which holds a Tensor) from given buffer_ptr
:param shape: List of integers indicating the shape of the OrtValue
:param element_type: The data type of the elements in the OrtValue (numpy type)
:param device_type: e.g. cpu, cuda, cpu by default
:param device_id: device id, e.g. 0
:param buffer_ptr: data buffer pointer
'''
if shape is None or element_type is None or buffer_ptr is None:
raise ValueError("`element_type`, `shape` and `buffer_ptr` are to be provided")
return OrtValue(C.OrtValue.ortvalue_from_data_ptr(shape, element_type,
C.OrtDevice(get_ort_device_type(device_type), C.OrtDevice.default_memory(), device_id),
buffer_ptr))
def data_ptr(self):
'''
Returns the address of the first element in the OrtValue's data buffer
@ -524,4 +556,8 @@ class OrtValue:
return self._ortvalue.numpy()
def to_dlpack(self):
'''
Returns a DLPack object from the OrtValue.
Valid only for OrtValues holding Tensors. Throws for OrtValues holding non-Tensors.
'''
return self._ortvalue.to_dlpack()

View file

@ -25,7 +25,7 @@
#include "core/platform/env.h"
#include "core/session/IOBinding.h"
#include "core/session/abi_session_options_impl.h"
#include "python/dl_convertor.h"
#include "python/dlpack_convertor.h"
// execution provider factory creator headers
#include "core/providers/cpu/cpu_provider_factory_creator.h"
@ -1216,6 +1216,27 @@ void addObjectMethods(py::module& m, Environment& env) {
return ml_value;
})
.def_static("ortvalue_from_data_ptr", [](std::vector<int64_t>& shape, py::object& element_type,
OrtDevice& device, int64_t data_ptr) {
ORT_ENFORCE(data_ptr != 0, "Pointer to data memory is invalid");
PyArray_Descr* dtype;
if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) {
throw std::runtime_error("Not a valid numpy type");
}
int type_num = dtype->type_num;
Py_DECREF(dtype);
OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device, device.Id());
std::unique_ptr<Tensor> p_tensor = onnxruntime::make_unique<Tensor>(NumpyTypeToOnnxRuntimeType(type_num), shape,
reinterpret_cast<void*>(data_ptr), info);
auto ort_value = onnxruntime::make_unique<OrtValue>();
ort_value->Init(p_tensor.release(), DataTypeImpl::GetType<Tensor>(),
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
return ort_value;
})
.def("data_ptr", [](OrtValue* ml_value) -> int64_t {
// TODO: Assumes that the OrtValue is a Tensor, make this generic to handle non-Tensors
ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are currently supported");
@ -1278,8 +1299,8 @@ void addObjectMethods(py::module& m, Environment& env) {
#endif
return obj;
})
.def("to_dlpack", [](OrtValue* ml_value) -> py::object {
DLManagedTensor* dlmanaged_tensor = ort_value_to_dlpack(*ml_value);
.def("to_dlpack", [](OrtValue* ort_value) -> py::object {
DLManagedTensor* dlmanaged_tensor = ort_value_to_dlpack(*ort_value);
return py::reinterpret_steal<py::object>(
PyCapsule_New(dlmanaged_tensor, "dltensor", dlpack_capsule_destructor));
});
@ -1804,6 +1825,20 @@ including arg name, arg type (contains both type and shape).)pbdoc")
status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get());
if (!status.IsOK())
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
})
.def("run_forward", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions& run_options) -> std::vector<OrtValue> {
std::vector<OrtValue> module_outputs;
Status status = sess->GetSessionHandle()->RunInBackgroundAndWaitForYield(run_options, *io_binding.Get(), module_outputs);
if (!status.IsOK()) {
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
}
return module_outputs;
})
.def("run_backward", [](PyInferenceSession* sess, const std::vector<OrtValue>& backward_output_grads) -> void {
Status status = sess->GetSessionHandle()->ContinueRunInBackground(backward_output_grads);
if (!status.IsOK())
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
});
py::enum_<onnxruntime::ArenaExtendStrategy>(m, "ArenaExtendStrategy", py::arithmetic())

View file

@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/model.h"
#include "core/graph/graph_utils.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "orttraining/core/framework/module_gradient_graph_builder.h"
@ -14,51 +13,26 @@ namespace training {
using namespace onnxruntime::common;
void GetInputAndOutputNames(const Node& node, std::unordered_set<std::string>& input_names,
std::unordered_set<std::string>& output_names) {
std::for_each(node.InputDefs().begin(), node.InputDefs().end(),
[&input_names](const NodeArg* node_arg) { input_names.insert(node_arg->Name()); });
std::for_each(node.OutputDefs().begin(), node.OutputDefs().end(),
[&output_names](const NodeArg* node_arg) { output_names.insert(node_arg->Name()); });
}
void RemoveNodes(Graph& graph, const std::vector<Node*>& nodes_to_remove) {
for (Node* node_to_remove : nodes_to_remove) {
graph_utils::RemoveNodeOutputEdges(graph, *node_to_remove);
graph.RemoveNode(node_to_remove->Index());
}
}
void FilterInitializers(Graph& graph, const std::unordered_set<std::string>& input_names) {
const auto& initializers = graph.GetAllInitializedTensors();
std::unordered_set<std::string> initializer_names_to_remove;
for (const auto& initializer : initializers) {
if (input_names.find(initializer.first) == input_names.end()) {
initializer_names_to_remove.insert(initializer.first);
}
}
for (const auto& initializer_name : initializer_names_to_remove) {
graph.RemoveInitializedTensor(initializer_name);
}
}
Status ModuleGradientGraphBuilder::Initialize(std::istream& model_istream,
const ModuleGradientGraphBuilderConfiguration& config) {
// We need to apply the pre-training transformers before the gradient graph builder so we can build
// an optimized gradient graph. The constant folding transformer depends on concrete shapes, without
// constant folding with concrete shapes, shapes of some intermediate tensors will fail to infer.
// This means we need to "apply transformers -> build gradient graph -> split" each time we have different
// concrete input shapes. So this init func is just to save the original graph and config.
// Save the model and config.
ONNX_NAMESPACE::ModelProto model_proto;
ORT_RETURN_IF_ERROR(Model::Load(model_istream, &model_proto));
ORT_RETURN_IF_ERROR(Model::Load(model_proto, model_, nullptr, *logger_));
config_ = config;
// Handle original model inputs, outputs and trainable initializers.
// We need to move the trainable initializers to graph inputs and keep the order in config,
// it's possible that the graph already has some trainable initializers in graph inputs,
// so we need to NOT assign these trainable initializers to the user inputs list.
Graph& graph = model_->MainGraph();
std::unordered_set<std::string> initializer_names_to_train_set(config.initializer_names_to_train.begin(),
config.initializer_names_to_train.end());
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
for (auto& node_arg : graph_inputs) {
split_graphs_info_.user_input_names.emplace_back(node_arg->Name());
if (initializer_names_to_train_set.find(node_arg->Name()) == initializer_names_to_train_set.end()) {
split_graphs_info_.user_input_names.emplace_back(node_arg->Name());
}
}
const std::vector<const NodeArg*>& graph_outputs = graph.GetOutputs();
@ -69,35 +43,65 @@ Status ModuleGradientGraphBuilder::Initialize(std::istream& model_istream,
split_graphs_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(),
config.initializer_names_to_train.end());
// Remove the training initializers from the graph and move them to input to save memory.
std::vector<const NodeArg*> input_args;
for (const auto& input_name : split_graphs_info_.user_input_names) {
input_args.emplace_back(graph.GetNodeArg(input_name));
}
// Remove the training initializers from the graph and move them to graph inputs.
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
input_args.emplace_back(graph.GetNodeArg(initializer_name));
graph.RemoveInitializedTensor(initializer_name);
}
graph.SetInputs(input_args);
config_ = config;
return Status::OK();
}
Status ModuleGradientGraphBuilder::BuildAndSplit(const std::vector<std::vector<int64_t>>& input_shapes) {
// Build the gradient graphs from original graph.
// Since the input shapes may differ, and the graph optimizers (mainly constant folding) may fold this
// shape info to constants, the optimized graph (before gradient graph building) can not be shared.
// So each time we need to start from the beginning, i.e., 1) replace input shapes, 2) apply graph optimizers,
// 3) build gradient graph, and finally 4) adjust the graph inputs and outputs.
Status ModuleGradientGraphBuilder::Build(const std::vector<std::vector<int64_t>>* input_shapes_ptr) {
// Make a copy of the original model.
auto model_proto = model_->ToProto();
std::shared_ptr<onnxruntime::Model> model_copied;
ORT_RETURN_IF_ERROR(Model::Load(model_proto, model_copied, nullptr, *logger_));
Graph& graph = model_copied->MainGraph();
ORT_RETURN_IF_ERROR(Model::Load(model_proto, gradient_model_, nullptr, *logger_));
// Replace the input shapes.
// Replace the user input shapes if input_shapes_ptr is not null_ptr.
if (input_shapes_ptr) {
SetConcreteInputShapes(*input_shapes_ptr);
}
// Build the gradient graph.
ORT_RETURN_IF_ERROR(BuildGradientGraph());
// Add Yield Op.
AddYieldOp();
// Reorder outputs.
ReorderOutputs();
return Status::OK();
}
std::string ModuleGradientGraphBuilder::GetGradientModel() const {
std::string model_str;
if (!gradient_model_->ToProto().SerializeToString(&model_str)) {
ORT_THROW("Fail to serialize gradient model to string.");
}
return model_str;
}
void ModuleGradientGraphBuilder::SetConcreteInputShapes(const std::vector<std::vector<int64_t>>& input_shapes) {
ORT_ENFORCE(input_shapes.size() == split_graphs_info_.user_input_names.size(),
"The size of concrete input shapes and the size of user inputs does not match.");
Graph& gradient_graph = gradient_model_->MainGraph();
std::vector<const NodeArg*> input_args;
size_t input_index = 0;
for (const auto& input_name : split_graphs_info_.user_input_names) {
NodeArg* input_node_arg = graph.GetNodeArg(input_name);
NodeArg* input_node_arg = gradient_graph.GetNodeArg(input_name);
ONNX_NAMESPACE::TensorShapeProto new_shape;
for (size_t i = 0; i < input_shapes[input_index].size(); i++) {
new_shape.add_dim()->set_dim_value(input_shapes[input_index][i]);
@ -109,15 +113,19 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(const std::vector<std::vector<i
}
// Move over all training initializer inputs. They already have the concrete shapes.
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
const std::vector<const NodeArg*>& graph_inputs = gradient_graph.GetInputsIncludingInitializers();
for (; input_index < graph_inputs.size(); input_index++) {
input_args.emplace_back(graph_inputs[input_index]);
}
graph.SetInputs(input_args);
ORT_RETURN_IF_ERROR(graph.Resolve());
gradient_graph.SetInputs(input_args);
}
Status ModuleGradientGraphBuilder::BuildGradientGraph() {
// Resolve original graph, register and apply transformers for pre-training.
Graph& gradient_graph = gradient_model_->MainGraph();
ORT_RETURN_IF_ERROR(gradient_graph.Resolve());
// Register and apply transformers for pre-training.
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{};
GraphTransformerManager graph_transformation_mgr{2};
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
@ -143,235 +151,118 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(const std::vector<std::vector<i
}
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i), *logger_));
ORT_RETURN_IF_ERROR(
graph_transformation_mgr.ApplyTransformers(gradient_graph, static_cast<TransformerLevel>(i), *logger_));
}
// Build gradient graph.
GradientGraphConfiguration gradient_graph_config{};
gradient_graph_config.use_invertible_layernorm_grad = config_.use_invertible_layernorm_grad;
gradient_graph_config.set_gradients_as_graph_outputs = config_.set_gradients_as_graph_outputs;
gradient_graph_config.set_gradients_as_graph_outputs = true;
std::unordered_set<std::string> y_node_arg_names(split_graphs_info_.user_output_names.begin(),
split_graphs_info_.user_output_names.end());
GradientGraphBuilder grad_graph_builder(&graph, y_node_arg_names, x_node_arg_names, "",
GradientGraphBuilder grad_graph_builder(&gradient_graph, y_node_arg_names, x_node_arg_names, "",
gradient_graph_config, *logger_);
ORT_RETURN_IF_ERROR(grad_graph_builder.Build());
return Status::OK();
}
// Fix inputs/outputs related to gradients.
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
std::unordered_set<std::string> input_names;
std::unordered_set<std::string> output_names;
for (auto node_index : node_topology_list) {
auto& node = *graph.GetNode(node_index);
GetInputAndOutputNames(node, input_names, output_names);
void ModuleGradientGraphBuilder::AddYieldOp() {
Graph& gradient_graph = gradient_model_->MainGraph();
GraphViewer gradient_graph_viewer(gradient_graph);
const auto& gradient_node_topology_list = gradient_graph_viewer.GetNodesInTopologicalOrder();
std::unordered_set<std::string> user_output_grad_names_set;
for (const auto& name : split_graphs_info_.user_output_names) {
user_output_grad_names_set.insert(name + "_grad");
}
input_args.clear();
for (const NodeArg* input_node_arg: graph.GetInputsIncludingInitializers()) {
input_args.emplace_back(input_node_arg);
}
// Add the entry points of gradients (normally loss_gard) to the graph inputs. Using the order of graph outputs.
split_graphs_info_.user_output_grad_names.clear();
split_graphs_info_.backward_output_grad_names.clear();
for (const auto& output_name : split_graphs_info_.user_output_names) {
std::string output_gradient_name = output_name + "_grad";
if (input_names.find(output_gradient_name) != input_names.end()) {
split_graphs_info_.user_output_grad_names.emplace_back(output_gradient_name);
// Only add to graph input when it's not an output of a node.
if (output_names.find(output_gradient_name) == output_names.end()) {
split_graphs_info_.backward_output_grad_names.emplace_back(output_gradient_name);
NodeArg* output_gradient_node_arg = graph.GetNodeArg(output_gradient_name);
output_gradient_node_arg->UpdateTypeAndShape(*graph.GetNodeArg(output_name), true, true, *logger_);
input_args.emplace_back(output_gradient_node_arg);
// If an NodeArg is output of one of nodes, it's not the user output gradient needed by backward graph.
std::unordered_set<std::string> non_backward_user_output_grad_names;
for (auto node_index : gradient_node_topology_list) {
auto& node = *gradient_graph.GetNode(node_index);
for (const auto& node_arg : node.OutputDefs()) {
if (user_output_grad_names_set.find(node_arg->Name()) != user_output_grad_names_set.end()) {
non_backward_user_output_grad_names.insert(node_arg->Name());
}
}
}
graph.SetInputs(input_args);
// Yield inputs include all user outputs, those require output gradients come first, so Yield Op can use their shapes
// to infer Op output shapes.
std::vector<std::string> user_output_names_require_grad;
std::vector<std::string> user_output_names_no_grad;
split_graphs_info_.backward_output_grad_names.clear();
for (const auto& name : split_graphs_info_.user_output_names) {
std::string grad_name = name + "_grad";
if (non_backward_user_output_grad_names.find(grad_name) == non_backward_user_output_grad_names.end()) {
user_output_names_require_grad.emplace_back(name);
split_graphs_info_.backward_output_grad_names.emplace_back(grad_name);
} else {
user_output_names_no_grad.emplace_back(name);
}
}
std::vector<const NodeArg*> output_args;
for (auto& output_name : split_graphs_info_.user_output_names) {
output_args.emplace_back(graph.GetNodeArg(output_name));
// Reorder the user outputs.
split_graphs_info_.user_output_names.clear();
for (const auto& name : user_output_names_require_grad) {
split_graphs_info_.user_output_names.emplace_back(name);
}
for (const auto& name : user_output_names_no_grad) {
split_graphs_info_.user_output_names.emplace_back(name);
}
std::vector<NodeArg*> yield_input_node_args;
std::vector<NodeArg*> yield_output_node_args;
for (const auto& name : split_graphs_info_.user_output_names) {
yield_input_node_args.emplace_back(gradient_graph.GetNodeArg(name));
}
for (const auto& name : split_graphs_info_.backward_output_grad_names) {
yield_output_node_args.emplace_back(gradient_graph.GetNodeArg(name));
}
gradient_graph.AddNode("YieldOp", "YieldOp", "Yield Op", yield_input_node_args, yield_output_node_args, {}, kMSDomain);
}
void ModuleGradientGraphBuilder::ReorderOutputs() {
// Adjust gradient graph outputs by the following order:
// 1. user input grads if required, with same order of user inputs,
// 2. trainable initailizer grads, with same order of trainable initializers.
Graph& gradient_graph = gradient_model_->MainGraph();
const std::vector<const NodeArg*>& gradient_graph_outputs = gradient_graph.GetOutputs();
std::unordered_map<std::string, const NodeArg*> gradient_output_arg_map;
for (auto& node_arg : gradient_graph_outputs) {
gradient_output_arg_map[node_arg->Name()] = node_arg;
}
std::unordered_set<std::string> user_input_require_grad_set(config_.input_names_require_grad.begin(),
config_.input_names_require_grad.end());
std::vector<const NodeArg*> new_output_args;
split_graphs_info_.user_input_grad_names.clear();
for (const auto& input_name : split_graphs_info_.user_input_names) {
if (user_input_require_grad_set.find(input_name) != user_input_require_grad_set.end()) {
std::string input_gradient_name = input_name + "_grad";
ORT_ENFORCE(gradient_output_arg_map.find(input_gradient_name) != gradient_output_arg_map.end(),
"Required user input grad is not found on gradient graph.");
split_graphs_info_.user_input_grad_names[input_name] = input_gradient_name;
new_output_args.emplace_back(gradient_output_arg_map[input_gradient_name]);
}
}
// Add initializer gradients to graph outputs.
split_graphs_info_.initializer_grad_names_to_train.clear();
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
std::string initializer_gradient_name = initializer_name + "_grad";
if (output_names.find(initializer_gradient_name) != output_names.end()) {
split_graphs_info_.initializer_grad_names_to_train.emplace_back(initializer_gradient_name);
output_args.emplace_back(graph.GetNodeArg(initializer_gradient_name));
}
ORT_ENFORCE(gradient_output_arg_map.find(initializer_gradient_name) != gradient_output_arg_map.end(),
"Trainable initializer grad is not found on gradient graph.");
split_graphs_info_.initializer_grad_names_to_train.emplace_back(initializer_gradient_name);
new_output_args.emplace_back(gradient_output_arg_map[initializer_gradient_name]);
}
// Add input gradients to graph outputs if it's required.
for (const auto& input_name : config_.input_names_require_grad) {
std::string input_gradient_name = input_name + "_grad";
if (output_names.find(input_gradient_name) != output_names.end()) {
output_args.emplace_back(graph.GetNodeArg(input_gradient_name));
}
}
graph.SetOutputs(output_args);
graph.Resolve();
// Run the transformers again mainly for backward part, e.g., constant fold from those Shape nodes in backward graph.
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i), *logger_));
}
// Create two copies of gradient model for forward and backward models respectively.
auto gradient_model_proto = model_copied->ToProto();
ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, forward_model_, nullptr, *logger_));
ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, backward_model_, nullptr, *logger_));
// Split the graph in the copies of gradient model.
ORT_RETURN_IF_ERROR(Split());
return Status::OK();
}
std::string SerializeModel(const std::shared_ptr<onnxruntime::Model>& model, const std::string& tag) {
std::string model_str;
if (!model->ToProto().SerializeToString(&model_str)) {
ORT_THROW("Fail to serialize", tag, "model to string.");
}
return model_str;
}
std::string ModuleGradientGraphBuilder::GetForwardModel() const { return SerializeModel(forward_model_, "forward"); }
std::string ModuleGradientGraphBuilder::GetBackwardModel() const { return SerializeModel(backward_model_, "backward"); }
Status ModuleGradientGraphBuilder::Split() {
// Get forward model, also collect some information for backward model generation.
Graph& forward_graph = forward_model_->MainGraph();
GraphViewer forward_graph_viewer(forward_graph);
const auto& forward_node_topology_list = forward_graph_viewer.GetNodesInTopologicalOrder();
std::vector<Node*> forward_nodes_to_remove;
std::unordered_set<std::string> forward_input_names;
std::unordered_set<std::string> forward_output_names;
std::unordered_set<std::string> backward_input_names;
std::unordered_set<std::string> backward_output_names;
for (auto node_index : forward_node_topology_list) {
auto& node = *forward_graph.GetNode(node_index);
// Currently we are using node description to distinguish the forward and backward nodes.
if (node.Description() == "Backward pass") {
forward_nodes_to_remove.emplace_back(&node);
GetInputAndOutputNames(node, backward_input_names, backward_output_names);
} else {
GetInputAndOutputNames(node, forward_input_names, forward_output_names);
}
}
std::unordered_set<std::string> intermediate_arg_names;
for (const auto& forward_output_name : forward_output_names) {
if (backward_input_names.find(forward_output_name) != backward_input_names.end()) {
intermediate_arg_names.insert(forward_output_name);
}
}
RemoveNodes(forward_graph, forward_nodes_to_remove);
FilterInitializers(forward_graph, forward_input_names);
// All user inputs should be also part of the forward graph inputs.
std::vector<const NodeArg*> forward_input_args;
for (const auto& input_name : split_graphs_info_.user_input_names) {
forward_input_args.emplace_back(forward_graph.GetNodeArg(input_name));
}
// Add initializers to forward graph inputs.
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
forward_input_args.emplace_back(forward_graph.GetNodeArg(initializer_name));
}
forward_graph.SetInputs(forward_input_args);
// All user outputs should be also part of the forward graph outputs.
std::vector<const NodeArg*> forward_output_args;
for (const auto& output_name : split_graphs_info_.user_output_names) {
forward_output_args.emplace_back(forward_graph.GetNodeArg(output_name));
}
// Add intermediate args to forward graph outputs.
split_graphs_info_.intermediate_tensor_names.clear();
for (const auto& intermediate_arg_name : intermediate_arg_names) {
// Ignore the user outputs.
if (std::find(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end(),
intermediate_arg_name) == split_graphs_info_.user_output_names.end()) {
split_graphs_info_.intermediate_tensor_names.emplace_back(intermediate_arg_name);
forward_output_args.emplace_back(forward_graph.GetNodeArg(intermediate_arg_name));
}
}
forward_graph.SetOutputs(forward_output_args);
forward_graph.Resolve();
// Get backward graph.
Graph& backward_graph = backward_model_->MainGraph();
GraphViewer backward_graph_viewer(backward_graph);
const auto& backward_node_topology_list = backward_graph_viewer.GetNodesInTopologicalOrder();
std::vector<Node*> backward_nodes_to_remove;
for (auto node_index : backward_node_topology_list) {
auto& node = *backward_graph.GetNode(node_index);
if (node.Description() != "Backward pass") {
backward_nodes_to_remove.emplace_back(&node);
}
}
RemoveNodes(backward_graph, backward_nodes_to_remove);
FilterInitializers(backward_graph, backward_input_names);
// User inputs to backward graph inputs.
split_graphs_info_.backward_user_input_names.clear();
std::vector<const NodeArg*> backward_input_args;
for (const auto& input_name : split_graphs_info_.user_input_names) {
// Only takes those in the backward inputs.
if (backward_input_names.find(input_name) != backward_input_names.end()) {
split_graphs_info_.backward_user_input_names.emplace_back(input_name);
backward_input_args.emplace_back(backward_graph.GetNodeArg(input_name));
}
}
// Add initializer args to backward graph inputs if any node uses them.
split_graphs_info_.backward_intializer_names_as_input.clear();
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
// Some initializers will be inputs for backward graph.
if (backward_input_names.find(initializer_name) != backward_input_names.end()) {
split_graphs_info_.backward_intializer_names_as_input.emplace_back(initializer_name);
backward_input_args.emplace_back(backward_graph.GetNodeArg(initializer_name));
backward_graph.RemoveInitializedTensor(initializer_name);
}
}
// Add intermediate args to backward graph inputs.
for (const auto& intermediate_arg_name : split_graphs_info_.intermediate_tensor_names) {
NodeArg* intermediate_node_arg = backward_graph.GetNodeArg(intermediate_arg_name);
intermediate_node_arg->UpdateTypeAndShape(*forward_graph.GetNodeArg(intermediate_arg_name), true, true, *logger_);
backward_input_args.emplace_back(intermediate_node_arg);
}
// Grad of user outputs to backward graph inputs.
for (const auto& output_grad_name : split_graphs_info_.backward_output_grad_names) {
backward_input_args.emplace_back(backward_graph.GetNodeArg(output_grad_name));
}
backward_graph.SetInputs(backward_input_args);
// Exclude user outputs from the backward graph.
const std::vector<const NodeArg*>& backward_graph_outputs = backward_graph.GetOutputs();
std::vector<const NodeArg*> backward_output_args;
for (auto& node_arg : backward_graph_outputs) {
if (backward_output_names.find(node_arg->Name()) != backward_output_names.end()) {
backward_output_args.emplace_back(node_arg);
}
}
backward_graph.SetOutputs(backward_output_args);
backward_graph.Resolve();
return Status::OK();
gradient_graph.SetOutputs(new_output_args);
}
} // namespace training

View file

@ -6,6 +6,9 @@
#include <string>
#include <unordered_set>
#include "core/common/status.h"
#include "core/graph/model.h"
namespace onnxruntime {
namespace training {
@ -20,47 +23,77 @@ struct ModuleGradientGraphBuilderConfiguration {
// Gradient graph configuration.
bool use_invertible_layernorm_grad = false;
bool set_gradients_as_graph_outputs = false;
// TODO: add GraphTransformerConfiguration
// TODO: add mixed precision config
// TODO: do we need to support graph with loss?
};
/**
* The information of split graphs for frontend.
*/
struct SplitGraphsInfo {
// The user inputs.
std::vector<std::string> user_input_names{};
// Map from user input names to corresponding user input grad names for those user inputs that require grad.
std::unordered_map<std::string, std::string> user_input_grad_names{};
// Trainable initializers.
std::vector<std::string> initializer_names_to_train{};
// Trainable initializer grad names, ordered according to initializer_names_to_train.
std::vector<std::string> initializer_grad_names_to_train{};
// The user outputs.
std::vector<std::string> user_output_names{};
std::vector<std::string> backward_user_input_names{};
std::vector<std::string> backward_intializer_names_as_input{};
std::vector<std::string> intermediate_tensor_names{};
std::vector<std::string> user_output_grad_names{};
// The user output grad names that are actual required by the backward graph.
std::vector<std::string> backward_output_grad_names{};
};
class ModuleGradientGraphBuilder {
public:
/**
* Initialize the builder. It saves the initial model and the configuration.
* It also removes the trainable initializers from initial model and move them to graph inputs.
* @param model_istream The initial model as input stream.
* @param config The configuration to control the builder.
* @return The status of the initialization.
*/
Status Initialize(std::istream& model_istream, const ModuleGradientGraphBuilderConfiguration& config);
Status BuildAndSplit(const std::vector<std::vector<int64_t>>& input_shapes);
std::string GetForwardModel() const;
std::string GetBackwardModel() const;
/**
* Build the gradient graph and split it to forward and backward graphs.
* @param input_shapes_ptr The pointer to vector of concrete shapes of the user inputs.
* @return The status of the gradient graph building and forward/backward graphs splitting.
*/
Status Build(const std::vector<std::vector<int64_t>>* input_shapes_ptr = nullptr);
/**
* Get gradient model.
* @return The gradient model serialized to string.
*/
std::string GetGradientModel() const;
/**
* Get the split graphs information.
* @return The split graphs information.
*/
SplitGraphsInfo GetSplitGraphsInfo() const { return split_graphs_info_; }
private:
Status Split();
// Set concrete shapes for graph inputs.
void SetConcreteInputShapes(const std::vector<std::vector<int64_t>>& input_shapes);
// Build gradient graph.
Status BuildGradientGraph();
// Add Yield Op.
void AddYieldOp();
// Reorder gradient graph outputs.
void ReorderOutputs();
std::shared_ptr<onnxruntime::Model> model_;
std::shared_ptr<onnxruntime::Model> forward_model_;
std::shared_ptr<onnxruntime::Model> backward_model_;
std::shared_ptr<onnxruntime::Model> gradient_model_;
SplitGraphsInfo split_graphs_info_;
ModuleGradientGraphBuilderConfiguration config_;
const logging::Logger* logger_ = &logging::LoggingManager::DefaultLogger(); // use default logger for now.
const logging::Logger* logger_ = &logging::LoggingManager::DefaultLogger(); // use default logger for now.
};
} // namespace training

View file

@ -2208,6 +2208,33 @@ Return true if all elements are true and false otherwise.
propagateShapeFromInputToOutput(ctx, i + 1, i);
}
});
ONNX_CONTRIB_OPERATOR_SCHEMA(YieldOp)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
.SetDoc("Yield Op.")
.Input(0, "outputs", "Module outputs to be returned to pytorch.", "T", OpSchema::Variadic,
/*is_homogeneous*/ false,
/*min_arity*/ 1)
.Output(0, "outputs_grad", "Gradient of outputs returned from pytorch.", "T", OpSchema::Variadic,
/*is_homogeneous*/ false,
/*min_arity*/ 1)
.TypeConstraint("T", OpSchema::all_tensor_types(), "Allow inputs and outputs to be any kind of tensor.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Assume the outputs and gradients are one-to-one matching
// TODO: The contrain is relaxed for now
// ORT_ENFORCE(ctx.getNumInputs() == ctx.getNumOutputs(), "Yield op doesn't have the same number of inputs and output");
for (size_t i = 0; i < ctx.getNumOutputs(); ++i) {
propagateElemTypeFromInputToOutput(ctx, i, i);
auto typeProto = ctx.getInputType(i);
if (!hasShape(*typeProto)) {
continue;
}
propagateShapeFromInputToOutput(ctx, i, i);
}
});
}
} // namespace training
} // namespace onnxruntime

View file

@ -475,47 +475,45 @@ void addObjectMethodsForTraining(py::module& m) {
});
py::class_<ModuleGradientGraphBuilderConfiguration> module_gradient_graph_builder_config(
m, "ModuleGradientGraphBuilderConfiguration", R"pbdoc(Configuration information for module gradient graph builder.)pbdoc");
m, "ModuleGradientGraphBuilderConfiguration",
R"pbdoc(Configuration information for module gradient graph builder.)pbdoc");
module_gradient_graph_builder_config.def(py::init())
.def_readwrite("initializer_names_to_train", &ModuleGradientGraphBuilderConfiguration::initializer_names_to_train)
.def_readwrite("input_names_require_grad", &ModuleGradientGraphBuilderConfiguration::input_names_require_grad)
.def_readwrite("use_invertible_layernorm_grad", &ModuleGradientGraphBuilderConfiguration::use_invertible_layernorm_grad)
.def_readwrite("set_gradients_as_graph_outputs", &ModuleGradientGraphBuilderConfiguration::set_gradients_as_graph_outputs);
.def_readwrite("use_invertible_layernorm_grad",
&ModuleGradientGraphBuilderConfiguration::use_invertible_layernorm_grad);
py::class_<SplitGraphsInfo> split_graphs_info(
m, "SplitGraphsInfo", R"pbdoc(The information of split graphs for frontend.)pbdoc");
py::class_<SplitGraphsInfo> split_graphs_info(m, "SplitGraphsInfo",
R"pbdoc(The information of split graphs for frontend.)pbdoc");
split_graphs_info.def(py::init())
.def_readwrite("user_input_names", &SplitGraphsInfo::user_input_names)
.def_readwrite("user_input_grad_names", &SplitGraphsInfo::user_input_grad_names)
.def_readwrite("initializer_names_to_train", &SplitGraphsInfo::initializer_names_to_train)
.def_readwrite("initializer_grad_names_to_train", &SplitGraphsInfo::initializer_grad_names_to_train)
.def_readwrite("user_output_names", &SplitGraphsInfo::user_output_names)
.def_readwrite("backward_user_input_names", &SplitGraphsInfo::backward_user_input_names)
.def_readwrite("backward_intializer_names_as_input", &SplitGraphsInfo::backward_intializer_names_as_input)
.def_readwrite("intermediate_tensor_names", &SplitGraphsInfo::intermediate_tensor_names)
.def_readwrite("user_output_grad_names", &SplitGraphsInfo::user_output_grad_names)
.def_readwrite("backward_output_grad_names", &SplitGraphsInfo::backward_output_grad_names);
py::class_<ModuleGradientGraphBuilder> module_gradient_graph_builder(m, "ModuleGradientGraphBuilder");
module_gradient_graph_builder
.def(py::init([]() {
return onnxruntime::make_unique<ModuleGradientGraphBuilder>();
}))
.def("initialize", [](ModuleGradientGraphBuilder* module_gradient_graph_builder,
const py::bytes& serialized_model,
const ModuleGradientGraphBuilderConfiguration& config) {
std::istringstream buffer(serialized_model);
ORT_THROW_IF_ERROR(module_gradient_graph_builder->Initialize(buffer, config));
})
.def("build_and_split", [](ModuleGradientGraphBuilder* module_gradient_graph_builder,
const std::vector<std::vector<int64_t>>& input_shapes) {
ORT_THROW_IF_ERROR(module_gradient_graph_builder->BuildAndSplit(input_shapes));
})
.def("get_forward_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
return py::bytes(module_gradient_graph_builder->GetForwardModel());
})
.def("get_backward_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
return py::bytes(module_gradient_graph_builder->GetBackwardModel());
})
module_gradient_graph_builder.def(py::init([]() { return onnxruntime::make_unique<ModuleGradientGraphBuilder>(); }))
.def("initialize",
[](ModuleGradientGraphBuilder* module_gradient_graph_builder, const py::bytes& serialized_model,
const ModuleGradientGraphBuilderConfiguration& config) {
std::istringstream buffer(serialized_model);
ORT_THROW_IF_ERROR(module_gradient_graph_builder->Initialize(buffer, config));
})
.def("build",
[](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
ORT_THROW_IF_ERROR(module_gradient_graph_builder->Build());
})
.def("build",
[](ModuleGradientGraphBuilder* module_gradient_graph_builder,
const std::vector<std::vector<int64_t>>& input_shapes) {
ORT_THROW_IF_ERROR(module_gradient_graph_builder->Build(&input_shapes));
})
.def("get_gradient_model",
[](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
return py::bytes(module_gradient_graph_builder->GetGradientModel());
})
.def("get_split_graphs_info", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
return module_gradient_graph_builder->GetSplitGraphsInfo();
});

View file

@ -65,9 +65,7 @@ def _create_iobinding(io_binding, inputs, model, device):
inputs[idx].data_ptr())
for value_info in model.graph.output:
io_binding.bind_output(value_info.name, device.type,
device_id=_get_device_index(device))
io_binding.bind_output(value_info.name, device.type, device_id=_get_device_index(device))
def _onnx_value_info_to_buffer_tensor(value_info, device):
'''Create a torch zeroed tensor with the same shape and type of `value_info`'''
@ -113,10 +111,9 @@ def _extract_input_information(module, *inputs, **kwargs):
class ORTModule(torch.nn.Module):
def __init__(self, module):
assert isinstance(module, torch.nn.Module), "'module' mst be a torch.nn.Module"
assert isinstance(module, torch.nn.Module), "'module' must be a torch.nn.Module"
super(ORTModule, self).__init__()
self._export_again = False
# TODO: This is incorrect when different layers may be in different devices
self._device = next(module.parameters()).device
self._device_changed = False
@ -124,21 +121,18 @@ class ORTModule(torch.nn.Module):
# User module is wrapped to use its initializers and save computed gradients
self._original_module = module
self._onnx_training = None
self._is_training = True
# Related to training graph split/shape inference
self._current_input_shape = None
self._module_gradient_graph_builder = None
self._input_names_require_grad = None
# Forward pass
self._onnx_forward = None
self._forward_session = None
self._forward_io_binding = None
# Backward pass
self._onnx_backward = None
self._backward_session = None
self._backward_io_binding = None
# Gradient model
self._onnx_gradient = None
self._gradient_session = None
self._gradient_io_binding = None
self._run_options = None
# Log level
self._loglevel = getattr(logging, 'WARNING')
@ -148,21 +142,21 @@ class ORTModule(torch.nn.Module):
self._save_onnx_prefix = ''
def _initialize_module_gradient_graph_builder(self):
# TODO: PyTorch exporter bug: changes the initializer order
initializer_names = [p[0] for p in self._original_module.named_parameters()]
# Build full training graph and split in forward/backward
# Build full training graph
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
grad_builder_config.initializer_names_to_train = initializer_names
grad_builder_config.input_names_require_grad = self._input_names_require_grad
self._module_gradient_graph_builder = C.ModuleGradientGraphBuilder()
self._module_gradient_graph_builder.initialize(self._onnx_training.SerializeToString(), grad_builder_config)
def _build_training_graph(self, *inputs, **kwargs):
def _get_forward_graph_and_init_gradient_graph_builder(self, *inputs, **kwargs):
input_names, dynamic_axes, self._input_names_require_grad = \
_extract_input_information(self._original_module, *inputs, **kwargs)
self._onnx_training = self._get_forward_graph(input_names, dynamic_axes, *inputs, **kwargs)
if self._save_onnx:
onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx')
@ -179,27 +173,29 @@ class ORTModule(torch.nn.Module):
providers = ["CPUExecutionProvider"]
provider_options = [{}]
self._forward_session = onnxruntime.InferenceSession(
self._onnx_forward.SerializeToString(), providers=providers, provider_options=provider_options)
self._backward_session = onnxruntime.InferenceSession(
self._onnx_backward.SerializeToString(), providers=providers, provider_options=provider_options)
session_options = onnxruntime.SessionOptions()
session_options.enable_mem_pattern = False
session_options.use_deterministic_compute = False
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
session_options.log_severity_level = 1
self._gradient_session = onnxruntime.InferenceSession(
self._onnx_gradient.SerializeToString(), session_options, providers=providers, provider_options=provider_options)
# Use this global run_options for now
self._run_options = C.RunOptions()
# IO binding
# TODO: we should try to reuse the output buffers as some of the output tensors are same sizes, expecially the backward graph outputs.
self._forward_io_binding = self._forward_session.io_binding()
self._backward_io_binding = self._backward_session.io_binding()
self._gradient_io_binding = self._gradient_session.io_binding()
def _split_training_graph(self, *inputs, **kwargs):
# Perform shape inference and re-split forward/backward graph for batches with different shapes
self._module_gradient_graph_builder.build_and_split(self._current_input_shape)
self._onnx_forward = onnx.load_model_from_string(self._module_gradient_graph_builder.get_forward_model())
self._onnx_backward = onnx.load_model_from_string(self._module_gradient_graph_builder.get_backward_model())
def _build_training_graph(self, *inputs, **kwargs):
self._module_gradient_graph_builder.build(self._current_input_shape)
self._onnx_gradient = onnx.load_model_from_string(self._module_gradient_graph_builder.get_gradient_model())
self._onnx_graphs_info = self._module_gradient_graph_builder.get_split_graphs_info()
self._create_training_session()
if self._save_onnx:
onnx.save(self._onnx_forward, self._save_onnx_prefix + '_forward.onnx')
onnx.save(self._onnx_backward, self._save_onnx_prefix + '_backward.onnx')
onnx.save(self._onnx_gradient, self._save_onnx_prefix + '_gradient.onnx')
def cpu(self: T) -> T:
'''Thin layer to capture device for ORTModule IO bindings'''
@ -248,17 +244,28 @@ class ORTModule(torch.nn.Module):
self._device = torch.device(device_str)
return super(ORTModule, self).to(*args, **kwargs)
def eval(self: T) -> T:
self._is_training = False
self._original_module.eval()
def train(self: T, mode: bool = True) -> T:
self._is_training = mode
self._original_module.train(mode)
def forward(self, *inputs, **kwargs):
'''Forward pass starts here and continues at `_ORTModuleFunction.forward`
ONNX model is exported the first time this method is executed.
Next, a full training graph is splitted in forward and backward graph which are used
to instantiate ONNX Runtime InferenceSession`s
Next, we build a full training graph with module_gradient_graph_builder.
Finally, we instantiate the ONNX Runtime InferenceSession.
'''
# TODO: using pytorch for evaluation for now. We will use ORT for evaluation latter.
if not self._is_training:
return self._original_module(*inputs, **kwargs)
# Exporting module to ONNX for the first time
if not self._onnx_training:
self._build_training_graph(*inputs, **kwargs)
self._get_forward_graph_and_init_gradient_graph_builder(*inputs, **kwargs)
_, _, input_names_require_grad = _extract_input_information(self._original_module, *inputs, **kwargs)
# If inputs requiring gradient change from one call to forward to the next, the module_gradient_graph_builder
@ -270,10 +277,13 @@ class ORTModule(torch.nn.Module):
new_input_shape = [list(input.size()) for input in inputs if input is not None]
if self._current_input_shape is None or self._current_input_shape != new_input_shape:
self._current_input_shape = new_input_shape
self._split_training_graph(*inputs, **kwargs)
elif self._device_changed:
self._build_training_graph()
self._create_training_session()
self._device_changed = False
# TODO: disabled for now, since it caused a bug in NVBert fp32 run
# When creating a new InferenceSession, there is a bug for destructing the original InferenceSession
# elif self._device_changed:
# self._create_training_session()
# self._device_changed = False
# Use a custom torch.autograd.Function to associate self.backward_graph as the
# gradient implementation for self.forward_graph.
@ -284,80 +294,57 @@ class ORTModule(torch.nn.Module):
TODO: **kwargs are not supported
Model outputs are returned to the user
The following tensors are stashed (in order) for backward pass
* (Partial) user input
* (Partial) Initializers
* Intermediate tensors
Module outputs are returned to the user
'''
# Use IO binding
_create_iobinding(self._forward_io_binding, inputs,
self._onnx_forward,
self._device)
_create_iobinding(self._gradient_io_binding, inputs, self._onnx_gradient, self._device)
# Run
self._forward_session.run_with_iobinding(self._forward_io_binding)
forward_outputs = self._forward_io_binding.get_outputs()
# Stash tensors needed by backward
forward_input_dict = self._convert_forward_input_list_to_dict(*inputs)
ctx_inputs = tuple(forward_input_dict[name] \
for name in self._onnx_graphs_info.backward_user_input_names)
ctx_initializers = tuple(forward_input_dict[name] \
for name in self._onnx_graphs_info.backward_intializer_names_as_input)
ctx_intermediates = tuple(_ort_output_to_torch_tensor(forward_output) \
for forward_output in forward_outputs[len(self._onnx_graphs_info.user_output_names):])
ctx.save_for_backward(*[*ctx_inputs, *ctx_initializers, *ctx_intermediates])
# Return model output
# Run and return module outputs.
user_outputs = tuple(_ort_output_to_torch_tensor(forward_output) \
for forward_output in forward_outputs[:len(self._onnx_graphs_info.user_output_names)])
for forward_output in self._gradient_session.run_forward(self._gradient_io_binding, self._run_options))
return user_outputs[0] if len(user_outputs) == 1 else user_outputs
@staticmethod
def backward(ctx, *grad_output):
'''Performs backward pass based on grad wrt output and internal state
Internal state is composed of:
* Tensor stashed (in a particular order) during forward:
* (partial) user input, (partial) initializers and intermediate tensors
TODO: Input gradient is hard-coded to torch.tensor([1.])
'''Performs backward pass based on grad wrt module output
'''
# Use IO binding
grad_output_dict = dict(zip(self._onnx_graphs_info.user_output_grad_names, grad_output))
backward_grad_output = tuple(grad_output_dict[name] for name in self._onnx_graphs_info.backward_output_grad_names)
_create_iobinding(self._backward_io_binding, [*ctx.saved_tensors, *backward_grad_output],
self._onnx_backward,
self._device)
# Push user output grads to ONNX backend.
backward_grad_output_ortvalue = []
for grad_output in grad_output[:len(self._onnx_graphs_info.backward_output_grad_names)]:
backward_grad_output_ortvalue.append(onnxruntime.OrtValue.ortvalue_from_data_ptr(list(grad_output.size()), _utils.dtype_torch_to_numpy(
grad_output.dtype), grad_output.device.type, _get_device_index(grad_output.device), grad_output.data_ptr()))
# Run
self._backward_session.run_with_iobinding(self._backward_io_binding)
backward_outputs = self._backward_io_binding.get_outputs()
# Run and get results
self._gradient_session.run_backward(backward_grad_output_ortvalue)
backward_outputs = self._gradient_io_binding.get_outputs()
# Return input and initializer gradients
num_initializers = len(self._onnx_graphs_info.initializer_grad_names_to_train)
num_user_input_grads = len(self._input_names_require_grad)
results = []
for input_name in self._onnx_graphs_info.user_input_names:
try:
# Append to the results the backward output for each input that required grad
results.append(_ort_output_to_torch_tensor(
backward_outputs[num_initializers + self._input_names_require_grad.index(input_name)]))
backward_outputs[self._input_names_require_grad.index(input_name)]))
except ValueError:
# input_name is not found in the self._input_names_require_grad list
# Append None to results for each input that did not require grad
results.append(None)
# Append backward ouput for all trained initializers
results += [_ort_output_to_torch_tensor(backward_output)
for backward_output in backward_outputs[:num_initializers]]
# Append gradients of initializer to results
results += [_ort_output_to_torch_tensor(backward_output)
for backward_output in backward_outputs[num_user_input_grads:]]
return tuple(results)
proc_inputs = [data for data in inputs if data is not None]
return _ORTModuleFunction.apply(*self._convert_forward_input_to_list(*proc_inputs, **kwargs))
return _ORTModuleFunction.apply(*self._convert_gradient_graph_input_to_list(*proc_inputs, **kwargs))
@_utils.timeit(enabled=__TEMP_ENABLE_METHOD_TIMING__)
def _convert_forward_input_to_list(self, *inputs, **kwargs):
def _convert_gradient_graph_input_to_list(self, *inputs, **kwargs):
'''Creates forward `*inputs` list from user input and PyTorch initializers
TODO: **kwargs is not supported
@ -379,63 +366,6 @@ class ORTModule(torch.nn.Module):
return result
@_utils.timeit(enabled=__TEMP_ENABLE_METHOD_TIMING__)
def _convert_forward_input_list_to_dict(self, *inputs):
'''Convert forward `*inputs` list to dict
TODO: Input gradient is being ignored for MVP
'''
# Dictionary containing both inputs and initializers
forward_input_names = [*self._onnx_graphs_info.user_input_names,
*self._onnx_graphs_info.initializer_names_to_train]
return dict(zip(forward_input_names, inputs))
@_utils.timeit(enabled=__TEMP_ENABLE_METHOD_TIMING__)
def _convert_backward_input_list_to_dict(self, *inputs):
'''Convert backward `*inputs` list to dict
ONNX Runtime backward requires dict as input, which is composed of:
* User input
Although not necessary, all user inputs are used for simplicity
* (Partial) Initializers
init_begin = len(user_input)
init_count = len(Pre-computed list of initializer)
* Intermediate tensors
* Gradient wrt outputs
'''
# Dictionary containing both inputs and initializers
result = {}
backward_user_input = self._onnx_graphs_info.backward_user_input_names
backward_intializer = self._onnx_graphs_info.backward_intializer_names_as_input
intermediate = self._onnx_graphs_info.intermediate_tensor_names
backward_output_grad_names = self._onnx_graphs_info.backward_output_grad_names
# Extract info about stashed input and grad output
# Inputs
inputs_pos = 0
for idx, name in enumerate(backward_user_input):
result.update({ name : inputs[idx]})
inputs_pos += 1
# Initializers
for idx, name in enumerate(backward_intializer, inputs_pos):
result.update({name: inputs[idx]})
inputs_pos += 1
# Intermediate
for idx, name in enumerate(intermediate, inputs_pos):
result.update({name: inputs[idx]})
inputs_pos += 1
# Grad outputs
for idx, name in enumerate(backward_output_grad_names, inputs_pos):
result.update({name: inputs[idx]})
inputs_pos += 1
return result
def _get_forward_graph(self, input_names, dynamic_axes, *inputs, **kwargs):
'''Exports PyTorch `module` to ONNX with training flag, using `*inputs` as input
@ -467,5 +397,9 @@ class ORTModule(torch.nn.Module):
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING,
dynamic_axes=dynamic_axes)
# TODO: this step might not be needed when we use the torch external allocator
# clear cache after model export
torch.cuda.empty_cache()
return onnx.load_model_from_string(f.getvalue())

View file

@ -26,10 +26,29 @@ def parse_arguments():
def run_ortmodule_api_tests(cwd, log):
log.debug('Running: ORTModule-API tests')
command = [sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_ortmodule_api.py']
class TestNameCollecterPlugin:
def __init__(self):
self.collected = set()
run_subprocess(command, cwd=cwd, log=log).check_returncode()
def pytest_collection_modifyitems(self, items):
for item in items:
print('item.name: ', item.name)
self.collected.add(item.name)
import os
import pytest
plugin = TestNameCollecterPlugin()
print(cwd)
test_script_filename = os.path.join("orttraining_test_ortmodule_api.py")
pytest.main(['--collect-only', test_script_filename], plugins=[plugin])
# TODO: FIX THIS!
# Running tests in a loop one after another,
# because ORTModule doesn't support multiple run call at the same time
for test_name in plugin.collected:
run_subprocess([
sys.executable, '-m', 'pytest',
'orttraining_test_ortmodule_api.py', '-sv', '-k', test_name], cwd=cwd).check_returncode()
def run_ortmodule_poc_net(cwd, log, no_cuda, data_dir):
log.debug('Running: ORTModule POCNet for MNIST with --no-cuda arg {}.'.format(no_cuda))

View file

@ -269,20 +269,22 @@ def test_model_to_device_and_back_to_original(original_device, to_device):
for _, parameter_value in model.named_parameters():
assert parameter_value.device.type == original_device
def test_model_with_different_devices_same_session():
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
model = ORTModule(model)
# TODO: Fix the following Unit Test
# @pytest.mark.skip(reason="TODO: ORTModule.to(device) is disabled for now")
# def test_model_with_different_devices_same_session():
# N, D_in, H, D_out = 64, 784, 500, 10
# model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
# model = ORTModule(model)
for i in range(5):
if i % 2 == 0:
device = 'cpu'
else:
device = 'cuda'
# for i in range(5):
# if i % 2 == 0:
# device = 'cpu'
# else:
# device = 'cuda'
model.to(device)
x = torch.randn(N, D_in, device=device)
y = model(x)
# model.to(device)
# x = torch.randn(N, D_in, device=device)
# y = model(x)
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
def test_input_requires_grad_saved(device):
@ -305,16 +307,18 @@ def test_input_requires_grad_backward_creates_input_grad(device):
s.backward()
assert x.grad is not None
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder(device):
N, D_in, H, D_out = 32, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
model = ORTModule(model)
x = torch.randn(N, D_in, device=device, requires_grad=True)
model(x.data)
module_gradient_graph_builder = model._module_gradient_graph_builder
model(x)
assert module_gradient_graph_builder != model._module_gradient_graph_builder
# TODO: Fix the following Unit Test
# @pytest.mark.parametrize("device", ['cuda', 'cpu'])
# @pytest.mark.skip(reason="ORTModule doesn't support multiple consecutive forward calls.")
# def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder(device):
# N, D_in, H, D_out = 32, 784, 500, 10
# model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
# model = ORTModule(model)
# x = torch.randn(N, D_in, device=device, requires_grad=True)
# model(x.data)
# module_gradient_graph_builder = model._module_gradient_graph_builder
# model(x)
# assert module_gradient_graph_builder != model._module_gradient_graph_builder
def test_gpu_reserved_memory_with_torch_no_grad():
device = 'cuda'
@ -328,20 +332,21 @@ def test_gpu_reserved_memory_with_torch_no_grad():
model_with_no_grad = ORTModule(model_with_no_grad)
mem_reserved_before_export = torch.cuda.memory_reserved(device)
model_with_no_grad(x, y, None, None, None, None, z)
mem_reserved_after_export_with_torch_no_grad = torch.cuda.memory_reserved(device)
mem_reserved_after_export = torch.cuda.memory_reserved(device)
assert mem_reserved_before_export == mem_reserved_after_export
del model_with_no_grad
torch.cuda.empty_cache()
mem_reserved_after_cache_empty = torch.cuda.memory_reserved(device)
assert mem_reserved_before_export == mem_reserved_after_cache_empty
# Create another model and get the memory_reserved when torch.no_grad has not been enabled
# after export
# Create another model and get the memory_reserved when torch.no_grad and torch.cuda.empty_cache
# has not been enabled after export
model_without_no_grad = _get_bert_for_sequence_classification_model(device)
model_without_no_grad = ORTModule(model_without_no_grad)
mem_reserved_after_export_without_torch_no_grad = 0
with patch('torch.no_grad'):
model_without_no_grad(x, y, None, None, None, None, z)
mem_reserved_after_export_without_torch_no_grad = torch.cuda.memory_reserved(device)
assert mem_reserved_after_export_with_torch_no_grad < mem_reserved_after_export_without_torch_no_grad
assert mem_reserved_before_export < mem_reserved_after_export_with_torch_no_grad
with patch('torch.no_grad'), patch('torch.cuda.empty_cache'):
model_without_no_grad(x, y, None, None, None, None, z)
mem_reserved_after_export_without_torch_no_grad = torch.cuda.memory_reserved(device)
assert mem_reserved_after_export < mem_reserved_after_export_without_torch_no_grad

View file

@ -7,11 +7,7 @@ namespace onnxruntime {
namespace contrib {
void OrtEventPool::CheckRange(const int64_t id) const {
ORT_ENFORCE(
id >= 0 && id < MaxNumItems,
"Got id ", id,
". It should be in a range from 0 to ",
MaxNumItems, ".");
ORT_ENFORCE(id >= 0 && id < MaxNumItems, "Got id ", id, ". It should be in a range from 0 to ", MaxNumItems, ".");
}
void OrtEventPool::SignalEvent(int64_t id) {
@ -27,6 +23,19 @@ bool OrtEventPool::QueryEvent(int64_t id) const {
return pool_[id].signaled.load();
}
void OrtEventPool::WaitAndResetEvent(int64_t id) {
CheckRange(id);
std::unique_lock<std::mutex> lock(pool_[id].mutex);
pool_[id].cv.wait(lock, [this, id] { return pool_[id].signaled.load(); });
pool_[id].signaled.store(false);
};
void OrtEventPool::ResetEvent(int64_t id) {
CheckRange(id);
std::unique_lock<std::mutex> lock(pool_[id].mutex);
pool_[id].signaled.store(false);
};
void OrtEventPool::WaitEvent(int64_t id) const {
CheckRange(id);
std::unique_lock<std::mutex> lock(pool_[id].mutex);

View file

@ -21,7 +21,9 @@ class OrtEventPool final {
}
void SignalEvent(int64_t id);
bool QueryEvent(int64_t id) const;
void WaitAndResetEvent(int64_t id);
void WaitEvent(int64_t id) const;
void ResetEvent(int64_t id);
void ResetAllEvents();
static size_t GetPoolSize() {

View file

@ -0,0 +1,47 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <queue>
#include <vector>
#include "core/common/common.h"
#include "core/framework/ml_value.h"
namespace onnxruntime {
namespace contrib {
class OrtMessageQueue final {
public:
static OrtMessageQueue& GetInstance() {
static OrtMessageQueue instance_;
return instance_;
}
void Push(const OrtValue& ort_value) { ort_values.emplace(ort_value); }
OrtValue Pop() {
OrtValue ort_value = ort_values.front();
ort_values.pop();
return ort_value;
}
void PopAll(std::vector<OrtValue>& results) {
while (!ort_values.empty()) {
OrtValue ort_value = ort_values.front();
ort_values.pop();
results.emplace_back(ort_value);
}
}
private:
OrtMessageQueue() = default;
~OrtMessageQueue() = default;
OrtMessageQueue(const OrtMessageQueue&) = delete;
OrtMessageQueue& operator=(const OrtMessageQueue&) = delete;
std::queue<OrtValue> ort_values;
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,52 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/training_ops/cpu/controlflow/yield.h"
#include "orttraining/training_ops/cpu/controlflow/event_pool.h"
#include "orttraining/training_ops/cpu/controlflow/message_queue.h"
#include "core/framework/op_kernel_context_internal.h"
namespace onnxruntime {
namespace contrib {
ONNX_OPERATOR_KERNEL_EX(
YieldOp,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.VariadicAlias(0, 0), // TODO: this is a hack to avoid allocating output buffer
YieldOp);
Status YieldOp::Compute(OpKernelContext* ctx) const {
auto* ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
for (int i_in = 0; i_in < ctx->InputCount(); ++i_in) {
onnxruntime::contrib::OrtMessageQueue::GetInstance().Push(*ctx_internal->GetInputMLValue(i_in));
}
// Reset background event before returning to main thread
const int64_t background_thread_event_id = 1;
onnxruntime::contrib::OrtEventPool::GetInstance().ResetEvent(background_thread_event_id);
// single event for InferenceSession::RunInBackgroundAndWaitForYield() that FW graph is done
const int64_t main_thread_event_id = 0;
OrtEventPool::GetInstance().SignalEvent(main_thread_event_id);
// wait for event from InferenceSession::ContinueRunInBackground() to continue the BW graph
OrtEventPool::GetInstance().WaitAndResetEvent(background_thread_event_id);
if (ctx_internal->GetTerminateFlag()) {
LOGS(ctx->Logger(), WARNING) << "Resumed executing backward subgraph, terminate_flag is set to true.";
} else {
// Get output grad from somewhere and prepare Op outputs.
for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) {
ctx_internal->SetOutputMLValue(i_out, OrtMessageQueue::GetInstance().Pop());
}
}
return Status::OK();
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
namespace onnxruntime {
namespace contrib {
class YieldOp final : public OpKernel {
public:
YieldOp(const OpKernelInfo& info) : OpKernel(info) {}
Status Compute(OpKernelContext* context) const override;
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -94,6 +94,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Recv)
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, RecordEvent);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WaitEvent);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, YieldOp);
Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
@ -181,7 +182,8 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, RecordEvent)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WaitEvent)>};
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WaitEvent)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, YieldOp)>};
for (auto& function_table_entry : function_table) {
ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry()));

View file

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/training_ops/cpu/controlflow/yield.h"
#include "core/providers/cuda/cuda_fwd.h"
namespace onnxruntime {
namespace cuda {
ONNX_OPERATOR_KERNEL_EX(
YieldOp,
kMSDomain,
1,
kCudaExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.VariadicAlias(0, 0), // TODO: this is a hack to avoid allocating output buffer
onnxruntime::contrib::YieldOp);
} // namespace cuda
} // namespace onnxruntime

View file

@ -164,6 +164,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Adas
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, RecordEvent);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, YieldOp);
#ifdef ORT_USE_NCCL
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllReduce);
@ -326,6 +327,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, RecordEvent)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, YieldOp)>,
#ifdef ORT_USE_NCCL
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllReduce)>,