mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-20 02:07:56 +00:00
OrtModule v0.21 (#6395)
* ortmodule v0.2 * use pt module for eval * get user outputs in yield op * pass output grads to yield output without copy * Disable mem_pattern for ORTModule * Avoid allocating output buffer for Yield op * Change to WaitAndReset to avoid overriding signal * remove unnecessory signal/wait at the end of bg thread * Return Session.Run result as a std::future * export model with torch.no_grad() * Handle bg thread's early return in Forward call * Removed duplicated Yield kernel * Silence "CUDA kernel missing log" * Add missing transforms, clear iobinding (#6532) * revert ortmodule.py to a working state first * Apply ortmodule.py change from dev branch * Rename to YieldOp Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> Co-authored-by: ashbhandare <ash.bhandare@gmail.com> Co-authored-by: Sherlock <baihan.huang@gmail.com>
This commit is contained in:
parent
bc0d04bf07
commit
eec602e48a
29 changed files with 761 additions and 516 deletions
|
|
@ -251,6 +251,7 @@ class OpKernelContext {
|
|||
const OrtValue* GetInputMLValue(int index) const;
|
||||
const OrtValue* GetImplicitInputMLValue(int index) const;
|
||||
OrtValue* GetOutputMLValue(int index);
|
||||
Status SetOutputMLValue(int index, const OrtValue& ort_value);
|
||||
|
||||
// Creates the OrtValue* based on the shape, if it does not exist
|
||||
// The parameter nnz is used only for sparse-tensors and indicates the
|
||||
|
|
|
|||
|
|
@ -45,6 +45,16 @@ OrtValue* IExecutionFrame::GetMutableNodeInputOrOutputMLValue(int index) {
|
|||
return const_cast<OrtValue*>(GetNodeInputOrOutputMLValue(index));
|
||||
}
|
||||
|
||||
Status IExecutionFrame::SetOutputMLValue(int index, const OrtValue& ort_value) {
|
||||
int ort_value_idx = GetNodeIdxToMLValueIdx(index);
|
||||
if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast<size_t>(ort_value_idx) >= all_values_size_) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx);
|
||||
}
|
||||
|
||||
all_values_[ort_value_idx] = ort_value;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TO DO: make it thread safe
|
||||
// This method is not thread safe!
|
||||
// Return S_OK and nullptr if index map to an value that is an unused optional input/output
|
||||
|
|
|
|||
|
|
@ -49,6 +49,9 @@ class IExecutionFrame {
|
|||
const OrtValue* GetNodeInputOrOutputMLValue(int index) const;
|
||||
OrtValue* GetMutableNodeInputOrOutputMLValue(int index);
|
||||
|
||||
// Override the index-th output with ort_value
|
||||
Status SetOutputMLValue(int index, const OrtValue& ort_value);
|
||||
|
||||
// TO DO: make it thread safe
|
||||
// This method is not thread safe!
|
||||
// Return S_OK and nullptr if index map to an value that is an unused optional input/output
|
||||
|
|
|
|||
|
|
@ -175,4 +175,13 @@ OrtValue* OpKernelContext::GetOutputMLValue(int index) {
|
|||
return execution_frame_->GetMutableNodeInputOrOutputMLValue(output_arg_index);
|
||||
}
|
||||
|
||||
Status OpKernelContext::SetOutputMLValue(int index, const OrtValue& ort_value) {
|
||||
if (index < 0 || index >= OutputCount()) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "Index out of range.");
|
||||
}
|
||||
|
||||
auto output_arg_index = GetOutputArgIndex(index);
|
||||
return execution_frame_->SetOutputMLValue(output_arg_index, ort_value);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -53,6 +53,10 @@ class OpKernelContextInternal : public OpKernelContext {
|
|||
return OpKernelContext::GetOutputMLValue(index);
|
||||
}
|
||||
|
||||
void SetOutputMLValue(int index, const OrtValue& ort_value) {
|
||||
ORT_ENFORCE(OpKernelContext::SetOutputMLValue(index, ort_value).IsOK());
|
||||
}
|
||||
|
||||
OrtValue* OutputMLValue(int index, const TensorShape& shape) {
|
||||
return OpKernelContext::OutputMLValue(index, shape);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@
|
|||
#include "core/optimizer/skip_layer_norm_fusion.h"
|
||||
#include "core/optimizer/slice_elimination.h"
|
||||
#include "core/optimizer/unsqueeze_elimination.h"
|
||||
#include "core/optimizer/matmul_transpose_fusion.h"
|
||||
#include "orttraining/core/optimizer/bias_dropout_fusion.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class IExecutionProvider;
|
||||
|
|
@ -144,13 +146,17 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
|
|||
|
||||
transformers.emplace_back(onnxruntime::make_unique<ConvActivationFusion>(cpu_cuda_acl_armnn_execution_providers));
|
||||
|
||||
std::unordered_set<std::string> cpu_cuda_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider};
|
||||
const std::unordered_set<std::string> cuda_execution_providers = {onnxruntime::kCudaExecutionProvider};
|
||||
const std::unordered_set<std::string> cpu_cuda_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider};
|
||||
transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(cpu_cuda_execution_providers));
|
||||
transformers.emplace_back(onnxruntime::make_unique<LayerNormFusion>(cpu_cuda_execution_providers));
|
||||
transformers.emplace_back(onnxruntime::make_unique<SimplifiedLayerNormFusion>(cpu_cuda_execution_providers));
|
||||
transformers.emplace_back(onnxruntime::make_unique<AttentionFusion>(cpu_cuda_execution_providers));
|
||||
transformers.emplace_back(onnxruntime::make_unique<EmbedLayerNormFusion>(cpu_cuda_execution_providers));
|
||||
|
||||
transformers.emplace_back(onnxruntime::make_unique<BiasDropoutFusion>(cuda_execution_providers));
|
||||
// TODO: This should be combined with MatMulScaleFusion and deprecate MatmulTransposeFusion
|
||||
transformers.emplace_back(onnxruntime::make_unique<MatmulTransposeFusion>(cpu_cuda_execution_providers));
|
||||
transformers.emplace_back(onnxruntime::make_unique<BiasGeluFusion>(cpu_cuda_execution_providers));
|
||||
transformers.emplace_back(onnxruntime::make_unique<BiasSoftmaxFusion>(cpu_cuda_execution_providers));
|
||||
transformers.emplace_back(onnxruntime::make_unique<SkipLayerNormFusion>(cpu_cuda_execution_providers));
|
||||
|
|
|
|||
|
|
@ -1906,7 +1906,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
|
|||
|
||||
// none of the provided registries has a CUDA kernel for this node
|
||||
if (cuda_kernel_def == nullptr) {
|
||||
LOGS_DEFAULT(WARNING) << "CUDA kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name();
|
||||
LOGS_DEFAULT(INFO) << "CUDA kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name();
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,8 @@
|
|||
#include "core/util/protobuf_parsing_utils.h"
|
||||
#include "core/util/thread_utils.h"
|
||||
|
||||
#include "orttraining/training_ops/cpu/controlflow/event_pool.h"
|
||||
#include "orttraining/training_ops/cpu/controlflow/message_queue.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::experimental;
|
||||
|
|
@ -372,6 +374,14 @@ InferenceSession::~InferenceSession() {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: find a better way to terminate the background thread
|
||||
// backward is not completed yet, set terminate_flag to True
|
||||
if (task_.bg_thread_future_.valid()) {
|
||||
*(task_.terminate_flag_) = true;
|
||||
Status s = ContinueRunInBackground({});
|
||||
ORT_UNUSED_PARAMETER(s);
|
||||
}
|
||||
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
if (session_activity_started_)
|
||||
TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity");
|
||||
|
|
@ -1709,6 +1719,60 @@ common::Status InferenceSession::Run(IOBinding& io_binding) {
|
|||
return Run(run_options, io_binding);
|
||||
}
|
||||
|
||||
common::Status InferenceSession::RunInBackgroundAndWaitForYield(RunOptions& run_options, IOBinding& io_binding,
|
||||
std::vector<OrtValue>& user_outputs) {
|
||||
const int64_t main_thread_event_id = 0;
|
||||
onnxruntime::contrib::OrtEventPool::GetInstance().ResetEvent(0);
|
||||
|
||||
task_.terminate_flag_ = &(run_options.terminate);
|
||||
task_.bg_thread_promise_ = std::promise<Status>();
|
||||
task_.bg_thread_future_ = task_.bg_thread_promise_.get_future();
|
||||
task_.bg_thread_ = std::thread([&](std::promise<common::Status> result_promise) {
|
||||
common::Status s = Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(),
|
||||
&io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo());
|
||||
|
||||
result_promise.set_value(s);
|
||||
|
||||
// signal main thread for background thread completion
|
||||
const int64_t main_thread_event_id = 0;
|
||||
onnxruntime::contrib::OrtEventPool::GetInstance().SignalEvent(main_thread_event_id);
|
||||
},
|
||||
std::move(task_.bg_thread_promise_));
|
||||
|
||||
// Wait for events from
|
||||
// 1. Yield op, if the bg thread sucessfully reached Yield's signal point
|
||||
// 2. The end of bg thread, if it hit execptions and returned earlier
|
||||
onnxruntime::contrib::OrtEventPool::GetInstance().WaitAndResetEvent(main_thread_event_id);
|
||||
|
||||
// background thread has completed without hitting Yield Op
|
||||
if (task_.bg_thread_future_.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready) {
|
||||
Status bg_thread_status = task_.bg_thread_future_.get();
|
||||
task_.bg_thread_.join();
|
||||
return bg_thread_status;
|
||||
}
|
||||
|
||||
onnxruntime::contrib::OrtMessageQueue::GetInstance().PopAll(user_outputs);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status InferenceSession::ContinueRunInBackground(const std::vector<OrtValue>& backward_output_grads) {
|
||||
for (const auto& ort_value : backward_output_grads) {
|
||||
onnxruntime::contrib::OrtMessageQueue::GetInstance().Push(ort_value);
|
||||
}
|
||||
|
||||
// resume background thread
|
||||
const int64_t background_thread_event_id = 1;
|
||||
onnxruntime::contrib::OrtEventPool::GetInstance().SignalEvent(background_thread_event_id);
|
||||
|
||||
Status bg_thread_status = task_.bg_thread_future_.get();
|
||||
// wait for bg_thread to complete
|
||||
if (task_.bg_thread_.joinable()) {
|
||||
task_.bg_thread_.join();
|
||||
}
|
||||
|
||||
return bg_thread_status;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void InferenceSession::StartProfiling(const std::basic_string<T>& file_prefix) {
|
||||
std::basic_ostringstream<T> ss;
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <thread>
|
||||
#include <future>
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
|
|
@ -294,6 +296,13 @@ class InferenceSession {
|
|||
virtual common::Status Run(const RunOptions& run_options, IOBinding& io_binding) ORT_MUST_USE_RESULT;
|
||||
common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT;
|
||||
|
||||
// For ORTModule.forward()
|
||||
virtual common::Status RunInBackgroundAndWaitForYield(RunOptions& run_options, IOBinding& io_binding,
|
||||
std::vector<OrtValue>& user_outputs) ORT_MUST_USE_RESULT;
|
||||
|
||||
// For ORTModule.backward()
|
||||
common::Status ContinueRunInBackground(const std::vector<OrtValue>& backward_output_grads) ORT_MUST_USE_RESULT;
|
||||
|
||||
/**
|
||||
* @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK.
|
||||
* @note lifetime of the returned pointer is valid as long as the Session object is live.
|
||||
|
|
@ -654,7 +663,15 @@ class InferenceSession {
|
|||
// Longer term we may want to directly refer to offsets in this buffer for initializers so we don't need to copy
|
||||
// those into new OrtValue instances, at which point we won't free them until the InferenceSession goes away.
|
||||
std::vector<uint8_t> ort_format_model_bytes_;
|
||||
|
||||
|
||||
// background thread for RunInBackgroundAndWaitForYield
|
||||
struct Task {
|
||||
std::thread bg_thread_;
|
||||
std::promise<Status> bg_thread_promise_;
|
||||
std::future<Status> bg_thread_future_;
|
||||
bool* terminate_flag_ = nullptr;
|
||||
} task_;
|
||||
|
||||
std::shared_ptr<onnxruntime::AllocatorManager> allocator_manager_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/framework/ml_value.h"
|
||||
#include "python/dlpack.h"
|
||||
|
||||
// this convertor will:
|
||||
// 1) take a OrtValue object and wrap it in the DLPack tensor
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace python {
|
||||
|
||||
DLManagedTensor* ort_value_to_dlpack(const OrtValue& ml_value);
|
||||
DLDataType get_dlpack_data_type(const OrtValue& ml_value);
|
||||
DLContext get_dlpack_context(const OrtValue& ml_value, const int64_t& device_id);
|
||||
|
||||
} // namespace python
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,16 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "python/dl_convertor.h"
|
||||
#include "python/dlpack_convertor.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace python {
|
||||
|
||||
DLDataType get_dlpack_data_type(const OrtValue& ml_value) {
|
||||
ORT_ENFORCE(ml_value.IsTensor(), "Only OrtValues that are Tensors are currently supported");
|
||||
DLDataType get_dlpack_data_type(const OrtValue& ort_value) {
|
||||
ORT_ENFORCE(ort_value.IsTensor(), "Only tensor-type OrtValues are supported");
|
||||
DLDataType dtype;
|
||||
dtype.lanes = 1;
|
||||
const Tensor& tensor = ml_value.Get<Tensor>();
|
||||
const Tensor& tensor = ort_value.Get<Tensor>();
|
||||
switch (tensor.GetElementType()) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
|
||||
dtype.code = DLDataTypeCode::kDLFloat;
|
||||
|
|
@ -73,11 +73,11 @@ DLDataType get_dlpack_data_type(const OrtValue& ml_value) {
|
|||
return dtype;
|
||||
}
|
||||
|
||||
DLContext get_dlpack_context(const OrtValue& ml_value, const int64_t& device_id) {
|
||||
ORT_ENFORCE(ml_value.IsTensor(), "Only OrtValues that are Tensors are currently supported");
|
||||
DLContext get_dlpack_context(const OrtValue& ort_value, const int64_t& device_id) {
|
||||
ORT_ENFORCE(ort_value.IsTensor(), "Only OrtValues that are Tensors are currently supported");
|
||||
DLContext ctx;
|
||||
ctx.device_id = device_id;
|
||||
const Tensor& tensor = ml_value.Get<Tensor>();
|
||||
ctx.device_id = static_cast<int>(device_id);
|
||||
const Tensor& tensor = ort_value.Get<Tensor>();
|
||||
const auto& location = tensor.Location();
|
||||
switch (location.device.Type()) {
|
||||
case OrtDevice::CPU:
|
||||
|
|
@ -102,17 +102,17 @@ void deleter(DLManagedTensor* arg) { delete static_cast<OrtDLManagedTensor*>(arg
|
|||
|
||||
// This function returns a shared_ptr to memory managed DLpack tensor
|
||||
// constructed out of OrtValue.
|
||||
DLManagedTensor* ort_value_to_dlpack(const OrtValue& ml_value) {
|
||||
ORT_ENFORCE(ml_value.IsTensor(), "Only OrtValues that are Tensors are currently supported");
|
||||
DLManagedTensor* ort_value_to_dlpack(const OrtValue& ort_value) {
|
||||
ORT_ENFORCE(ort_value.IsTensor(), "Only tensor type OrtValues are supported");
|
||||
OrtDLManagedTensor* ort_dlmanaged_tensor(new OrtDLManagedTensor);
|
||||
const Tensor& tensor = ml_value.Get<Tensor>();
|
||||
ort_dlmanaged_tensor->handle = ml_value;
|
||||
const Tensor& tensor = ort_value.Get<Tensor>();
|
||||
ort_dlmanaged_tensor->handle = ort_value;
|
||||
ort_dlmanaged_tensor->tensor.manager_ctx = ort_dlmanaged_tensor;
|
||||
ort_dlmanaged_tensor->tensor.deleter = &deleter;
|
||||
ort_dlmanaged_tensor->tensor.dl_tensor.data = const_cast<void*>(tensor.DataRaw());
|
||||
ort_dlmanaged_tensor->tensor.dl_tensor.ctx = get_dlpack_context(ml_value, tensor.Location().device.Id());
|
||||
ort_dlmanaged_tensor->tensor.dl_tensor.ndim = tensor.Shape().NumDimensions();
|
||||
ort_dlmanaged_tensor->tensor.dl_tensor.dtype = get_dlpack_data_type(ml_value);
|
||||
ort_dlmanaged_tensor->tensor.dl_tensor.ctx = get_dlpack_context(ort_value, tensor.Location().device.Id());
|
||||
ort_dlmanaged_tensor->tensor.dl_tensor.ndim = static_cast<int>(tensor.Shape().NumDimensions());
|
||||
ort_dlmanaged_tensor->tensor.dl_tensor.dtype = get_dlpack_data_type(ort_value);
|
||||
ort_dlmanaged_tensor->tensor.dl_tensor.shape =
|
||||
tensor.Shape().NumDimensions() > 0 ? const_cast<int64_t*>(&tensor.Shape()[0]) : nullptr;
|
||||
ort_dlmanaged_tensor->tensor.dl_tensor.strides = nullptr;
|
||||
19
onnxruntime/python/dlpack_convertor.h
Normal file
19
onnxruntime/python/dlpack_convertor.h
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/framework/ml_value.h"
|
||||
#include "python/dlpack.h"
|
||||
|
||||
// This convertor will take an OrtValue and wrap it as a DLPack tensor
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace python {
|
||||
|
||||
DLManagedTensor* ort_value_to_dlpack(const OrtValue& ort_value);
|
||||
DLDataType get_dlpack_data_type(const OrtValue& ort_value);
|
||||
DLContext get_dlpack_context(const OrtValue& ort_value, const int64_t& device_id);
|
||||
|
||||
} // namespace python
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -228,6 +228,21 @@ class Session:
|
|||
"""
|
||||
self._sess.run_with_iobinding(iobinding._iobinding, run_options)
|
||||
|
||||
def run_forward(self, iobinding, run_options):
|
||||
"""
|
||||
Compute the forward subgraph until it hits the Yield Op.
|
||||
:param iobinding: the iobinding object that has graph inputs/outputs bind.
|
||||
:param run_options: See :class:`onnxruntime.RunOptions`.
|
||||
"""
|
||||
return [OrtValue(ortvalue) for ortvalue in self._sess.run_forward(iobinding._iobinding, run_options)]
|
||||
|
||||
def run_backward(self, backward_output_grads):
|
||||
"""
|
||||
Resume executing the backward subgraph starting from Yield Op.
|
||||
:param backward_output_grads: Output gradients for backward.
|
||||
"""
|
||||
self._sess.run_backward([ortvalue._ortvalue for ortvalue in backward_output_grads])
|
||||
|
||||
|
||||
class InferenceSession(Session):
|
||||
"""
|
||||
|
|
@ -486,6 +501,23 @@ class OrtValue:
|
|||
return OrtValue(C.OrtValue.ortvalue_from_shape_and_type(shape, element_type,
|
||||
C.OrtDevice(get_ort_device_type(device_type), C.OrtDevice.default_memory(), device_id)))
|
||||
|
||||
@staticmethod
|
||||
def ortvalue_from_data_ptr(shape=None, element_type=None, device_type='cpu', device_id=0, buffer_ptr=None):
|
||||
'''
|
||||
Factory method to construct an OrtValue (which holds a Tensor) from given buffer_ptr
|
||||
:param shape: List of integers indicating the shape of the OrtValue
|
||||
:param element_type: The data type of the elements in the OrtValue (numpy type)
|
||||
:param device_type: e.g. cpu, cuda, cpu by default
|
||||
:param device_id: device id, e.g. 0
|
||||
:param buffer_ptr: data buffer pointer
|
||||
'''
|
||||
if shape is None or element_type is None or buffer_ptr is None:
|
||||
raise ValueError("`element_type`, `shape` and `buffer_ptr` are to be provided")
|
||||
|
||||
return OrtValue(C.OrtValue.ortvalue_from_data_ptr(shape, element_type,
|
||||
C.OrtDevice(get_ort_device_type(device_type), C.OrtDevice.default_memory(), device_id),
|
||||
buffer_ptr))
|
||||
|
||||
def data_ptr(self):
|
||||
'''
|
||||
Returns the address of the first element in the OrtValue's data buffer
|
||||
|
|
@ -524,4 +556,8 @@ class OrtValue:
|
|||
return self._ortvalue.numpy()
|
||||
|
||||
def to_dlpack(self):
|
||||
'''
|
||||
Returns a DLPack object from the OrtValue.
|
||||
Valid only for OrtValues holding Tensors. Throws for OrtValues holding non-Tensors.
|
||||
'''
|
||||
return self._ortvalue.to_dlpack()
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
#include "core/platform/env.h"
|
||||
#include "core/session/IOBinding.h"
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
#include "python/dl_convertor.h"
|
||||
#include "python/dlpack_convertor.h"
|
||||
|
||||
// execution provider factory creator headers
|
||||
#include "core/providers/cpu/cpu_provider_factory_creator.h"
|
||||
|
|
@ -1216,6 +1216,27 @@ void addObjectMethods(py::module& m, Environment& env) {
|
|||
|
||||
return ml_value;
|
||||
})
|
||||
.def_static("ortvalue_from_data_ptr", [](std::vector<int64_t>& shape, py::object& element_type,
|
||||
OrtDevice& device, int64_t data_ptr) {
|
||||
ORT_ENFORCE(data_ptr != 0, "Pointer to data memory is invalid");
|
||||
PyArray_Descr* dtype;
|
||||
if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) {
|
||||
throw std::runtime_error("Not a valid numpy type");
|
||||
}
|
||||
|
||||
int type_num = dtype->type_num;
|
||||
Py_DECREF(dtype);
|
||||
|
||||
OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device, device.Id());
|
||||
std::unique_ptr<Tensor> p_tensor = onnxruntime::make_unique<Tensor>(NumpyTypeToOnnxRuntimeType(type_num), shape,
|
||||
reinterpret_cast<void*>(data_ptr), info);
|
||||
|
||||
auto ort_value = onnxruntime::make_unique<OrtValue>();
|
||||
ort_value->Init(p_tensor.release(), DataTypeImpl::GetType<Tensor>(),
|
||||
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
|
||||
|
||||
return ort_value;
|
||||
})
|
||||
.def("data_ptr", [](OrtValue* ml_value) -> int64_t {
|
||||
// TODO: Assumes that the OrtValue is a Tensor, make this generic to handle non-Tensors
|
||||
ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are currently supported");
|
||||
|
|
@ -1278,8 +1299,8 @@ void addObjectMethods(py::module& m, Environment& env) {
|
|||
#endif
|
||||
return obj;
|
||||
})
|
||||
.def("to_dlpack", [](OrtValue* ml_value) -> py::object {
|
||||
DLManagedTensor* dlmanaged_tensor = ort_value_to_dlpack(*ml_value);
|
||||
.def("to_dlpack", [](OrtValue* ort_value) -> py::object {
|
||||
DLManagedTensor* dlmanaged_tensor = ort_value_to_dlpack(*ort_value);
|
||||
return py::reinterpret_steal<py::object>(
|
||||
PyCapsule_New(dlmanaged_tensor, "dltensor", dlpack_capsule_destructor));
|
||||
});
|
||||
|
|
@ -1804,6 +1825,20 @@ including arg name, arg type (contains both type and shape).)pbdoc")
|
|||
status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get());
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
|
||||
})
|
||||
.def("run_forward", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions& run_options) -> std::vector<OrtValue> {
|
||||
std::vector<OrtValue> module_outputs;
|
||||
Status status = sess->GetSessionHandle()->RunInBackgroundAndWaitForYield(run_options, *io_binding.Get(), module_outputs);
|
||||
if (!status.IsOK()) {
|
||||
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
|
||||
}
|
||||
|
||||
return module_outputs;
|
||||
})
|
||||
.def("run_backward", [](PyInferenceSession* sess, const std::vector<OrtValue>& backward_output_grads) -> void {
|
||||
Status status = sess->GetSessionHandle()->ContinueRunInBackground(backward_output_grads);
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
|
||||
});
|
||||
|
||||
py::enum_<onnxruntime::ArenaExtendStrategy>(m, "ArenaExtendStrategy", py::arithmetic())
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/graph/model.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
#include "orttraining/core/framework/module_gradient_graph_builder.h"
|
||||
|
|
@ -14,51 +13,26 @@ namespace training {
|
|||
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
void GetInputAndOutputNames(const Node& node, std::unordered_set<std::string>& input_names,
|
||||
std::unordered_set<std::string>& output_names) {
|
||||
std::for_each(node.InputDefs().begin(), node.InputDefs().end(),
|
||||
[&input_names](const NodeArg* node_arg) { input_names.insert(node_arg->Name()); });
|
||||
std::for_each(node.OutputDefs().begin(), node.OutputDefs().end(),
|
||||
[&output_names](const NodeArg* node_arg) { output_names.insert(node_arg->Name()); });
|
||||
}
|
||||
|
||||
void RemoveNodes(Graph& graph, const std::vector<Node*>& nodes_to_remove) {
|
||||
for (Node* node_to_remove : nodes_to_remove) {
|
||||
graph_utils::RemoveNodeOutputEdges(graph, *node_to_remove);
|
||||
graph.RemoveNode(node_to_remove->Index());
|
||||
}
|
||||
}
|
||||
|
||||
void FilterInitializers(Graph& graph, const std::unordered_set<std::string>& input_names) {
|
||||
const auto& initializers = graph.GetAllInitializedTensors();
|
||||
std::unordered_set<std::string> initializer_names_to_remove;
|
||||
for (const auto& initializer : initializers) {
|
||||
if (input_names.find(initializer.first) == input_names.end()) {
|
||||
initializer_names_to_remove.insert(initializer.first);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& initializer_name : initializer_names_to_remove) {
|
||||
graph.RemoveInitializedTensor(initializer_name);
|
||||
}
|
||||
}
|
||||
|
||||
Status ModuleGradientGraphBuilder::Initialize(std::istream& model_istream,
|
||||
const ModuleGradientGraphBuilderConfiguration& config) {
|
||||
// We need to apply the pre-training transformers before the gradient graph builder so we can build
|
||||
// an optimized gradient graph. The constant folding transformer depends on concrete shapes, without
|
||||
// constant folding with concrete shapes, shapes of some intermediate tensors will fail to infer.
|
||||
// This means we need to "apply transformers -> build gradient graph -> split" each time we have different
|
||||
// concrete input shapes. So this init func is just to save the original graph and config.
|
||||
// Save the model and config.
|
||||
ONNX_NAMESPACE::ModelProto model_proto;
|
||||
ORT_RETURN_IF_ERROR(Model::Load(model_istream, &model_proto));
|
||||
ORT_RETURN_IF_ERROR(Model::Load(model_proto, model_, nullptr, *logger_));
|
||||
config_ = config;
|
||||
|
||||
// Handle original model inputs, outputs and trainable initializers.
|
||||
// We need to move the trainable initializers to graph inputs and keep the order in config,
|
||||
// it's possible that the graph already has some trainable initializers in graph inputs,
|
||||
// so we need to NOT assign these trainable initializers to the user inputs list.
|
||||
Graph& graph = model_->MainGraph();
|
||||
std::unordered_set<std::string> initializer_names_to_train_set(config.initializer_names_to_train.begin(),
|
||||
config.initializer_names_to_train.end());
|
||||
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
|
||||
for (auto& node_arg : graph_inputs) {
|
||||
split_graphs_info_.user_input_names.emplace_back(node_arg->Name());
|
||||
if (initializer_names_to_train_set.find(node_arg->Name()) == initializer_names_to_train_set.end()) {
|
||||
split_graphs_info_.user_input_names.emplace_back(node_arg->Name());
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<const NodeArg*>& graph_outputs = graph.GetOutputs();
|
||||
|
|
@ -69,35 +43,65 @@ Status ModuleGradientGraphBuilder::Initialize(std::istream& model_istream,
|
|||
split_graphs_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(),
|
||||
config.initializer_names_to_train.end());
|
||||
|
||||
// Remove the training initializers from the graph and move them to input to save memory.
|
||||
std::vector<const NodeArg*> input_args;
|
||||
for (const auto& input_name : split_graphs_info_.user_input_names) {
|
||||
input_args.emplace_back(graph.GetNodeArg(input_name));
|
||||
}
|
||||
|
||||
// Remove the training initializers from the graph and move them to graph inputs.
|
||||
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
|
||||
input_args.emplace_back(graph.GetNodeArg(initializer_name));
|
||||
graph.RemoveInitializedTensor(initializer_name);
|
||||
}
|
||||
|
||||
graph.SetInputs(input_args);
|
||||
|
||||
config_ = config;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ModuleGradientGraphBuilder::BuildAndSplit(const std::vector<std::vector<int64_t>>& input_shapes) {
|
||||
// Build the gradient graphs from original graph.
|
||||
// Since the input shapes may differ, and the graph optimizers (mainly constant folding) may fold this
|
||||
// shape info to constants, the optimized graph (before gradient graph building) can not be shared.
|
||||
// So each time we need to start from the beginning, i.e., 1) replace input shapes, 2) apply graph optimizers,
|
||||
// 3) build gradient graph, and finally 4) adjust the graph inputs and outputs.
|
||||
Status ModuleGradientGraphBuilder::Build(const std::vector<std::vector<int64_t>>* input_shapes_ptr) {
|
||||
// Make a copy of the original model.
|
||||
auto model_proto = model_->ToProto();
|
||||
std::shared_ptr<onnxruntime::Model> model_copied;
|
||||
ORT_RETURN_IF_ERROR(Model::Load(model_proto, model_copied, nullptr, *logger_));
|
||||
Graph& graph = model_copied->MainGraph();
|
||||
ORT_RETURN_IF_ERROR(Model::Load(model_proto, gradient_model_, nullptr, *logger_));
|
||||
|
||||
// Replace the input shapes.
|
||||
// Replace the user input shapes if input_shapes_ptr is not null_ptr.
|
||||
if (input_shapes_ptr) {
|
||||
SetConcreteInputShapes(*input_shapes_ptr);
|
||||
}
|
||||
|
||||
// Build the gradient graph.
|
||||
ORT_RETURN_IF_ERROR(BuildGradientGraph());
|
||||
|
||||
// Add Yield Op.
|
||||
AddYieldOp();
|
||||
|
||||
// Reorder outputs.
|
||||
ReorderOutputs();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::string ModuleGradientGraphBuilder::GetGradientModel() const {
|
||||
std::string model_str;
|
||||
if (!gradient_model_->ToProto().SerializeToString(&model_str)) {
|
||||
ORT_THROW("Fail to serialize gradient model to string.");
|
||||
}
|
||||
|
||||
return model_str;
|
||||
}
|
||||
|
||||
void ModuleGradientGraphBuilder::SetConcreteInputShapes(const std::vector<std::vector<int64_t>>& input_shapes) {
|
||||
ORT_ENFORCE(input_shapes.size() == split_graphs_info_.user_input_names.size(),
|
||||
"The size of concrete input shapes and the size of user inputs does not match.");
|
||||
Graph& gradient_graph = gradient_model_->MainGraph();
|
||||
std::vector<const NodeArg*> input_args;
|
||||
size_t input_index = 0;
|
||||
for (const auto& input_name : split_graphs_info_.user_input_names) {
|
||||
NodeArg* input_node_arg = graph.GetNodeArg(input_name);
|
||||
NodeArg* input_node_arg = gradient_graph.GetNodeArg(input_name);
|
||||
ONNX_NAMESPACE::TensorShapeProto new_shape;
|
||||
for (size_t i = 0; i < input_shapes[input_index].size(); i++) {
|
||||
new_shape.add_dim()->set_dim_value(input_shapes[input_index][i]);
|
||||
|
|
@ -109,15 +113,19 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(const std::vector<std::vector<i
|
|||
}
|
||||
|
||||
// Move over all training initializer inputs. They already have the concrete shapes.
|
||||
const std::vector<const NodeArg*>& graph_inputs = graph.GetInputsIncludingInitializers();
|
||||
const std::vector<const NodeArg*>& graph_inputs = gradient_graph.GetInputsIncludingInitializers();
|
||||
for (; input_index < graph_inputs.size(); input_index++) {
|
||||
input_args.emplace_back(graph_inputs[input_index]);
|
||||
}
|
||||
|
||||
graph.SetInputs(input_args);
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
gradient_graph.SetInputs(input_args);
|
||||
}
|
||||
|
||||
Status ModuleGradientGraphBuilder::BuildGradientGraph() {
|
||||
// Resolve original graph, register and apply transformers for pre-training.
|
||||
Graph& gradient_graph = gradient_model_->MainGraph();
|
||||
ORT_RETURN_IF_ERROR(gradient_graph.Resolve());
|
||||
|
||||
// Register and apply transformers for pre-training.
|
||||
const TrainingSession::TrainingConfiguration::GraphTransformerConfiguration graph_transformer_config{};
|
||||
GraphTransformerManager graph_transformation_mgr{2};
|
||||
std::unique_ptr<CPUExecutionProvider> cpu_execution_provider =
|
||||
|
|
@ -143,235 +151,118 @@ Status ModuleGradientGraphBuilder::BuildAndSplit(const std::vector<std::vector<i
|
|||
}
|
||||
|
||||
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
|
||||
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i), *logger_));
|
||||
ORT_RETURN_IF_ERROR(
|
||||
graph_transformation_mgr.ApplyTransformers(gradient_graph, static_cast<TransformerLevel>(i), *logger_));
|
||||
}
|
||||
|
||||
// Build gradient graph.
|
||||
GradientGraphConfiguration gradient_graph_config{};
|
||||
gradient_graph_config.use_invertible_layernorm_grad = config_.use_invertible_layernorm_grad;
|
||||
gradient_graph_config.set_gradients_as_graph_outputs = config_.set_gradients_as_graph_outputs;
|
||||
gradient_graph_config.set_gradients_as_graph_outputs = true;
|
||||
std::unordered_set<std::string> y_node_arg_names(split_graphs_info_.user_output_names.begin(),
|
||||
split_graphs_info_.user_output_names.end());
|
||||
GradientGraphBuilder grad_graph_builder(&graph, y_node_arg_names, x_node_arg_names, "",
|
||||
GradientGraphBuilder grad_graph_builder(&gradient_graph, y_node_arg_names, x_node_arg_names, "",
|
||||
gradient_graph_config, *logger_);
|
||||
|
||||
ORT_RETURN_IF_ERROR(grad_graph_builder.Build());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Fix inputs/outputs related to gradients.
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
std::unordered_set<std::string> input_names;
|
||||
std::unordered_set<std::string> output_names;
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto& node = *graph.GetNode(node_index);
|
||||
GetInputAndOutputNames(node, input_names, output_names);
|
||||
void ModuleGradientGraphBuilder::AddYieldOp() {
|
||||
Graph& gradient_graph = gradient_model_->MainGraph();
|
||||
GraphViewer gradient_graph_viewer(gradient_graph);
|
||||
const auto& gradient_node_topology_list = gradient_graph_viewer.GetNodesInTopologicalOrder();
|
||||
std::unordered_set<std::string> user_output_grad_names_set;
|
||||
for (const auto& name : split_graphs_info_.user_output_names) {
|
||||
user_output_grad_names_set.insert(name + "_grad");
|
||||
}
|
||||
|
||||
input_args.clear();
|
||||
for (const NodeArg* input_node_arg: graph.GetInputsIncludingInitializers()) {
|
||||
input_args.emplace_back(input_node_arg);
|
||||
}
|
||||
|
||||
// Add the entry points of gradients (normally loss_gard) to the graph inputs. Using the order of graph outputs.
|
||||
split_graphs_info_.user_output_grad_names.clear();
|
||||
split_graphs_info_.backward_output_grad_names.clear();
|
||||
for (const auto& output_name : split_graphs_info_.user_output_names) {
|
||||
std::string output_gradient_name = output_name + "_grad";
|
||||
if (input_names.find(output_gradient_name) != input_names.end()) {
|
||||
split_graphs_info_.user_output_grad_names.emplace_back(output_gradient_name);
|
||||
// Only add to graph input when it's not an output of a node.
|
||||
if (output_names.find(output_gradient_name) == output_names.end()) {
|
||||
split_graphs_info_.backward_output_grad_names.emplace_back(output_gradient_name);
|
||||
NodeArg* output_gradient_node_arg = graph.GetNodeArg(output_gradient_name);
|
||||
output_gradient_node_arg->UpdateTypeAndShape(*graph.GetNodeArg(output_name), true, true, *logger_);
|
||||
input_args.emplace_back(output_gradient_node_arg);
|
||||
// If an NodeArg is output of one of nodes, it's not the user output gradient needed by backward graph.
|
||||
std::unordered_set<std::string> non_backward_user_output_grad_names;
|
||||
for (auto node_index : gradient_node_topology_list) {
|
||||
auto& node = *gradient_graph.GetNode(node_index);
|
||||
for (const auto& node_arg : node.OutputDefs()) {
|
||||
if (user_output_grad_names_set.find(node_arg->Name()) != user_output_grad_names_set.end()) {
|
||||
non_backward_user_output_grad_names.insert(node_arg->Name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
graph.SetInputs(input_args);
|
||||
// Yield inputs include all user outputs, those require output gradients come first, so Yield Op can use their shapes
|
||||
// to infer Op output shapes.
|
||||
std::vector<std::string> user_output_names_require_grad;
|
||||
std::vector<std::string> user_output_names_no_grad;
|
||||
split_graphs_info_.backward_output_grad_names.clear();
|
||||
for (const auto& name : split_graphs_info_.user_output_names) {
|
||||
std::string grad_name = name + "_grad";
|
||||
if (non_backward_user_output_grad_names.find(grad_name) == non_backward_user_output_grad_names.end()) {
|
||||
user_output_names_require_grad.emplace_back(name);
|
||||
split_graphs_info_.backward_output_grad_names.emplace_back(grad_name);
|
||||
} else {
|
||||
user_output_names_no_grad.emplace_back(name);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const NodeArg*> output_args;
|
||||
for (auto& output_name : split_graphs_info_.user_output_names) {
|
||||
output_args.emplace_back(graph.GetNodeArg(output_name));
|
||||
// Reorder the user outputs.
|
||||
split_graphs_info_.user_output_names.clear();
|
||||
for (const auto& name : user_output_names_require_grad) {
|
||||
split_graphs_info_.user_output_names.emplace_back(name);
|
||||
}
|
||||
|
||||
for (const auto& name : user_output_names_no_grad) {
|
||||
split_graphs_info_.user_output_names.emplace_back(name);
|
||||
}
|
||||
|
||||
std::vector<NodeArg*> yield_input_node_args;
|
||||
std::vector<NodeArg*> yield_output_node_args;
|
||||
for (const auto& name : split_graphs_info_.user_output_names) {
|
||||
yield_input_node_args.emplace_back(gradient_graph.GetNodeArg(name));
|
||||
}
|
||||
|
||||
for (const auto& name : split_graphs_info_.backward_output_grad_names) {
|
||||
yield_output_node_args.emplace_back(gradient_graph.GetNodeArg(name));
|
||||
}
|
||||
|
||||
gradient_graph.AddNode("YieldOp", "YieldOp", "Yield Op", yield_input_node_args, yield_output_node_args, {}, kMSDomain);
|
||||
}
|
||||
|
||||
void ModuleGradientGraphBuilder::ReorderOutputs() {
|
||||
// Adjust gradient graph outputs by the following order:
|
||||
// 1. user input grads if required, with same order of user inputs,
|
||||
// 2. trainable initailizer grads, with same order of trainable initializers.
|
||||
Graph& gradient_graph = gradient_model_->MainGraph();
|
||||
const std::vector<const NodeArg*>& gradient_graph_outputs = gradient_graph.GetOutputs();
|
||||
std::unordered_map<std::string, const NodeArg*> gradient_output_arg_map;
|
||||
for (auto& node_arg : gradient_graph_outputs) {
|
||||
gradient_output_arg_map[node_arg->Name()] = node_arg;
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> user_input_require_grad_set(config_.input_names_require_grad.begin(),
|
||||
config_.input_names_require_grad.end());
|
||||
|
||||
std::vector<const NodeArg*> new_output_args;
|
||||
split_graphs_info_.user_input_grad_names.clear();
|
||||
for (const auto& input_name : split_graphs_info_.user_input_names) {
|
||||
if (user_input_require_grad_set.find(input_name) != user_input_require_grad_set.end()) {
|
||||
std::string input_gradient_name = input_name + "_grad";
|
||||
ORT_ENFORCE(gradient_output_arg_map.find(input_gradient_name) != gradient_output_arg_map.end(),
|
||||
"Required user input grad is not found on gradient graph.");
|
||||
split_graphs_info_.user_input_grad_names[input_name] = input_gradient_name;
|
||||
new_output_args.emplace_back(gradient_output_arg_map[input_gradient_name]);
|
||||
}
|
||||
}
|
||||
|
||||
// Add initializer gradients to graph outputs.
|
||||
split_graphs_info_.initializer_grad_names_to_train.clear();
|
||||
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
|
||||
std::string initializer_gradient_name = initializer_name + "_grad";
|
||||
if (output_names.find(initializer_gradient_name) != output_names.end()) {
|
||||
split_graphs_info_.initializer_grad_names_to_train.emplace_back(initializer_gradient_name);
|
||||
output_args.emplace_back(graph.GetNodeArg(initializer_gradient_name));
|
||||
}
|
||||
ORT_ENFORCE(gradient_output_arg_map.find(initializer_gradient_name) != gradient_output_arg_map.end(),
|
||||
"Trainable initializer grad is not found on gradient graph.");
|
||||
split_graphs_info_.initializer_grad_names_to_train.emplace_back(initializer_gradient_name);
|
||||
new_output_args.emplace_back(gradient_output_arg_map[initializer_gradient_name]);
|
||||
}
|
||||
|
||||
// Add input gradients to graph outputs if it's required.
|
||||
for (const auto& input_name : config_.input_names_require_grad) {
|
||||
std::string input_gradient_name = input_name + "_grad";
|
||||
if (output_names.find(input_gradient_name) != output_names.end()) {
|
||||
output_args.emplace_back(graph.GetNodeArg(input_gradient_name));
|
||||
}
|
||||
}
|
||||
|
||||
graph.SetOutputs(output_args);
|
||||
graph.Resolve();
|
||||
|
||||
// Run the transformers again mainly for backward part, e.g., constant fold from those Shape nodes in backward graph.
|
||||
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
|
||||
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i), *logger_));
|
||||
}
|
||||
|
||||
// Create two copies of gradient model for forward and backward models respectively.
|
||||
auto gradient_model_proto = model_copied->ToProto();
|
||||
ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, forward_model_, nullptr, *logger_));
|
||||
ORT_RETURN_IF_ERROR(Model::Load(gradient_model_proto, backward_model_, nullptr, *logger_));
|
||||
|
||||
// Split the graph in the copies of gradient model.
|
||||
ORT_RETURN_IF_ERROR(Split());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::string SerializeModel(const std::shared_ptr<onnxruntime::Model>& model, const std::string& tag) {
|
||||
std::string model_str;
|
||||
if (!model->ToProto().SerializeToString(&model_str)) {
|
||||
ORT_THROW("Fail to serialize", tag, "model to string.");
|
||||
}
|
||||
|
||||
return model_str;
|
||||
}
|
||||
|
||||
std::string ModuleGradientGraphBuilder::GetForwardModel() const { return SerializeModel(forward_model_, "forward"); }
|
||||
|
||||
std::string ModuleGradientGraphBuilder::GetBackwardModel() const { return SerializeModel(backward_model_, "backward"); }
|
||||
|
||||
Status ModuleGradientGraphBuilder::Split() {
|
||||
// Get forward model, also collect some information for backward model generation.
|
||||
Graph& forward_graph = forward_model_->MainGraph();
|
||||
GraphViewer forward_graph_viewer(forward_graph);
|
||||
const auto& forward_node_topology_list = forward_graph_viewer.GetNodesInTopologicalOrder();
|
||||
std::vector<Node*> forward_nodes_to_remove;
|
||||
std::unordered_set<std::string> forward_input_names;
|
||||
std::unordered_set<std::string> forward_output_names;
|
||||
std::unordered_set<std::string> backward_input_names;
|
||||
std::unordered_set<std::string> backward_output_names;
|
||||
for (auto node_index : forward_node_topology_list) {
|
||||
auto& node = *forward_graph.GetNode(node_index);
|
||||
// Currently we are using node description to distinguish the forward and backward nodes.
|
||||
if (node.Description() == "Backward pass") {
|
||||
forward_nodes_to_remove.emplace_back(&node);
|
||||
GetInputAndOutputNames(node, backward_input_names, backward_output_names);
|
||||
} else {
|
||||
GetInputAndOutputNames(node, forward_input_names, forward_output_names);
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_set<std::string> intermediate_arg_names;
|
||||
for (const auto& forward_output_name : forward_output_names) {
|
||||
if (backward_input_names.find(forward_output_name) != backward_input_names.end()) {
|
||||
intermediate_arg_names.insert(forward_output_name);
|
||||
}
|
||||
}
|
||||
|
||||
RemoveNodes(forward_graph, forward_nodes_to_remove);
|
||||
FilterInitializers(forward_graph, forward_input_names);
|
||||
|
||||
// All user inputs should be also part of the forward graph inputs.
|
||||
std::vector<const NodeArg*> forward_input_args;
|
||||
for (const auto& input_name : split_graphs_info_.user_input_names) {
|
||||
forward_input_args.emplace_back(forward_graph.GetNodeArg(input_name));
|
||||
}
|
||||
|
||||
// Add initializers to forward graph inputs.
|
||||
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
|
||||
forward_input_args.emplace_back(forward_graph.GetNodeArg(initializer_name));
|
||||
}
|
||||
|
||||
forward_graph.SetInputs(forward_input_args);
|
||||
|
||||
// All user outputs should be also part of the forward graph outputs.
|
||||
std::vector<const NodeArg*> forward_output_args;
|
||||
for (const auto& output_name : split_graphs_info_.user_output_names) {
|
||||
forward_output_args.emplace_back(forward_graph.GetNodeArg(output_name));
|
||||
}
|
||||
|
||||
// Add intermediate args to forward graph outputs.
|
||||
split_graphs_info_.intermediate_tensor_names.clear();
|
||||
for (const auto& intermediate_arg_name : intermediate_arg_names) {
|
||||
// Ignore the user outputs.
|
||||
if (std::find(split_graphs_info_.user_output_names.begin(), split_graphs_info_.user_output_names.end(),
|
||||
intermediate_arg_name) == split_graphs_info_.user_output_names.end()) {
|
||||
split_graphs_info_.intermediate_tensor_names.emplace_back(intermediate_arg_name);
|
||||
forward_output_args.emplace_back(forward_graph.GetNodeArg(intermediate_arg_name));
|
||||
}
|
||||
}
|
||||
|
||||
forward_graph.SetOutputs(forward_output_args);
|
||||
forward_graph.Resolve();
|
||||
|
||||
// Get backward graph.
|
||||
Graph& backward_graph = backward_model_->MainGraph();
|
||||
GraphViewer backward_graph_viewer(backward_graph);
|
||||
const auto& backward_node_topology_list = backward_graph_viewer.GetNodesInTopologicalOrder();
|
||||
std::vector<Node*> backward_nodes_to_remove;
|
||||
for (auto node_index : backward_node_topology_list) {
|
||||
auto& node = *backward_graph.GetNode(node_index);
|
||||
if (node.Description() != "Backward pass") {
|
||||
backward_nodes_to_remove.emplace_back(&node);
|
||||
}
|
||||
}
|
||||
|
||||
RemoveNodes(backward_graph, backward_nodes_to_remove);
|
||||
FilterInitializers(backward_graph, backward_input_names);
|
||||
|
||||
// User inputs to backward graph inputs.
|
||||
split_graphs_info_.backward_user_input_names.clear();
|
||||
std::vector<const NodeArg*> backward_input_args;
|
||||
for (const auto& input_name : split_graphs_info_.user_input_names) {
|
||||
// Only takes those in the backward inputs.
|
||||
if (backward_input_names.find(input_name) != backward_input_names.end()) {
|
||||
split_graphs_info_.backward_user_input_names.emplace_back(input_name);
|
||||
backward_input_args.emplace_back(backward_graph.GetNodeArg(input_name));
|
||||
}
|
||||
}
|
||||
|
||||
// Add initializer args to backward graph inputs if any node uses them.
|
||||
split_graphs_info_.backward_intializer_names_as_input.clear();
|
||||
for (const auto& initializer_name : split_graphs_info_.initializer_names_to_train) {
|
||||
// Some initializers will be inputs for backward graph.
|
||||
if (backward_input_names.find(initializer_name) != backward_input_names.end()) {
|
||||
split_graphs_info_.backward_intializer_names_as_input.emplace_back(initializer_name);
|
||||
backward_input_args.emplace_back(backward_graph.GetNodeArg(initializer_name));
|
||||
backward_graph.RemoveInitializedTensor(initializer_name);
|
||||
}
|
||||
}
|
||||
|
||||
// Add intermediate args to backward graph inputs.
|
||||
for (const auto& intermediate_arg_name : split_graphs_info_.intermediate_tensor_names) {
|
||||
NodeArg* intermediate_node_arg = backward_graph.GetNodeArg(intermediate_arg_name);
|
||||
intermediate_node_arg->UpdateTypeAndShape(*forward_graph.GetNodeArg(intermediate_arg_name), true, true, *logger_);
|
||||
backward_input_args.emplace_back(intermediate_node_arg);
|
||||
}
|
||||
|
||||
// Grad of user outputs to backward graph inputs.
|
||||
for (const auto& output_grad_name : split_graphs_info_.backward_output_grad_names) {
|
||||
backward_input_args.emplace_back(backward_graph.GetNodeArg(output_grad_name));
|
||||
}
|
||||
|
||||
backward_graph.SetInputs(backward_input_args);
|
||||
|
||||
// Exclude user outputs from the backward graph.
|
||||
const std::vector<const NodeArg*>& backward_graph_outputs = backward_graph.GetOutputs();
|
||||
std::vector<const NodeArg*> backward_output_args;
|
||||
for (auto& node_arg : backward_graph_outputs) {
|
||||
if (backward_output_names.find(node_arg->Name()) != backward_output_names.end()) {
|
||||
backward_output_args.emplace_back(node_arg);
|
||||
}
|
||||
}
|
||||
|
||||
backward_graph.SetOutputs(backward_output_args);
|
||||
backward_graph.Resolve();
|
||||
return Status::OK();
|
||||
gradient_graph.SetOutputs(new_output_args);
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@
|
|||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "core/common/status.h"
|
||||
#include "core/graph/model.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
|
||||
|
|
@ -20,47 +23,77 @@ struct ModuleGradientGraphBuilderConfiguration {
|
|||
|
||||
// Gradient graph configuration.
|
||||
bool use_invertible_layernorm_grad = false;
|
||||
bool set_gradients_as_graph_outputs = false;
|
||||
|
||||
// TODO: add GraphTransformerConfiguration
|
||||
// TODO: add mixed precision config
|
||||
// TODO: do we need to support graph with loss?
|
||||
};
|
||||
|
||||
/**
|
||||
* The information of split graphs for frontend.
|
||||
*/
|
||||
struct SplitGraphsInfo {
|
||||
// The user inputs.
|
||||
std::vector<std::string> user_input_names{};
|
||||
// Map from user input names to corresponding user input grad names for those user inputs that require grad.
|
||||
std::unordered_map<std::string, std::string> user_input_grad_names{};
|
||||
// Trainable initializers.
|
||||
std::vector<std::string> initializer_names_to_train{};
|
||||
// Trainable initializer grad names, ordered according to initializer_names_to_train.
|
||||
std::vector<std::string> initializer_grad_names_to_train{};
|
||||
// The user outputs.
|
||||
std::vector<std::string> user_output_names{};
|
||||
std::vector<std::string> backward_user_input_names{};
|
||||
std::vector<std::string> backward_intializer_names_as_input{};
|
||||
std::vector<std::string> intermediate_tensor_names{};
|
||||
std::vector<std::string> user_output_grad_names{};
|
||||
// The user output grad names that are actual required by the backward graph.
|
||||
std::vector<std::string> backward_output_grad_names{};
|
||||
};
|
||||
|
||||
class ModuleGradientGraphBuilder {
|
||||
public:
|
||||
/**
|
||||
* Initialize the builder. It saves the initial model and the configuration.
|
||||
* It also removes the trainable initializers from initial model and move them to graph inputs.
|
||||
* @param model_istream The initial model as input stream.
|
||||
* @param config The configuration to control the builder.
|
||||
* @return The status of the initialization.
|
||||
*/
|
||||
Status Initialize(std::istream& model_istream, const ModuleGradientGraphBuilderConfiguration& config);
|
||||
Status BuildAndSplit(const std::vector<std::vector<int64_t>>& input_shapes);
|
||||
|
||||
std::string GetForwardModel() const;
|
||||
std::string GetBackwardModel() const;
|
||||
/**
|
||||
* Build the gradient graph and split it to forward and backward graphs.
|
||||
* @param input_shapes_ptr The pointer to vector of concrete shapes of the user inputs.
|
||||
* @return The status of the gradient graph building and forward/backward graphs splitting.
|
||||
*/
|
||||
Status Build(const std::vector<std::vector<int64_t>>* input_shapes_ptr = nullptr);
|
||||
|
||||
/**
|
||||
* Get gradient model.
|
||||
* @return The gradient model serialized to string.
|
||||
*/
|
||||
std::string GetGradientModel() const;
|
||||
|
||||
/**
|
||||
* Get the split graphs information.
|
||||
* @return The split graphs information.
|
||||
*/
|
||||
SplitGraphsInfo GetSplitGraphsInfo() const { return split_graphs_info_; }
|
||||
|
||||
private:
|
||||
Status Split();
|
||||
// Set concrete shapes for graph inputs.
|
||||
void SetConcreteInputShapes(const std::vector<std::vector<int64_t>>& input_shapes);
|
||||
|
||||
// Build gradient graph.
|
||||
Status BuildGradientGraph();
|
||||
|
||||
// Add Yield Op.
|
||||
void AddYieldOp();
|
||||
|
||||
// Reorder gradient graph outputs.
|
||||
void ReorderOutputs();
|
||||
|
||||
std::shared_ptr<onnxruntime::Model> model_;
|
||||
std::shared_ptr<onnxruntime::Model> forward_model_;
|
||||
std::shared_ptr<onnxruntime::Model> backward_model_;
|
||||
std::shared_ptr<onnxruntime::Model> gradient_model_;
|
||||
SplitGraphsInfo split_graphs_info_;
|
||||
|
||||
ModuleGradientGraphBuilderConfiguration config_;
|
||||
const logging::Logger* logger_ = &logging::LoggingManager::DefaultLogger(); // use default logger for now.
|
||||
const logging::Logger* logger_ = &logging::LoggingManager::DefaultLogger(); // use default logger for now.
|
||||
};
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -2208,6 +2208,33 @@ Return true if all elements are true and false otherwise.
|
|||
propagateShapeFromInputToOutput(ctx, i + 1, i);
|
||||
}
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(YieldOp)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
|
||||
.SetDoc("Yield Op.")
|
||||
.Input(0, "outputs", "Module outputs to be returned to pytorch.", "T", OpSchema::Variadic,
|
||||
/*is_homogeneous*/ false,
|
||||
/*min_arity*/ 1)
|
||||
.Output(0, "outputs_grad", "Gradient of outputs returned from pytorch.", "T", OpSchema::Variadic,
|
||||
/*is_homogeneous*/ false,
|
||||
/*min_arity*/ 1)
|
||||
.TypeConstraint("T", OpSchema::all_tensor_types(), "Allow inputs and outputs to be any kind of tensor.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
// Assume the outputs and gradients are one-to-one matching
|
||||
// TODO: The contrain is relaxed for now
|
||||
// ORT_ENFORCE(ctx.getNumInputs() == ctx.getNumOutputs(), "Yield op doesn't have the same number of inputs and output");
|
||||
|
||||
for (size_t i = 0; i < ctx.getNumOutputs(); ++i) {
|
||||
propagateElemTypeFromInputToOutput(ctx, i, i);
|
||||
auto typeProto = ctx.getInputType(i);
|
||||
if (!hasShape(*typeProto)) {
|
||||
continue;
|
||||
}
|
||||
propagateShapeFromInputToOutput(ctx, i, i);
|
||||
}
|
||||
});
|
||||
}
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -475,47 +475,45 @@ void addObjectMethodsForTraining(py::module& m) {
|
|||
});
|
||||
|
||||
py::class_<ModuleGradientGraphBuilderConfiguration> module_gradient_graph_builder_config(
|
||||
m, "ModuleGradientGraphBuilderConfiguration", R"pbdoc(Configuration information for module gradient graph builder.)pbdoc");
|
||||
m, "ModuleGradientGraphBuilderConfiguration",
|
||||
R"pbdoc(Configuration information for module gradient graph builder.)pbdoc");
|
||||
module_gradient_graph_builder_config.def(py::init())
|
||||
.def_readwrite("initializer_names_to_train", &ModuleGradientGraphBuilderConfiguration::initializer_names_to_train)
|
||||
.def_readwrite("input_names_require_grad", &ModuleGradientGraphBuilderConfiguration::input_names_require_grad)
|
||||
.def_readwrite("use_invertible_layernorm_grad", &ModuleGradientGraphBuilderConfiguration::use_invertible_layernorm_grad)
|
||||
.def_readwrite("set_gradients_as_graph_outputs", &ModuleGradientGraphBuilderConfiguration::set_gradients_as_graph_outputs);
|
||||
.def_readwrite("use_invertible_layernorm_grad",
|
||||
&ModuleGradientGraphBuilderConfiguration::use_invertible_layernorm_grad);
|
||||
|
||||
py::class_<SplitGraphsInfo> split_graphs_info(
|
||||
m, "SplitGraphsInfo", R"pbdoc(The information of split graphs for frontend.)pbdoc");
|
||||
py::class_<SplitGraphsInfo> split_graphs_info(m, "SplitGraphsInfo",
|
||||
R"pbdoc(The information of split graphs for frontend.)pbdoc");
|
||||
split_graphs_info.def(py::init())
|
||||
.def_readwrite("user_input_names", &SplitGraphsInfo::user_input_names)
|
||||
.def_readwrite("user_input_grad_names", &SplitGraphsInfo::user_input_grad_names)
|
||||
.def_readwrite("initializer_names_to_train", &SplitGraphsInfo::initializer_names_to_train)
|
||||
.def_readwrite("initializer_grad_names_to_train", &SplitGraphsInfo::initializer_grad_names_to_train)
|
||||
.def_readwrite("user_output_names", &SplitGraphsInfo::user_output_names)
|
||||
.def_readwrite("backward_user_input_names", &SplitGraphsInfo::backward_user_input_names)
|
||||
.def_readwrite("backward_intializer_names_as_input", &SplitGraphsInfo::backward_intializer_names_as_input)
|
||||
.def_readwrite("intermediate_tensor_names", &SplitGraphsInfo::intermediate_tensor_names)
|
||||
.def_readwrite("user_output_grad_names", &SplitGraphsInfo::user_output_grad_names)
|
||||
.def_readwrite("backward_output_grad_names", &SplitGraphsInfo::backward_output_grad_names);
|
||||
|
||||
py::class_<ModuleGradientGraphBuilder> module_gradient_graph_builder(m, "ModuleGradientGraphBuilder");
|
||||
module_gradient_graph_builder
|
||||
.def(py::init([]() {
|
||||
return onnxruntime::make_unique<ModuleGradientGraphBuilder>();
|
||||
}))
|
||||
.def("initialize", [](ModuleGradientGraphBuilder* module_gradient_graph_builder,
|
||||
const py::bytes& serialized_model,
|
||||
const ModuleGradientGraphBuilderConfiguration& config) {
|
||||
std::istringstream buffer(serialized_model);
|
||||
ORT_THROW_IF_ERROR(module_gradient_graph_builder->Initialize(buffer, config));
|
||||
})
|
||||
.def("build_and_split", [](ModuleGradientGraphBuilder* module_gradient_graph_builder,
|
||||
const std::vector<std::vector<int64_t>>& input_shapes) {
|
||||
ORT_THROW_IF_ERROR(module_gradient_graph_builder->BuildAndSplit(input_shapes));
|
||||
})
|
||||
.def("get_forward_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
|
||||
return py::bytes(module_gradient_graph_builder->GetForwardModel());
|
||||
})
|
||||
.def("get_backward_model", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
|
||||
return py::bytes(module_gradient_graph_builder->GetBackwardModel());
|
||||
})
|
||||
module_gradient_graph_builder.def(py::init([]() { return onnxruntime::make_unique<ModuleGradientGraphBuilder>(); }))
|
||||
.def("initialize",
|
||||
[](ModuleGradientGraphBuilder* module_gradient_graph_builder, const py::bytes& serialized_model,
|
||||
const ModuleGradientGraphBuilderConfiguration& config) {
|
||||
std::istringstream buffer(serialized_model);
|
||||
ORT_THROW_IF_ERROR(module_gradient_graph_builder->Initialize(buffer, config));
|
||||
})
|
||||
.def("build",
|
||||
[](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
|
||||
ORT_THROW_IF_ERROR(module_gradient_graph_builder->Build());
|
||||
})
|
||||
.def("build",
|
||||
[](ModuleGradientGraphBuilder* module_gradient_graph_builder,
|
||||
const std::vector<std::vector<int64_t>>& input_shapes) {
|
||||
ORT_THROW_IF_ERROR(module_gradient_graph_builder->Build(&input_shapes));
|
||||
})
|
||||
.def("get_gradient_model",
|
||||
[](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
|
||||
return py::bytes(module_gradient_graph_builder->GetGradientModel());
|
||||
})
|
||||
.def("get_split_graphs_info", [](ModuleGradientGraphBuilder* module_gradient_graph_builder) {
|
||||
return module_gradient_graph_builder->GetSplitGraphsInfo();
|
||||
});
|
||||
|
|
|
|||
|
|
@ -65,9 +65,7 @@ def _create_iobinding(io_binding, inputs, model, device):
|
|||
inputs[idx].data_ptr())
|
||||
|
||||
for value_info in model.graph.output:
|
||||
io_binding.bind_output(value_info.name, device.type,
|
||||
device_id=_get_device_index(device))
|
||||
|
||||
io_binding.bind_output(value_info.name, device.type, device_id=_get_device_index(device))
|
||||
|
||||
def _onnx_value_info_to_buffer_tensor(value_info, device):
|
||||
'''Create a torch zeroed tensor with the same shape and type of `value_info`'''
|
||||
|
|
@ -113,10 +111,9 @@ def _extract_input_information(module, *inputs, **kwargs):
|
|||
class ORTModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, module):
|
||||
assert isinstance(module, torch.nn.Module), "'module' mst be a torch.nn.Module"
|
||||
assert isinstance(module, torch.nn.Module), "'module' must be a torch.nn.Module"
|
||||
super(ORTModule, self).__init__()
|
||||
|
||||
self._export_again = False
|
||||
# TODO: This is incorrect when different layers may be in different devices
|
||||
self._device = next(module.parameters()).device
|
||||
self._device_changed = False
|
||||
|
|
@ -124,21 +121,18 @@ class ORTModule(torch.nn.Module):
|
|||
# User module is wrapped to use its initializers and save computed gradients
|
||||
self._original_module = module
|
||||
self._onnx_training = None
|
||||
self._is_training = True
|
||||
|
||||
# Related to training graph split/shape inference
|
||||
self._current_input_shape = None
|
||||
self._module_gradient_graph_builder = None
|
||||
self._input_names_require_grad = None
|
||||
|
||||
# Forward pass
|
||||
self._onnx_forward = None
|
||||
self._forward_session = None
|
||||
self._forward_io_binding = None
|
||||
|
||||
# Backward pass
|
||||
self._onnx_backward = None
|
||||
self._backward_session = None
|
||||
self._backward_io_binding = None
|
||||
# Gradient model
|
||||
self._onnx_gradient = None
|
||||
self._gradient_session = None
|
||||
self._gradient_io_binding = None
|
||||
self._run_options = None
|
||||
|
||||
# Log level
|
||||
self._loglevel = getattr(logging, 'WARNING')
|
||||
|
|
@ -148,21 +142,21 @@ class ORTModule(torch.nn.Module):
|
|||
self._save_onnx_prefix = ''
|
||||
|
||||
def _initialize_module_gradient_graph_builder(self):
|
||||
|
||||
# TODO: PyTorch exporter bug: changes the initializer order
|
||||
initializer_names = [p[0] for p in self._original_module.named_parameters()]
|
||||
|
||||
# Build full training graph and split in forward/backward
|
||||
# Build full training graph
|
||||
grad_builder_config = C.ModuleGradientGraphBuilderConfiguration()
|
||||
grad_builder_config.initializer_names_to_train = initializer_names
|
||||
grad_builder_config.input_names_require_grad = self._input_names_require_grad
|
||||
self._module_gradient_graph_builder = C.ModuleGradientGraphBuilder()
|
||||
self._module_gradient_graph_builder.initialize(self._onnx_training.SerializeToString(), grad_builder_config)
|
||||
|
||||
def _build_training_graph(self, *inputs, **kwargs):
|
||||
def _get_forward_graph_and_init_gradient_graph_builder(self, *inputs, **kwargs):
|
||||
input_names, dynamic_axes, self._input_names_require_grad = \
|
||||
_extract_input_information(self._original_module, *inputs, **kwargs)
|
||||
self._onnx_training = self._get_forward_graph(input_names, dynamic_axes, *inputs, **kwargs)
|
||||
|
||||
if self._save_onnx:
|
||||
onnx.save(self._onnx_training, self._save_onnx_prefix + '_full_training.onnx')
|
||||
|
||||
|
|
@ -179,27 +173,29 @@ class ORTModule(torch.nn.Module):
|
|||
providers = ["CPUExecutionProvider"]
|
||||
provider_options = [{}]
|
||||
|
||||
self._forward_session = onnxruntime.InferenceSession(
|
||||
self._onnx_forward.SerializeToString(), providers=providers, provider_options=provider_options)
|
||||
self._backward_session = onnxruntime.InferenceSession(
|
||||
self._onnx_backward.SerializeToString(), providers=providers, provider_options=provider_options)
|
||||
session_options = onnxruntime.SessionOptions()
|
||||
session_options.enable_mem_pattern = False
|
||||
session_options.use_deterministic_compute = False
|
||||
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
|
||||
session_options.log_severity_level = 1
|
||||
|
||||
self._gradient_session = onnxruntime.InferenceSession(
|
||||
self._onnx_gradient.SerializeToString(), session_options, providers=providers, provider_options=provider_options)
|
||||
|
||||
# Use this global run_options for now
|
||||
self._run_options = C.RunOptions()
|
||||
|
||||
# IO binding
|
||||
# TODO: we should try to reuse the output buffers as some of the output tensors are same sizes, expecially the backward graph outputs.
|
||||
self._forward_io_binding = self._forward_session.io_binding()
|
||||
self._backward_io_binding = self._backward_session.io_binding()
|
||||
self._gradient_io_binding = self._gradient_session.io_binding()
|
||||
|
||||
def _split_training_graph(self, *inputs, **kwargs):
|
||||
# Perform shape inference and re-split forward/backward graph for batches with different shapes
|
||||
self._module_gradient_graph_builder.build_and_split(self._current_input_shape)
|
||||
self._onnx_forward = onnx.load_model_from_string(self._module_gradient_graph_builder.get_forward_model())
|
||||
self._onnx_backward = onnx.load_model_from_string(self._module_gradient_graph_builder.get_backward_model())
|
||||
def _build_training_graph(self, *inputs, **kwargs):
|
||||
self._module_gradient_graph_builder.build(self._current_input_shape)
|
||||
self._onnx_gradient = onnx.load_model_from_string(self._module_gradient_graph_builder.get_gradient_model())
|
||||
self._onnx_graphs_info = self._module_gradient_graph_builder.get_split_graphs_info()
|
||||
self._create_training_session()
|
||||
|
||||
if self._save_onnx:
|
||||
onnx.save(self._onnx_forward, self._save_onnx_prefix + '_forward.onnx')
|
||||
onnx.save(self._onnx_backward, self._save_onnx_prefix + '_backward.onnx')
|
||||
onnx.save(self._onnx_gradient, self._save_onnx_prefix + '_gradient.onnx')
|
||||
|
||||
def cpu(self: T) -> T:
|
||||
'''Thin layer to capture device for ORTModule IO bindings'''
|
||||
|
|
@ -248,17 +244,28 @@ class ORTModule(torch.nn.Module):
|
|||
self._device = torch.device(device_str)
|
||||
return super(ORTModule, self).to(*args, **kwargs)
|
||||
|
||||
def eval(self: T) -> T:
|
||||
self._is_training = False
|
||||
self._original_module.eval()
|
||||
|
||||
def train(self: T, mode: bool = True) -> T:
|
||||
self._is_training = mode
|
||||
self._original_module.train(mode)
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
'''Forward pass starts here and continues at `_ORTModuleFunction.forward`
|
||||
|
||||
ONNX model is exported the first time this method is executed.
|
||||
Next, a full training graph is splitted in forward and backward graph which are used
|
||||
to instantiate ONNX Runtime InferenceSession`s
|
||||
Next, we build a full training graph with module_gradient_graph_builder.
|
||||
Finally, we instantiate the ONNX Runtime InferenceSession.
|
||||
'''
|
||||
# TODO: using pytorch for evaluation for now. We will use ORT for evaluation latter.
|
||||
if not self._is_training:
|
||||
return self._original_module(*inputs, **kwargs)
|
||||
|
||||
# Exporting module to ONNX for the first time
|
||||
if not self._onnx_training:
|
||||
self._build_training_graph(*inputs, **kwargs)
|
||||
self._get_forward_graph_and_init_gradient_graph_builder(*inputs, **kwargs)
|
||||
|
||||
_, _, input_names_require_grad = _extract_input_information(self._original_module, *inputs, **kwargs)
|
||||
# If inputs requiring gradient change from one call to forward to the next, the module_gradient_graph_builder
|
||||
|
|
@ -270,10 +277,13 @@ class ORTModule(torch.nn.Module):
|
|||
new_input_shape = [list(input.size()) for input in inputs if input is not None]
|
||||
if self._current_input_shape is None or self._current_input_shape != new_input_shape:
|
||||
self._current_input_shape = new_input_shape
|
||||
self._split_training_graph(*inputs, **kwargs)
|
||||
elif self._device_changed:
|
||||
self._build_training_graph()
|
||||
self._create_training_session()
|
||||
self._device_changed = False
|
||||
# TODO: disabled for now, since it caused a bug in NVBert fp32 run
|
||||
# When creating a new InferenceSession, there is a bug for destructing the original InferenceSession
|
||||
# elif self._device_changed:
|
||||
# self._create_training_session()
|
||||
# self._device_changed = False
|
||||
|
||||
# Use a custom torch.autograd.Function to associate self.backward_graph as the
|
||||
# gradient implementation for self.forward_graph.
|
||||
|
|
@ -284,80 +294,57 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
TODO: **kwargs are not supported
|
||||
|
||||
Model outputs are returned to the user
|
||||
The following tensors are stashed (in order) for backward pass
|
||||
* (Partial) user input
|
||||
* (Partial) Initializers
|
||||
* Intermediate tensors
|
||||
Module outputs are returned to the user
|
||||
'''
|
||||
|
||||
# Use IO binding
|
||||
_create_iobinding(self._forward_io_binding, inputs,
|
||||
self._onnx_forward,
|
||||
self._device)
|
||||
_create_iobinding(self._gradient_io_binding, inputs, self._onnx_gradient, self._device)
|
||||
|
||||
# Run
|
||||
self._forward_session.run_with_iobinding(self._forward_io_binding)
|
||||
forward_outputs = self._forward_io_binding.get_outputs()
|
||||
|
||||
# Stash tensors needed by backward
|
||||
forward_input_dict = self._convert_forward_input_list_to_dict(*inputs)
|
||||
ctx_inputs = tuple(forward_input_dict[name] \
|
||||
for name in self._onnx_graphs_info.backward_user_input_names)
|
||||
ctx_initializers = tuple(forward_input_dict[name] \
|
||||
for name in self._onnx_graphs_info.backward_intializer_names_as_input)
|
||||
ctx_intermediates = tuple(_ort_output_to_torch_tensor(forward_output) \
|
||||
for forward_output in forward_outputs[len(self._onnx_graphs_info.user_output_names):])
|
||||
ctx.save_for_backward(*[*ctx_inputs, *ctx_initializers, *ctx_intermediates])
|
||||
|
||||
# Return model output
|
||||
# Run and return module outputs.
|
||||
user_outputs = tuple(_ort_output_to_torch_tensor(forward_output) \
|
||||
for forward_output in forward_outputs[:len(self._onnx_graphs_info.user_output_names)])
|
||||
for forward_output in self._gradient_session.run_forward(self._gradient_io_binding, self._run_options))
|
||||
return user_outputs[0] if len(user_outputs) == 1 else user_outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_output):
|
||||
'''Performs backward pass based on grad wrt output and internal state
|
||||
|
||||
Internal state is composed of:
|
||||
* Tensor stashed (in a particular order) during forward:
|
||||
* (partial) user input, (partial) initializers and intermediate tensors
|
||||
|
||||
TODO: Input gradient is hard-coded to torch.tensor([1.])
|
||||
'''Performs backward pass based on grad wrt module output
|
||||
'''
|
||||
|
||||
# Use IO binding
|
||||
grad_output_dict = dict(zip(self._onnx_graphs_info.user_output_grad_names, grad_output))
|
||||
backward_grad_output = tuple(grad_output_dict[name] for name in self._onnx_graphs_info.backward_output_grad_names)
|
||||
_create_iobinding(self._backward_io_binding, [*ctx.saved_tensors, *backward_grad_output],
|
||||
self._onnx_backward,
|
||||
self._device)
|
||||
# Push user output grads to ONNX backend.
|
||||
backward_grad_output_ortvalue = []
|
||||
for grad_output in grad_output[:len(self._onnx_graphs_info.backward_output_grad_names)]:
|
||||
backward_grad_output_ortvalue.append(onnxruntime.OrtValue.ortvalue_from_data_ptr(list(grad_output.size()), _utils.dtype_torch_to_numpy(
|
||||
grad_output.dtype), grad_output.device.type, _get_device_index(grad_output.device), grad_output.data_ptr()))
|
||||
|
||||
# Run
|
||||
self._backward_session.run_with_iobinding(self._backward_io_binding)
|
||||
backward_outputs = self._backward_io_binding.get_outputs()
|
||||
# Run and get results
|
||||
self._gradient_session.run_backward(backward_grad_output_ortvalue)
|
||||
backward_outputs = self._gradient_io_binding.get_outputs()
|
||||
|
||||
# Return input and initializer gradients
|
||||
num_initializers = len(self._onnx_graphs_info.initializer_grad_names_to_train)
|
||||
num_user_input_grads = len(self._input_names_require_grad)
|
||||
|
||||
results = []
|
||||
for input_name in self._onnx_graphs_info.user_input_names:
|
||||
try:
|
||||
# Append to the results the backward output for each input that required grad
|
||||
results.append(_ort_output_to_torch_tensor(
|
||||
backward_outputs[num_initializers + self._input_names_require_grad.index(input_name)]))
|
||||
backward_outputs[self._input_names_require_grad.index(input_name)]))
|
||||
except ValueError:
|
||||
# input_name is not found in the self._input_names_require_grad list
|
||||
# Append None to results for each input that did not require grad
|
||||
results.append(None)
|
||||
# Append backward ouput for all trained initializers
|
||||
results += [_ort_output_to_torch_tensor(backward_output)
|
||||
for backward_output in backward_outputs[:num_initializers]]
|
||||
# Append gradients of initializer to results
|
||||
results += [_ort_output_to_torch_tensor(backward_output)
|
||||
for backward_output in backward_outputs[num_user_input_grads:]]
|
||||
return tuple(results)
|
||||
|
||||
proc_inputs = [data for data in inputs if data is not None]
|
||||
return _ORTModuleFunction.apply(*self._convert_forward_input_to_list(*proc_inputs, **kwargs))
|
||||
|
||||
return _ORTModuleFunction.apply(*self._convert_gradient_graph_input_to_list(*proc_inputs, **kwargs))
|
||||
|
||||
@_utils.timeit(enabled=__TEMP_ENABLE_METHOD_TIMING__)
|
||||
def _convert_forward_input_to_list(self, *inputs, **kwargs):
|
||||
def _convert_gradient_graph_input_to_list(self, *inputs, **kwargs):
|
||||
'''Creates forward `*inputs` list from user input and PyTorch initializers
|
||||
|
||||
TODO: **kwargs is not supported
|
||||
|
|
@ -379,63 +366,6 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
return result
|
||||
|
||||
@_utils.timeit(enabled=__TEMP_ENABLE_METHOD_TIMING__)
|
||||
def _convert_forward_input_list_to_dict(self, *inputs):
|
||||
'''Convert forward `*inputs` list to dict
|
||||
|
||||
TODO: Input gradient is being ignored for MVP
|
||||
'''
|
||||
# Dictionary containing both inputs and initializers
|
||||
forward_input_names = [*self._onnx_graphs_info.user_input_names,
|
||||
*self._onnx_graphs_info.initializer_names_to_train]
|
||||
return dict(zip(forward_input_names, inputs))
|
||||
|
||||
@_utils.timeit(enabled=__TEMP_ENABLE_METHOD_TIMING__)
|
||||
def _convert_backward_input_list_to_dict(self, *inputs):
|
||||
'''Convert backward `*inputs` list to dict
|
||||
|
||||
ONNX Runtime backward requires dict as input, which is composed of:
|
||||
* User input
|
||||
Although not necessary, all user inputs are used for simplicity
|
||||
* (Partial) Initializers
|
||||
init_begin = len(user_input)
|
||||
init_count = len(Pre-computed list of initializer)
|
||||
* Intermediate tensors
|
||||
* Gradient wrt outputs
|
||||
'''
|
||||
|
||||
# Dictionary containing both inputs and initializers
|
||||
result = {}
|
||||
|
||||
backward_user_input = self._onnx_graphs_info.backward_user_input_names
|
||||
backward_intializer = self._onnx_graphs_info.backward_intializer_names_as_input
|
||||
intermediate = self._onnx_graphs_info.intermediate_tensor_names
|
||||
backward_output_grad_names = self._onnx_graphs_info.backward_output_grad_names
|
||||
|
||||
# Extract info about stashed input and grad output
|
||||
# Inputs
|
||||
inputs_pos = 0
|
||||
for idx, name in enumerate(backward_user_input):
|
||||
result.update({ name : inputs[idx]})
|
||||
inputs_pos += 1
|
||||
|
||||
# Initializers
|
||||
for idx, name in enumerate(backward_intializer, inputs_pos):
|
||||
result.update({name: inputs[idx]})
|
||||
inputs_pos += 1
|
||||
|
||||
# Intermediate
|
||||
for idx, name in enumerate(intermediate, inputs_pos):
|
||||
result.update({name: inputs[idx]})
|
||||
inputs_pos += 1
|
||||
|
||||
# Grad outputs
|
||||
for idx, name in enumerate(backward_output_grad_names, inputs_pos):
|
||||
result.update({name: inputs[idx]})
|
||||
inputs_pos += 1
|
||||
|
||||
return result
|
||||
|
||||
def _get_forward_graph(self, input_names, dynamic_axes, *inputs, **kwargs):
|
||||
'''Exports PyTorch `module` to ONNX with training flag, using `*inputs` as input
|
||||
|
||||
|
|
@ -467,5 +397,9 @@ class ORTModule(torch.nn.Module):
|
|||
do_constant_folding=False,
|
||||
training=torch.onnx.TrainingMode.TRAINING,
|
||||
dynamic_axes=dynamic_axes)
|
||||
|
||||
# TODO: this step might not be needed when we use the torch external allocator
|
||||
# clear cache after model export
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return onnx.load_model_from_string(f.getvalue())
|
||||
|
|
|
|||
|
|
@ -26,10 +26,29 @@ def parse_arguments():
|
|||
def run_ortmodule_api_tests(cwd, log):
|
||||
log.debug('Running: ORTModule-API tests')
|
||||
|
||||
command = [sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_ortmodule_api.py']
|
||||
class TestNameCollecterPlugin:
|
||||
def __init__(self):
|
||||
self.collected = set()
|
||||
|
||||
run_subprocess(command, cwd=cwd, log=log).check_returncode()
|
||||
def pytest_collection_modifyitems(self, items):
|
||||
for item in items:
|
||||
print('item.name: ', item.name)
|
||||
self.collected.add(item.name)
|
||||
|
||||
import os
|
||||
import pytest
|
||||
plugin = TestNameCollecterPlugin()
|
||||
print(cwd)
|
||||
test_script_filename = os.path.join("orttraining_test_ortmodule_api.py")
|
||||
pytest.main(['--collect-only', test_script_filename], plugins=[plugin])
|
||||
|
||||
# TODO: FIX THIS!
|
||||
# Running tests in a loop one after another,
|
||||
# because ORTModule doesn't support multiple run call at the same time
|
||||
for test_name in plugin.collected:
|
||||
run_subprocess([
|
||||
sys.executable, '-m', 'pytest',
|
||||
'orttraining_test_ortmodule_api.py', '-sv', '-k', test_name], cwd=cwd).check_returncode()
|
||||
|
||||
def run_ortmodule_poc_net(cwd, log, no_cuda, data_dir):
|
||||
log.debug('Running: ORTModule POCNet for MNIST with --no-cuda arg {}.'.format(no_cuda))
|
||||
|
|
|
|||
|
|
@ -269,20 +269,22 @@ def test_model_to_device_and_back_to_original(original_device, to_device):
|
|||
for _, parameter_value in model.named_parameters():
|
||||
assert parameter_value.device.type == original_device
|
||||
|
||||
def test_model_with_different_devices_same_session():
|
||||
N, D_in, H, D_out = 64, 784, 500, 10
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
|
||||
model = ORTModule(model)
|
||||
# TODO: Fix the following Unit Test
|
||||
# @pytest.mark.skip(reason="TODO: ORTModule.to(device) is disabled for now")
|
||||
# def test_model_with_different_devices_same_session():
|
||||
# N, D_in, H, D_out = 64, 784, 500, 10
|
||||
# model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
|
||||
# model = ORTModule(model)
|
||||
|
||||
for i in range(5):
|
||||
if i % 2 == 0:
|
||||
device = 'cpu'
|
||||
else:
|
||||
device = 'cuda'
|
||||
# for i in range(5):
|
||||
# if i % 2 == 0:
|
||||
# device = 'cpu'
|
||||
# else:
|
||||
# device = 'cuda'
|
||||
|
||||
model.to(device)
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
y = model(x)
|
||||
# model.to(device)
|
||||
# x = torch.randn(N, D_in, device=device)
|
||||
# y = model(x)
|
||||
|
||||
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
|
||||
def test_input_requires_grad_saved(device):
|
||||
|
|
@ -305,16 +307,18 @@ def test_input_requires_grad_backward_creates_input_grad(device):
|
|||
s.backward()
|
||||
assert x.grad is not None
|
||||
|
||||
@pytest.mark.parametrize("device", ['cuda', 'cpu'])
|
||||
def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder(device):
|
||||
N, D_in, H, D_out = 32, 784, 500, 10
|
||||
model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
model = ORTModule(model)
|
||||
x = torch.randn(N, D_in, device=device, requires_grad=True)
|
||||
model(x.data)
|
||||
module_gradient_graph_builder = model._module_gradient_graph_builder
|
||||
model(x)
|
||||
assert module_gradient_graph_builder != model._module_gradient_graph_builder
|
||||
# TODO: Fix the following Unit Test
|
||||
# @pytest.mark.parametrize("device", ['cuda', 'cpu'])
|
||||
# @pytest.mark.skip(reason="ORTModule doesn't support multiple consecutive forward calls.")
|
||||
# def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder(device):
|
||||
# N, D_in, H, D_out = 32, 784, 500, 10
|
||||
# model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
# model = ORTModule(model)
|
||||
# x = torch.randn(N, D_in, device=device, requires_grad=True)
|
||||
# model(x.data)
|
||||
# module_gradient_graph_builder = model._module_gradient_graph_builder
|
||||
# model(x)
|
||||
# assert module_gradient_graph_builder != model._module_gradient_graph_builder
|
||||
|
||||
def test_gpu_reserved_memory_with_torch_no_grad():
|
||||
device = 'cuda'
|
||||
|
|
@ -328,20 +332,21 @@ def test_gpu_reserved_memory_with_torch_no_grad():
|
|||
model_with_no_grad = ORTModule(model_with_no_grad)
|
||||
mem_reserved_before_export = torch.cuda.memory_reserved(device)
|
||||
model_with_no_grad(x, y, None, None, None, None, z)
|
||||
mem_reserved_after_export_with_torch_no_grad = torch.cuda.memory_reserved(device)
|
||||
mem_reserved_after_export = torch.cuda.memory_reserved(device)
|
||||
assert mem_reserved_before_export == mem_reserved_after_export
|
||||
|
||||
del model_with_no_grad
|
||||
torch.cuda.empty_cache()
|
||||
mem_reserved_after_cache_empty = torch.cuda.memory_reserved(device)
|
||||
assert mem_reserved_before_export == mem_reserved_after_cache_empty
|
||||
|
||||
# Create another model and get the memory_reserved when torch.no_grad has not been enabled
|
||||
# after export
|
||||
# Create another model and get the memory_reserved when torch.no_grad and torch.cuda.empty_cache
|
||||
# has not been enabled after export
|
||||
model_without_no_grad = _get_bert_for_sequence_classification_model(device)
|
||||
model_without_no_grad = ORTModule(model_without_no_grad)
|
||||
mem_reserved_after_export_without_torch_no_grad = 0
|
||||
with patch('torch.no_grad'):
|
||||
model_without_no_grad(x, y, None, None, None, None, z)
|
||||
mem_reserved_after_export_without_torch_no_grad = torch.cuda.memory_reserved(device)
|
||||
|
||||
assert mem_reserved_after_export_with_torch_no_grad < mem_reserved_after_export_without_torch_no_grad
|
||||
assert mem_reserved_before_export < mem_reserved_after_export_with_torch_no_grad
|
||||
with patch('torch.no_grad'), patch('torch.cuda.empty_cache'):
|
||||
model_without_no_grad(x, y, None, None, None, None, z)
|
||||
mem_reserved_after_export_without_torch_no_grad = torch.cuda.memory_reserved(device)
|
||||
assert mem_reserved_after_export < mem_reserved_after_export_without_torch_no_grad
|
||||
|
||||
|
|
@ -7,11 +7,7 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
|
||||
void OrtEventPool::CheckRange(const int64_t id) const {
|
||||
ORT_ENFORCE(
|
||||
id >= 0 && id < MaxNumItems,
|
||||
"Got id ", id,
|
||||
". It should be in a range from 0 to ",
|
||||
MaxNumItems, ".");
|
||||
ORT_ENFORCE(id >= 0 && id < MaxNumItems, "Got id ", id, ". It should be in a range from 0 to ", MaxNumItems, ".");
|
||||
}
|
||||
|
||||
void OrtEventPool::SignalEvent(int64_t id) {
|
||||
|
|
@ -27,6 +23,19 @@ bool OrtEventPool::QueryEvent(int64_t id) const {
|
|||
return pool_[id].signaled.load();
|
||||
}
|
||||
|
||||
void OrtEventPool::WaitAndResetEvent(int64_t id) {
|
||||
CheckRange(id);
|
||||
std::unique_lock<std::mutex> lock(pool_[id].mutex);
|
||||
pool_[id].cv.wait(lock, [this, id] { return pool_[id].signaled.load(); });
|
||||
pool_[id].signaled.store(false);
|
||||
};
|
||||
|
||||
void OrtEventPool::ResetEvent(int64_t id) {
|
||||
CheckRange(id);
|
||||
std::unique_lock<std::mutex> lock(pool_[id].mutex);
|
||||
pool_[id].signaled.store(false);
|
||||
};
|
||||
|
||||
void OrtEventPool::WaitEvent(int64_t id) const {
|
||||
CheckRange(id);
|
||||
std::unique_lock<std::mutex> lock(pool_[id].mutex);
|
||||
|
|
|
|||
|
|
@ -21,7 +21,9 @@ class OrtEventPool final {
|
|||
}
|
||||
void SignalEvent(int64_t id);
|
||||
bool QueryEvent(int64_t id) const;
|
||||
void WaitAndResetEvent(int64_t id);
|
||||
void WaitEvent(int64_t id) const;
|
||||
void ResetEvent(int64_t id);
|
||||
void ResetAllEvents();
|
||||
|
||||
static size_t GetPoolSize() {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,47 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/ml_value.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
class OrtMessageQueue final {
|
||||
public:
|
||||
static OrtMessageQueue& GetInstance() {
|
||||
static OrtMessageQueue instance_;
|
||||
return instance_;
|
||||
}
|
||||
|
||||
void Push(const OrtValue& ort_value) { ort_values.emplace(ort_value); }
|
||||
OrtValue Pop() {
|
||||
OrtValue ort_value = ort_values.front();
|
||||
ort_values.pop();
|
||||
return ort_value;
|
||||
}
|
||||
|
||||
void PopAll(std::vector<OrtValue>& results) {
|
||||
while (!ort_values.empty()) {
|
||||
OrtValue ort_value = ort_values.front();
|
||||
ort_values.pop();
|
||||
results.emplace_back(ort_value);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OrtMessageQueue() = default;
|
||||
~OrtMessageQueue() = default;
|
||||
OrtMessageQueue(const OrtMessageQueue&) = delete;
|
||||
OrtMessageQueue& operator=(const OrtMessageQueue&) = delete;
|
||||
|
||||
std::queue<OrtValue> ort_values;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/cpu/controlflow/yield.h"
|
||||
#include "orttraining/training_ops/cpu/controlflow/event_pool.h"
|
||||
#include "orttraining/training_ops/cpu/controlflow/message_queue.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
YieldOp,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.VariadicAlias(0, 0), // TODO: this is a hack to avoid allocating output buffer
|
||||
YieldOp);
|
||||
|
||||
Status YieldOp::Compute(OpKernelContext* ctx) const {
|
||||
auto* ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
|
||||
for (int i_in = 0; i_in < ctx->InputCount(); ++i_in) {
|
||||
onnxruntime::contrib::OrtMessageQueue::GetInstance().Push(*ctx_internal->GetInputMLValue(i_in));
|
||||
}
|
||||
|
||||
// Reset background event before returning to main thread
|
||||
const int64_t background_thread_event_id = 1;
|
||||
onnxruntime::contrib::OrtEventPool::GetInstance().ResetEvent(background_thread_event_id);
|
||||
|
||||
// single event for InferenceSession::RunInBackgroundAndWaitForYield() that FW graph is done
|
||||
const int64_t main_thread_event_id = 0;
|
||||
OrtEventPool::GetInstance().SignalEvent(main_thread_event_id);
|
||||
|
||||
// wait for event from InferenceSession::ContinueRunInBackground() to continue the BW graph
|
||||
OrtEventPool::GetInstance().WaitAndResetEvent(background_thread_event_id);
|
||||
|
||||
if (ctx_internal->GetTerminateFlag()) {
|
||||
LOGS(ctx->Logger(), WARNING) << "Resumed executing backward subgraph, terminate_flag is set to true.";
|
||||
} else {
|
||||
// Get output grad from somewhere and prepare Op outputs.
|
||||
for (int i_out = 0; i_out < ctx->OutputCount(); ++i_out) {
|
||||
ctx_internal->SetOutputMLValue(i_out, OrtMessageQueue::GetInstance().Pop());
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
19
orttraining/orttraining/training_ops/cpu/controlflow/yield.h
Normal file
19
orttraining/orttraining/training_ops/cpu/controlflow/yield.h
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
class YieldOp final : public OpKernel {
|
||||
public:
|
||||
YieldOp(const OpKernelInfo& info) : OpKernel(info) {}
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -94,6 +94,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Recv)
|
|||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, RecordEvent);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WaitEvent);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, YieldOp);
|
||||
|
||||
Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
|
|
@ -181,7 +182,8 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
#endif
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, RecordEvent)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WaitEvent)>};
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WaitEvent)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, YieldOp)>};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry()));
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/cpu/controlflow/yield.h"
|
||||
#include "core/providers/cuda/cuda_fwd.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
YieldOp,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCudaExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
|
||||
.VariadicAlias(0, 0), // TODO: this is a hack to avoid allocating output buffer
|
||||
onnxruntime::contrib::YieldOp);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -164,6 +164,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Adas
|
|||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, RecordEvent);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, YieldOp);
|
||||
|
||||
#ifdef ORT_USE_NCCL
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllReduce);
|
||||
|
|
@ -326,6 +327,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, RecordEvent)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, YieldOp)>,
|
||||
|
||||
#ifdef ORT_USE_NCCL
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclAllReduce)>,
|
||||
|
|
|
|||
Loading…
Reference in a new issue